diff --git a/comfy/float.py b/comfy/float.py index 521316fd2fac..848c6ff68aa3 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -55,13 +55,26 @@ def stochastic_rounding(value, dtype, seed=0): if dtype == torch.bfloat16: return value.to(dtype=torch.bfloat16) if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: - generator = torch.Generator(device=value.device) + # MPS workaround: perform float8 conversion on CPU + target_device = value.device + use_cpu_staging = (target_device.type == "mps") + + output_device = "cpu" if use_cpu_staging else target_device + output = torch.empty_like(value, dtype=dtype, device=output_device) + + generator = torch.Generator(device=target_device) generator.manual_seed(seed) - output = torch.empty_like(value, dtype=dtype) + num_slices = max(1, (value.numel() / (4096 * 4096))) slice_size = max(1, round(value.shape[0] / num_slices)) for i in range(0, value.shape[0], slice_size): - output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator)) + res = manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator) + if use_cpu_staging: + res = res.cpu() + output[i:i+slice_size].copy_(res) + + if use_cpu_staging: + return output.to(target_device) return output return value.to(dtype=dtype) diff --git a/comfy/mps_ops.py b/comfy/mps_ops.py new file mode 100644 index 000000000000..a5930356b53c --- /dev/null +++ b/comfy/mps_ops.py @@ -0,0 +1,77 @@ +import torch + +_LUT_CACHE = {} + +def get_lut(dtype, device): + """ + Get or create a lookup table for float8 dequantization on MPS. + Returns a Tensor[256] of dtype=torch.float16 on the specified device. + """ + key = (dtype, device) + if key in _LUT_CACHE: + return _LUT_CACHE[key] + + # Generate all possible 8-bit values (0-255) + # We create them on CPU first as float8, then cast to float16, then move to MPS. + # This acts as our decoding table. + + # Create uint8 pattern 0..255 + byte_pattern = torch.arange(256, dtype=torch.uint8, device="cpu") + + # View as the target float8 type + # Note: We must use .view() on a tensor that has the same number of bytes. + # We can't view uint8 as float8 directly if standard pytorch doesn't allow it easily, + # but we can create the float8 tensor from bytes. + + # Actually, the easiest way to generate the LUT is: + # 1. Create bytes 0..255 + # 2. View as float8 (on CPU, where it is supported) + # 3. Convert to float16 (on CPU) + # 4. Move float16 LUT to MPS + + try: + f8_tensor = byte_pattern.view(dtype) + f16_lut = f8_tensor.to(torch.float16) + + # Move to the requested MPS device + lut = f16_lut.to(device) + _LUT_CACHE[key] = lut + return lut + except Exception as e: + print(f"Failed to create MPS LUT for {dtype}: {e}") + # Fallback: return None or raise + raise e + +def mps_dequantize(qdata, scale, orig_dtype, float8_dtype): + """ + Dequantize a uint8 tensor (representing float8 data) using a LUT on MPS. + + Args: + qdata: Tensor of shape (...) with dtype=torch.uint8 (on MPS) + scale: Tensor (scalar) + orig_dtype: The target dtype (e.g. float16) + float8_dtype: The original float8 dtype (torch.float8_e4m3fn or torch.float8_e5m2) + + Returns: + Tensor of shape (...) with dtype=orig_dtype + """ + lut = get_lut(float8_dtype, qdata.device) + + # Use index_select or advanced indexing. + # Advanced indexing lut[qdata.long()] is generally efficient. + # We explicitly cast to long (int64) for indexing. + # Note: Flattening might be slightly faster depending on shape, but simple indexing is safest. + + # We want the LUT to be in the target orig_dtype (likely float16 or bfloat16) + if lut.dtype != orig_dtype: + lut = lut.to(dtype=orig_dtype) + + output = lut[qdata.long()] + + # Apply scale + # Scale might need to be cast to orig_dtype too + if isinstance(scale, torch.Tensor): + scale = scale.to(dtype=orig_dtype) + + output.mul_(scale) + return output diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index cd96541d78eb..8531e974f95d 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -2,6 +2,7 @@ import logging from typing import Tuple, Dict import comfy.float +import comfy.mps_ops _LAYOUT_REGISTRY = {} _GENERIC_UTILS = {} @@ -269,8 +270,18 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout= if target_device != current_device: logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}") - new_q_data = qt._qdata.to(device=target_device) - new_params = _move_layout_params_to_device(qt._layout_params, target_device) + + # MPS Hack: Convert Float8 to Uint8 before moving if native float8 is unsupported + qdata = qt._qdata + layout_params = qt._layout_params.copy() + + if target_device.type == "mps" and qdata.element_size() == 1 and qdata.is_floating_point(): # Catch float8 + layout_params["mps_float8_dtype"] = qdata.dtype + qdata = qdata.view(torch.uint8) + + new_q_data = qdata.to(device=target_device) + new_params = _move_layout_params_to_device(layout_params, target_device) + if target_dtype is not None: new_params["orig_dtype"] = target_dtype new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params) @@ -431,6 +442,13 @@ def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_roun @staticmethod def dequantize(qdata, scale, orig_dtype, **kwargs): + if qdata.device.type == "mps": + if qdata.dtype == torch.uint8: + return comfy.mps_ops.mps_dequantize(qdata, scale, orig_dtype, kwargs.get("mps_float8_dtype", torch.float8_e4m3fn)) + elif qdata.is_floating_point() and qdata.element_size() == 1: + # It is MPS Float8. View as uint8. + return comfy.mps_ops.mps_dequantize(qdata.view(torch.uint8), scale, orig_dtype, qdata.dtype) + plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype) plain_tensor.mul_(scale) return plain_tensor