Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-topk related issue in mixtral-like model tests. #125

Closed
nikitaved opened this issue Apr 3, 2024 · 3 comments
Closed

Non-topk related issue in mixtral-like model tests. #125

nikitaved opened this issue Apr 3, 2024 · 3 comments
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@nikitaved
Copy link
Contributor

🐛 Bug

Now that we have topk supported, it is time to unlock some tests. However, the following diff:

diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py
index d1d55073..ad69a721 100644
--- a/thunder/tests/test_jit_general.py
+++ b/thunder/tests/test_jit_general.py
@@ -613,7 +613,7 @@ def test_nanogpt():
         "falcon-7b-like",
         "falcon-40b-like",
         "codellama2-like",
-        pytest.param("mixtral-like", marks=pytest.mark.xfail(raises=TypeError, reason="topk", strict=True)),
+        "mixtral-like",
     ),
 )
 @pytest.mark.parametrize(

Breaks pytest -sv thunder/tests/test_jit_general.py -k test_litgpt_variants[cpu-mixtral-like] with

___________________________________________________________________________________________________ test_litgpt_variants[cpu-mixtral-like] ___________________________________________________________________________________________________

name = 'mixtral-like', device = device(type='cpu')

    @skipif_not_pytorch_2_1
    @pytest.mark.parametrize(
        "name",
        (
            "gpt-neox-like",
            "llama1-like",
            "long-context-like",
            "llama2-like",
            "falcon-7b-like",
            "falcon-40b-like",
            "codellama2-like",
            "mixtral-like",
        ),
    )
    @pytest.mark.parametrize(
        "device",
        ("cpu", "cuda"),
    )
    def test_litgpt_variants(name, device):
        if device == "cuda" and not torch.cuda.is_available():
            pytest.skip("CUDA not available")
    
        device = torch.device(device)
    
        x = torch.randint(0, 200, (5, 5), device=device)
        config = litgpt_model.Config.from_name(name)
    
        with device:
            reference = litgpt_model.GPT(config)
        expected_logits = reference(x)
    
        expected_logits.sum().backward()
    
        with device:
            model = litgpt_model.GPT(config)
        model.load_state_dict(reference.state_dict())
        tom = thunder.jit(model, executors=nvfuserex if device.type == "cuda" else torchex)
>       actual_logits = tom(x)

thunder/tests/test_jit_general.py:642: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1527: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1536: in _call_impl
    return forward_call(*args, **kwargs)
thunder/__init__.py:194: in forward
    res = self._forward_fn(*args, **kwargs)
thunder/__init__.py:629: in fn_
    cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
thunder/__init__.py:262: in cache_info_wrapper
    res = fn(*args, **kwargs)
thunder/__init__.py:504: in get_computation_and_inputs
    prologue_trc, computation_trc, *maybe_epilogue = interpreter(
thunder/__init__.py:175: in _general_frontend
    return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges)
thunder/core/jit_ext.py:1430: in thunder_general_jit
    result = jfn(*args, **kwargs)
thunder/core/interpreter.py:6684: in fn_
    raise e
thunder/core/interpreter.py:6647: in fn_2
    return fn(*args, **kwargs)
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1527: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1536: in _call_impl
    return forward_call(*args, **kwargs)
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/litgpt/model.py:94: in forward
    x = block(x, cos, sin, mask, input_pos)
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1527: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1536: in _call_impl
    return forward_call(*args, **kwargs)
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/litgpt/model.py:187: in forward
    x = self.mlp(self.norm_2(x)) + x
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1527: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/torch/nn/modules/module.py:1536: in _call_impl
    return forward_call(*args, **kwargs)
thunder/core/interpreter.py:6046: in _impl
    return fn.__func__(fn.__self__, *args, **kwargs)
../../miniconda3/envs/thunder_dev/lib/python3.10/site-packages/litgpt/model.py:347: in forward
    token_idx, expert_idx = torch.where(mask)
thunder/core/interpreter.py:1258: in wrapping_wrapper
    res = ufn(*uargs, **ukwargs)
thunder/core/symbol.py:250: in __call__
    result = self.meta(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

args = (t157,), kwargs = {}, tok = <Token used var=<ContextVar name='langctx' at 0x7fa2ad45a340> at 0x7f9bf1b6bdc0>

    @wraps(fn)
    def _fn(*args, **kwargs):
        try:
            tok = set_langctx(self.langctx)
>           result = fn(*args, **kwargs)
E           TypeError: where() missing 2 required positional arguments: 'a' and 'b'

thunder/core/langctxs.py:124: TypeError
========================================================================================================== short test summary info ===========================================================================================================
FAILED thunder/tests/test_jit_general.py::test_litgpt_variants[cpu-mixtral-like] - TypeError: where() missing 2 required positional arguments: 'a' and 'b'
=============================================================================================== 1 failed, 54 deselected, 10 warnings in 8.04s ================================================================================================
@nikitaved nikitaved added bug Something isn't working help wanted Extra attention is needed labels Apr 3, 2024
@nikitaved
Copy link
Contributor Author

Sorry, @carmocca , I missed your related issue #124

@carmocca
Copy link
Contributor

carmocca commented Apr 3, 2024

No problem! Maybe let's keep #124 instead since it's the more generic ask (support for nonzero)

@nikitaved
Copy link
Contributor Author

Closing in favor of #124

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants