Skip to content

Commit

Permalink
Merge pull request #25 from MurrellGroup/sampfix
Browse files Browse the repository at this point in the history
Fixing sampling issue
  • Loading branch information
murrellb authored Dec 30, 2024
2 parents 90447e3 + a32d1d1 commit a148d42
Showing 1 changed file with 20 additions and 15 deletions.
35 changes: 20 additions & 15 deletions src/sampling.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# This generate function seems to do one unnecessary forward pass when switching from the forward pass over the initial sequence
# to the sampling of each token. But when I try and fix it, the model gets slightly dumber.
# Vibes feel like a shift-by-1 in the RoPE, or something similar. Need to investigate when I find time.
function nexttoken!(tokens, model, sampler, logits, tokenizer_for_printing)
tokens[model.pos+1] = sampler(logits[:, end, 1])
!isnothing(tokenizer_for_printing) && print(decode(tokenizer_for_printing, [tokens[model.pos+1]], skip_special_tokens = false))
end

"""
generate(model, initial_tokens; max_new_tokens=100, sampler=top_pk_sampler(p=0.5f0, k=5), tokenizer_for_printing=tkn, end_token=128010)
Expand All @@ -24,24 +26,27 @@ function generate(
device = identity,
sdpa_func = sdpa
) where T
current_len = length(initial_tokens)
tokens = vcat(initial_tokens, similar(initial_tokens, max_new_tokens))
if clear_cache
clear_cache!(model)
config_cache!(model, current_len + max_new_tokens)
config_cache!(model, length(initial_tokens) + max_new_tokens)
else
extend_cache!(model, current_len + max_new_tokens)
extend_cache!(model, length(initial_tokens) + max_new_tokens)
end
input_tokens = device(reshape(initial_tokens, :, 1)) # (seq_len, batch=1)
logits = model(input_tokens, sdpa_func = sdpa_func)
for _ in 1:max_new_tokens
input_tokens = device(reshape([tokens[current_len]], :, 1)) # Just the last token
if max_new_tokens > 0
nexttoken!(tokens, model, sampler, logits, tokenizer_for_printing)
tokens[model.pos+1] == end_token && return tokens[1:model.pos+1]
else
return tokens
end
for _ in 1:max_new_tokens-1
input_tokens = device(reshape([tokens[model.pos+1]], :, 1)) # Just the last token
logits = model(input_tokens, sdpa_func = sdpa_func)
next_token = sampler(logits[:, end, 1])
current_len += 1
tokens[current_len] = next_token
!isnothing(tokenizer_for_printing) && print(decode(tokenizer_for_printing, [next_token], skip_special_tokens = false))
next_token == end_token && break
nexttoken!(tokens, model, sampler, logits, tokenizer_for_printing)
tokens[model.pos+1] == end_token && break
end
return tokens[1:current_len]
end
return tokens[1:model.pos+1]
end

0 comments on commit a148d42

Please sign in to comment.