Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support torch.where(condition) with thunder.jit #124

Open
Tracked by #194 ...
carmocca opened this issue Apr 3, 2024 · 11 comments
Open
Tracked by #194 ...

Support torch.where(condition) with thunder.jit #124

carmocca opened this issue Apr 3, 2024 · 11 comments
Labels
enhancement New feature or request help wanted Extra attention is needed operators

Comments

@carmocca
Copy link
Contributor

carmocca commented Apr 3, 2024

🚀 Feature

Motivation

Mixtral uses it: https://github.com/mistralai/mistral-src/blob/8598cf582091a596671be31990448e0620017851/moe_one_file_ref.py#L215

Minimal Repro

import thunder

def fn(cond):
    return torch.where(cond)

thunder.jit(fn)(torch.randn(3) > 0)

Pitch

Support
image
from https://pytorch.org/docs/stable/generated/torch.where.html

Additional context

We already support torch.where(condition, input, other): https://github.com/search?q=repo%3ALightning-AI%2Flightning-thunder+%22def+where%22&type=code

cc @apaz-cli

@nikitaved
Copy link
Contributor

As of now, we cannot support data-dependent ops, alas...

@nikitaved
Copy link
Contributor

nikitaved commented Apr 3, 2024

@carmocca , looking at the code I think the solution could be modifying the model in the package. The result of topk can be sorted, and then we do not need to apply where at all. This will also eliminate the device sync (syncs, actually) caused by where.

@carmocca
Copy link
Contributor Author

carmocca commented Apr 3, 2024

@nikitaved Faster and better code is very welcome in LitGPT. I benchmarked a few different implementations when this was added and this came out to be the best in general (see description and discussion in Lightning-AI/litgpt#823). It would be useful to see them compared to whatever you propose.

@IvanYashchuk
Copy link
Collaborator

The error message is not friendly and doesn't tell that torch.where(condition) is not supported properly:

In [1]: import torch

In [2]: import thunder

In [3]: from litgpt import Config

In [4]: from litgpt.model import LLaMAMoE

In [5]: config = Config.from_name("Mixtral-8x7B-v0.1")

In [6]: model = LLaMAMoE(config).to(dtype=torch.bfloat16, device="cuda")

In [7]: jit_model = thunder.jit(model)

In [8]: x = torch.randn(2, config.block_size, config.n_embd, dtype=torch.bfloat16, device="cuda")

In [9]: jit_model(x);

Traceback:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[9], line 1
----> 1 jit_model(x);

File ~/dev/pytorch/main/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/dev/pytorch/main/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/dev/lightning-thunder/thunder/__init__.py:194, in ThunderModule.forward(self, *args, **kwargs)
    193 def forward(self, *args, **kwargs):
--> 194     res = self._forward_fn(*args, **kwargs)
    195     return res

File ~/dev/lightning-thunder/thunder/__init__.py:629, in jit.<locals>.fn_(*args, **kwargs)
    626 cs.last_trace_host_start = time.time_ns()
    627 cs.calls += 1
--> 629 cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
    630 cs.last_trace_host_execution_start = time.time_ns()
    632 result = cache_entry.computation_fn(*inps)

File ~/dev/lightning-thunder/thunder/__init__.py:262, in _with_cache_info_ctx.<locals>.cache_info_wrapper(*args, **kwargs)
    260 tok = _cache_info_ctx.set({})
    261 try:
--> 262     res = fn(*args, **kwargs)
    263 finally:
    264     _cache_info_ctx.reset(tok)

File ~/dev/lightning-thunder/thunder/__init__.py:504, in jit.<locals>.get_computation_and_inputs(*args, **kwargs)
    502     prologue_trc: TraceCtx
    503     computation_trc: TraceCtx
--> 504     prologue_trc, computation_trc, *maybe_epilogue = interpreter(
    505         fn, args, kwargs, sharp_edges=cd.sharp_edges
    506     )
    508 if maybe_epilogue:
    509     epilogue_traces = maybe_epilogue

File ~/dev/lightning-thunder/thunder/__init__.py:175, in _general_frontend(fn, args, kwargs, sharp_edges)
    174 def _general_frontend(fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDGES_OPTIONS) -> tuple[TraceCtx, TraceCtx]:
--> 175     return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges)

File ~/dev/lightning-thunder/thunder/core/jit_ext.py:1440, in thunder_general_jit(fn, args, kwargs, sharp_edges)
   1438 with general_jit_ctx(ctx):
   1439     with tracectx(computation_trace):
-> 1440         result = jfn(*args, **kwargs)
   1441         prims.python_return(result)
   1442         process_recorded_modifications(ctx, epilogue_trace)

File ~/dev/lightning-thunder/thunder/core/interpreter.py:6684, in interpret.<locals>.fn_(*args, **kwargs)
   6682     assert isinstance(e, BaseException), e
   6683     runtimectx.curexc = None
-> 6684     raise e
   6686 return interpretation_result

File ~/dev/lightning-thunder/thunder/core/interpreter.py:6647, in interpret.<locals>.fn_.<locals>.getfn.<locals>.fn_2()
   6646 def fn_2(args, kwargs):
-> 6647     return fn(*args, **kwargs)

File ~/dev/lightning-thunder/thunder/core/interpreter.py:6046, in _call_dispatch.<locals>._impl()
   6045 def _impl(fn, *args, **kwargs):
-> 6046     return fn.__func__(fn.__self__, *args, **kwargs)

File ~/dev/pytorch/main/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl()
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/dev/lightning-thunder/thunder/core/interpreter.py:6046, in _call_dispatch.<locals>._impl()
   6045 def _impl(fn, *args, **kwargs):
-> 6046     return fn.__func__(fn.__self__, *args, **kwargs)

File ~/dev/pytorch/main/torch/nn/modules/module.py:1520, in Module._call_impl()
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/dev/lightning-thunder/thunder/core/interpreter.py:6046, in _call_dispatch.<locals>._impl()
   6045 def _impl(fn, *args, **kwargs):
-> 6046     return fn.__func__(fn.__self__, *args, **kwargs)

File ~/dev/litgpt/litgpt/model.py:347, in LLaMAMoE.forward()
    345 y = torch.zeros_like(x)  # (B*T, C)
    346 for mask, expert in zip(masks, self.experts):
--> 347     token_idx, expert_idx = torch.where(mask)
    348     y[token_idx] += probs[token_idx, expert_idx, None] * expert(x[token_idx])
    349 return y.view(B, T, C)

File ~/dev/lightning-thunder/thunder/core/interpreter.py:1258, in interpreter_needs_wrap.<locals>.wrapping_wrapper(*args, **kwargs)
   1255     ukwargs = kwargs
   1257 try:
-> 1258     res = ufn(*uargs, **ukwargs)
   1260     # If result is a WrappedValue, we trust its provenance record
   1261     if isinstance(res, WrappedValue):

File ~/dev/lightning-thunder/thunder/core/symbol.py:250, in Symbol.__call__(self, *args, **kwargs)
    248 else:
    249     trace.push_scope(subsymbols)
--> 250     result = self.meta(*args, **kwargs)
    251     trace.pop_scope()
    253 bsym = self.bind(*args, **kwargs, output=result, subsymbols=subsymbols)

File ~/dev/lightning-thunder/thunder/core/langctxs.py:124, in langctx.__call__.<locals>._fn(*args, **kwargs)
    122 try:
    123     tok = set_langctx(self.langctx)
--> 124     result = fn(*args, **kwargs)
    125     return result
    126 finally:

TypeError: where() missing 2 required positional arguments: 'a' and 'b'

@nikitaved
Copy link
Contributor

@IvanYashchuk , looks like we should update the meta function for where. To be frank, I did not even know about this overload...

Might be a very nice issue for external contributors...

@mruberry
Copy link
Collaborator

mruberry commented Apr 15, 2024

triage review:

  • can the call to torch.where(condition) in mixtral use the hypothetical shape parameter to nonzero to make the output shape known at compile-time?
  • we should implement nonzero(..., shape=...)

@carmocca
Copy link
Contributor Author

nonzero doesn't have a shape= argument. Did you mean as_tuple=?

@mruberry
Copy link
Collaborator

nonzero doesn't have a shape= argument. Did you mean as_tuple=?

We were referring to a parameter that would be analogous to jax.lax.nonzero's size parameter:

https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.nonzero.html

@IvanYashchuk
Copy link
Collaborator

This prototype adds torch.where(boolean_tensor) #303

@mruberry
Copy link
Collaborator

@IvanYashchuk @kshitij12345 Is this still unsupported when using ThunderFX? If torch.where(condition) is supported when using ThunderFX (because the operator is sent to PyTorch for execution?), then maybe we can close or amend this issue to refer more specifically to using torch.where(condition) with the Thunder interpreter as the entrypoint?

@kshitij12345
Copy link
Collaborator

torch.where(condition) works with ThunderFX path by sending it to PyTorch. We also have a test for the same.

@instantiate(dtypes=NOTHING, executors=[DynamoThunderExecutor])
def test_where_nonzero_overload(executor, device: str, dtype: dtypes.dtype):
# Verify that `torch.where(cond)` leads to graph break and `torch.where(cond, x, y)`
# is correctly passed to `thunder`.

Will update the issue title to reflect the request for torch.where(condition) not being supported by thunder.jit entrypoint.

@kshitij12345 kshitij12345 changed the title Support torch.where(condition) Support torch.where(condition) with thunder.jit Oct 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed operators
Projects
None yet
Development

No branches or pull requests

6 participants