-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support torch.where(condition)
with thunder.jit
#124
Comments
As of now, we cannot support data-dependent ops, alas... |
@carmocca , looking at the code I think the solution could be modifying the model in the package. The result of |
@nikitaved Faster and better code is very welcome in LitGPT. I benchmarked a few different implementations when this was added and this came out to be the best in general (see description and discussion in Lightning-AI/litgpt#823). It would be useful to see them compared to whatever you propose. |
The error message is not friendly and doesn't tell that In [1]: import torch
In [2]: import thunder
In [3]: from litgpt import Config
In [4]: from litgpt.model import LLaMAMoE
In [5]: config = Config.from_name("Mixtral-8x7B-v0.1")
In [6]: model = LLaMAMoE(config).to(dtype=torch.bfloat16, device="cuda")
In [7]: jit_model = thunder.jit(model)
In [8]: x = torch.randn(2, config.block_size, config.n_embd, dtype=torch.bfloat16, device="cuda")
In [9]: jit_model(x); Traceback: ---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[9], line 1
----> 1 jit_model(x);
File ~/dev/pytorch/main/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File ~/dev/pytorch/main/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File ~/dev/lightning-thunder/thunder/__init__.py:194, in ThunderModule.forward(self, *args, **kwargs)
193 def forward(self, *args, **kwargs):
--> 194 res = self._forward_fn(*args, **kwargs)
195 return res
File ~/dev/lightning-thunder/thunder/__init__.py:629, in jit.<locals>.fn_(*args, **kwargs)
626 cs.last_trace_host_start = time.time_ns()
627 cs.calls += 1
--> 629 cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
630 cs.last_trace_host_execution_start = time.time_ns()
632 result = cache_entry.computation_fn(*inps)
File ~/dev/lightning-thunder/thunder/__init__.py:262, in _with_cache_info_ctx.<locals>.cache_info_wrapper(*args, **kwargs)
260 tok = _cache_info_ctx.set({})
261 try:
--> 262 res = fn(*args, **kwargs)
263 finally:
264 _cache_info_ctx.reset(tok)
File ~/dev/lightning-thunder/thunder/__init__.py:504, in jit.<locals>.get_computation_and_inputs(*args, **kwargs)
502 prologue_trc: TraceCtx
503 computation_trc: TraceCtx
--> 504 prologue_trc, computation_trc, *maybe_epilogue = interpreter(
505 fn, args, kwargs, sharp_edges=cd.sharp_edges
506 )
508 if maybe_epilogue:
509 epilogue_traces = maybe_epilogue
File ~/dev/lightning-thunder/thunder/__init__.py:175, in _general_frontend(fn, args, kwargs, sharp_edges)
174 def _general_frontend(fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDGES_OPTIONS) -> tuple[TraceCtx, TraceCtx]:
--> 175 return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges)
File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1440, in thunder_general_jit(fn, args, kwargs, sharp_edges)
1438 with general_jit_ctx(ctx):
1439 with tracectx(computation_trace):
-> 1440 result = jfn(*args, **kwargs)
1441 prims.python_return(result)
1442 process_recorded_modifications(ctx, epilogue_trace)
File ~/dev/lightning-thunder/thunder/core/interpreter.py:6684, in interpret.<locals>.fn_(*args, **kwargs)
6682 assert isinstance(e, BaseException), e
6683 runtimectx.curexc = None
-> 6684 raise e
6686 return interpretation_result
File ~/dev/lightning-thunder/thunder/core/interpreter.py:6647, in interpret.<locals>.fn_.<locals>.getfn.<locals>.fn_2()
6646 def fn_2(args, kwargs):
-> 6647 return fn(*args, **kwargs)
File ~/dev/lightning-thunder/thunder/core/interpreter.py:6046, in _call_dispatch.<locals>._impl()
6045 def _impl(fn, *args, **kwargs):
-> 6046 return fn.__func__(fn.__self__, *args, **kwargs)
File ~/dev/pytorch/main/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl()
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File ~/dev/lightning-thunder/thunder/core/interpreter.py:6046, in _call_dispatch.<locals>._impl()
6045 def _impl(fn, *args, **kwargs):
-> 6046 return fn.__func__(fn.__self__, *args, **kwargs)
File ~/dev/pytorch/main/torch/nn/modules/module.py:1520, in Module._call_impl()
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File ~/dev/lightning-thunder/thunder/core/interpreter.py:6046, in _call_dispatch.<locals>._impl()
6045 def _impl(fn, *args, **kwargs):
-> 6046 return fn.__func__(fn.__self__, *args, **kwargs)
File ~/dev/litgpt/litgpt/model.py:347, in LLaMAMoE.forward()
345 y = torch.zeros_like(x) # (B*T, C)
346 for mask, expert in zip(masks, self.experts):
--> 347 token_idx, expert_idx = torch.where(mask)
348 y[token_idx] += probs[token_idx, expert_idx, None] * expert(x[token_idx])
349 return y.view(B, T, C)
File ~/dev/lightning-thunder/thunder/core/interpreter.py:1258, in interpreter_needs_wrap.<locals>.wrapping_wrapper(*args, **kwargs)
1255 ukwargs = kwargs
1257 try:
-> 1258 res = ufn(*uargs, **ukwargs)
1260 # If result is a WrappedValue, we trust its provenance record
1261 if isinstance(res, WrappedValue):
File ~/dev/lightning-thunder/thunder/core/symbol.py:250, in Symbol.__call__(self, *args, **kwargs)
248 else:
249 trace.push_scope(subsymbols)
--> 250 result = self.meta(*args, **kwargs)
251 trace.pop_scope()
253 bsym = self.bind(*args, **kwargs, output=result, subsymbols=subsymbols)
File ~/dev/lightning-thunder/thunder/core/langctxs.py:124, in langctx.__call__.<locals>._fn(*args, **kwargs)
122 try:
123 tok = set_langctx(self.langctx)
--> 124 result = fn(*args, **kwargs)
125 return result
126 finally:
TypeError: where() missing 2 required positional arguments: 'a' and 'b' |
@IvanYashchuk , looks like we should update the meta function for Might be a very nice issue for external contributors... |
triage review:
|
|
We were referring to a parameter that would be analogous to https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.nonzero.html |
This prototype adds |
@IvanYashchuk @kshitij12345 Is this still unsupported when using ThunderFX? If |
lightning-thunder/thunder/tests/test_dynamo.py Lines 346 to 349 in b28d5b3
Will update the issue title to reflect the request for |
torch.where(condition)
torch.where(condition)
with thunder.jit
🚀 Feature
Motivation
Mixtral uses it: https://github.com/mistralai/mistral-src/blob/8598cf582091a596671be31990448e0620017851/moe_one_file_ref.py#L215
Minimal Repro
Pitch
Support
from https://pytorch.org/docs/stable/generated/torch.where.html
Additional context
We already support
torch.where(condition, input, other)
: https://github.com/search?q=repo%3ALightning-AI%2Flightning-thunder+%22def+where%22&type=codecc @apaz-cli
The text was updated successfully, but these errors were encountered: