Skip to content

Conversation

@macpaul
Copy link

@macpaul macpaul commented Jan 8, 2026

Issue: #11626

Main solution is feat(mps): implement native-like Float8 support via LUT dequantization.
However, it encountered some merge conflicts when master branch was updated to 0.8.0 and later.
Hence, I've added a bunch of fixes to the errors I've encountered. Please check if they are adapteable to master branch.

Signed-off-by: Macpaul Lin macpaul@gmail.com

Add a new MPS-specific operations module to handle Float8 tensor support
on Apple Silicon. Since MPS does not natively support Float8 dtypes, this
implementation uses a uint8 storage strategy combined with a GPU-accelerated
Lookup Table (LUT) for efficient dequantization, keeping data on the GPU.

- Add comfy/mps_ops.py: Implement cached LUT generation and index-based
  dequantization for MPS.
- Modify comfy/quant_ops.py: Add logic to view Float8 tensors as uint8
  when moving to MPS, and route dequantization to mps_ops.
- Modify comfy/float.py: Add CPU staging for stochastic rounding to
  prevent MPS casting errors during quantization.
- Modify comfy/quant_ops.py: Add fallback for fp8_linear.

Signed-off-by: Macpaul Lin <macpaul@gmail.com>
…ng errors

Signed-off-by: Macpaul Lin <macpaul@gmail.com>
…edTensor

Signed-off-by: Macpaul Lin <macpaul@gmail.com>
…Tensor

Signed-off-by: Macpaul Lin <macpaul@gmail.com>
…edTensor

Signed-off-by: Macpaul Lin <macpaul@gmail.com>
…pe to prevent precision mismatch RuntimeErrors

Signed-off-by: Macpaul Lin <macpaul@gmail.com>
…ike for mock QuantizedTensor

Signed-off-by: Macpaul Lin <macpaul@gmail.com>
…r QuantizedTensor

Signed-off-by: Macpaul Lin <macpaul@gmail.com>
@rattus128 rattus128 added the MacOS MPS device related issues label Jan 11, 2026
@tushar9989
Copy link

This is leading to a crash when running video_wan2_2_5B_ti2v @macpaul

got prompt
Using split attention in VAE
Using split attention in VAE
VAE load device: mps, offload device: cpu, dtype: torch.bfloat16
Found quantization metadata version 1
Using MixedPrecisionOps for text encoder
Requested to load WanTEModel
loaded completely;  6419.48 MB loaded, full load: True
CLIP/text encoder model load device: cpu, offload device: cpu, current: cpu, dtype: torch.float16
!!! Exception during processing !!! _TensorCoreFP8LayoutBase.dequantize() missing 1 required positional argument: 'orig_dtype'
Traceback (most recent call last):
  File "/Users/tushardudani/comfyui/ComfyUI/execution.py", line 518, in execute
    output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
                                                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tushardudani/comfyui/ComfyUI/execution.py", line 329, in get_output_data
    return_values = await _async_map_node_over_list(prompt_id, unique_id, obj, input_data_all, obj.FUNCTION, allow_interrupt=True, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tushardudani/comfyui/ComfyUI/execution.py", line 303, in _async_map_node_over_list
    await process_inputs(input_dict, i)
  File "/Users/tushardudani/comfyui/ComfyUI/execution.py", line 291, in process_inputs
    result = f(**inputs)
  File "/Users/tushardudani/comfyui/ComfyUI/nodes.py", line 77, in encode
    return (clip.encode_from_tokens_scheduled(tokens), )
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^
  File "/Users/tushardudani/comfyui/ComfyUI/comfy/sd.py", line 207, in encode_from_tokens_scheduled
    pooled_dict = self.encode_from_tokens(tokens, return_pooled=return_pooled, return_dict=True)
  File "/Users/tushardudani/comfyui/ComfyUI/comfy/sd.py", line 271, in encode_from_tokens
    o = self.cond_stage_model.encode_token_weights(tokens)
  File "/Users/tushardudani/comfyui/ComfyUI/comfy/sd1_clip.py", line 704, in encode_token_weights
    out = getattr(self, self.clip).encode_token_weights(token_weight_pairs)
  File "/Users/tushardudani/comfyui/ComfyUI/comfy/sd1_clip.py", line 45, in encode_token_weights
    o = self.encode(to_encode)
  File "/Users/tushardudani/comfyui/ComfyUI/comfy/sd1_clip.py", line 297, in encode
    return self(tokens)
  File "/Users/tushardudani/miniconda3/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/Users/tushardudani/miniconda3/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/tushardudani/comfyui/ComfyUI/comfy/sd1_clip.py", line 270, in forward
    outputs = self.transformer(None, attention_mask_model, embeds=embeds, num_tokens=num_tokens, intermediate_output=intermediate_output, final_layer_norm_intermediate=self.layer_norm_hidden_state, dtype=torch.float32, embeds_info=embeds_info)
  File "/Users/tushardudani/miniconda3/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/Users/tushardudani/miniconda3/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/tushardudani/comfyui/ComfyUI/comfy/text_encoders/t5.py", line 249, in forward
    return self.encoder(x, attention_mask=attention_mask, **kwargs)
           ~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tushardudani/miniconda3/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/Users/tushardudani/miniconda3/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/tushardudani/comfyui/ComfyUI/comfy/text_encoders/t5.py", line 217, in forward
    x, past_bias = l(x, mask, past_bias, optimized_attention)
                   ~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tushardudani/miniconda3/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/Users/tushardudani/miniconda3/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/tushardudani/comfyui/ComfyUI/comfy/text_encoders/t5.py", line 188, in forward
    x, past_bias = self.layer[0](x, mask, past_bias, optimized_attention)
                   ~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tushardudani/miniconda3/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/Users/tushardudani/miniconda3/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/tushardudani/comfyui/ComfyUI/comfy/text_encoders/t5.py", line 175, in forward
    output, past_bias = self.SelfAttention(self.layer_norm(x), mask=mask, past_bias=past_bias, optimized_attention=optimized_attention)
                        ~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tushardudani/miniconda3/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/Users/tushardudani/miniconda3/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/tushardudani/comfyui/ComfyUI/comfy/text_encoders/t5.py", line 152, in forward
    q = self.q(x)
  File "/Users/tushardudani/miniconda3/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
  File "/Users/tushardudani/miniconda3/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
  File "/Users/tushardudani/comfyui/ComfyUI/comfy/ops.py", line 676, in forward
    output = self.forward_comfy_cast_weights(input)
  File "/Users/tushardudani/comfyui/ComfyUI/comfy/ops.py", line 648, in forward_comfy_cast_weights
    weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
                                   ~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/tushardudani/comfyui/ComfyUI/comfy/ops.py", line 119, in cast_bias_weight
    weight = weight.dequantize()
  File "/Users/tushardudani/miniconda3/lib/python3.13/site-packages/comfy_kitchen/tensor/base.py", line 286, in dequantize
    full = self.layout_cls.dequantize(qdata, self._params)
TypeError: _TensorCoreFP8LayoutBase.dequantize() missing 1 required positional argument: 'orig_dtype'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

MacOS MPS device related issues

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants