Skip to content

Commit

Permalink
better error message
Browse files Browse the repository at this point in the history
still failing as `_scaled_mm` requires the secomd matrix to be column
major:

```
E               NotImplementedError: Failing to map `torch._scaled_mm` to `thunder.torch` op of [Symbol name=_scaled_mm] with args of [<TensorProxy(name="t166", dtype=thunder.dtypes.float8_e4m3fn, shape=(16, 32))>, <TensorProxy(name="t169", dtype=thunder.dtypes.float8_e4m3fn, shape=(32, 64))>, <TensorProxy(name="t170", dtype=thunder.dtypes.float32, shape=())>, <TensorProxy(name="t171", dtype=thunder.dtypes.float32, shape=())>, None, None, torch.float32, True]
E               BoundSymbol in question is
E               ```python
E               t165 = manual_float8_matmul_with_args_in_float8_127377658692416_2(input_fp8, t164)  # t165: "cuda:0 f32[16, 64]"
E                 # t102 = ltorch.reshape(input_fp8, -1, 32)  # t102: "cuda:0 f32[16, 32]"
E                   # t102 = prims.reshape(input_fp8, (16, 32))  # t102: "cuda:0 f32[16, 32]"
E                 # t103 = ltorch.spmm(t102, t164)  # t103: "cuda:0 f32[16, 64]"
E                 # t165 = prims.shallow_copy(t103)  # t165: "cuda:0 f32[16, 64]"
E               ```
E               Corresponding torch.fx Graph is
E               ```python
E               class <lambda>(torch.nn.Module):
E                   def forward(self, arg0, arg1, arg2, arg3, arg4, arg5):
E                       arg0_1: "f8e4m3fn[16, 32]"; arg1_1: "f32[]"; arg3_1: "f8e4m3fn[32, 64]"; arg4_1: "f32[]";
E
E                       arg0_1, arg1_1, arg2_1, arg2_2, arg2_3, arg2_4, arg2_5, arg2_6, arg2_7, arg2_8, arg2_9, arg2_10, arg2_11, arg2_12, arg2_13, arg2_14, arg2_15, arg3_1, arg4_1, arg5_1, arg5_2, arg5_3, arg5_4, arg5_5, arg5_6, arg5_7, arg5_8, arg5_9, arg5_10, arg5_11, arg5_12, arg5_13, arg5_14, arg5_15, = fx_pytree.tree_flatten_spec([arg0, arg1, arg2, arg3, arg4, arg5], self._in_spec)
E                       # No stacktrace found for following nodes
E                       view: "f8e4m3fn[16, 32]" = torch.ops.aten.view.default(arg0_1, [-1, 32]);  arg0_1 = None
E                       t: "f8e4m3fn[64, 32]" = torch.ops.aten.t.default(arg3_1);  arg3_1 = None
E                       clone: "f8e4m3fn[64, 32]" = torch.ops.aten.clone.default(t, memory_format = torch.contiguous_format);  t = None
E                       t_1: "f8e4m3fn[32, 64]" = torch.ops.aten.t.default(clone);  clone = None
E                       reciprocal: "f32[]" = torch.ops.aten.reciprocal.default(arg1_1);  arg1_1 = None
E                       reciprocal_1: "f32[]" = torch.ops.aten.reciprocal.default(arg4_1);  arg4_1 = None
E                       _scaled_mm: "f32[16, 64]" = torch.ops.aten._scaled_mm.default(view, t_1, reciprocal, reciprocal_1, None, None, torch.float32, True);  view = t_1 = reciprocal = reciprocal_1 = None
E                       return pytree.tree_unflatten([_scaled_mm, None], self._out_spec)
E
E               ```
E               Original error is Exception encountered when doing automatic registration for _scaled_mm, please use manual registration: RuntimeError('mat2 must be col_major')
```

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
  • Loading branch information
crcrpar committed Nov 26, 2024
1 parent 4e3c181 commit 0de44ee
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions thunder/transforms/tensor_subclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,8 +329,10 @@ def translate_fx_graph_into_bsym(
out = ltorch_op(*arg_proxies)
except Exception as e:
msg = (
f"Failing to map {node=} to {ltorch_op=} with {arg_proxies = }\n"
f"BoundSymbol in question is\n{bsym}\nCorresponding torch.fx Graph is\n{fx_graph.print_readable(print_output=False)}\n"
f"Failing to map `torch.{node}` to `thunder.torch` op of "
f"{ltorch_op} with args of {arg_proxies}\n"
f"BoundSymbol in question is\n```python\n{bsym}\n```\n"
f"Corresponding torch.fx Graph is\n```python\n{fx_graph.print_readable(print_output=False)}\n```\n"
f"Original error is {e}"
)
raise type(e)(msg)
Expand Down

0 comments on commit 0de44ee

Please sign in to comment.