Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CI failure in thunder/tests/test_jit_general.py::test_litgpt_variants_kvcache with PyTorch 2.5dev #669

Closed
t-vi opened this issue Jun 27, 2024 · 0 comments · Fixed by #778
Assignees
Labels
bug Something isn't working high priority

Comments

@t-vi
Copy link
Collaborator

t-vi commented Jun 27, 2024

I can reproduce:

____________________________________________________________________________________________________ test_litgpt_variants_kvcache[cpu-llama1-like] _____________________________________________________________________________________________________

name = 'llama1-like', device = device(type='cpu')

    @skipif_not_pytorch_2_1
    @pytest.mark.parametrize(
        "name",
        (
            # TODO this seems flaky on CI - the cause is unclear
            # "gpt-neox-like",
            "llama1-like",
            "long-context-like",
            "llama2-like",
            "falcon-7b-like",
            "falcon-40b-like",
            "codellama2-like",
            pytest.param(
                "mixtral-like",
                marks=pytest.mark.xfail(raises=(NotImplementedError, TypeError), reason="topk and where", strict=True),
            ),
        ),
    )
    @pytest.mark.parametrize(
        "device",
        ("cpu", "cuda"),
    )
    def test_litgpt_variants_kvcache(name, device):
        import torch._dynamo  # this monkeypatches torch.manual_seed
    
        if device == "cuda" and not torch.cuda.is_available():
            pytest.skip("CUDA not available")
    
        device = torch.device(device)
        x = torch.randint(0, 200, (1, 2), device=device)
        config = litgpt_model.Config.from_name(name)
    
        with device:
            model = litgpt_model.GPT(config)
            model.max_seq_length = 3
    
        for p in model.parameters():
            p.requires_grad_(False)
    
        executors = nvfuserex if device.type == "cuda" else torchex
        executors = [sdpa_ex] + executors
    
        def sample(logits):
            return torch.argmax(logits[:, -1], dim=-1, keepdim=True)
    
        # the reference is 2 regular forward without the kv cache
        logits_1 = model(x)
        token_1 = sample(logits_1)
        logits_2 = model(torch.cat((x, token_1), dim=-1))
    
        with device:
            model.set_kv_cache(batch_size=1)
        tom = thunder.jit(model, executors=executors)  # , disable_torch_autograd_support=True
    
        # kv cache prefill
        thunder_logits_1 = tom(x, torch.tensor([0, 1], device=device))
        thunder_token_1 = sample(thunder_logits_1)
        # 1 token generation
>       thunder_logits_2 = tom(thunder_token_1, torch.tensor([2], device=device))

thunder/tests/test_jit_general.py:745: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1714: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
/usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1725: in _call_impl
    return forward_call(*args, **kwargs)
thunder/core/module.py:61: in forward
    res = self._forward_fn(*args, **kwargs)
thunder/__init__.py:675: in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
thunder/__init__.py:224: in cache_info_wrapper
    res = fn(*args, **kwargs)
thunder/__init__.py:503: in get_computation_and_inputs
    jit_results: TraceResults = interpreter(
thunder/__init__.py:212: in _general_frontend
    return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges, record_history=record_history)
thunder/core/jit_ext.py:1692: in thunder_general_jit
    pro_to_comp_proxies, pro_to_epi_proxies = unpack_inputs(ctx, prologue_trace, pro_to_comp, pro_to_epi, args, kwargs)
thunder/core/jit_ext.py:1499: in unpack_inputs
    pro_to_comp = tuple(sorted((unpack(v) for v in pro_to_comp_inps), key=lambda x: param_ordering[id(x)][1]))
thunder/core/jit_ext.py:1499: in <genexpr>
    pro_to_comp = tuple(sorted((unpack(v) for v in pro_to_comp_inps), key=lambda x: param_ordering[id(x)][1]))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

v = t83

    def unpack(v: Variable | Proxy) -> Proxy:
        p: Proxy
        if isinstance(v, Proxy):
            p = v
        else:
            p = v.proxy
    
>       assert p.history is not None, f"{p} has history None"
E       AssertionError: t83 has history None

thunder/core/jit_ext.py:1294: AssertionError

cc @apaz-cli

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working high priority
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants