diff --git a/src/lib.rs b/src/lib.rs index 9f629dbb..9825d3ae 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,7 +15,7 @@ pub struct Buffer { special_tokens_mask: *mut u32, attention_mask: *mut u32, tokens: *mut *mut libc::c_char, - offsets: *mut (usize, usize), + offsets: *mut usize, len: usize, } @@ -126,9 +126,14 @@ pub extern "C" fn encode(ptr: *mut libc::c_void, message: *const libc::c_char, o std::mem::forget(vec_attention_mask); } - let mut offsets: *mut (usize, usize) = ptr::null_mut(); + let mut offsets: *mut usize = ptr::null_mut(); if options.return_offsets { - let mut vec_offsets = encoding.get_offsets().to_vec(); + let vec_offsets_tuples = encoding.get_offsets().to_vec(); + let mut vec_offsets = Vec::with_capacity(vec_offsets_tuples.len() * 2); + for i in vec_offsets_tuples { + vec_offsets.push(i.0); + vec_offsets.push(i.1); + } vec_offsets.shrink_to_fit(); offsets = vec_offsets.as_mut_ptr(); std::mem::forget(vec_offsets); @@ -193,6 +198,11 @@ pub extern "C" fn free_buffer(buf: Buffer) { Vec::from_raw_parts(buf.attention_mask, buf.len, buf.len); } } + if !buf.offsets.is_null() { + unsafe { + Vec::from_raw_parts(buf.offsets, buf.len*2, buf.len*2); + } + } if !buf.tokens.is_null() { unsafe { let strings = Vec::from_raw_parts(buf.tokens, buf.len, buf.len); diff --git a/tokenizer.go b/tokenizer.go index a48b937c..80c0a17c 100644 --- a/tokenizer.go +++ b/tokenizer.go @@ -105,11 +105,14 @@ func uintVecToSlice(arrPtr *C.uint, len int) []uint32 { return slice } -func offsetVecToSlice(arrPtr *C.struct_Tuple, len int) []Offset { - arr := unsafe.Slice(arrPtr, len) - slice := make([]Offset, len) - for i, v := range arr { - slice[i] = Offset{uint(v.a), uint(v.b)} +func offsetVecToSlice(arrPtr *C.size_t, tokenLength int) []Offset { + arr := unsafe.Slice(arrPtr, tokenLength*2) + slice := make([]Offset, tokenLength) + counter := 0 + for i := 0; i < tokenLength; i++ { + offset := Offset{uint(arr[counter]), uint(arr[counter+1])} + slice[i] = offset + counter = counter + 2 } return slice } diff --git a/tokenizers.h b/tokenizers.h index 67bb278f..db881da3 100644 --- a/tokenizers.h +++ b/tokenizers.h @@ -1,12 +1,6 @@ #include #include -// maps to a rust tuple of usize -struct Tuple { - size_t a; - size_t b; -}; - struct EncodeOptions { bool add_special_token; bool return_type_ids; @@ -26,7 +20,7 @@ struct Buffer { uint32_t *special_tokens_mask; uint32_t *attention_mask; char *tokens; - struct Tuple* offsets; + size_t *offsets; uint32_t len; };