diff --git a/litgpt/generate/base.py b/litgpt/generate/base.py index 2536b15613..273fe7497d 100644 --- a/litgpt/generate/base.py +++ b/litgpt/generate/base.py @@ -77,7 +77,7 @@ def next_token( model: GPT, input_pos: torch.Tensor, x: torch.Tensor, - input_pos_maxp1: Optional[int] = None, + input_pos_maxp1: Optional[torch.Tensor] = None, **sample_kwargs: Dict[str, Any], ) -> torch.Tensor: logits = model(x, input_pos, input_pos_maxp1=input_pos_maxp1) @@ -174,7 +174,7 @@ def generate_fn( token = prompt prefill_token = True input_pos = torch.arange(0, prompt_size, device=device, dtype=torch.int64) - input_pos_maxp1 = prompt_size + input_pos_maxp1 = torch.tensor(prompt_size, device=device) for current_idx in range(max_returned_tokens - prompt_size): # Generate the token @@ -222,7 +222,7 @@ def generate_fn( input_pos = torch.tensor([prompt_size], device=device, dtype=torch.int64) else: input_pos.add_(1) - input_pos_maxp1 += 1 + input_pos_maxp1.add_(1) # Yield any remaining tokens if yielded_idx < len(tokens): diff --git a/litgpt/model.py b/litgpt/model.py index bff11ccb6f..a89070d8bb 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -87,7 +87,7 @@ def forward( self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None, - input_pos_maxp1: Optional[int] = None, + input_pos_maxp1: Optional[torch.Tensor] = None, lm_head_chunk_size: int = 0, ) -> Union[torch.Tensor, List[torch.Tensor]]: """ @@ -283,7 +283,7 @@ def forward( sin: torch.Tensor, mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, - input_pos_maxp1: Optional[int] = None, + input_pos_maxp1: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Non-parallel residual Parallel residual @@ -351,7 +351,7 @@ def forward( sin: torch.Tensor, mask: Optional[torch.Tensor] = None, input_pos: Optional[torch.Tensor] = None, - input_pos_maxp1: Optional[int] = None, + input_pos_maxp1: Optional[torch.Tensor] = None, ) -> torch.Tensor: # Notation: # - B | batch size diff --git a/tests/test_model.py b/tests/test_model.py index 21095e9f2c..e8a110a409 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -914,7 +914,7 @@ def test_against_original_salamandra(model_name, device, dtype): ours_y = ours_model(x) theirs_y = theirs_model(x)["logits"].to(dtype) # HF converts logits to float torch.testing.assert_close(ours_y, theirs_y) - + @torch.inference_mode() @pytest.mark.parametrize("model_name", ("SmolLM2-135M", "SmolLM2-360M", "SmolLM2-1.7B")) @@ -1380,7 +1380,7 @@ def test_forward_with_without_input_pos_maxp1(): model.set_kv_cache(batch_size) idx = torch.randint(0, config.padded_vocab_size, (1, 10)) input_pos = torch.arange(1, 11) - input_pos_maxp1 = 11 + input_pos_maxp1 = torch.tensor(11) logits_with_maxp1 = model(idx, input_pos, input_pos_maxp1=input_pos_maxp1) logits_no_maxp1 = model(idx, input_pos) torch.testing.assert_close(logits_with_maxp1, logits_no_maxp1)