Skip to content

Commit

Permalink
Snippet term score (#423)
Browse files Browse the repository at this point in the history
  • Loading branch information
fulmicoton authored Sep 16, 2018
1 parent 10f6c07 commit 5449ec3
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 8 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ winapi = "0.2"

[dev-dependencies]
rand = "0.5"
maplit = "1"

[profile.release]
opt-level = 3
Expand Down
6 changes: 6 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@ extern crate tempdir;
extern crate tempfile;
extern crate uuid;



#[cfg(test)]
#[macro_use]
extern crate matches;
Expand All @@ -162,6 +164,10 @@ extern crate winapi;
#[cfg(test)]
extern crate rand;

#[cfg(test)]
#[macro_use]
extern crate maplit;

#[cfg(all(test, feature = "unstable"))]
extern crate test;

Expand Down
50 changes: 50 additions & 0 deletions src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,53 @@ pub use self::scorer::ConstScorer;
pub use self::scorer::Scorer;
pub use self::term_query::TermQuery;
pub use self::weight::Weight;

#[cfg(test)]
mod tests {
use Index;
use schema::{SchemaBuilder, TEXT};
use query::QueryParser;
use Term;
use std::collections::BTreeSet;

#[test]
fn test_query_terms() {
let mut schema_builder = SchemaBuilder::default();
let text_field = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
let query_parser = QueryParser::for_index(&index, vec![text_field]);
let term_a = Term::from_field_text(text_field, "a");
let term_b = Term::from_field_text(text_field, "b");
{
let mut terms_set: BTreeSet<Term> = BTreeSet::new();
query_parser.parse_query("a").unwrap().query_terms(&mut terms_set);
let terms: Vec<&Term> = terms_set.iter().collect();
assert_eq!(vec![&term_a], terms);
}
{
let mut terms_set: BTreeSet<Term> = BTreeSet::new();
query_parser.parse_query("a b").unwrap().query_terms(&mut terms_set);
let terms: Vec<&Term> = terms_set.iter().collect();
assert_eq!(vec![&term_a, &term_b], terms);
}
{
let mut terms_set: BTreeSet<Term> = BTreeSet::new();
query_parser.parse_query("\"a b\"").unwrap().query_terms(&mut terms_set);
let terms: Vec<&Term> = terms_set.iter().collect();
assert_eq!(vec![&term_a, &term_b], terms);
}
{
let mut terms_set: BTreeSet<Term> = BTreeSet::new();
query_parser.parse_query("a a a a a").unwrap().query_terms(&mut terms_set);
let terms: Vec<&Term> = terms_set.iter().collect();
assert_eq!(vec![&term_a], terms);
}
{
let mut terms_set: BTreeSet<Term> = BTreeSet::new();
query_parser.parse_query("a -b").unwrap().query_terms(&mut terms_set);
let terms: Vec<&Term> = terms_set.iter().collect();
assert_eq!(vec![&term_a, &term_b], terms);
}
}
}
112 changes: 104 additions & 8 deletions src/snippet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,15 @@ impl SnippetGenerator {
let terms_text: BTreeMap<String, f32> = terms
.into_iter()
.filter(|term| term.field() == field)
.map(|term| (term.text().to_string(), 1f32))
.flat_map(|term| {
let doc_freq = searcher.doc_freq(&term);
let score = 1f32 / (1f32 + doc_freq as f32);
if doc_freq > 0 {
Some((term.text().to_string(), score))
} else {
None
}
})
.collect();
let tokenizer = searcher.index().tokenizer_for_field(field)?;
Ok(SnippetGenerator {
Expand All @@ -263,6 +271,11 @@ impl SnippetGenerator {
self.max_num_chars = max_num_chars;
}

#[cfg(test)]
pub fn terms_text(&self) -> &BTreeMap<String, f32> {
&self.terms_text
}

/// Generates a snippet for the given `Document`.
///
/// This method extract the text associated to the `SnippetGenerator`'s field
Expand Down Expand Up @@ -293,7 +306,7 @@ impl SnippetGenerator {
mod tests {
use super::{search_fragments, select_best_fragment_combination};
use query::QueryParser;
use schema::{IndexRecordOption, SchemaBuilder, TextFieldIndexing, TextOptions};
use schema::{IndexRecordOption, SchemaBuilder, TextFieldIndexing, TextOptions, TEXT};
use std::collections::BTreeMap;
use std::iter::Iterator;
use tokenizer::{box_tokenizer, SimpleTokenizer};
Expand All @@ -315,24 +328,67 @@ to the project are from community members.[15]
Rust won first place for "most loved programming language" in the Stack Overflow Developer
Survey in 2016, 2017, and 2018."#;



#[test]
fn test_snippet() {
let boxed_tokenizer = box_tokenizer(SimpleTokenizer);
let mut terms = BTreeMap::new();
terms.insert(String::from("rust"), 1.0);
terms.insert(String::from("language"), 0.9);
let terms = btreemap! {
String::from("rust") => 1.0,
String::from("language") => 0.9
};
let fragments = search_fragments(&*boxed_tokenizer, TEST_TEXT, &terms, 100);
assert_eq!(fragments.len(), 7);
{
let first = fragments.iter().nth(0).unwrap();
let first = &fragments[0];
assert_eq!(first.score, 1.9);
assert_eq!(first.stop_offset, 89);
}
let snippet = select_best_fragment_combination(&fragments[..], &TEST_TEXT);
assert_eq!(snippet.fragments, "Rust is a systems programming language sponsored by Mozilla which\ndescribes it as a \"safe".to_owned());
assert_eq!(snippet.to_html(), "<b>Rust</b> is a systems programming <b>language</b> sponsored by Mozilla which\ndescribes it as a &quot;safe".to_owned())
assert_eq!(snippet.fragments, "Rust is a systems programming language sponsored by \
Mozilla which\ndescribes it as a \"safe");
assert_eq!(snippet.to_html(), "<b>Rust</b> is a systems programming <b>language</b> \
sponsored by Mozilla which\ndescribes it as a &quot;safe")
}


#[test]
fn test_snippet_scored_fragment() {
let boxed_tokenizer = box_tokenizer(SimpleTokenizer);
{
let terms = btreemap! {
String::from("rust") =>1.0f32,
String::from("language") => 0.9f32
};
let fragments = search_fragments(&*boxed_tokenizer, TEST_TEXT, &terms, 20);
{
let first = &fragments[0];
assert_eq!(first.score, 1.0);
assert_eq!(first.stop_offset, 17);
}
let snippet = select_best_fragment_combination(&fragments[..], &TEST_TEXT);
assert_eq!(snippet.to_html(), "<b>Rust</b> is a systems")
}
let boxed_tokenizer = box_tokenizer(SimpleTokenizer);
{
let terms = btreemap! {
String::from("rust") =>0.9f32,
String::from("language") => 1.0f32
};
let fragments = search_fragments(&*boxed_tokenizer, TEST_TEXT, &terms, 20);
//assert_eq!(fragments.len(), 7);
{
let first = &fragments[0];
assert_eq!(first.score, 0.9);
assert_eq!(first.stop_offset, 17);
}
let snippet = select_best_fragment_combination(&fragments[..], &TEST_TEXT);
assert_eq!(snippet.to_html(), "programming <b>language</b>")
}

}


#[test]
fn test_snippet_in_second_fragment() {
let boxed_tokenizer = box_tokenizer(SimpleTokenizer);
Expand Down Expand Up @@ -439,6 +495,46 @@ Survey in 2016, 2017, and 2018."#;
assert_eq!(snippet.to_html(), "");
}


#[test]
fn test_snippet_generator_term_score() {
let mut schema_builder = SchemaBuilder::default();
let text_field = schema_builder.add_text_field("text", TEXT);
let schema = schema_builder.build();
let index = Index::create_in_ram(schema);
{
// writing the segment
let mut index_writer = index.writer_with_num_threads(1, 40_000_000).unwrap();
index_writer.add_document(doc!(text_field => "a"));
index_writer.add_document(doc!(text_field => "a"));
index_writer.add_document(doc!(text_field => "a b"));
index_writer.commit().unwrap();
index.load_searchers().unwrap();
}
let searcher = index.searcher();
let query_parser = QueryParser::for_index(&index, vec![text_field]);
{
let query = query_parser.parse_query("e").unwrap();
let snippet_generator = SnippetGenerator::new(&searcher, &*query, text_field).unwrap();
assert!(snippet_generator.terms_text().is_empty());
}
{
let query = query_parser.parse_query("a").unwrap();
let snippet_generator = SnippetGenerator::new(&searcher, &*query, text_field).unwrap();
assert_eq!(&btreemap!("a".to_string() => 0.25f32), snippet_generator.terms_text());
}
{
let query = query_parser.parse_query("a b").unwrap();
let snippet_generator = SnippetGenerator::new(&searcher, &*query, text_field).unwrap();
assert_eq!(&btreemap!("a".to_string() => 0.25f32, "b".to_string() => 0.5), snippet_generator.terms_text());
}
{
let query = query_parser.parse_query("a b c").unwrap();
let snippet_generator = SnippetGenerator::new(&searcher, &*query, text_field).unwrap();
assert_eq!(&btreemap!("a".to_string() => 0.25f32, "b".to_string() => 0.5), snippet_generator.terms_text());
}
}

#[test]
fn test_snippet_generator() {
let mut schema_builder = SchemaBuilder::default();
Expand Down

0 comments on commit 5449ec3

Please sign in to comment.