Skip to content

Commit

Permalink
fix: avoid tuple struct in offsets
Browse files Browse the repository at this point in the history
  • Loading branch information
riccardopinosio committed Jul 19, 2024
1 parent ca3ae06 commit bf05314
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 15 deletions.
16 changes: 13 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub struct Buffer {
special_tokens_mask: *mut u32,
attention_mask: *mut u32,
tokens: *mut *mut libc::c_char,
offsets: *mut (usize, usize),
offsets: *mut usize,
len: usize,
}

Expand Down Expand Up @@ -126,9 +126,14 @@ pub extern "C" fn encode(ptr: *mut libc::c_void, message: *const libc::c_char, o
std::mem::forget(vec_attention_mask);
}

let mut offsets: *mut (usize, usize) = ptr::null_mut();
let mut offsets: *mut usize = ptr::null_mut();
if options.return_offsets {
let mut vec_offsets = encoding.get_offsets().to_vec();
let vec_offsets_tuples = encoding.get_offsets().to_vec();
let mut vec_offsets = Vec::with_capacity(vec_offsets_tuples.len() * 2);
for i in vec_offsets_tuples {
vec_offsets.push(i.0);
vec_offsets.push(i.1);
}
vec_offsets.shrink_to_fit();
offsets = vec_offsets.as_mut_ptr();
std::mem::forget(vec_offsets);
Expand Down Expand Up @@ -193,6 +198,11 @@ pub extern "C" fn free_buffer(buf: Buffer) {
Vec::from_raw_parts(buf.attention_mask, buf.len, buf.len);
}
}
if !buf.offsets.is_null() {
unsafe {
Vec::from_raw_parts(buf.offsets, buf.len*2, buf.len*2);
}
}
if !buf.tokens.is_null() {
unsafe {
let strings = Vec::from_raw_parts(buf.tokens, buf.len, buf.len);
Expand Down
13 changes: 8 additions & 5 deletions tokenizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,14 @@ func uintVecToSlice(arrPtr *C.uint, len int) []uint32 {
return slice
}

func offsetVecToSlice(arrPtr *C.struct_Tuple, len int) []Offset {
arr := unsafe.Slice(arrPtr, len)
slice := make([]Offset, len)
for i, v := range arr {
slice[i] = Offset{uint(v.a), uint(v.b)}
func offsetVecToSlice(arrPtr *C.size_t, tokenLength int) []Offset {
arr := unsafe.Slice(arrPtr, tokenLength*2)
slice := make([]Offset, tokenLength)
counter := 0
for i := 0; i < tokenLength; i++ {
offset := Offset{uint(arr[counter]), uint(arr[counter+1])}
slice[i] = offset
counter = counter + 2
}
return slice
}
Expand Down
8 changes: 1 addition & 7 deletions tokenizers.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
#include <stdbool.h>
#include <stdint.h>

// maps to a rust tuple of usize
struct Tuple {
size_t a;
size_t b;
};

struct EncodeOptions {
bool add_special_token;
bool return_type_ids;
Expand All @@ -26,7 +20,7 @@ struct Buffer {
uint32_t *special_tokens_mask;
uint32_t *attention_mask;
char *tokens;
struct Tuple* offsets;
size_t *offsets;
uint32_t len;
};

Expand Down

0 comments on commit bf05314

Please sign in to comment.