Skip to content

Commit

Permalink
core: splitter improvements and fixes (#2715)
Browse files Browse the repository at this point in the history
* splitter fixes

* filter whitespace only or empty chunks
  • Loading branch information
spolu authored Nov 30, 2023
1 parent 994a78f commit 7df2145
Showing 1 changed file with 70 additions and 33 deletions.
103 changes: 70 additions & 33 deletions core/src/data_sources/splitter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,11 @@ async fn split_text(
// Construct valid decoded chunks.
let mut splits: Vec<TokenizedText> = vec![];

// We attempt to split in a balanced manner to avoid trailing samll chunks.
let target_chunk_size = encoded.len() / (encoded.len() / max_chunk_size + 1);

while encoded.len() > 0 {
let mut current_chunk_size = cmp::min(max_chunk_size, encoded.len());
let mut current_chunk_size = cmp::min(target_chunk_size, encoded.len());
let tokenized_chunk = &encoded[0..current_chunk_size];

match decode_chunk_with_remainder(&embedder, tokenized_chunk).await {
Expand Down Expand Up @@ -164,34 +167,37 @@ impl TokenizedSection {

let mut sections: Vec<TokenizedSection> = vec![];

// Create new children for content to enforce the invariant that content nodes are leaf
// nodes. If content overflows max_chunk_size, we split in multiple nodes to enforce the
// invariant that any content node fit in a `max_chunk_size`.
// Create new children for content if a prefix is present to enforce the invariant that
// content nodes are leaf nodes. Even if prefix is not present but content overflows
// max_chunk_size, we split in multiple nodes to enforce the invariant that any content node
// fit in a `max_chunk_size`.
if let Some(c) = content.as_ref() {
let effective_max_chunk_size = max_chunk_size - prefixes_tokens_count;

let splits = match c.tokens.len() > effective_max_chunk_size {
true => split_text(&embedder, effective_max_chunk_size, &c.text).await?,
false => vec![c.clone()],
};

// Prepend to childrens the splits of the content with no additional prefixes (they
// will inherit the current section prefixes whose content will be removed).
sections.extend(
splits
.into_iter()
.map(|t| TokenizedSection {
max_chunk_size,
prefixes: prefixes.clone(),
tokens_count: prefixes_tokens_count + t.tokens.len(),
content: Some(t),
sections: vec![],
})
.collect::<Vec<_>>(),
);
if (c.tokens.len() + prefixes_tokens_count) > max_chunk_size || prefix.is_some() {
let effective_max_chunk_size = max_chunk_size - prefixes_tokens_count;

let splits = match c.tokens.len() > effective_max_chunk_size {
true => split_text(&embedder, effective_max_chunk_size, &c.text).await?,
false => vec![c.clone()],
};

// Prepend to childrens the splits of the content with no additional prefixes (they
// will inherit the current section prefixes whose content will be removed).
sections.extend(
splits
.into_iter()
.map(|t| TokenizedSection {
max_chunk_size,
prefixes: prefixes.clone(),
tokens_count: prefixes_tokens_count + t.tokens.len(),
content: Some(t),
sections: vec![],
})
.collect::<Vec<_>>(),
);

// Remove the content from the current section.
content = None;
// Remove the content from the current section.
content = None;
}
}

sections.extend(
Expand Down Expand Up @@ -428,9 +434,12 @@ impl Splitter for BaseV0Splitter {
let tokenized_section =
TokenizedSection::from(&embedder, max_chunk_size, vec![], &section).await?;

// We filter out whitespace only or empty strings which is possible to obtain if the section
// passed have empty or whitespace only content.
Ok(tokenized_section
.chunks()
.into_iter()
.filter(|t| t.text.trim().len() > 0)
.map(|t| t.text)
.collect())
}
Expand Down Expand Up @@ -491,11 +500,11 @@ mod tests {
let cases: [(String, usize); 2] = [
(
"a random document string with no double space".repeat(10),
10,
12,
),
(
"a random document \nstring WITH double spaces".repeat(8),
10,
12,
),
];

Expand Down Expand Up @@ -804,13 +813,41 @@ mod tests {
splitted.join("|"),
vec![
"pc..".to_string(),
"pp0c0........+-+-+-+-+-+-+-+-+-+-+-+-+-++-+-+-+-+-".to_string(),
"pp0+-++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-++-".to_string(),
"pp0+-+-+-+-+-++-+-+-+-+-+-+-c01......".to_string(),
"pp0p02c02....".to_string(),
"pp0c0........+-+-+-+-+-+-+-+-+-+-+-+-+-++-".to_string(),
"pp0+-+-+-+-+-++-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-".to_string(),
"pp0++-+-+-+-+-+-++-+-+-+-+-+-+-".to_string(),
"pp0c01......p02c02....".to_string(),
"pp1c1p10c10........".to_string(),
]
.join("|")
)
}

#[tokio::test]
async fn test_splitter_v0_unaligned_content() {
let section = Section {
prefix: None,
content: Some("asdjqweiozclknasidjhqlkdnaljch\n".to_string()),
sections: vec![],
};

let provider_id = ProviderID::OpenAI;
let model_id = "text-embedding-ada-002";
let credentials = Credentials::from([("OPENAI_API_KEY".to_string(), "abc".to_string())]);

let splitted = splitter(SplitterID::BaseV0)
.split(credentials, provider_id, model_id, 8, section)
.await
.unwrap();

assert_eq!(
splitted.join("|"),
vec![
"asdjqweioz".to_string(),
"clknasidjh".to_string(),
"qlkdnaljch\n".to_string(),
]
.join("|")
)
}
}

0 comments on commit 7df2145

Please sign in to comment.