From c7f14df464ca96f4071f307df7830644eb0b6197 Mon Sep 17 00:00:00 2001 From: ZJaume Date: Thu, 3 Oct 2024 14:55:11 +0000 Subject: [PATCH] Update languagemodel tests --- src/languagemodel.rs | 83 ++++++++++++++++++++++++-------------------- 1 file changed, 45 insertions(+), 38 deletions(-) diff --git a/src/languagemodel.rs b/src/languagemodel.rs index b08f2a3..f02f291 100644 --- a/src/languagemodel.rs +++ b/src/languagemodel.rs @@ -58,6 +58,7 @@ impl ModelNgram { } } + /// Load the model from plain text for a subset of languages pub fn from_text_langs( model_dir: &Path, model_type: OrderNgram, @@ -80,6 +81,7 @@ impl ModelNgram { Ok(model) } + /// Load the model from plain text for all languages pub fn from_text_all(model_dir: &Path, model_type: OrderNgram) -> Result { let mut model = ModelNgram { dic: HashMap::default(), @@ -112,6 +114,7 @@ impl ModelNgram { Ok(model) } + /// Parse the ngram file, compute probabilities and insert into the model fn read_model(&mut self, p: &Path, langcode: &Lang) -> Result<()> { // Read the language model file to a string all at once let modelfile = @@ -289,47 +292,51 @@ impl Index for Model { #[cfg(test)] mod tests { use super::*; - use std::collections::HashMap; - use std::thread; + + use tempfile::NamedTempFile; #[test] fn test_langs() { + let tempf = NamedTempFile::new().unwrap(); + let temppath = tempf.into_temp_path(); let modelpath = Path::new("./LanguageModels"); - let wordmodel = ModelNgram::from_text(&modelpath, OrderNgram::Word); - let path = Path::new("wordict.ser"); - wordmodel.save(path); - - let charmodel = ModelNgram::from_text(&modelpath, OrderNgram::Quadgram); - let path = Path::new("gramdict.ser"); - charmodel.save(path); - - let char_handle = thread::spawn(move || { - let path = Path::new("gramdict.ser"); - ModelNgram::from_bin(path) - }); - - let word_handle = thread::spawn(move || { - let path = Path::new("wordict.ser"); - ModelNgram::from_bin(path) - }); - - // let word_model = word_handle.join().unwrap(); - let char_model = char_handle.join().unwrap(); - - // failing because original HeLI is using a java float - // instead of a double for accumulating frequencies - let mut expected = HashMap::default(); - expected.insert(Lang::Cat, 3.4450269f32); - expected.insert(Lang::Epo, 4.5279417f32); - expected.insert(Lang::Ext, 2.5946937f32); - expected.insert(Lang::Gla, 4.7058706f32); - expected.insert(Lang::Glg, 2.3187783f32); - expected.insert(Lang::Grn, 2.9653773f32); - expected.insert(Lang::Nhn, 4.774119f32); - expected.insert(Lang::Que, 3.8074818f32); - expected.insert(Lang::Spa, 2.480955f32); - - let probs = char_model.dic.get("ación").unwrap(); - assert_eq!(probs, &expected); + + let model = ModelNgram::from_text(&modelpath, + OrderNgram::Quingram, + None).unwrap(); + // let path = Path::new("gramdict.ser"); + model.save(&temppath).unwrap(); + let model = ModelNgram::from_bin(&temppath).unwrap(); + temppath.close().unwrap(); + + let mut expected = Vec::new(); + expected.push((Lang::ayr, 4.2863530f32)); + expected.push((Lang::cat, 3.3738296f32)); + expected.push((Lang::epo, 4.5279417f32)); + expected.push((Lang::ext, 2.5946038f32)); + expected.push((Lang::gla, 4.7052390f32)); + expected.push((Lang::glg, 2.3186955f32)); + expected.push((Lang::grn, 3.1885893f32)); + expected.push((Lang::kac, 5.5482570f32)); + expected.push((Lang::lmo, 5.2805230f32)); + expected.push((Lang::nhn, 5.0725970f32)); + expected.push((Lang::que, 3.8049161f32)); + expected.push((Lang::spa, 2.3922930f32)); + expected.push((Lang::vol, 5.1173210f32)); + + let mut probs = model.dic.get("ación") + .expect("Could not found the ngram in the model") + .clone(); + // round to less decimals to be a lit permissive + // as there are differences between java and rust + let round_to = 10000.0; + for i in expected.iter_mut() { + i.1 = (i.1 * round_to).round() / round_to; + } + for i in probs.iter_mut() { + i.1 = (i.1 * round_to).round() / round_to; + } + assert_eq!(&probs, &expected); + } }