-
Notifications
You must be signed in to change notification settings - Fork 533
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add test for supporting torch.float16 and torch.bfloat16
Summary: # context * We found the new operator `permute_multi_embedding` can't support `torch.float16` in an inference test * added test to cover the dtype support * before the operator change, we see the following error ``` Failures: 1) torchrec.sparse.tests.test_jagged_tensor.TestKeyedTensorRegroupOp: test_multi_permute_dtype 1) RuntimeError: expected scalar type Float but found Half File "torchrec/sparse/tests/test_jagged_tensor.py", line 2798, in test_multi_permute_dtype outputs = torch.ops.fbgemm.permute_multi_embedding( File "torch/_ops.py", line 1113, in __call__ return self._op(*args, **(kwargs or {})) ``` * suspicion is that in the cpu operator, there are tensor data access with `data_ptr<float>` in the code, which limited the dtype could only be `float32` ``` auto outp = outputs[out_tensor][b].data_ptr<float>() + out_offset; auto inp = inputs[in_tensor][b].data_ptr<float>() + in_offset; ``` # changes * use `FBGEMM_DISPATCH_FLOATING_TYPES` to dispatch the dtype to template `scalar_t`. * after the change the operator can support `float16`, `bfloat16` WARNING: somehow this operator still can't support `int` types. Differential Revision: D57143637
- Loading branch information
1 parent
ada1050
commit 0de38a6
Showing
1 changed file
with
70 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters