diff --git a/litgpt/adapter.py b/litgpt/adapter.py index bef77ece1b..bc095a3ca9 100644 --- a/litgpt/adapter.py +++ b/litgpt/adapter.py @@ -9,7 +9,7 @@ """ from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple import torch import torch.nn as nn @@ -28,56 +28,27 @@ class Config(BaseConfig): class GPT(BaseModel): - """The implementation is identical to `litgpt.model.GPT` with the exception that - the `Block` saves the layer index and passes it down to the attention layer.""" - + # Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here. def __init__(self, config: Config) -> None: nn.Module.__init__(self) assert config.padded_vocab_size is not None self.config = config - self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias) + self.lm_head = nn.Linear( + config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias + ) self.transformer = nn.ModuleDict( dict( wte=nn.Embedding(config.padded_vocab_size, config.n_embd), - h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)), + h=nn.ModuleList( + Block(config, block_idx) + for block_idx in range(config.n_layer) + ), ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), ) ) - self.max_seq_length = self.config.block_size self.mask_cache: Optional[torch.Tensor] = None - - def forward( - self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0 - ) -> Union[torch.Tensor, List[torch.Tensor]]: - T = idx.size(1) - if self.max_seq_length < T: - raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.") - - if input_pos is not None: # use the kv cache - cos = self.cos.index_select(0, input_pos) - sin = self.sin.index_select(0, input_pos) - if self.mask_cache is None: - raise TypeError("You need to call `gpt.set_kv_cache()`") - mask = self.mask_cache.index_select(2, input_pos) - else: - cos = self.cos[:T] - sin = self.sin[:T] - mask = None - - x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) - if self.config.scale_embeddings: - x = x * (self.config.n_embd**0.5) - for block in self.transformer.h: - x = block(x, cos, sin, mask, input_pos) - x = self.transformer.ln_f(x) - if lm_head_chunk_size > 0: - # chunk the lm head logits to reduce the peak memory used by autograd - return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)] - x = self.lm_head(x) # (b, t, vocab_size) - if self.config.final_logit_softcapping is not None: - x = torch.tanh(x / self.config.final_logit_softcapping) * self.config.final_logit_softcapping - return x + self.max_seq_length = self.config.block_size @classmethod def from_name(cls, name: str, **kwargs: Any) -> Self: @@ -91,30 +62,9 @@ def _init_weights(self, module: nn.Module) -> None: class Block(BaseBlock): - """The implementation is identical to `litgpt.model.Block` with the exception that - we replace the attention layer where adaption is implemented.""" - def __init__(self, config: Config, block_idx: int) -> None: - # Skip the parent class __init__ altogether and replace it to avoid useless allocations - nn.Module.__init__(self) - if not config.parallel_residual and config.shared_attention_norm: - raise NotImplementedError( - "No checkpoint amongst the ones we support uses this configuration:" - - " non-parallel residual and shared attention norm." - ) - self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) + super().__init__(config, block_idx) self.attn = CausalSelfAttention(config, block_idx) - self.post_attention_norm = ( - config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_attention_norm else nn.Identity() - ) - self.norm_2 = None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps) - self.mlp = config.mlp_class(config) - self.post_mlp_norm = ( - config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_mlp_norm else nn.Identity() - ) - - self.config = config class CausalSelfAttention(BaseCausalSelfAttention): @@ -130,12 +80,6 @@ def __init__(self, config: Config, block_idx: int) -> None: self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1)) # kv cache for inference self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - self.block_idx = block_idx - self.apply_sliding_window_attention = ( - config.sliding_window_size is not None and - block_idx % config.sliding_window_layer_stride == 0 - ) - self.config = config def scaled_dot_product_attention( self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None diff --git a/litgpt/adapter_v2.py b/litgpt/adapter_v2.py index 9b975260f0..e7a203ba6d 100644 --- a/litgpt/adapter_v2.py +++ b/litgpt/adapter_v2.py @@ -9,7 +9,7 @@ """ from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, Type, Optional import torch import torch.nn as nn @@ -17,10 +17,9 @@ import litgpt from litgpt.adapter import GPT as BaseModel -from litgpt.adapter import Block as BaseBlock +from litgpt.model import Block as BaseBlock from litgpt.adapter import CausalSelfAttention as BaseCausalSelfAttention from litgpt.adapter import Config as BaseConfig -from litgpt.model import KVCache from litgpt.scripts.convert_hf_checkpoint import qkv_reassemble from litgpt.utils import map_old_state_dict_weights @@ -64,54 +63,27 @@ def reset_parameters(self) -> None: class GPT(BaseModel): + # Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here. def __init__(self, config: Config) -> None: - # Skip the parent class __init__ altogether and replace it to avoid useless allocations nn.Module.__init__(self) assert config.padded_vocab_size is not None self.config = config - self.lm_head = AdapterV2Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias) + self.lm_head = AdapterV2Linear( + config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias + ) self.transformer = nn.ModuleDict( dict( wte=nn.Embedding(config.padded_vocab_size, config.n_embd), - h=nn.ModuleList(Block(config, i) for i in range(config.n_layer)), + h=nn.ModuleList( + Block(config, block_idx) + for block_idx in range(config.n_layer) + ), ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), ) ) - self.max_seq_length = self.config.block_size self.mask_cache: Optional[torch.Tensor] = None - - def forward( - self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0 - ) -> Union[torch.Tensor, List[torch.Tensor]]: - T = idx.size(1) - if self.max_seq_length < T: - raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.") - - if input_pos is not None: # use the kv cache - cos = self.cos.index_select(0, input_pos) - sin = self.sin.index_select(0, input_pos) - if self.mask_cache is None: - raise TypeError("You need to call `gpt.set_kv_cache()`") - mask = self.mask_cache.index_select(2, input_pos) - else: - cos = self.cos[:T] - sin = self.sin[:T] - mask = None - - x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) - if self.config.scale_embeddings: - x = x * (self.config.n_embd**0.5) - for block in self.transformer.h: - x = block(x, cos, sin, mask, input_pos) - x = self.transformer.ln_f(x) - if lm_head_chunk_size > 0: - # chunk the lm head logits to reduce the peak memory used by autograd - return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)] - x = self.lm_head(x) # (b, t, vocab_size) - if self.config.final_logit_softcapping is not None: - x = torch.tanh(x / self.config.final_logit_softcapping) * self.config.final_logit_softcapping - return x + self.max_seq_length = self.config.block_size @classmethod def from_name(cls, name: str, **kwargs: Any) -> Self: @@ -131,61 +103,30 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa class Block(BaseBlock): - """The implementation is identical to `litgpt.model.Block` with the exception that - we replace the attention layer where adaption is implemented.""" - def __init__(self, config: Config, block_idx: int) -> None: - # Skip the parent class __init__ altogether and replace it to avoid useless allocations - nn.Module.__init__(self) - if not config.parallel_residual and config.shared_attention_norm: - raise NotImplementedError( - "No checkpoint amongst the ones we support uses this configuration:" - " non-parallel residual and shared attention norm." - ) - self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) + super().__init__(config, block_idx) self.attn = CausalSelfAttention(config, block_idx) - self.post_attention_norm = ( - config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_attention_norm else nn.Identity() - ) - self.norm_2 = None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps) self.mlp = config.mlp_class(config) - self.post_mlp_norm = ( - config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_mlp_norm else nn.Identity() - ) - - self.config = config class CausalSelfAttention(BaseCausalSelfAttention): """A modification of `litgpt.adapter.CausalSelfAttention` that uses the Adapter V2 Linear class""" + # Copy&paste from :class:`model.CausalSelfAttention` def __init__(self, config: Config, block_idx: int) -> None: - # Skip the parent class __init__ altogether and replace it to avoid useless allocations - nn.Module.__init__(self) - shape = (config.n_head + 2 * config.n_query_groups) * config.head_size + super().__init__(config, block_idx) # key, query, value projections for all heads, but in a batch - self.qkv = AdapterV2Linear(in_features=config.n_embd, out_features=shape, bias=config.bias or config.attn_bias) + shape = (config.n_head + 2 * config.n_query_groups) * config.head_size + self.qkv = AdapterV2Linear( + in_features=config.n_embd, + out_features=shape, + bias=config.bias or config.attn_bias + ) # output projection - # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head` - self.proj = AdapterV2Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) - # disabled by default - self.kv_cache: Optional[KVCache] = None - - if block_idx >= config.adapter_start_layer: - # adapter embedding layer - self.adapter_wte = nn.Embedding(config.adapter_prompt_length, config.n_embd) - # gate for adaption - self.gating_factor = torch.nn.Parameter(torch.zeros(1, 1, config.n_head, 1)) - # kv cache for inference - self.adapter_kv_cache: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - self.block_idx = block_idx - self.apply_sliding_window_attention = ( - config.sliding_window_size is not None and - block_idx % config.sliding_window_layer_stride == 0 + self.proj = AdapterV2Linear( + config.head_size * config.n_head, config.n_embd, bias=config.bias ) - self.config = config - def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with base and/or legacy checkpoints.""" mapping = { @@ -211,9 +152,12 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa class GptNeoxMLP(litgpt.model.GptNeoxMLP): def __init__(self, config: Config) -> None: nn.Module.__init__(self) - self.fc = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias) - self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias) - + self.fc = AdapterV2Linear( + config.n_embd, config.intermediate_size, bias=config.bias + ) + self.proj = AdapterV2Linear( + config.intermediate_size, config.n_embd, bias=config.bias + ) self.config = config def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: @@ -231,10 +175,15 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa class LLaMAMLP(litgpt.model.LLaMAMLP): def __init__(self, config: Config) -> None: nn.Module.__init__(self) - self.fc_1 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias) - self.fc_2 = AdapterV2Linear(config.n_embd, config.intermediate_size, bias=config.bias) - self.proj = AdapterV2Linear(config.intermediate_size, config.n_embd, bias=config.bias) - + self.fc_1 = AdapterV2Linear( + config.n_embd, config.intermediate_size, bias=config.bias + ) + self.fc_2 = AdapterV2Linear( + config.n_embd, config.intermediate_size, bias=config.bias + ) + self.proj = AdapterV2Linear( + config.intermediate_size, config.n_embd, bias=config.bias + ) self.config = config def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: @@ -264,7 +213,6 @@ def __init__(self, config: Config) -> None: nn.Module.__init__(self) self.gate = AdapterV2Linear(config.n_embd, config.n_expert, bias=False) self.experts = nn.ModuleList(LLaMAMLP(config) for _ in range(config.n_expert)) - self.config = config def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: diff --git a/litgpt/generate/base.py b/litgpt/generate/base.py index 866947beea..2536b15613 100644 --- a/litgpt/generate/base.py +++ b/litgpt/generate/base.py @@ -4,7 +4,7 @@ import time from pathlib import Path from pprint import pprint -from typing import Any, Literal, Optional, Tuple, List, Union, Iterator +from typing import Any, Literal, Optional, Tuple, List, Union, Iterator, Dict import warnings import lightning as L @@ -73,15 +73,23 @@ def sample( return torch.argmax(logits, dim=-1, keepdim=True) -def next_token(model: GPT, input_pos: torch.Tensor, x: torch.Tensor, **kwargs: Any) -> torch.Tensor: - logits = model(x, input_pos) - _next = sample(logits, **kwargs).to(dtype=torch.int64) +def next_token( + model: GPT, + input_pos: torch.Tensor, + x: torch.Tensor, + input_pos_maxp1: Optional[int] = None, + **sample_kwargs: Dict[str, Any], +) -> torch.Tensor: + logits = model(x, input_pos, input_pos_maxp1=input_pos_maxp1) + _next = sample(logits, **sample_kwargs).to(dtype=torch.int64) return _next + def batched_sample(logits: list[torch.Tensor], kwargs: list[dict]) -> torch.Tensor: assert len(logits) == len(kwargs), "logits and kwargs must have the same length." return torch.stack([sample(l, **sample_args).to(dtype=torch.int64) for sample_args, l in zip(kwargs, logits)], dim=0) + def batched_next_token(model: GPT, input_pos: torch.Tensor, x: torch.Tensor, kwargs: Union[dict, list[dict]]) -> torch.Tensor: # Where: # input_pos is a 1d tensor of shape [seq_length...] @@ -166,10 +174,19 @@ def generate_fn( token = prompt prefill_token = True input_pos = torch.arange(0, prompt_size, device=device, dtype=torch.int64) + input_pos_maxp1 = prompt_size for current_idx in range(max_returned_tokens - prompt_size): # Generate the token - token = next_token(model, input_pos, token.view(1, -1), temperature=temperature, top_k=top_k, top_p=top_p) + token = next_token( + model, + input_pos, + token.view(1, -1), + input_pos_maxp1=input_pos_maxp1, + temperature=temperature, + top_k=top_k, + top_p=top_p, + ) tokens.append(token) int_token = token.item() @@ -205,6 +222,7 @@ def generate_fn( input_pos = torch.tensor([prompt_size], device=device, dtype=torch.int64) else: input_pos.add_(1) + input_pos_maxp1 += 1 # Yield any remaining tokens if yielded_idx < len(tokens): diff --git a/litgpt/lora.py b/litgpt/lora.py index beca761c48..8144695aaf 100644 --- a/litgpt/lora.py +++ b/litgpt/lora.py @@ -45,7 +45,7 @@ import math from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Dict, Tuple, Type, Union, Optional import torch import torch.nn as nn @@ -481,60 +481,31 @@ def mlp_class(self) -> Type: class GPT(BaseModel): + # Copy & paste from :class:`model.GPT`. Note that :class:`Block` is new here. def __init__(self, config: Config) -> None: nn.Module.__init__(self) assert config.padded_vocab_size is not None self.config = config - self.lm_head = LoRALinear( + self.lm_head = create_lora_linear( + config, config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias, - r=(config.lora_r if config.lora_head else 0), - lora_alpha=config.lora_alpha, - lora_dropout=config.lora_dropout, + use_r=config.lora_head, ) self.transformer = nn.ModuleDict( dict( wte=nn.Embedding(config.padded_vocab_size, config.n_embd), - h=nn.ModuleList(Block(config, block_idx) for block_idx in range(config.n_layer)), + h=nn.ModuleList( + Block(config, block_idx) + for block_idx in range(config.n_layer) + ), ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), ) ) - self.max_seq_length = self.config.block_size self.mask_cache: Optional[torch.Tensor] = None - - def forward( - self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0 - ) -> Union[torch.Tensor, List[torch.Tensor]]: - T = idx.size(1) - if self.max_seq_length < T: - raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.") - - if input_pos is not None: # use the kv cache - cos = self.cos.index_select(0, input_pos) - sin = self.sin.index_select(0, input_pos) - if self.mask_cache is None: - raise TypeError("You need to call `gpt.set_kv_cache()`") - mask = self.mask_cache.index_select(2, input_pos) - else: - cos = self.cos[:T] - sin = self.sin[:T] - mask = None - - x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) - if self.config.scale_embeddings: - x = x * (self.config.n_embd**0.5) - for block in self.transformer.h: - x = block(x, cos, sin, mask, input_pos) - x = self.transformer.ln_f(x) - if lm_head_chunk_size > 0: - # chunk the lm head logits to reduce the peak memory used by autograd - return [self.lm_head(x_i) for x_i in x.split(lm_head_chunk_size, dim=1)] - x = self.lm_head(x) # (b, t, vocab_size) - if self.config.final_logit_softcapping is not None: - x = torch.tanh(x / self.config.final_logit_softcapping) * self.config.final_logit_softcapping - return x + self.max_seq_length = self.config.block_size @classmethod def from_name(cls, name: str, **kwargs: Any) -> Self: @@ -555,33 +526,16 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa class Block(BaseBlock): def __init__(self, config: Config, block_idx: int) -> None: - nn.Module.__init__(self) - if not config.parallel_residual and config.shared_attention_norm: - raise NotImplementedError( - "No checkpoint amongst the ones we support uses this configuration:" - " non-parallel residual and shared attention norm." - ) - self.norm_1 = config.norm_class(config.n_embd, eps=config.norm_eps) + super().__init__(config, block_idx) self.attn = CausalSelfAttention(config, block_idx) - self.post_attention_norm = ( - config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_attention_norm else nn.Identity() - ) - self.norm_2 = None if config.shared_attention_norm else config.norm_class(config.n_embd, eps=config.norm_eps) self.mlp = config.mlp_class(config) - self.post_mlp_norm = ( - config.norm_class(config.n_embd, eps=config.norm_eps) if config.post_mlp_norm else nn.Identity() - ) - - self.config = config class CausalSelfAttention(BaseCausalSelfAttention): def __init__(self, config: Config, block_idx: int) -> None: - # Skip the parent class __init__ altogether and replace it to avoid - # useless allocations - nn.Module.__init__(self) - shape = (config.n_head + 2 * config.n_query_groups) * config.head_size + super().__init__(config, block_idx) # key, query, value projections for all heads, but in a batch + shape = (config.n_head + 2 * config.n_query_groups) * config.head_size self.qkv = LoRAQKVLinear( in_features=config.n_embd, out_features=shape, @@ -596,23 +550,12 @@ def __init__(self, config: Config, block_idx: int) -> None: n_query_groups=config.n_query_groups, ) # output projection - # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head` - self.proj = LoRALinear( + self.proj = create_lora_linear( + config, config.head_size * config.n_head, config.n_embd, - bias=config.bias, - r=(config.lora_r if config.lora_projection else 0), - lora_alpha=config.lora_alpha, - lora_dropout=config.lora_dropout, + use_r=config.lora_projection, ) - # disabled by default - self.kv_cache: Optional[KVCache] = None - self.apply_sliding_window_attention = ( - config.sliding_window_size is not None and - block_idx % config.sliding_window_layer_stride == 0 - ) - - self.config = config def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with base and/or legacy checkpoints.""" @@ -633,26 +576,36 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) +def create_lora_linear( + config: Config, + in_size: int, + out_size: int, + bias: Optional[Union[float, bool]] = None, + use_r: Optional[bool] = None, +) -> LoRALinear: + if bias is None: + bias = config.bias + if use_r is None: + use_r = config.lora_mlp + return LoRALinear( + in_size, + out_size, + bias=bias, + r=(config.lora_r if use_r else 0), + lora_alpha=config.lora_alpha, + lora_dropout=config.lora_dropout, + ) + + class GptNeoxMLP(litgpt.model.GptNeoxMLP): def __init__(self, config: Config) -> None: nn.Module.__init__(self) - self.fc = LoRALinear( - config.n_embd, - config.intermediate_size, - bias=config.bias, - r=(config.lora_r if config.lora_mlp else 0), - lora_alpha=config.lora_alpha, - lora_dropout=config.lora_dropout, + self.fc = create_lora_linear( + config, config.n_embd, config.intermediate_size ) - self.proj = LoRALinear( - config.intermediate_size, - config.n_embd, - bias=config.bias, - r=(config.lora_r if config.lora_mlp else 0), - lora_alpha=config.lora_alpha, - lora_dropout=config.lora_dropout, + self.proj = create_lora_linear( + config, config.intermediate_size, config.n_embd ) - self.config = config def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: @@ -670,31 +623,15 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa class LLaMAMLP(litgpt.model.LLaMAMLP): def __init__(self, config: Config) -> None: nn.Module.__init__(self) - self.fc_1 = LoRALinear( - config.n_embd, - config.intermediate_size, - bias=config.bias, - r=(config.lora_r if config.lora_mlp else 0), - lora_alpha=config.lora_alpha, - lora_dropout=config.lora_dropout, + self.fc_1 = create_lora_linear( + config, config.n_embd, config.intermediate_size ) - self.fc_2 = LoRALinear( - config.n_embd, - config.intermediate_size, - bias=config.bias, - r=(config.lora_r if config.lora_mlp else 0), - lora_alpha=config.lora_alpha, - lora_dropout=config.lora_dropout, + self.fc_2 = create_lora_linear( + config, config.n_embd, config.intermediate_size ) - self.proj = LoRALinear( - config.intermediate_size, - config.n_embd, - bias=config.bias, - r=(config.lora_r if config.lora_mlp else 0), - lora_alpha=config.lora_alpha, - lora_dropout=config.lora_dropout, + self.proj = create_lora_linear( + config, config.intermediate_size, config.n_embd ) - self.config = config def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: @@ -722,16 +659,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class LLaMAMoE(litgpt.model.LLaMAMoE): def __init__(self, config: Config) -> None: nn.Module.__init__(self) - self.gate = LoRALinear( - config.n_embd, - config.n_expert, - bias=False, - r=(config.lora_r if config.lora_mlp else 0), - lora_alpha=config.lora_alpha, - lora_dropout=config.lora_dropout, + self.gate = create_lora_linear( + config, config.n_embd, config.n_expert, bias=False ) self.experts = nn.ModuleList(LLaMAMLP(config) for _ in range(config.n_expert)) - self.config = config def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: diff --git a/litgpt/model.py b/litgpt/model.py index cbdf2a4bdd..bff11ccb6f 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -7,10 +7,12 @@ """ import math -from typing import Any, Dict, Optional, Tuple +from typing import Any, Optional, Tuple, Union, List +from functools import partial import torch import torch.nn as nn +import torch.nn.functional as F from typing_extensions import Self from litgpt.config import Config @@ -23,16 +25,21 @@ def __init__(self, config: Config) -> None: assert config.padded_vocab_size is not None self.config = config - self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias) + self.lm_head = nn.Linear( + config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias + ) self.transformer = nn.ModuleDict( dict( wte=nn.Embedding(config.padded_vocab_size, config.n_embd), - h=nn.ModuleList(Block(config, block_idx) for block_idx in range(config.n_layer)), + h=nn.ModuleList( + Block(config, block_idx) + for block_idx in range(config.n_layer) + ), ln_f=config.norm_class(config.n_embd, eps=config.norm_eps), ) ) - self.max_seq_length = self.config.block_size self.mask_cache: Optional[torch.Tensor] = None + self.max_seq_length = self.config.block_size @property def max_seq_length(self) -> int: @@ -60,6 +67,8 @@ def max_seq_length(self, value: int) -> None: self.cos, self.sin = self.rope_cache(device=self.cos.device) # the mask and kv cache size will get updated on `set_kv_cache`. we cannot update it here because we don't know # if the kv cache is expected + if self.mask_cache is not None and self.mask_cache.shape[-1] < value: + print(f"Warning: KV cache has length {self.mask_cache.shape[-1]} < {value} = max_seq_length. Call 'set_kv_cache' before doing any forwards!") def reset_parameters(self) -> None: # Trigger resetting the rope-cache @@ -74,21 +83,41 @@ def _init_weights(self, module: nn.Module) -> None: elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, + idx: torch.Tensor, + input_pos: Optional[torch.Tensor] = None, + input_pos_maxp1: Optional[int] = None, + lm_head_chunk_size: int = 0, + ) -> Union[torch.Tensor, List[torch.Tensor]]: """ + If `input_pos` is provided, the KV cache uses K and V vectors for + positions smaller than entries in `input_pos`. For efficiency, pass + `input_pos_maxp1` as `max(input_pos) + 1` if already available from + your forward algorithm. This slices the KV cache buffers and speeds + up multi-head attention. + + Without `input_pos_maxp1`, the computation uses the full KV cache + (`max_seq_length`) with masking applied. Note that inferring + `input_pos_maxp1` from `input_pos` causes graph breaks and prevents + compilation. + Args: - idx (torch.Tensor): Input token indices, shape `(B, T)` - input_pos (torch.Tensor, optional): Contains input positions, - either with shape `(T,)` or `(B, T)`, if provided. This is used - for generative inference, where a KV cache is required. By - default, this assumes `input_dim == arange(T)` with all inputs - up to `T` provided upfront. + idx: Token indices of input sequences, shape `(B, T)`, where `B` + is batch size. + input_pos: Optional. Positions of input tokens. The default is + `arange(T)`. Can have shape `(T,)` or `(B, T)` (batched index). + input_pos_maxp1: Optional. See above. + lm_head_chunk_size: Optional. If `lm_head_chunk_size > 0`, the final + `lm_head` computation is done in chunks of this size. Returns: - torch.Tensor: Output (logits), shape `(B, T, config.padded_vocab_size)` + Logit outputs, shape `(B, T, config.padded_vocab_size)`. If + `lm_head_chunk_size > 0`, this is a list of chunks of shape + `(B, lm_head_chunk_size, config.padded_vocab_size)`, the final + entry can be shorter. + """ - if idx.dim() != 2: - raise ValueError(f"idx must have 2 dimensions, idx.shape = {idx.shape}") T = idx.size(1) if self.max_seq_length < T: raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.") @@ -101,31 +130,49 @@ def forward(self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None) - raise ValueError(f"input_pos.shape[-1] = {input_pos.shape[-1]} != {T} = idx.shape[1], must be the same") cos = batched_index_select(self.cos, 0, input_pos) sin = batched_index_select(self.sin, 0, input_pos) + if input_pos.dim() == 1: + cos = cos.unsqueeze(0) + sin = sin.unsqueeze(0) if self.mask_cache is None: raise TypeError("You need to call `gpt.set_kv_cache()`") mask = batched_index_select(self.mask_cache, 2, input_pos) if mask.dim() > 4: # the mask cache has a batch dim of 1 in addition to the one # we get if input_pos has a batch dimension - mask = mask.squeeze(1) + mask = mask.view(*(mask.shape[0:1] + mask.shape[2:])) + if input_pos_maxp1 is not None: + # Shorten final dimension so it just covers all `input_pos` entries + if input_pos_maxp1 > self.max_seq_length: + raise ValueError(f"Positions in 'input_pos' must be in [0,{self.max_seq_length})") + mask = mask[..., :input_pos_maxp1] else: # unsqueeze to have a batch dimension cos = self.cos[:T].unsqueeze(0) sin = self.sin[:T].unsqueeze(0) # `cos`, `sin` have shape (1, T, config.rope_n_elem) mask = None # defaults to causal mask + input_pos_maxp1 = None x = self.transformer.wte(idx) # token embeddings of shape (B, T, n_embd) if self.config.scale_embeddings: - x = x * torch.tensor(self.config.n_embd**0.5, dtype=x.dtype) + x = x * torch.tensor(self.config.n_embd ** 0.5, dtype=x.dtype) for block in self.transformer.h: - x = block(x, cos, sin, mask, input_pos) + x = block(x, cos, sin, mask, input_pos, input_pos_maxp1) x = self.transformer.ln_f(x) - x = self.lm_head(x) # (B, T, padded_vocab_size) - if self.config.final_logit_softcapping is not None: - x = do_softcapping(x, self.config.final_logit_softcapping) - return x + clamp_head = ( + partial(do_softcapping, thresh=self.config.final_logit_softcapping) + if self.config.final_logit_softcapping is not None + else nn.Identity() + ) + if lm_head_chunk_size > 0: + # chunk the lm head logits to reduce the peak memory used by autograd + return [ + clamp_head(self.lm_head(x_i)) + for x_i in x.split(lm_head_chunk_size, dim=1) + ] + else: + return clamp_head(self.lm_head(x)) # (B, T, padded_vocab_size) @classmethod def from_name(cls, name: str, **kwargs: Any) -> Self: @@ -204,7 +251,11 @@ def clear_kv_cache(self) -> None: class Block(nn.Module): - def __init__(self, config: Config, block_idx: int) -> None: + def __init__( + self, + config: Config, + block_idx: int, + ) -> None: super().__init__() if not config.parallel_residual and config.shared_attention_norm: raise NotImplementedError( @@ -232,6 +283,7 @@ def forward( sin: torch.Tensor, mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, + input_pos_maxp1: Optional[int] = None, ) -> torch.Tensor: """ Non-parallel residual Parallel residual @@ -255,7 +307,9 @@ def forward( """ x_normed = self.norm_1(x) - attention_output = self.attn(x_normed, cos, sin, mask, input_pos) + attention_output = self.attn( + x_normed, cos, sin, mask, input_pos, input_pos_maxp1 + ) attention_output = self.post_attention_norm(attention_output) if self.config.parallel_residual: @@ -278,16 +332,17 @@ def __init__(self, config: Config, block_idx: int) -> None: bias=config.bias or config.attn_bias, ) # output projection - # if `head_size` is explicitly specified in the config, `n_emd` might not be equal to `head_size * n_head` - self.proj = nn.Linear(config.head_size * config.n_head, config.n_embd, bias=config.bias) + self.proj = nn.Linear( + config.head_size * config.n_head, config.n_embd, bias=config.bias + ) # disabled by default self.kv_cache: Optional[KVCache] = None self.apply_sliding_window_attention = ( config.sliding_window_size is not None and block_idx % config.sliding_window_layer_stride == 0 ) - self.config = config + self.block_idx = block_idx def forward( self, @@ -296,6 +351,7 @@ def forward( sin: torch.Tensor, mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, + input_pos_maxp1: Optional[int] = None, ) -> torch.Tensor: # Notation: # - B | batch size @@ -304,8 +360,11 @@ def forward( # - C* | attentions's embeddings size # - nh_(q,k,v) | number of heads for query, key and value # - hs | head size - - B, T, C = x.size() + head_size = self.config.head_size + n_head = self.config.n_head + n_query_groups = self.config.n_query_groups + rope_n_elem = self.config.rope_n_elem + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # Perform a single multiplication operation using a combined QKV matrix to calculate `query`, `key`, and `value` # instead of individually multiplying the input `x` with the respective weight matrices. @@ -313,16 +372,16 @@ def forward( # Define query, key and value sizes. # If grouped/multi query is enabled, these sizes are not equal (see the diagram in `lit_gpt/config.py::Config`). - query_size = self.config.n_head * self.config.head_size - key_size = value_size = self.config.n_query_groups * self.config.head_size + query_size = n_head * head_size + key_size = value_size = n_query_groups * head_size # Split qkv into query, key and value matrices. q, k, v = qkv.split((query_size, key_size, value_size), dim=-1) # 3x(B, T, C*) # To place the num_heads (nh) dimension right after the batch (B) dimension, the first step is to decouple the # embedding size (C) into num_heads (nh) and head_size (hs). - q = q.view(B, T, self.config.n_head, self.config.head_size) # (B, T, nh_q, hs) - k = k.view(B, T, self.config.n_query_groups, self.config.head_size) # (B, T, nh_k, hs) - v = v.view(B, T, self.config.n_query_groups, self.config.head_size) # (B, T, nh_v, hs) + q = q.view(B, T, n_head, head_size) # (B, T, nh_q, hs) + k = k.view(B, T, n_query_groups, head_size) # (B, T, nh_k, hs) + v = v.view(B, T, n_query_groups, head_size) # (B, T, nh_v, hs) # The tensors `query`, `key`, and `value` are now accurately structured: within each batch element (B), there are # multiple heads (nh), and within each head, there is a sequence of elements (T), each represented by a vector @@ -332,22 +391,28 @@ def forward( v = v.transpose(1, 2) # (B, nh_v, T, hs) # Unlike standard positional embeddings rotary embeddings must be applied at every layer. - q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) - k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) - q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) # (B, nh_q, T, hs) - k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) # (B, nh_k, T, hs) + q_roped = apply_rope(q[..., : rope_n_elem], cos, sin) + k_roped = apply_rope(k[..., : rope_n_elem], cos, sin) + q = torch.cat((q_roped, q[..., rope_n_elem :]), dim=-1) # (B, nh_q, T, hs) + k = torch.cat((k_roped, k[..., rope_n_elem :]), dim=-1) # (B, nh_k, T, hs) # Apply kv-cache during inference. if input_pos is not None: if not isinstance(self.kv_cache, KVCache): raise TypeError("You need to call `gpt.set_kv_cache()`") k, v = self.kv_cache(input_pos, k, v) + if input_pos_maxp1 is not None: + # Subselect along sequence dimension + k = k[..., :input_pos_maxp1, :] + v = v[..., :input_pos_maxp1, :] + # k, v: (B, nh_k, input_pos_maxp1, hs) + # If input_pos_maxp1 is None -> max_seq_length # Grouped queries: balance the number of heads across all three matrices. # NOTE: flash attention requires it in training mode. # Multi-query: this step can be skipped since there is only 1 head, allowing us to use broadcasting. - if self.config.n_query_groups != self.config.n_head and (input_pos is None or self.config.n_query_groups != 1): - q_per_kv = self.config.n_head // self.config.n_query_groups + if n_query_groups != n_head and (input_pos is None or n_query_groups != 1): + q_per_kv = n_head // n_query_groups k = k.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs) v = v.repeat_interleave(q_per_kv, dim=1) # (B, nh_q, T, hs) @@ -365,6 +430,7 @@ def forward( if mask is None: mask = torch.ones(T, T, dtype=q.dtype, device=q.device).triu(diagonal=1) mask.masked_fill_(mask.bool(), float("-inf")) + mask = mask.view(1, 1, *mask.shape) sliding_window_bias = torch.ones_like(mask).tril(diagonal=-self.config.sliding_window_size) sliding_window_bias.masked_fill_(sliding_window_bias.bool(), float("-inf")) mask += sliding_window_bias @@ -375,7 +441,7 @@ def forward( y = self.scaled_dot_product_attention(q, k, v, mask) # Re-assemble all head outputs side by side. - y = y.reshape(B, T, self.config.head_size * self.config.n_head) + y = y.reshape(B, T, head_size * n_head) # Output projection. return self.proj(y) # (B, T, C) @@ -393,10 +459,10 @@ def scaled_dot_product_attention( mask = torch.ones(q.size(2), q.size(2), dtype=q.dtype, device=q.device).triu(diagonal=1) mask.masked_fill_(mask.bool(), torch.finfo(q.dtype).min) scores = scores + mask - scores = torch.nn.functional.softmax(scores, dim=-1, dtype=torch.float).to(dtype=q.dtype) + scores = F.softmax(scores, dim=-1, dtype=torch.float).to(dtype=q.dtype) y = scores @ v else: - y = torch.nn.functional.scaled_dot_product_attention( + y = F.scaled_dot_product_attention( q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None ) return y.transpose(1, 2) @@ -423,7 +489,7 @@ def build_kv_cache( ) return KVCache(k_shape, v_shape, device=device, dtype=dtype) - def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwargs: Any) -> None: + def _load_from_state_dict(self, state_dict: dict, prefix: str, *args: Any, **kwargs: Any) -> None: """For compatibility with legacy checkpoints.""" for attr in ("weight", "bias"): @@ -438,30 +504,38 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, *args: Any, **kwa class GptNeoxMLP(nn.Module): def __init__(self, config: Config) -> None: super().__init__() - self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) - self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) - + self.fc = nn.Linear( + config.n_embd, config.intermediate_size, bias=config.bias + ) + self.proj = nn.Linear( + config.intermediate_size, config.n_embd, bias=config.bias + ) self.config = config def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.fc(x) - x = torch.nn.functional.gelu(x, approximate=self.config.gelu_approximate) + x = F.gelu(x, approximate=self.config.gelu_approximate) return self.proj(x) class LLaMAMLP(nn.Module): def __init__(self, config: Config) -> None: super().__init__() - self.fc_1 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) - self.fc_2 = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) - self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) - + self.fc_1 = nn.Linear( + config.n_embd, config.intermediate_size, bias=config.bias + ) + self.fc_2 = nn.Linear( + config.n_embd, config.intermediate_size, bias=config.bias + ) + self.proj = nn.Linear( + config.intermediate_size, config.n_embd, bias=config.bias + ) self.config = config def forward(self, x: torch.Tensor) -> torch.Tensor: x_fc_1 = self.fc_1(x) x_fc_2 = self.fc_2(x) - x = torch.nn.functional.silu(x_fc_1) * x_fc_2 + x = F.silu(x_fc_1) * x_fc_2 return self.proj(x) @@ -469,7 +543,7 @@ class GemmaMLP(LLaMAMLP): def forward(self, x: torch.Tensor) -> torch.Tensor: x_fc_1 = self.fc_1(x) x_fc_2 = self.fc_2(x) - x = torch.nn.functional.gelu(x_fc_1, approximate=self.config.gelu_approximate) * x_fc_2 + x = F.gelu(x_fc_1, approximate=self.config.gelu_approximate) * x_fc_2 return self.proj(x) @@ -478,7 +552,6 @@ def __init__(self, config: Config) -> None: super().__init__() self.gate = nn.Linear(config.n_embd, config.n_expert, bias=False) self.experts = nn.ModuleList(LLaMAMLP(config) for _ in range(config.n_expert)) - self.config = config def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -521,6 +594,7 @@ def build_rope_cache( Returns: Tuple[torch.Tensor, torch.Tensor]: Cosine and sine caches for RoPE. + Shapes are `(seq_len, n_elem)`. """ # Compute the inverse frequencies theta @@ -546,6 +620,15 @@ def build_rope_cache( # Calculate the product of position index and $\theta_i$ idx_theta = torch.outer(seq_idx, theta).repeat(1, 2) + # If `n_elem` is odd, the final dimension of `idx_theta` has size + # `n_elem + 1`, so need to cut something off. + # Due to a current bug in Hugging Face, in the case `n_elem == 1`, we leave + # `idx_theta`, `cos`, `sin` as is. Things work out in `apply_rope` due to + # broadcasting. If we shorten `idx_theta`, unit tests comparing to + # Hugging Face fail. + # https://github.com/huggingface/transformers/issues/35233 + if idx_theta.shape[-1] > n_elem > 1: + idx_theta = idx_theta[..., :n_elem] return torch.cos(idx_theta), torch.sin(idx_theta) @@ -620,18 +703,32 @@ def batched_index_copy_(t, dim, idx, val): def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: - # x: (B, nh, T, hs) - # sin, cos: (B, T, hs) or (1, T, hs) - head_size = x.size(-1) - x1 = x[..., : head_size // 2] # (B, nh, T, hs/2) - x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2) - rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs) - if cos.dim() > 1: - # batch dimensions must align - # sin/cos are (B, T, hs) so we unsqeeze -3 for nh - # we count from back because all of apply_rope does - cos = cos.unsqueeze(-3) - sin = sin.unsqueeze(-3) + """ + Applies RoPE transform to `x`. Note that `cos`, `sin` need to have a batch + dimension. + + Args: + x: Input tensor, `(B, ..., T, head_size)` + cos: Cached cosines, `(B, T, head_size)` or `(1, T, head_size)` + sin: Cached sines, `(B, T, head_size)` or `(1, T, head_size)` + + Returns: + Encoded tensor, `(B, ..., T, head_size)` + """ + if cos.dim() != 3: + raise ValueError(f"cos must be three-dimensional, but shape is {cos.shape}") + if cos.shape != sin.shape: + raise ValueError(f"cos, sin must have same shape, but cos.shape={cos.shape}, sin.shape={sin.shape}") + head_size_half = x.size(-1) // 2 + x1 = x[..., : head_size_half] # (B, ..., T, head_size/2) + x2 = x[..., head_size_half :] # (B, ..., T, head_size/2) + rotated = torch.cat((-x2, x1), dim=-1) # (B, ..., T, head_size) + dims_diff = x.dim() - cos.dim() + if dims_diff > 0: + # Ensure that shapes of `x`, `cos`, `sin` align + new_shape = cos.shape[0:1] + (1,) * dims_diff + cos.shape[1:] + cos = cos.view(*new_shape) + sin = sin.view(*new_shape) roped = (x * cos) + (rotated * sin) return roped.to(dtype=x.dtype) @@ -642,6 +739,10 @@ def do_softcapping(x: torch.Tensor, thresh: float) -> torch.Tensor: class KVCache(nn.Module): + """ + Buffers `k`, `v` have shape + `(batch_size, n_query_groups, max_seq_length, head_size)`. + """ def __init__( self, k_shape: Tuple[int, int, int, int], @@ -654,13 +755,28 @@ def __init__( self.register_buffer("v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False) def forward(self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Writes new values `k` and `v` into the cache at the positions specified + by `input_pos` along the sequence dimension (`max_seq_length`). The batch + size of `k` and `v` (`bs`) must be smaller or equal to `KVCache` batch + size. Returns the full buffers, adjusted to the batch size `bs`. + + Args: + input_pos: Position index, `(bs, T)` or `(T,)` + k: New values, `(bs, n_query_groups, T, head_size)` + v: New values, `(bs, n_query_groups, T, head_size)` + + Returns: + k_full, v_full, `(bs, n_query_groups, max_seq_length, head_size)` + + """ # move the buffer to the activation dtype for when AMP is used self.k = self.k.to(k.dtype) self.v = self.v.to(v.dtype) # update the cache - n = k.size(0) - k = batched_index_copy_(self.k[:n, ...], -2, input_pos, k) - v = batched_index_copy_(self.v[:n, ...], -2, input_pos, v) + bs = k.size(0) + k = batched_index_copy_(self.k[:bs, ...], -2, input_pos, k) + v = batched_index_copy_(self.v[:bs, ...], -2, input_pos, v) return k, v def reset_parameters(self) -> None: diff --git a/litgpt/utils.py b/litgpt/utils.py index 2180762617..60e7cd9034 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -358,7 +358,6 @@ def get_default_supported_precision(training: bool) -> str: Args: training: If True, returns '-mixed' version of the precision; if False, returns '-true' version. - use_mps: Flag to determine if MPS should be used when available. Returns: The default precision that is suitable for the task and is supported by the hardware. diff --git a/tests/test_model.py b/tests/test_model.py index abd1a767bf..21095e9f2c 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -1322,3 +1322,65 @@ def test_load_legacy_state_dict(): attention_2 = CausalSelfAttention(config=config, block_idx=0) attention_2.load_state_dict(state_dict) + +@pytest.mark.parametrize("n_query_groups", (1, 2, 4, 8)) +@torch.inference_mode() +def test_kv_cache_buffer_shape(n_query_groups): + batch_size = 3 + max_seq_length = 23 + config = Config( + block_size=25, + padded_vocab_size=5, + n_layer=2, + n_head=8, + n_embd=16, + n_query_groups=n_query_groups, + ) + model = GPT(config) + model.max_seq_length = max_seq_length + model.set_kv_cache(batch_size) + required_shape = (batch_size, n_query_groups, max_seq_length, config.head_size) + for block in model.transformer.h: + kv_cache = block.attn.kv_cache + assert kv_cache is not None + assert kv_cache.k.shape == required_shape + assert kv_cache.v.shape == required_shape + + +@pytest.mark.parametrize( + ("rotary_percentage", "final_dim"), + ((0.75, 3), (0.25, 2)) +) +@torch.inference_mode() +def test_rope_cos_sin_shapes_if_rope_n_elem_is_odd(rotary_percentage, final_dim): + batch_size = 3 + config = Config( + block_size=25, + padded_vocab_size=5, + n_layer=2, + n_head=4, + n_embd=16, + rotary_percentage=rotary_percentage, + ) + model = GPT(config) + required_shape = (config.block_size, final_dim) + assert model.cos.shape == required_shape + assert model.sin.shape == required_shape + +def test_forward_with_without_input_pos_maxp1(): + batch_size = 3 + config = Config( + block_size=25, + padded_vocab_size=5, + n_layer=2, + n_head=8, + n_embd=16, + ) + model = GPT(config) + model.set_kv_cache(batch_size) + idx = torch.randint(0, config.padded_vocab_size, (1, 10)) + input_pos = torch.arange(1, 11) + input_pos_maxp1 = 11 + logits_with_maxp1 = model(idx, input_pos, input_pos_maxp1=input_pos_maxp1) + logits_no_maxp1 = model(idx, input_pos) + torch.testing.assert_close(logits_with_maxp1, logits_no_maxp1) diff --git a/tests/test_rope.py b/tests/test_rope.py index 7293e52fa7..0aa10aeb58 100644 --- a/tests/test_rope.py +++ b/tests/test_rope.py @@ -13,7 +13,7 @@ @torch.inference_mode() def test_rope_gptneox(): bs, seq_len, n_head, n_embed = 1, 6, 2, 8 - head_size = n_embed // n_head + head_size = n_embed // n_head # 4 x = torch.randint(0, 10000, size=(bs, n_head, seq_len, head_size)).float() position_ids = torch.arange(seq_len).unsqueeze(0) @@ -21,9 +21,10 @@ def test_rope_gptneox(): theirs_cos, theirs_sin = theirs_rot_emb(x, position_ids) ours_cos_cached, ours_sin_cached = build_rope_cache(seq_len, head_size, device=x.device) - # their rope cache has 2 added dimensions and the cos/sin is duplicated - torch.testing.assert_close(ours_cos_cached, theirs_cos.squeeze()) - torch.testing.assert_close(ours_sin_cached, theirs_sin.squeeze()) + ours_cos_cached = ours_cos_cached.unsqueeze(0) + ours_sin_cached = ours_sin_cached.unsqueeze(0) + torch.testing.assert_close(ours_cos_cached, theirs_cos) + torch.testing.assert_close(ours_sin_cached, theirs_sin) ours_x_rope = apply_rope(x, ours_cos_cached, ours_sin_cached) theirs_x_rope, _ = apply_rotary_pos_emb_gptneo(x, x, theirs_cos, theirs_sin, position_ids) @@ -47,8 +48,10 @@ def test_rope_llama_2(): # our rope ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta) - torch.testing.assert_close(theirs_cos.squeeze(0), ours_cos) - torch.testing.assert_close(theirs_sin.squeeze(0), ours_sin) + ours_cos = ours_cos.unsqueeze(0) + ours_sin = ours_sin.unsqueeze(0) + torch.testing.assert_close(theirs_cos, ours_cos) + torch.testing.assert_close(theirs_sin, ours_sin) ################################## # Compare rotated tensors @@ -86,8 +89,10 @@ def test_rope_llama_3(): # our rope ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta) - torch.testing.assert_close(theirs_cos.squeeze(0), ours_cos) - torch.testing.assert_close(theirs_sin.squeeze(0), ours_sin) + ours_cos = ours_cos.unsqueeze(0) + ours_sin = ours_sin.unsqueeze(0) + torch.testing.assert_close(theirs_cos, ours_cos) + torch.testing.assert_close(theirs_sin, ours_sin) ################################## # Compare rotated tensors @@ -146,8 +151,10 @@ def test_rope_llama_3_1(): # our rope ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta, extra_config=our_rope_config) - torch.testing.assert_close(theirs_cos.squeeze(0), ours_cos) - torch.testing.assert_close(theirs_sin.squeeze(0), ours_sin) + ours_cos = ours_cos.unsqueeze(0) + ours_sin = ours_sin.unsqueeze(0) + torch.testing.assert_close(theirs_cos, ours_cos) + torch.testing.assert_close(theirs_sin, ours_sin) ################################## # Compare rotated tensors @@ -206,8 +213,10 @@ def test_rope_llama_3_2(): # our rope ours_cos, ours_sin = build_rope_cache(seq_len, n_elem=head_dim, base=rope_theta, extra_config=our_rope_config) - torch.testing.assert_close(theirs_cos.squeeze(0), ours_cos) - torch.testing.assert_close(theirs_sin.squeeze(0), ours_sin) + ours_cos = ours_cos.unsqueeze(0) + ours_sin = ours_sin.unsqueeze(0) + torch.testing.assert_close(theirs_cos, ours_cos) + torch.testing.assert_close(theirs_sin, ours_sin) ################################## # Compare rotated tensors @@ -225,4 +234,24 @@ def test_rope_llama_3_2(): theirs_q_rot, theirs_k_rot = apply_rotary_pos_emb_llama(queries, keys, theirs_cos, theirs_sin) torch.testing.assert_close(theirs_q_rot, ours_q_rot) torch.testing.assert_close(theirs_k_rot, ours_k_rot) - + +@torch.inference_mode() +def test_rope_cos_sin_shapes_if_rope_n_elem_is_odd(): + bs, seq_len, n_head, n_embed = 1, 6, 2, 8 + head_size = n_embed // n_head # 4 + rotary_percentage = 0.75 + rope_n_elem = int(head_size * rotary_percentage) # 3 + ours_cos, ours_sin = build_rope_cache(seq_len, rope_n_elem) + required_shape = (seq_len, rope_n_elem) + assert ours_cos.shape == required_shape + assert ours_sin.shape == required_shape + # Special case: If `rope_n_elem == 1`, the shape is extended. This is to + # accommodate a current bug in Hugging Face, ensuring that other unit tests + # pass. + # https://github.com/huggingface/transformers/issues/35233 + rotary_percentage = 0.25 + rope_n_elem = int(head_size * rotary_percentage) # 1 + ours_cos, ours_sin = build_rope_cache(seq_len, rope_n_elem) + required_shape = (seq_len, rope_n_elem + 1) + assert ours_cos.shape == required_shape + assert ours_sin.shape == required_shape