Skip to content

Commit

Permalink
Query num_key_value_heads when available to support GQA models.
Browse files Browse the repository at this point in the history
Models with GQA implemented would not have the same number of heads for
K,V vs Query. Hence we need to query `num_key_value_heads` attribute to
see if we require different value for KV cache size.

Not all model has `num_key_value_heads` as a part of their config such as QWEN.
Phi has it, but it is set to null, hence the code is structured this way
to handle those cases.
  • Loading branch information
Stanley Winata committed Feb 2, 2024
1 parent 6f67a97 commit bc603f8
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
6 changes: 4 additions & 2 deletions python/turbine_models/custom_models/stateless_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,10 @@ def export_transformer_model(
)
# TODO: generate these values instead of magic numbers
NUM_LAYERS = mod.config.num_hidden_layers
HEADS = mod.config.num_attention_heads
HIDDEN_DIM = int(mod.config.hidden_size / HEADS)
HEADS = getattr(mod.config, "num_key_value_heads", None)
if HEADS is None:
HEADS = mod.config.num_attention_heads
HIDDEN_DIM = int(mod.config.hidden_size / mod.config.num_attention_heads)
BATCH_SIZE = 1
MAX_STEP_SEQ = mod.config.max_position_embeddings - 1
global_pkv = torch.zeros(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def gen_external_params(
auto_model=AutoModelForCausalLM,
hf_auth_token=hf_auth_token,
)
model_builder.build_model()

if precision == "f16":
model = model_builder.model.half()
Expand Down

0 comments on commit bc603f8

Please sign in to comment.