Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions comfy/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,26 @@ def stochastic_rounding(value, dtype, seed=0):
if dtype == torch.bfloat16:
return value.to(dtype=torch.bfloat16)
if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2:
generator = torch.Generator(device=value.device)
# MPS workaround: perform float8 conversion on CPU
target_device = value.device
use_cpu_staging = (target_device.type == "mps")

output_device = "cpu" if use_cpu_staging else target_device
output = torch.empty_like(value, dtype=dtype, device=output_device)

generator = torch.Generator(device=target_device)
generator.manual_seed(seed)
output = torch.empty_like(value, dtype=dtype)

num_slices = max(1, (value.numel() / (4096 * 4096)))
slice_size = max(1, round(value.shape[0] / num_slices))
for i in range(0, value.shape[0], slice_size):
output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator))
res = manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator)
if use_cpu_staging:
res = res.cpu()
output[i:i+slice_size].copy_(res)

if use_cpu_staging:
return output.to(target_device)
return output

return value.to(dtype=dtype)
77 changes: 77 additions & 0 deletions comfy/mps_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import torch

_LUT_CACHE = {}

def get_lut(dtype, device):
"""
Get or create a lookup table for float8 dequantization on MPS.
Returns a Tensor[256] of dtype=torch.float16 on the specified device.
"""
key = (dtype, device)
if key in _LUT_CACHE:
return _LUT_CACHE[key]

# Generate all possible 8-bit values (0-255)
# We create them on CPU first as float8, then cast to float16, then move to MPS.
# This acts as our decoding table.

# Create uint8 pattern 0..255
byte_pattern = torch.arange(256, dtype=torch.uint8, device="cpu")

# View as the target float8 type
# Note: We must use .view() on a tensor that has the same number of bytes.
# We can't view uint8 as float8 directly if standard pytorch doesn't allow it easily,
# but we can create the float8 tensor from bytes.

# Actually, the easiest way to generate the LUT is:
# 1. Create bytes 0..255
# 2. View as float8 (on CPU, where it is supported)
# 3. Convert to float16 (on CPU)
# 4. Move float16 LUT to MPS

try:
f8_tensor = byte_pattern.view(dtype)
f16_lut = f8_tensor.to(torch.float16)

# Move to the requested MPS device
lut = f16_lut.to(device)
_LUT_CACHE[key] = lut
return lut
except Exception as e:
print(f"Failed to create MPS LUT for {dtype}: {e}")
# Fallback: return None or raise
raise e

def mps_dequantize(qdata, scale, orig_dtype, float8_dtype):
"""
Dequantize a uint8 tensor (representing float8 data) using a LUT on MPS.

Args:
qdata: Tensor of shape (...) with dtype=torch.uint8 (on MPS)
scale: Tensor (scalar)
orig_dtype: The target dtype (e.g. float16)
float8_dtype: The original float8 dtype (torch.float8_e4m3fn or torch.float8_e5m2)

Returns:
Tensor of shape (...) with dtype=orig_dtype
"""
lut = get_lut(float8_dtype, qdata.device)

# Use index_select or advanced indexing.
# Advanced indexing lut[qdata.long()] is generally efficient.
# We explicitly cast to long (int64) for indexing.
# Note: Flattening might be slightly faster depending on shape, but simple indexing is safest.

# We want the LUT to be in the target orig_dtype (likely float16 or bfloat16)
if lut.dtype != orig_dtype:
lut = lut.to(dtype=orig_dtype)

output = lut[qdata.long()]

# Apply scale
# Scale might need to be cast to orig_dtype too
if isinstance(scale, torch.Tensor):
scale = scale.to(dtype=orig_dtype)

output.mul_(scale)
return output
22 changes: 20 additions & 2 deletions comfy/quant_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from typing import Tuple, Dict
import comfy.float
import comfy.mps_ops

_LAYOUT_REGISTRY = {}
_GENERIC_UTILS = {}
Expand Down Expand Up @@ -269,8 +270,18 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=

if target_device != current_device:
logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}")
new_q_data = qt._qdata.to(device=target_device)
new_params = _move_layout_params_to_device(qt._layout_params, target_device)

# MPS Hack: Convert Float8 to Uint8 before moving if native float8 is unsupported
qdata = qt._qdata
layout_params = qt._layout_params.copy()

if target_device.type == "mps" and qdata.element_size() == 1 and qdata.is_floating_point(): # Catch float8
layout_params["mps_float8_dtype"] = qdata.dtype
qdata = qdata.view(torch.uint8)

new_q_data = qdata.to(device=target_device)
new_params = _move_layout_params_to_device(layout_params, target_device)

if target_dtype is not None:
new_params["orig_dtype"] = target_dtype
new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params)
Expand Down Expand Up @@ -431,6 +442,13 @@ def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_roun

@staticmethod
def dequantize(qdata, scale, orig_dtype, **kwargs):
if qdata.device.type == "mps":
if qdata.dtype == torch.uint8:
return comfy.mps_ops.mps_dequantize(qdata, scale, orig_dtype, kwargs.get("mps_float8_dtype", torch.float8_e4m3fn))
elif qdata.is_floating_point() and qdata.element_size() == 1:
# It is MPS Float8. View as uint8.
return comfy.mps_ops.mps_dequantize(qdata.view(torch.uint8), scale, orig_dtype, qdata.dtype)

plain_tensor = torch.ops.aten._to_copy.default(qdata, dtype=orig_dtype)
plain_tensor.mul_(scale)
return plain_tensor
Expand Down