Skip to content

Commit

Permalink
Support deserializing tokenizer in vaporetto_tantivy
Browse files Browse the repository at this point in the history
  • Loading branch information
vbkaisetsu committed Nov 6, 2024
1 parent fa65215 commit 773a4bd
Showing 1 changed file with 56 additions and 13 deletions.
69 changes: 56 additions & 13 deletions vaporetto_tantivy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,25 @@ pub struct VaporettoTokenizer {
postfilters: Vec<Arc<dyn SentenceFilter>>,
}

fn build_post_filters(
wsconst: &str,
) -> Result<Vec<Arc<dyn SentenceFilter>>, Box<dyn std::error::Error>> {
let mut postfilters: Vec<Arc<dyn SentenceFilter>> = vec![Arc::new(SplitLinebreaksFilter)];
for c in wsconst.chars() {
postfilters.push(match c {
'D' => Arc::new(KyteaWsConstFilter::new(CharacterType::Digit)),
'R' => Arc::new(KyteaWsConstFilter::new(CharacterType::Roman)),
'H' => Arc::new(KyteaWsConstFilter::new(CharacterType::Hiragana)),
'T' => Arc::new(KyteaWsConstFilter::new(CharacterType::Katakana)),
'K' => Arc::new(KyteaWsConstFilter::new(CharacterType::Kanji)),
'O' => Arc::new(KyteaWsConstFilter::new(CharacterType::Other)),
'G' => Arc::new(ConcatGraphemeClustersFilter),
_ => return Err("Could not parse a wsconst value".into()),
});
}
Ok(postfilters)
}

impl VaporettoTokenizer {
/// Creates a new VaporettoTokenizer.
///
Expand All @@ -82,25 +101,49 @@ impl VaporettoTokenizer {
/// - the model is invalid, or
/// - `wsconst` contains an invalid character type.
pub fn new(model: Model, wsconst: &str) -> Result<Self, Box<dyn std::error::Error>> {
let mut postfilters: Vec<Arc<dyn SentenceFilter>> = vec![Arc::new(SplitLinebreaksFilter)];
for c in wsconst.chars() {
postfilters.push(match c {
'D' => Arc::new(KyteaWsConstFilter::new(CharacterType::Digit)),
'R' => Arc::new(KyteaWsConstFilter::new(CharacterType::Roman)),
'H' => Arc::new(KyteaWsConstFilter::new(CharacterType::Hiragana)),
'T' => Arc::new(KyteaWsConstFilter::new(CharacterType::Katakana)),
'K' => Arc::new(KyteaWsConstFilter::new(CharacterType::Kanji)),
'O' => Arc::new(KyteaWsConstFilter::new(CharacterType::Other)),
'G' => Arc::new(ConcatGraphemeClustersFilter),
_ => return Err("Could not parse a wsconst value".into()),
});
}
let postfilters = build_post_filters(wsconst)?;
Ok(Self {
predictor: Arc::new(Predictor::new(model, false)?),
prefilter: KyteaFullwidthFilter,
postfilters,
})
}

/// Creates a new VaporettoTokenizer from a serialized predictor and returns a tuple of the
/// tokenizer and a remaining slice.
///
/// # Arguments
///
/// * `data` - Serialized data of Vaporetto.
/// * `wsconst` - Character types that the tokenizer does not segment.
/// D: Digit, R: Roman, H: Hiragana, T: Katakana, K: Kanji, O: Other,
/// G: Grapheme cluster.
///
/// # Errors
///
/// Error is returned when
/// - the data is invalid, or
/// - `wsconst` contains an invalid character type.
///
/// # Safety
///
/// The given data must be a correct predictor exported by
/// [`vaporetto::Predictor::serialize_to_vec()`] function.
pub unsafe fn deserialize_unchecked<'a>(
data: &'a [u8],
wsconst: &str,
) -> Result<(Self, &'a [u8]), Box<dyn std::error::Error>> {
let postfilters = build_post_filters(wsconst)?;
let (predictor, rest) = Predictor::deserialize_from_slice_unchecked(data)?;
Ok((
Self {
predictor: Arc::new(predictor),
prefilter: KyteaFullwidthFilter,
postfilters,
},
rest,
))
}
}

pub struct VaporettoTokenStream<'a> {
Expand Down

0 comments on commit 773a4bd

Please sign in to comment.