suggest/
query.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4
5use std::collections::HashSet;
6
7use crate::{LabeledTimingSample, Suggestion, SuggestionProvider, SuggestionProviderConstraints};
8
9/// A query for suggestions to show in the address bar.
10#[derive(Clone, Debug, Default, uniffi::Record)]
11pub struct SuggestionQuery {
12    pub keyword: String,
13    pub providers: Vec<SuggestionProvider>,
14    #[uniffi(default = None)]
15    pub provider_constraints: Option<SuggestionProviderConstraints>,
16    #[uniffi(default = None)]
17    pub limit: Option<i32>,
18}
19
20#[derive(uniffi::Record)]
21pub struct QueryWithMetricsResult {
22    pub suggestions: Vec<Suggestion>,
23    /// Samples for the `suggest.query_time` metric
24    pub query_times: Vec<LabeledTimingSample>,
25}
26
27impl SuggestionQuery {
28    // Builder style methods for creating queries (mostly used by the test code)
29
30    pub fn all_providers(keyword: &str) -> Self {
31        Self {
32            keyword: keyword.to_string(),
33            providers: Vec::from(SuggestionProvider::all()),
34            ..Self::default()
35        }
36    }
37
38    pub fn with_providers(keyword: &str, providers: Vec<SuggestionProvider>) -> Self {
39        Self {
40            keyword: keyword.to_string(),
41            providers,
42            ..Self::default()
43        }
44    }
45
46    pub fn all_providers_except(keyword: &str, provider: SuggestionProvider) -> Self {
47        Self::with_providers(
48            keyword,
49            SuggestionProvider::all()
50                .into_iter()
51                .filter(|p| *p != provider)
52                .collect(),
53        )
54    }
55
56    pub fn amp(keyword: &str) -> Self {
57        Self {
58            keyword: keyword.into(),
59            providers: vec![SuggestionProvider::Amp],
60            ..Self::default()
61        }
62    }
63
64    pub fn wikipedia(keyword: &str) -> Self {
65        Self {
66            keyword: keyword.into(),
67            providers: vec![SuggestionProvider::Wikipedia],
68            ..Self::default()
69        }
70    }
71
72    pub fn amo(keyword: &str) -> Self {
73        Self {
74            keyword: keyword.into(),
75            providers: vec![SuggestionProvider::Amo],
76            ..Self::default()
77        }
78    }
79
80    pub fn yelp(keyword: &str) -> Self {
81        Self {
82            keyword: keyword.into(),
83            providers: vec![SuggestionProvider::Yelp],
84            ..Self::default()
85        }
86    }
87
88    pub fn mdn(keyword: &str) -> Self {
89        Self {
90            keyword: keyword.into(),
91            providers: vec![SuggestionProvider::Mdn],
92            ..Self::default()
93        }
94    }
95
96    pub fn weather(keyword: &str) -> Self {
97        Self {
98            keyword: keyword.into(),
99            providers: vec![SuggestionProvider::Weather],
100            ..Self::default()
101        }
102    }
103
104    pub fn dynamic(keyword: &str, suggestion_types: &[&str]) -> Self {
105        Self {
106            keyword: keyword.into(),
107            providers: vec![SuggestionProvider::Dynamic],
108            provider_constraints: Some(SuggestionProviderConstraints {
109                dynamic_suggestion_types: Some(
110                    suggestion_types.iter().map(|s| s.to_string()).collect(),
111                ),
112                ..SuggestionProviderConstraints::default()
113            }),
114            ..Self::default()
115        }
116    }
117
118    pub fn limit(self, limit: i32) -> Self {
119        Self {
120            limit: Some(limit),
121            ..self
122        }
123    }
124
125    /// Create an FTS query term for our keyword(s)
126    pub(crate) fn fts_query(&self) -> FtsQuery<'_> {
127        FtsQuery::new(&self.keyword)
128    }
129}
130
131pub struct FtsQuery<'a> {
132    pub match_arg: String,
133    pub match_arg_without_prefix_match: String,
134    pub is_prefix_query: bool,
135    keyword_terms: Vec<&'a str>,
136}
137
138impl<'a> FtsQuery<'a> {
139    fn new(keyword: &'a str) -> Self {
140        // Parse the `keyword` field into a set of keywords.
141        //
142        // This is used when passing the keywords into an FTS search.  It:
143        //   - Strips out any `():^*"` chars.  These are typically used for advanced searches, which
144        //     we don't support and it would be weird to only support for FTS searches.
145        //   - splits on whitespace to get a list of individual keywords
146        let keywords = Self::split_terms(keyword);
147        if keywords.is_empty() {
148            return Self {
149                keyword_terms: keywords,
150                match_arg: String::from(r#""""#),
151                match_arg_without_prefix_match: String::from(r#""""#),
152                is_prefix_query: false,
153            };
154        }
155        // Quote each term from `query` and join them together
156        let mut sqlite_match = keywords
157            .iter()
158            .map(|keyword| format!(r#""{keyword}""#))
159            .collect::<Vec<_>>()
160            .join(" ");
161        // If the input is > 3 characters, and there's no whitespace at the end.
162        // We want to append a `*` char to the end to do a prefix match on it.
163        let total_chars = keywords.iter().fold(0, |count, s| count + s.len());
164        let query_ends_in_whitespace = keyword.ends_with(' ');
165        let prefix_match = (total_chars > 3) && !query_ends_in_whitespace;
166        let sqlite_match_without_prefix_match = sqlite_match.clone();
167        if prefix_match {
168            sqlite_match.push('*');
169        }
170        Self {
171            keyword_terms: keywords,
172            is_prefix_query: prefix_match,
173            match_arg: sqlite_match,
174            match_arg_without_prefix_match: sqlite_match_without_prefix_match,
175        }
176    }
177
178    /// Try to figure out if a FTS match required stemming
179    ///
180    /// To test this, we have to try to mimic the SQLite FTS logic. This code doesn't do it
181    /// perfectly, but it should return the correct result most of the time.
182    pub fn match_required_stemming(&self, title: &str) -> bool {
183        let title = title.to_lowercase();
184        let split_title = Self::split_terms(&title);
185
186        !self.keyword_terms.iter().enumerate().all(|(i, keyword)| {
187            split_title.iter().any(|title_word| {
188                let last_keyword = i == self.keyword_terms.len() - 1;
189
190                if last_keyword && self.is_prefix_query {
191                    title_word.starts_with(keyword)
192                } else {
193                    title_word == keyword
194                }
195            })
196        })
197    }
198
199    fn split_terms(phrase: &str) -> Vec<&str> {
200        phrase
201            .split([' ', '(', ')', ':', '^', '*', '"', ','])
202            .filter(|s| !s.is_empty())
203            .collect()
204    }
205}
206
207/// Given a list of full keywords, create an FTS string to match against.
208///
209/// Creates a string with de-duped keywords.
210pub fn full_keywords_to_fts_content<'a>(
211    full_keywords: impl IntoIterator<Item = &'a str>,
212) -> String {
213    let parts: HashSet<_> = full_keywords
214        .into_iter()
215        .flat_map(str::split_whitespace)
216        .map(str::to_lowercase)
217        .collect();
218    let mut result = String::new();
219    for (i, part) in parts.into_iter().enumerate() {
220        if i != 0 {
221            result.push(' ');
222        }
223        result.push_str(&part);
224    }
225    result
226}
227
228#[cfg(test)]
229mod test {
230    use super::*;
231    use std::collections::HashMap;
232
233    fn check_parse_keywords(input: &str, expected: Vec<&str>) {
234        let query = SuggestionQuery::all_providers(input);
235        assert_eq!(query.fts_query().keyword_terms, expected);
236    }
237
238    #[test]
239    fn test_quote() {
240        check_parse_keywords("foo", vec!["foo"]);
241        check_parse_keywords("foo bar", vec!["foo", "bar"]);
242        // Special chars should be stripped
243        check_parse_keywords("\"foo()* ^bar:\"", vec!["foo", "bar"]);
244        // test some corner cases
245        check_parse_keywords("", vec![]);
246        check_parse_keywords(" ", vec![]);
247        check_parse_keywords("   foo     bar       ", vec!["foo", "bar"]);
248        check_parse_keywords("foo:bar", vec!["foo", "bar"]);
249    }
250
251    fn check_fts_query(input: &str, expected: &str) {
252        let query = SuggestionQuery::all_providers(input);
253        assert_eq!(query.fts_query().match_arg, expected);
254    }
255
256    #[test]
257    fn test_fts_query() {
258        // String with < 3 chars shouldn't get a prefix query
259        check_fts_query("r", r#""r""#);
260        check_fts_query("ru", r#""ru""#);
261        check_fts_query("run", r#""run""#);
262        // After 3 chars, we should append `*` to the last term to make it a prefix query
263        check_fts_query("runn", r#""runn"*"#);
264        check_fts_query("running", r#""running"*"#);
265        // The total number of chars is counted, not the number of chars in the last term
266        check_fts_query("running s", r#""running" "s"*"#);
267        // if the input ends in whitespace, then don't do a prefix query
268        check_fts_query("running ", r#""running""#);
269        // Special chars are filtered out
270        check_fts_query("running*\"()^: s", r#""running" "s"*"#);
271        check_fts_query("running *\"()^: s", r#""running" "s"*"#);
272        // Special chars shouldn't count towards the input size when deciding whether to do a
273        // prefix query or not
274        check_fts_query("r():", r#""r""#);
275        // Test empty strings
276        check_fts_query("", r#""""#);
277        check_fts_query(" ", r#""""#);
278        check_fts_query("()", r#""""#);
279    }
280
281    #[test]
282    fn test_fts_query_match_required_stemming() {
283        // These don't require stemming, since each keyword matches a term in the title
284        assert!(!FtsQuery::new("running shoes").match_required_stemming("running shoes"));
285        assert!(
286            !FtsQuery::new("running shoes").match_required_stemming("new balance running shoes")
287        );
288        // Case changes shouldn't matter
289        assert!(!FtsQuery::new("running shoes").match_required_stemming("Running Shoes"));
290        // This doesn't require stemming, since `:` is not part of the word
291        assert!(!FtsQuery::new("running shoes").match_required_stemming("Running: Shoes"));
292        // This requires the keywords to be stemmed in order to match
293        assert!(FtsQuery::new("run shoes").match_required_stemming("running shoes"));
294        // This didn't require stemming, since the last keyword was a prefix match
295        assert!(!FtsQuery::new("running sh").match_required_stemming("running shoes"));
296        // This does require stemming (we know it wasn't a prefix match since there's not enough
297        // characters).
298        assert!(FtsQuery::new("run").match_required_stemming("running shoes"));
299    }
300
301    #[test]
302    fn test_full_keywords_to_fts_content() {
303        check_full_keywords_to_fts_content(["a", "b", "c"], "a b c");
304        check_full_keywords_to_fts_content(["a", "b c"], "a b c");
305        check_full_keywords_to_fts_content(["a", "b c a"], "a b c");
306        check_full_keywords_to_fts_content(["a", "b C A"], "a b c");
307    }
308
309    fn check_full_keywords_to_fts_content<const N: usize>(input: [&str; N], expected: &str) {
310        let mut expected_counts = HashMap::<&str, usize>::new();
311        let mut actual_counts = HashMap::<&str, usize>::new();
312        for term in expected.split_whitespace() {
313            *expected_counts.entry(term).or_default() += 1;
314        }
315        let fts_content = full_keywords_to_fts_content(input);
316        for term in fts_content.split_whitespace() {
317            *actual_counts.entry(term).or_default() += 1;
318        }
319        assert_eq!(actual_counts, expected_counts);
320    }
321}