diff --git a/core/src/data_sources/splitter.rs b/core/src/data_sources/splitter.rs index 0134ae6209a8..0d3224924d77 100644 --- a/core/src/data_sources/splitter.rs +++ b/core/src/data_sources/splitter.rs @@ -80,8 +80,11 @@ async fn split_text( // Construct valid decoded chunks. let mut splits: Vec = vec![]; + // We attempt to split in a balanced manner to avoid trailing samll chunks. + let target_chunk_size = encoded.len() / (encoded.len() / max_chunk_size + 1); + while encoded.len() > 0 { - let mut current_chunk_size = cmp::min(max_chunk_size, encoded.len()); + let mut current_chunk_size = cmp::min(target_chunk_size, encoded.len()); let tokenized_chunk = &encoded[0..current_chunk_size]; match decode_chunk_with_remainder(&embedder, tokenized_chunk).await { @@ -164,34 +167,37 @@ impl TokenizedSection { let mut sections: Vec = vec![]; - // Create new children for content to enforce the invariant that content nodes are leaf - // nodes. If content overflows max_chunk_size, we split in multiple nodes to enforce the - // invariant that any content node fit in a `max_chunk_size`. + // Create new children for content if a prefix is present to enforce the invariant that + // content nodes are leaf nodes. Even if prefix is not present but content overflows + // max_chunk_size, we split in multiple nodes to enforce the invariant that any content node + // fit in a `max_chunk_size`. if let Some(c) = content.as_ref() { - let effective_max_chunk_size = max_chunk_size - prefixes_tokens_count; - - let splits = match c.tokens.len() > effective_max_chunk_size { - true => split_text(&embedder, effective_max_chunk_size, &c.text).await?, - false => vec![c.clone()], - }; - - // Prepend to childrens the splits of the content with no additional prefixes (they - // will inherit the current section prefixes whose content will be removed). - sections.extend( - splits - .into_iter() - .map(|t| TokenizedSection { - max_chunk_size, - prefixes: prefixes.clone(), - tokens_count: prefixes_tokens_count + t.tokens.len(), - content: Some(t), - sections: vec![], - }) - .collect::>(), - ); + if (c.tokens.len() + prefixes_tokens_count) > max_chunk_size || prefix.is_some() { + let effective_max_chunk_size = max_chunk_size - prefixes_tokens_count; + + let splits = match c.tokens.len() > effective_max_chunk_size { + true => split_text(&embedder, effective_max_chunk_size, &c.text).await?, + false => vec![c.clone()], + }; + + // Prepend to childrens the splits of the content with no additional prefixes (they + // will inherit the current section prefixes whose content will be removed). + sections.extend( + splits + .into_iter() + .map(|t| TokenizedSection { + max_chunk_size, + prefixes: prefixes.clone(), + tokens_count: prefixes_tokens_count + t.tokens.len(), + content: Some(t), + sections: vec![], + }) + .collect::>(), + ); - // Remove the content from the current section. - content = None; + // Remove the content from the current section. + content = None; + } } sections.extend( @@ -428,9 +434,12 @@ impl Splitter for BaseV0Splitter { let tokenized_section = TokenizedSection::from(&embedder, max_chunk_size, vec![], §ion).await?; + // We filter out whitespace only or empty strings which is possible to obtain if the section + // passed have empty or whitespace only content. Ok(tokenized_section .chunks() .into_iter() + .filter(|t| t.text.trim().len() > 0) .map(|t| t.text) .collect()) } @@ -491,11 +500,11 @@ mod tests { let cases: [(String, usize); 2] = [ ( "a random document string with no double space".repeat(10), - 10, + 12, ), ( "a random document \nstring WITH double spaces".repeat(8), - 10, + 12, ), ]; @@ -804,13 +813,41 @@ mod tests { splitted.join("|"), vec![ "pc..".to_string(), - "pp0c0........+-+-+-+-+-+-+-+-+-+-+-+-+-++-+-+-+-+-".to_string(), - "pp0+-++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-++-".to_string(), - "pp0+-+-+-+-+-++-+-+-+-+-+-+-c01......".to_string(), - "pp0p02c02....".to_string(), + "pp0c0........+-+-+-+-+-+-+-+-+-+-+-+-+-++-".to_string(), + "pp0+-+-+-+-+-++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-".to_string(), + "pp0++-+-+-+-+-+-++-+-+-+-+-+-+-".to_string(), + "pp0c01......p02c02....".to_string(), "pp1c1p10c10........".to_string(), ] .join("|") ) } + + #[tokio::test] + async fn test_splitter_v0_unaligned_content() { + let section = Section { + prefix: None, + content: Some("asdjqweiozclknasidjhqlkdnaljch\n".to_string()), + sections: vec![], + }; + + let provider_id = ProviderID::OpenAI; + let model_id = "text-embedding-ada-002"; + let credentials = Credentials::from([("OPENAI_API_KEY".to_string(), "abc".to_string())]); + + let splitted = splitter(SplitterID::BaseV0) + .split(credentials, provider_id, model_id, 8, section) + .await + .unwrap(); + + assert_eq!( + splitted.join("|"), + vec![ + "asdjqweioz".to_string(), + "clknasidjh".to_string(), + "qlkdnaljch\n".to_string(), + ] + .join("|") + ) + } }