We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
I can reproduce:
____________________________________________________________________________________________________ test_litgpt_variants_kvcache[cpu-llama1-like] _____________________________________________________________________________________________________ name = 'llama1-like', device = device(type='cpu') @skipif_not_pytorch_2_1 @pytest.mark.parametrize( "name", ( # TODO this seems flaky on CI - the cause is unclear # "gpt-neox-like", "llama1-like", "long-context-like", "llama2-like", "falcon-7b-like", "falcon-40b-like", "codellama2-like", pytest.param( "mixtral-like", marks=pytest.mark.xfail(raises=(NotImplementedError, TypeError), reason="topk and where", strict=True), ), ), ) @pytest.mark.parametrize( "device", ("cpu", "cuda"), ) def test_litgpt_variants_kvcache(name, device): import torch._dynamo # this monkeypatches torch.manual_seed if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available") device = torch.device(device) x = torch.randint(0, 200, (1, 2), device=device) config = litgpt_model.Config.from_name(name) with device: model = litgpt_model.GPT(config) model.max_seq_length = 3 for p in model.parameters(): p.requires_grad_(False) executors = nvfuserex if device.type == "cuda" else torchex executors = [sdpa_ex] + executors def sample(logits): return torch.argmax(logits[:, -1], dim=-1, keepdim=True) # the reference is 2 regular forward without the kv cache logits_1 = model(x) token_1 = sample(logits_1) logits_2 = model(torch.cat((x, token_1), dim=-1)) with device: model.set_kv_cache(batch_size=1) tom = thunder.jit(model, executors=executors) # , disable_torch_autograd_support=True # kv cache prefill thunder_logits_1 = tom(x, torch.tensor([0, 1], device=device)) thunder_token_1 = sample(thunder_logits_1) # 1 token generation > thunder_logits_2 = tom(thunder_token_1, torch.tensor([2], device=device)) thunder/tests/test_jit_general.py:745: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1714: in _wrapped_call_impl return self._call_impl(*args, **kwargs) /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1725: in _call_impl return forward_call(*args, **kwargs) thunder/core/module.py:61: in forward res = self._forward_fn(*args, **kwargs) thunder/__init__.py:675: in fn_ cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs) thunder/__init__.py:224: in cache_info_wrapper res = fn(*args, **kwargs) thunder/__init__.py:503: in get_computation_and_inputs jit_results: TraceResults = interpreter( thunder/__init__.py:212: in _general_frontend return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history) thunder/core/jit_ext.py:1692: in thunder_general_jit pro_to_comp_proxies, pro_to_epi_proxies = unpack_inputs(ctx, prologue_trace, pro_to_comp, pro_to_epi, args, kwargs) thunder/core/jit_ext.py:1499: in unpack_inputs pro_to_comp = tuple(sorted((unpack(v) for v in pro_to_comp_inps), key=lambda x: param_ordering[id(x)][1])) thunder/core/jit_ext.py:1499: in <genexpr> pro_to_comp = tuple(sorted((unpack(v) for v in pro_to_comp_inps), key=lambda x: param_ordering[id(x)][1])) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ v = t83 def unpack(v: Variable | Proxy) -> Proxy: p: Proxy if isinstance(v, Proxy): p = v else: p = v.proxy > assert p.history is not None, f"{p} has history None" E AssertionError: t83 has history None thunder/core/jit_ext.py:1294: AssertionError
cc @apaz-cli
The text was updated successfully, but these errors were encountered:
t-vi
Successfully merging a pull request may close this issue.
I can reproduce:
cc @apaz-cli
The text was updated successfully, but these errors were encountered: