From ba7b512b9b0c56ea832f07b81cb977cb8db4e8e0 Mon Sep 17 00:00:00 2001 From: Garvys Date: Fri, 16 Dec 2022 16:47:34 +0100 Subject: [PATCH 1/9] Lazy relabel fst --- Cargo.lock | 6 +- rustfst/src/algorithms/compose/matcher_fst.rs | 25 +++ rustfst/src/algorithms/mod.rs | 1 + rustfst/src/algorithms/relabel/mod.rs | 2 + rustfst/src/algorithms/relabel/relabel_fst.rs | 172 ++++++++++++++++++ .../src/algorithms/relabel/relabel_fst_op.rs | 62 +++++++ rustfst/src/algorithms/relabel_pairs.rs | 2 +- .../src/tests_openfst/algorithms/compose.rs | 1 + 8 files changed, 267 insertions(+), 4 deletions(-) create mode 100644 rustfst/src/algorithms/relabel/mod.rs create mode 100644 rustfst/src/algorithms/relabel/relabel_fst.rs create mode 100644 rustfst/src/algorithms/relabel/relabel_fst_op.rs diff --git a/Cargo.lock b/Cargo.lock index 1427e2fdd..6d49a23e6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -696,7 +696,7 @@ dependencies = [ [[package]] name = "rustfst" -version = "0.13.0" +version = "0.13.1" dependencies = [ "anyhow", "bimap", @@ -726,7 +726,7 @@ dependencies = [ [[package]] name = "rustfst-cli" -version = "0.13.0" +version = "0.13.1" dependencies = [ "anyhow", "clap", @@ -740,7 +740,7 @@ dependencies = [ [[package]] name = "rustfst-ffi" -version = "0.13.0" +version = "0.13.1" dependencies = [ "anyhow", "downcast-rs", diff --git a/rustfst/src/algorithms/compose/matcher_fst.rs b/rustfst/src/algorithms/compose/matcher_fst.rs index abe2fad72..fcf9f7d5f 100644 --- a/rustfst/src/algorithms/compose/matcher_fst.rs +++ b/rustfst/src/algorithms/compose/matcher_fst.rs @@ -79,6 +79,7 @@ where let omatcher_data = M::create_data::(&fst, MatchType::MatchOutput)?; let mut add_on = (imatcher_data, omatcher_data); + LabelLookAheadRelabeler::init(&mut fst, &mut add_on)?; LabelLookAheadRelabeler::relabel(fst2, &mut add_on, relabel_input)?; @@ -91,6 +92,30 @@ where w: PhantomData, }) } + + // // Construct a new Matcher Fst intended for LookAhead composition and relabel fst2 wrt to the first fst. + // pub fn new_with_relabeling_2>( + // mut fst: F, + // fst2: &F2, + // relabel_input: bool, + // ) -> Result { + // let imatcher_data = M::create_data::(&fst, MatchType::MatchInput)?; + // let omatcher_data = M::create_data::(&fst, MatchType::MatchOutput)?; + // + // let mut add_on = (imatcher_data, omatcher_data); + // + // LabelLookAheadRelabeler::init(&mut fst, &mut add_on)?; + // LabelLookAheadRelabeler::relabel(fst2, &mut add_on, relabel_input)?; + // + // let add_on = (add_on.0.map(Arc::new), add_on.1.map(Arc::new)); + // + // let fst_add_on = FstAddOn::new(fst, add_on); + // Ok(Self { + // fst_add_on, + // matcher: PhantomData, + // w: PhantomData, + // }) + // } } impl, B: Borrow, M, T> CoreFst for MatcherFst { diff --git a/rustfst/src/algorithms/mod.rs b/rustfst/src/algorithms/mod.rs index 3d4fc45b1..d37ace4d6 100644 --- a/rustfst/src/algorithms/mod.rs +++ b/rustfst/src/algorithms/mod.rs @@ -50,6 +50,7 @@ mod partition; mod projection; mod push; mod queue; +pub mod relabel; /// Module providing functions to randomly generate paths through an Fst. A static and a delayed version are available. pub mod randgen; diff --git a/rustfst/src/algorithms/relabel/mod.rs b/rustfst/src/algorithms/relabel/mod.rs new file mode 100644 index 000000000..142ad1260 --- /dev/null +++ b/rustfst/src/algorithms/relabel/mod.rs @@ -0,0 +1,2 @@ +mod relabel_fst; +mod relabel_fst_op; diff --git a/rustfst/src/algorithms/relabel/relabel_fst.rs b/rustfst/src/algorithms/relabel/relabel_fst.rs new file mode 100644 index 000000000..964aa9d7d --- /dev/null +++ b/rustfst/src/algorithms/relabel/relabel_fst.rs @@ -0,0 +1,172 @@ +use crate::algorithms::lazy::{LazyFst, SimpleHashMapCache}; +use crate::fst_properties::FstProperties; +use crate::fst_traits::{AllocableFst, CoreFst, Fst, FstIterator, MutableFst, StateIterator}; +use crate::{Semiring, StateId, SymbolTable, TrsVec}; +use anyhow::{Result, Context}; +use std::borrow::Borrow; +use std::fmt::Debug; +use std::sync::Arc; +use crate::algorithms::relabel::relabel_fst_op::RelabelFstOp; +use crate::algorithms::relabel_pairs::iterator_to_hashmap; + +type InnerLazyFst = LazyFst, SimpleHashMapCache>; + +struct RelabelFst, B: Borrow>(InnerLazyFst); + +impl, B: Borrow> RelabelFst { + pub fn new(fst: B, ipairs: I, opairs: J) -> Result + where + I: IntoIterator, + J: IntoIterator, + { + let map_ilabels = iterator_to_hashmap(ipairs) + .with_context(|| format_err!("Error while creating the HashMap for ipairs"))?; + + let map_olabels = iterator_to_hashmap(opairs) + .with_context(|| format_err!("Error while creating the HashMap for opairs"))?; + + let isymt = fst.borrow().input_symbols().cloned(); + let osymt = fst.borrow().output_symbols().cloned(); + + let fst_op = RelabelFstOp::new(fst, map_ilabels, map_olabels); + let fst_cache = SimpleHashMapCache::default(); + Ok(RelabelFst(LazyFst::from_op_and_cache( + fst_op, fst_cache, isymt, osymt, + ))) + } + + pub fn compute + AllocableFst>(&self) -> Result { + self.0.compute() + } +} + +impl CoreFst for RelabelFst +where + W: Semiring, + F: Fst, + B: Borrow, +{ + type TRS = TrsVec; + + fn start(&self) -> Option { + self.0.start() + } + + fn final_weight(&self, state_id: StateId) -> Result> { + self.0.final_weight(state_id) + } + + unsafe fn final_weight_unchecked(&self, state_id: StateId) -> Option { + self.0.final_weight_unchecked(state_id) + } + + fn num_trs(&self, s: StateId) -> Result { + self.0.num_trs(s) + } + + unsafe fn num_trs_unchecked(&self, s: StateId) -> usize { + self.0.num_trs_unchecked(s) + } + + fn get_trs(&self, state_id: StateId) -> Result { + self.0.get_trs(state_id) + } + + unsafe fn get_trs_unchecked(&self, state_id: StateId) -> Self::TRS { + self.0.get_trs_unchecked(state_id) + } + + fn properties(&self) -> FstProperties { + self.0.properties() + } + + fn num_input_epsilons(&self, state: StateId) -> Result { + self.0.num_input_epsilons(state) + } + + fn num_output_epsilons(&self, state: StateId) -> Result { + self.0.num_output_epsilons(state) + } +} + +impl<'a, W, F, B> StateIterator<'a> for RelabelFst +where + W: Semiring, + F: Fst + 'a, + B: Borrow + 'a, +{ + type Iter = as StateIterator<'a>>::Iter; + + fn states_iter(&'a self) -> Self::Iter { + self.0.states_iter() + } +} + +impl<'a, W, F, B> FstIterator<'a, W> for RelabelFst +where + W: Semiring, + F: Fst + 'a, + B: Borrow + 'a, +{ + type FstIter = as FstIterator<'a, W>>::FstIter; + + fn fst_iter(&'a self) -> Self::FstIter { + self.0.fst_iter() + } +} + +impl Fst for RelabelFst +where + W: Semiring, + F: Fst + 'static, + B: Borrow + 'static, +{ + fn input_symbols(&self) -> Option<&Arc> { + self.0.input_symbols() + } + + fn output_symbols(&self) -> Option<&Arc> { + self.0.output_symbols() + } + + fn set_input_symbols(&mut self, symt: Arc) { + self.0.set_input_symbols(symt) + } + + fn set_output_symbols(&mut self, symt: Arc) { + self.0.set_output_symbols(symt) + } + + fn take_input_symbols(&mut self) -> Option> { + self.0.take_input_symbols() + } + + fn take_output_symbols(&mut self) -> Option> { + self.0.take_output_symbols() + } +} + +impl Debug for RelabelFst +where + W: Semiring, + F: Fst, + B: Borrow, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +#[cfg(test)] +mod test { + use crate::fst_impls::VectorFst; + use crate::semirings::TropicalWeight; + + use super::*; + + #[test] + fn test_replace_fst_sync() { + fn is_sync() {} + is_sync::, VectorFst<_>>>(); + } +} diff --git a/rustfst/src/algorithms/relabel/relabel_fst_op.rs b/rustfst/src/algorithms/relabel/relabel_fst_op.rs new file mode 100644 index 000000000..fc1931efb --- /dev/null +++ b/rustfst/src/algorithms/relabel/relabel_fst_op.rs @@ -0,0 +1,62 @@ +use std::borrow::Borrow; +use std::marker::PhantomData; + +use anyhow::Result; + +use crate::algorithms::lazy::FstOp; +use crate::fst_properties::FstProperties; +use crate::fst_traits::Fst; +use crate::{Semiring, TrsVec, StateId, Trs}; +use std::fmt::{Debug, Formatter}; +use std::collections::HashMap; +use std::sync::Arc; + +pub struct RelabelFstOp, B: Borrow> { + fst: B, + map_ilabels: HashMap, + map_olabels: HashMap, + ghost: PhantomData<(W, F)>, +} + +impl, B: Borrow> RelabelFstOp { + pub fn new(fst: B, map_ilabels: HashMap, map_olabels: HashMap) -> Self { + Self { + fst, map_ilabels, map_olabels, ghost: PhantomData + } + } +} + +impl, B: Borrow> Debug for RelabelFstOp { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + unimplemented!() + } +} + +impl, B: Borrow> FstOp for RelabelFstOp { + fn compute_start(&self) -> Result> { + Ok(self.fst.borrow().start()) + } + + fn compute_trs(&self, id: u32) -> Result> { + let trs_original = self.fst.borrow().get_trs(id)?; + let mut trs = vec![]; + for tr in trs_original.trs() { + let mut new_tr = tr.clone(); + if let Some(new_ilabel) = self.map_ilabels.get(&tr.ilabel) { + new_tr.ilabel = *new_ilabel; + } + if let Some(new_olabel) = self.map_olabels.get(&tr.olabel) { + new_tr.olabel = *new_olabel; + } + } + Ok(TrsVec(Arc::new(trs))) + } + + fn compute_final_weight(&self, id: u32) -> Result> { + self.fst.borrow().final_weight(id) + } + + fn properties(&self) -> FstProperties { + unimplemented!() + } +} diff --git a/rustfst/src/algorithms/relabel_pairs.rs b/rustfst/src/algorithms/relabel_pairs.rs index fd342ecef..722bf222d 100644 --- a/rustfst/src/algorithms/relabel_pairs.rs +++ b/rustfst/src/algorithms/relabel_pairs.rs @@ -7,7 +7,7 @@ use crate::fst_traits::MutableFst; use crate::semirings::Semiring; use crate::StateId; -fn iterator_to_hashmap(pairs: I) -> Result> +pub fn iterator_to_hashmap(pairs: I) -> Result> where I: IntoIterator, { diff --git a/rustfst/src/tests_openfst/algorithms/compose.rs b/rustfst/src/tests_openfst/algorithms/compose.rs index 6c2959e31..7615ab309 100644 --- a/rustfst/src/tests_openfst/algorithms/compose.rs +++ b/rustfst/src/tests_openfst/algorithms/compose.rs @@ -121,6 +121,7 @@ fn do_test_compose_lookahead( where W: SerializableSemiring + WeightQuantize + WeaklyDivisibleSemiring, { + println!("Lookahead composition"); type TLaFst = MatcherFst< S, F, From 1de3858853ffd8175c74b0f43ab1d2d763bfbd0f Mon Sep 17 00:00:00 2001 From: Garvys Date: Mon, 19 Dec 2022 17:28:26 +0100 Subject: [PATCH 2/9] Implemement lazy lookahead relabeling --- .../src/algorithms/compose/label_reachable.rs | 42 +++++ .../push_weights_compose_filter.rs | 6 +- .../label_lookahead_relabeler.rs | 21 ++- .../lookahead_relabel_fst.rs | 171 ++++++++++++++++++ .../lookahead_relabel_fst_op.rs | 77 ++++++++ .../compose/lookahead_relabel/mod.rs | 4 + rustfst/src/algorithms/compose/matcher_fst.rs | 49 ++--- rustfst/src/algorithms/compose/mod.rs | 3 + rustfst/src/algorithms/lazy/lazy_fst.rs | 12 +- rustfst/src/algorithms/relabel/mod.rs | 2 + rustfst/src/algorithms/relabel/relabel_fst.rs | 8 +- .../src/algorithms/relabel/relabel_fst_op.rs | 24 ++- .../src/tests_openfst/algorithms/compose.rs | 159 +++++++++++++++- 13 files changed, 535 insertions(+), 43 deletions(-) create mode 100644 rustfst/src/algorithms/compose/lookahead_relabel/lookahead_relabel_fst.rs create mode 100644 rustfst/src/algorithms/compose/lookahead_relabel/lookahead_relabel_fst_op.rs create mode 100644 rustfst/src/algorithms/compose/lookahead_relabel/mod.rs diff --git a/rustfst/src/algorithms/compose/label_reachable.rs b/rustfst/src/algorithms/compose/label_reachable.rs index 12698419a..827b908b9 100644 --- a/rustfst/src/algorithms/compose/label_reachable.rs +++ b/rustfst/src/algorithms/compose/label_reachable.rs @@ -10,6 +10,7 @@ use crate::algorithms::{fst_convert_from_ref, tr_sort}; use crate::fst_impls::VectorFst; use crate::fst_properties::FstProperties; use crate::fst_traits::{CoreFst, ExpandedFst, Fst, MutableFst}; +use crate::prelude::compose::LookaheadRelabelFst; use crate::semirings::Semiring; use crate::{Label, StateId, Tr, Trs, EPS_LABEL, NO_LABEL, UNASSIGNED}; @@ -60,6 +61,19 @@ impl LabelReachableData { .or_insert_with(|| n as Label + 1) } + pub fn relabel_unmut(&self, label: Label) -> Result