@@ -265,73 +265,71 @@ mod bpe_tests {
265
265
}
266
266
}
267
267
268
+ fn test_bpe ( ) -> Bpe {
269
+ Bpe :: new (
270
+ [
271
+ "<unk>" , //
272
+ "a" , "b" , "c" , "d" , //
273
+ "ab" , "ac" , "ad" , "bd" , //
274
+ "bcd" ,
275
+ ] ,
276
+ [
277
+ 0. , //
278
+ 1. , 1. , 1. , 1. , //
279
+ 1.1 , 1.2 , 1.3 , 1.4 , //
280
+ 10. ,
281
+ ] ,
282
+ [ false ; 10 ] ,
283
+ 0 ,
284
+ )
285
+ }
286
+
268
287
#[ test]
269
288
fn test_bpe_new ( ) {
270
- let vocabs = vec ! [ "a" , "b" , "c" , "ab" , "bc" ] ;
271
- let scores = vec ! [ 1.0 , 1.0 , 1.0 , 2.0 , 2.0 ] ;
272
- let is_byte = vec ! [ false , false , false , false , false ] ;
273
- let bpe = Bpe :: new ( vocabs, scores, is_byte, 0 ) ;
274
- assert_eq ! ( bpe. vocab_size( ) , 5 ) ;
289
+ let bpe = test_bpe ( ) ;
290
+ assert_eq ! ( bpe. vocab_size( ) , 10 ) ;
275
291
}
276
292
277
293
#[ test]
278
- fn test_bpe_encode ( ) {
279
- let vocabs = vec ! [ "a" , "b" , "c" , "ab" , "bc" ] ;
280
- let scores = vec ! [ 1.0 , 1.0 , 1.0 , 2.0 , 2.0 ] ;
281
- let is_byte = vec ! [ false , false , false , false , false ] ;
282
- let bpe = Bpe :: new ( vocabs, scores, is_byte, 0 ) ;
294
+ fn test_bpe_unk_token ( ) {
295
+ let bpe = test_bpe ( ) ;
296
+ assert_eq ! ( bpe. unk_token( ) , 0 ) ;
297
+ }
283
298
284
- let encoded: Vec < _ > = bpe. encode ( "abc" ) . into_iter ( ) . collect ( ) ;
285
- assert_eq ! ( encoded. len( ) , 2 ) ; // Should merge "ab" and leave "c"
286
- assert_eq ! ( encoded[ 0 ] , 3 ) ; // Assuming "ab" is assigned token ID 3
287
- assert_eq ! ( encoded[ 1 ] , 2 ) ; // Assuming "c" is assigned token ID 2
299
+ #[ test]
300
+ fn test_bpe_encode ( ) {
301
+ let bpe = test_bpe ( ) ;
302
+ let encoded: Vec < _ > = bpe. encode ( "abd" ) . into_iter ( ) . collect ( ) ;
303
+ assert_eq ! ( encoded, [ 1 , 8 ] ) ; // Should merge "bc" and leave "a"
288
304
}
289
305
290
306
#[ test]
291
307
fn test_bpe_decode ( ) {
292
- let vocabs = vec ! [ "a" , "b" , "c" , "ab" , "bc" ] ;
293
- let scores = vec ! [ 1.0 , 1.0 , 1.0 , 2.0 , 2.0 ] ;
294
- let is_byte = vec ! [ false , false , false , false , false ] ;
295
- let bpe = Bpe :: new ( vocabs, scores, is_byte, 0 ) ;
296
-
297
- assert_eq ! ( bpe. decode( 3 ) , b"ab" ) ;
298
- assert_eq ! ( bpe. decode( 2 ) , b"c" ) ;
308
+ let bpe = test_bpe ( ) ;
309
+ assert_eq ! ( bpe. decode( 3 ) , b"c" ) ;
310
+ assert_eq ! ( bpe. decode( 6 ) , b"ac" ) ;
311
+ assert_eq ! ( bpe. decode( 9 ) , b"bcd" ) ;
312
+ assert_eq ! ( bpe. decode( 0 ) , b"<unk>" ) ;
299
313
}
300
314
301
315
#[ test]
302
316
fn test_bpe_encode_decode ( ) {
303
- let vocabs = vec ! [ "a" , "b" , "c" , "ab" , "bc" ] ;
304
- let scores = vec ! [ 1.0 , 1.0 , 1.0 , 2.0 , 2.0 ] ;
305
- let is_byte = vec ! [ false , false , false , false , false ] ;
306
- let bpe = Bpe :: new ( vocabs, scores, is_byte, 0 ) ;
317
+ let bpe = test_bpe ( ) ;
307
318
308
- let text = "abcbc " ;
319
+ let text = "abcdx " ;
309
320
let encoded: Vec < _ > = bpe. encode ( text) . into_iter ( ) . collect ( ) ;
310
- let decoded: Vec < u8 > = encoded
321
+ assert_eq ! ( encoded, [ 5 , 3 , 4 , 0 ] ) ;
322
+
323
+ let decoded: Vec < _ > = encoded
311
324
. iter ( )
312
325
. flat_map ( |& t| bpe. decode ( t) . iter ( ) . copied ( ) )
313
326
. collect ( ) ;
314
- assert_eq ! ( String :: from_utf8( decoded) . unwrap( ) , text) ;
315
- }
316
-
317
- #[ test]
318
- fn test_bpe_unk_token ( ) {
319
- let vocabs = vec ! [ "a" , "b" , "c" ] ;
320
- let scores = vec ! [ 1.0 , 1.0 , 1.0 ] ;
321
- let is_byte = vec ! [ false , false , false ] ;
322
- let unk_token = 100 ;
323
- let bpe = Bpe :: new ( vocabs, scores, is_byte, unk_token) ;
324
-
325
- assert_eq ! ( bpe. unk_token( ) , unk_token) ;
327
+ assert_eq ! ( std:: str :: from_utf8( & decoded) , Ok ( "abcd<unk>" ) ) ;
326
328
}
327
329
328
330
#[ test]
329
331
fn test_bpe_inaccessible ( ) {
330
- let vocabs = vec ! [ "a" , "b" , "c" , "ab" , "bcd" , "d" ] ;
331
- let scores = vec ! [ 1.0 , 1.0 , 1.0 , 2.0 , 1.5 , 1.0 ] ;
332
- let is_byte = vec ! [ false , false , false , false , false , false ] ;
333
- let bpe = Bpe :: new ( vocabs, scores, is_byte, 0 ) ;
334
-
332
+ let bpe = test_bpe ( ) ;
335
333
let inaccessible = bpe. inaccessible ( ) ;
336
334
println ! ( "Inaccessible tokens: {:?}" , inaccessible) ;
337
335
@@ -342,8 +340,9 @@ mod bpe_tests {
342
340
) ;
343
341
344
342
// 'bcd' cannot be formed by merging other tokens, so it should be inaccessible
345
- assert ! (
346
- inaccessible. contains_key( "bcd" ) ,
343
+ assert_eq ! (
344
+ inaccessible. get( "bcd" ) ,
345
+ Some ( & 9 ) ,
347
346
"Token 'bcd' should be inaccessible"
348
347
) ;
349
348
@@ -356,17 +355,12 @@ mod bpe_tests {
356
355
357
356
#[ test]
358
357
fn test_bpe_with_byte_tokens ( ) {
359
- let vocabs = vec ! [ "a" , "b" , "<0x41>" , "<0x42>" ] ;
360
- let scores = vec ! [ 1.0 , 1.0 , 1.0 , 1.0 ] ;
361
- let is_byte = vec ! [ false , false , true , true ] ;
358
+ let vocabs = [ "a" , "b" , "<0x41>" , "<0x42>" ] ;
359
+ let scores = [ 1.0 , 1.0 , 1.0 , 1.0 ] ;
360
+ let is_byte = [ false , false , true , true ] ;
362
361
let bpe = Bpe :: new ( vocabs, scores, is_byte, 0 ) ;
363
362
364
- let input = "aAB" ;
365
- let encoded: Vec < _ > = bpe. encode ( input) . into_iter ( ) . collect ( ) ;
366
- println ! ( "Input: {:?}" , input) ;
367
- println ! ( "Encoded tokens: {:?}" , encoded) ;
368
- println ! ( "Vocabulary size: {}" , bpe. vocab_size( ) ) ;
369
-
370
- assert_eq ! ( encoded. len( ) , 3 , "Expected 3 tokens for input 'aAB'" ) ;
363
+ let encoded: Vec < _ > = bpe. encode ( "aAB" ) . into_iter ( ) . collect ( ) ;
364
+ assert_eq ! ( encoded, [ 0 , 2 , 3 ] , "Expected 3 tokens for input 'aAB'" ) ;
371
365
}
372
366
}
0 commit comments