diff --git a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py index a47b2b2eba..29a1236db8 100644 --- a/fbgemm_gpu/fbgemm_gpu/sparse_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/sparse_ops.py @@ -972,6 +972,10 @@ def float_to_hfp8_quantized( return torch.empty_like(input, dtype=torch.uint8) +def hfp8_quantized_to_float(input: Tensor, ebits: int, exponent_bias: int) -> Tensor: + return torch.empty_like(input, dtype=torch.float32) + + def _setup() -> None: # pyre-ignore[16] _setup.done = getattr(_setup, "done", False) @@ -1102,6 +1106,11 @@ def impl_autograd(op_name, fn, setup_context: Optional[Callable] = None) -> None "fbgemm::FloatToHFP8Quantized", float_to_hfp8_quantized, ) + impl_abstract( + "fbgemm::HFP8QuantizedToFloat", + hfp8_quantized_to_float, + ) + _setup.done = True diff --git a/fbgemm_gpu/test/quantize/failures_dict_fast.json b/fbgemm_gpu/test/quantize/failures_dict_fast.json index fa917c72cc..53acb1cea7 100644 --- a/fbgemm_gpu/test/quantize/failures_dict_fast.json +++ b/fbgemm_gpu/test/quantize/failures_dict_fast.json @@ -53,19 +53,6 @@ "status": "xfail" } }, - "fbgemm::HFP8QuantizedToFloat": { - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_cpu": { - "comment": "", - "status": "xfail" - }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_gpu_no_cache": { - "comment": "", - "status": "xfail" - }, - "SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_gpu_no_cache_fp8_2048": { - "comment": "", - "status": "xfail" - } - } + "fbgemm::HFP8QuantizedToFloat": {} } }