Skip to content

Commit

Permalink
Makes decode and decode_batch work on borrowed content. (#1251)
Browse files Browse the repository at this point in the history
* Makes `decode` and `decode_batch` work on borrowed content.

* Make `decode_batch` work with borrowed content.

* Fix lint.

* Attempt to map it into Node.

* Second attempt.

* Step by step.

* One more step.

* Fix lint.

* Please ...

* Removing collect.

* Revert "Removing collect."

This reverts commit 2f7ec04.

---------

Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
  • Loading branch information
mfuntowicz and Narsil authored May 17, 2023
1 parent cefc41e commit b4fcc9c
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 13 deletions.
7 changes: 5 additions & 2 deletions bindings/node/native/src/tasks/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,17 @@ impl Task for DecodeTask {
.tokenizer
.read()
.unwrap()
.decode(ids.to_vec(), *skip_special_tokens)
.decode(ids.as_slice(), *skip_special_tokens)
.map_err(|e| format!("{}", e))
.map(DecodeOutput::Single),
DecodeTask::Batch(worker, ids, skip_special_tokens) => worker
.tokenizer
.read()
.unwrap()
.decode_batch(ids.to_vec(), *skip_special_tokens)
.decode_batch(
&ids.iter().map(|v| v.as_slice()).collect::<Vec<&[u32]>>(),
*skip_special_tokens,
)
.map_err(|e| format!("{}", e))
.map(DecodeOutput::Batch),
}
Expand Down
5 changes: 3 additions & 2 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,7 @@ impl PyTokenizer {
#[pyo3(signature = (ids, skip_special_tokens = true))]
#[pyo3(text_signature = "(self, ids, skip_special_tokens=True)")]
fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> PyResult<String> {
ToPyResult(self.tokenizer.decode(ids, skip_special_tokens)).into()
ToPyResult(self.tokenizer.decode(&ids, skip_special_tokens)).into()
}

/// Decode a batch of ids back to their corresponding string
Expand All @@ -1032,7 +1032,8 @@ impl PyTokenizer {
skip_special_tokens: bool,
) -> PyResult<Vec<String>> {
py.allow_threads(|| {
ToPyResult(self.tokenizer.decode_batch(sequences, skip_special_tokens)).into()
let slices = sequences.iter().map(|v| &v[..]).collect::<Vec<&[u32]>>();
ToPyResult(self.tokenizer.decode_batch(&slices, skip_special_tokens)).into()
})
}

Expand Down
2 changes: 1 addition & 1 deletion tokenizers/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ fn shell(vocab: &str, merges: &str) -> Result<()> {
println!("Offsets:\t{:?}", encoded.get_offsets());
println!(
"Decoded:\t{}",
tokenizer.decode(encoded.get_ids().to_vec(), true).unwrap()
tokenizer.decode(encoded.get_ids(), true).unwrap()
);
println!("Tokenized in {:?}", elapsed);
}
Expand Down
8 changes: 4 additions & 4 deletions tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -795,12 +795,12 @@ where
}

/// Decode the given ids, back to a String
pub fn decode(&self, ids: Vec<u32>, skip_special_tokens: bool) -> Result<String> {
pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result<String> {
let tokens = ids
.into_iter()
.iter()
.filter_map(|id| {
self.added_vocabulary
.id_to_token(id, &self.model)
.id_to_token(*id, &self.model)
.filter(|token| {
!skip_special_tokens || !self.added_vocabulary.is_special_token(token)
})
Expand Down Expand Up @@ -1008,7 +1008,7 @@ where
/// Decode all sentences in parallel
pub fn decode_batch(
&self,
sentences: Vec<Vec<u32>>,
sentences: &[&[u32]],
skip_special_tokens: bool,
) -> Result<Vec<String>>
where
Expand Down
8 changes: 4 additions & 4 deletions tokenizers/tests/documentation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ fn load_tokenizer() {
assert_eq!(encodings.get_ids(), ids);
assert_eq!(encodings.get_tokens(), tokens);

let decoded = tokenizer.decode(ids, false).unwrap();
let decoded = tokenizer.decode(&ids, false).unwrap();
assert_eq!(decoded, example);
}

Expand Down Expand Up @@ -347,7 +347,7 @@ fn pipeline() -> tokenizers::Result<()> {
// [1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2]

let decoded = tokenizer.decode(
vec![1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2],
&[1, 27253, 16, 93, 11, 5097, 5, 7961, 5112, 6218, 0, 35, 2],
true,
)?;
println!("{}", decoded);
Expand Down Expand Up @@ -435,7 +435,7 @@ fn pipeline_bert() -> tokenizers::Result<()> {
println!("{:?}", output.get_tokens());
// ["[CLS]", "welcome", "to", "the", "[UNK]", "tok", "##eni", "##zer", "##s", "library", ".", "[SEP]"]

let decoded = bert_tokenizer.decode(output.get_ids().to_vec(), true)?;
let decoded = bert_tokenizer.decode(output.get_ids(), true)?;
println!("{}", decoded);
// "welcome to the tok ##eni ##zer ##s library ."
// END bert_test_decoding
Expand All @@ -451,7 +451,7 @@ fn pipeline_bert() -> tokenizers::Result<()> {
use tokenizers::decoders::wordpiece::WordPiece as WordPieceDecoder;

bert_tokenizer.with_decoder(WordPieceDecoder::default());
let decoded = bert_tokenizer.decode(output.get_ids().to_vec(), true)?;
let decoded = bert_tokenizer.decode(output.get_ids(), true)?;
// "welcome to the tokenizers library."
// END bert_proper_decoding
assert_eq!(decoded, "welcome to the tokenizers library.");
Expand Down

0 comments on commit b4fcc9c

Please sign in to comment.