Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowing extension of a previous cache. #17

Merged
merged 7 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,8 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
[weakdeps]
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"

[sources.HuggingFaceTokenizers]
rev = "main"
url = "https://github.com/MurrellGroup/HuggingFaceTokenizers.jl"
[sources]
HuggingFaceTokenizers = {rev = "main", url = "https://github.com/MurrellGroup/HuggingFaceTokenizers.jl"}

[extensions]
MetalExt = "Metal"
Expand All @@ -39,6 +38,8 @@ julia = "1.11"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"

[targets]
test = ["Test"]
test = ["Test", "Downloads", "JSON3"]
95 changes: 95 additions & 0 deletions examples/scratch.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#Pkg.add(["Flux", "JSON3", "UnicodePlots", "StatsBase"])
using Jjama3, Flux, StatsBase, UnicodePlots

#Init a tiny model
model = Transformer(
22, # vocab_size
16*8, # dim
12, # n_layers
8, # n_heads
4, # n_kv_heads
8192, # max_seq_len
16*10, # ff_hidden_dim
)

#Make everything except the RoPE trainable
Jjama3.Flux.@layer Jjama3.Transformer trainable=(tok_embeddings, layers, norm, output)
Jjama3.Flux.@layer Jjama3.Attention trainable=(wq, wk, wv, wo)
Jjama3.Flux.@layer Jjama3.TransformerBlock trainable=(attention, feed_forward, attention_norm, ffn_norm)

#Set up trivial tokenizer
AAs = collect(">ACDEFGHIKLMNPQRSTVWY.")

#Read data, remove X-containing sequences, and adding start and end tokens
data = readlines("abs.txt")
data = [">"*d*"." for d in data if !(occursin("X", d))]

#Train the model
lr = 0.001f0
opt_state = Flux.setup(AdamW(lr), model)
losses = Float32[]
for i in 1:2001
#Prep random batch
train_toks = pad_and_batch(encode.((AAs, ), data[sample(1:length(data), 10, replace=false)]), 22);
#Compute loss and gradients
loss, grads = Flux.withgradient(model) do m
forward_loss(m, train_toks[1:end-1,:], train_toks[2:end,:])
end
#Update weights
Flux.update!(opt_state, model, grads[1])
#Monitor
push!(losses, loss)
println(i, " ", loss)
#Monitor sampling
if mod(i, 100) == 1
generate(model, encode(AAs, ">"),
max_new_tokens=500,
tokenizer_for_printing=AAs,
end_token = 22, sampler = top_pk_sampler(p = 1.0f0, k = 22))
println()
display(lineplot(losses, width = 150, height = 30))
end
#Linear learning rate cooldown
if i > 1500
lr = max(lr - 0.001f0/(2000-1500), 0.0000001f0)
Flux.adjust!(opt_state, lr)
end
end

#Test sampling
for i in 1:10
println(">", i)
generate(model, encode(AAs, ">"),
max_new_tokens=500,
tokenizer_for_printing=AAs,
end_token = 22, sampler = top_pk_sampler(p = 1.0f0, k = 22))
println()
end

#Exporting the model
export_model(model, "tinyabllama.safetensors", type_convert = x -> Jjama3.SafeTensors.BFloat16.(x))

#Saving a config so that it loads correctly using the Jjama3 loader
using JSON3
config = Dict()
config[:model_type] = "llama"
config[:vocab_size]= 22
config[:hidden_size] = 16*8
config[:num_hidden_layers] = 12
config[:num_attention_heads] = 8
config[:num_key_value_heads] = 4
config[:max_position_embeddings] = 8192
config[:intermediate_size] = 16*10
config[:rms_norm_eps] = 1f-8
config[:rope_theta] = 500000f0
config[:tie_word_embeddings] = false
open("tinyabllama_config.json", "w") do f
JSON3.pretty(f, JSON3.write(config))
println(f)
end

#Load a trained model and test it
config = JSON3.read(read("tinyabllama_config.json", String))
model_weight_paths = ["tinyabllama.safetensors"]
model = load_llama3_from_safetensors(model_weight_paths, config)
@assert generate(model, encode(AAs, ">"), end_token = 22) == [1, 15, 19, 15, 11, 19, 15, 17, 7, 2, 5, 19, 10, 10, 14, 7, 2, 17, 19, 10, 19, 17, 3, 10, 2, 17, 7, 21, 18, 6, 18, 17, 21, 7, 9, 17, 20, 19, 16, 15, 2, 14, 7, 15, 7, 11, 5, 20, 12, 7, 20, 9, 17, 2, 21, 13, 7, 13, 18, 13, 21, 2, 15, 10, 11, 15, 7, 16, 19, 18, 12, 18, 18, 4, 18, 17, 18, 17, 18, 2, 21, 12, 5, 11, 16, 17, 11, 16, 17, 4, 4, 18, 2, 19, 21, 21, 3, 2, 16, 4, 16]
3 changes: 3 additions & 0 deletions src/Jjama3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ export RoPE
export Attention
export TransformerBlock
export Transformer
export unrope

include("model.jl")
export forward_loss
Expand All @@ -49,5 +50,7 @@ export llama3_assistant_prompt
export smollm2_instruct_prompt
export smollm2_assistant_prompt
export structured_choice
export pad_and_batch
export export_model

end
11 changes: 10 additions & 1 deletion src/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
cache_v::A
end

Flux.@layer KVCache
Base.copy(cache::KVCache) = KVCache(copy(cache.cache_k), copy(cache.cache_v))

Check warning on line 6 in src/cache.jl

View check run for this annotation

Codecov / codecov/patch

src/cache.jl#L6

Added line #L6 was not covered by tests

Flux.@layer KVCache trainable=()

head_dim(cache::KVCache) = size(cache.cache_k, 1)
seq_length(cache::KVCache) = size(cache.cache_k, 2)
Expand All @@ -21,6 +23,13 @@
cache.cache_v = similar(cache.cache_v, head_dim(cache), seq_length, n_kv_heads(cache), batch_size) .= 0
end

function extend!(cache::KVCache, new_total_length::Int)
old_cache = copy(cache)
config!(cache, seq_length=new_total_length)
cache.cache_k[:, 1:seq_length(old_cache), :, :] .= old_cache.cache_k
cache.cache_v[:, 1:seq_length(old_cache), :, :] .= old_cache.cache_v

Check warning on line 30 in src/cache.jl

View check run for this annotation

Codecov / codecov/patch

src/cache.jl#L26-L30

Added lines #L26 - L30 were not covered by tests
end

clear!(cache::KVCache) = config!(cache, seq_length=0)

function update!(cache::KVCache, start_pos::Int, xk::AbstractArray, xv::AbstractArray)
Expand Down
40 changes: 31 additions & 9 deletions src/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,14 @@
sin::A
end

Flux.@layer RoPE
Flux.@layer RoPE trainable=()

Base.getindex(rope::RoPE, i) = RoPE(rope.cos[:,i,:,:], rope.sin[:,i,:,:])

function apply_scaling!(freqs::AbstractVector; scale_factor=8)
#Hard-coded - should move these to the main model struct and grab them from the config.
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192
###
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
for (i, freq) in enumerate(freqs)
Expand All @@ -68,15 +66,15 @@

function RoPE(
dim::Int, end_pos::Int;
theta::T=10000f0, use_scaled=true, scale_factor=8,
theta::T=10000f0, use_scaled=true, scale_factor=8, start_pos=0
) where T
freqs = 1f0 ./ (theta .^ (T.(0:2:dim-1)[1:dim÷2] ./ dim))
use_scaled && apply_scaling!(freqs; scale_factor)
freqs_complex = cis.(T.(0:end_pos-1) * freqs')
freqs_complex = cis.(T.(start_pos:end_pos-1) * freqs')
cos = permutedims(real(freqs_complex), (2, 1)) # (head_dim/2, seq_len)
sin = permutedims(imag(freqs_complex), (2, 1))
cos = reshape(cos, (dim÷2, end_pos, 1, 1))
sin = reshape(sin, (dim÷2, end_pos, 1, 1))
cos = reshape(cos, (dim÷2, end_pos - start_pos, 1, 1))
sin = reshape(sin, (dim÷2, end_pos - start_pos, 1, 1))
return RoPE(cos, sin)
end

Expand All @@ -93,6 +91,15 @@
)
end

function unrope(rope, x)
head_dim = size(x, 1)
x1 = x[1:head_dim÷2, :, :, :]
x2 = x[head_dim÷2+1:end, :, :, :]
return vcat(

Check warning on line 98 in src/layers.jl

View check run for this annotation

Codecov / codecov/patch

src/layers.jl#L94-L98

Added lines #L94 - L98 were not covered by tests
x1 .* rope.cos .+ x2 .* rope.sin,
x2 .* rope.cos .- x1 .* rope.sin
)
end

struct Attention{Q<:AnyDense,K<:AnyDense,V<:AnyDense,O<:AnyDense,C<:KVCache}
wq::Q
Expand Down Expand Up @@ -195,12 +202,13 @@
Flux.@layer TransformerBlock trainable=(attention,)


struct Transformer{E<:Flux.Embedding,B<:Tuple{Vararg{TransformerBlock}},N<:RMSNorm,O<:Dense,R<:RoPE}
mutable struct Transformer{E<:Flux.Embedding,B<:Tuple{Vararg{TransformerBlock}},N<:RMSNorm,O<:Dense,R<:RoPE}
tok_embeddings::E
layers::B
norm::N
output::O
rope::R
pos::Int
end

Flux.@layer Transformer trainable=(layers,)
Expand All @@ -218,6 +226,20 @@
layers = Tuple(TransformerBlock(dim, n_heads, n_kv_heads, ff_hidden_dim; norm_eps=norm_eps, qkv_bias=qkv_bias) for _ in 1:n_layers)
norm = RMSNorm(dim, eps=norm_eps)
output = Dense(dim => vocab_size, bias=false)
#This should probably be generated to a sane length, and then extended in the forward pass if needed.
rope = RoPE(dim ÷ n_heads, max_seq_len * 2; theta=rope_theta, use_scaled=use_scaled_rope, scale_factor=scale_factor)
Transformer(tok_embeddings, layers, norm, output, rope)
Transformer(tok_embeddings, layers, norm, output, rope, 0)
end


function clear_cache!(model::Transformer)
model.pos = 0
for layer in model.layers
clear!(layer.attention.cache)
end
end

config_cache!(model::Transformer, seq_length) = for layer in model.layers config!(layer.attention.cache, seq_length = seq_length) end

extend_cache!(model::Transformer, seq_length) = for layer in model.layers extend!(layer.attention.cache, seq_length + model.pos) end

Check warning on line 244 in src/layers.jl

View check run for this annotation

Codecov / codecov/patch

src/layers.jl#L244

Added line #L244 was not covered by tests

67 changes: 34 additions & 33 deletions src/model.jl
Original file line number Diff line number Diff line change
@@ -1,58 +1,59 @@
#Note about output layer being tied to embedding: https://github.com/meta-llama/llama-models/issues/172

function create_mask(h::AbstractArray{T}) where T<:AbstractFloat
Flux.Zygote.ignore() do
function create_mask(h::AbstractArray{T}; precached_size = 0) where T<:AbstractFloat
Flux.ChainRulesCore.ignore_derivatives() do
dim, seqlen, batch = size(h)
mask = similar(h, seqlen, seqlen)
mask .= T(-Inf)
mask = tril(mask, -1) #This is swapped because we're using the slightly more efficient dim setup
if precached_size > 0
pad = similar(h, precached_size, seqlen)
pad .= T(0.0)
mask = vcat(pad, mask)

Check warning on line 12 in src/model.jl

View check run for this annotation

Codecov / codecov/patch

src/model.jl#L10-L12

Added lines #L10 - L12 were not covered by tests
end
return mask
end
end

function (model::Transformer)(tokens::AbstractArray{Int}, start_pos::Int=0)
function (model::Transformer)(tokens::AbstractArray{Int})
h = model.tok_embeddings(tokens) # Embedding: (dim, seq_len, batch)
rope = model.rope[start_pos+1:start_pos+size(tokens, 1)]
mask = create_mask(h)
rope = model.rope[model.pos+1:model.pos+size(tokens, 1)]
if size(h, 2) == 1
mask = create_mask(h)
else
mask = create_mask(h; precached_size = model.pos)

Check warning on line 24 in src/model.jl

View check run for this annotation

Codecov / codecov/patch

src/model.jl#L24

Added line #L24 was not covered by tests
end
for layer in model.layers
h = layer(h, start_pos, rope, mask)
h = layer(h, model.pos, rope, mask)
end
h = model.norm(h)
output = model.output(h)
model.pos += size(tokens, 1)
return output
end

function masked_agg(ce, mask)
if mask !== nothing
ce = ce .* mask

Check warning on line 37 in src/model.jl

View check run for this annotation

Codecov / codecov/patch

src/model.jl#L35-L37

Added lines #L35 - L37 were not covered by tests
end
return sum(ce)/sum(mask)

Check warning on line 39 in src/model.jl

View check run for this annotation

Codecov / codecov/patch

src/model.jl#L39

Added line #L39 was not covered by tests
end

function forward_loss(model::Transformer, inputs::AbstractArray,
targets::AbstractArray; ignore_index::Int=-100,
mask = :auto)
seqlen = size(inputs, 1) #(seq_len, batch)
h = model.tok_embeddings(inputs) # (dim, seq_len, batch)
rope = model.rope[1:seqlen]
mask = create_mask(h)
for layer in model.layers
h = layer(h, 0, rope, mask)
targets::AbstractArray; clear_cache = true, loss_mask = nothing)
if clear_cache
Flux.ChainRulesCore.ignore_derivatives() do
clear_cache!(model)

Check warning on line 46 in src/model.jl

View check run for this annotation

Codecov / codecov/patch

src/model.jl#L44-L46

Added lines #L44 - L46 were not covered by tests
end
end
h = model.norm(h)
logits = model.output(h)
# Need to reshape to (vocab_size, seq_len * batch)
logits_2d = reshape(logits, size(logits,1), :)
targets_1d = reshape(targets, :)
# Mask out ignored indices - will handle this later.
# Note: this is not the autoregressive mask, but the mask for the loss function
#=
mask = targets_1d .!= ignore_index
if any(mask)
loss = Flux.logitcrossentropy(
logits_2d[:, mask],
targets_1d[mask]
)
logits = model(inputs)
vocab_size = size(model.tok_embeddings.weight, 2)
gt = Flux.onehotbatch(targets, 1:vocab_size)
if loss_mask !== nothing
loss = Flux.logitcrossentropy(logits, gt, agg = x -> masked_agg(x, loss_mask))

Check warning on line 53 in src/model.jl

View check run for this annotation

Codecov / codecov/patch

src/model.jl#L49-L53

Added lines #L49 - L53 were not covered by tests
else
loss = zero(Float32)
loss = Flux.logitcrossentropy(logits, gt)

Check warning on line 55 in src/model.jl

View check run for this annotation

Codecov / codecov/patch

src/model.jl#L55

Added line #L55 was not covered by tests
end
=#
vocab_size = size(model.tok_embeddings.weight, 2)
gt = Flux.onehotbatch(targets_1d, 1:vocab_size)
loss = Flux.logitcrossentropy(logits_2d, gt)
return loss
end

Expand Down
37 changes: 12 additions & 25 deletions src/sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,41 +19,28 @@
sampler::Function=argmax_sampler,
tokenizer_for_printing = nothing,
end_token = 128010,
clear_cache = true,
pos_offset = 0,
device = identity
) where T
current_len = length(initial_tokens)
tokens = vcat(initial_tokens, similar(initial_tokens, max_new_tokens))

for layer in model.layers
config!(layer.attention.cache, seq_length = current_len + max_new_tokens)
if clear_cache
clear_cache!(model)
config_cache!(model, current_len + max_new_tokens)
else
extend_cache!(model, current_len + max_new_tokens)

Check warning on line 32 in src/sampling.jl

View check run for this annotation

Codecov / codecov/patch

src/sampling.jl#L32

Added line #L32 was not covered by tests
end

input_tokens = device(reshape(initial_tokens, :, 1)) # (seq_len, batch=1)
logits = model(input_tokens, 0)
start_pos = current_len

# Generate new tokens one at a time
logits = model(input_tokens)
for _ in 1:max_new_tokens
# If sequence is empty or we want to process just the last token
if start_pos == 0
input_tokens = device(reshape([128001], :, 1)) # Use start of text token if empty
else
input_tokens = device(reshape([tokens[current_len]], :, 1)) # Just the last token
end
# Get logits for next token
logits = model(input_tokens, start_pos)
# Sample next token (logits are size vocab × 1 × 1)
input_tokens = device(reshape([tokens[current_len]], :, 1)) # Just the last token
logits = model(input_tokens)
next_token = sampler(logits[:, end, 1])
current_len += 1
tokens[current_len] = next_token
!isnothing(tokenizer_for_printing) && print(decode(tokenizer_for_printing, [next_token]))
!isnothing(tokenizer_for_printing) && print(decode(tokenizer_for_printing, [next_token], skip_special_tokens = false))
next_token == end_token && break
start_pos += 1
end
# Clear KV caches
for layer in model.layers
clear!(layer.attention.cache)
end
return tokens[1:current_len]
end

end
Loading
Loading