Skip to content

Conversation

@HyperExtendedReality
Copy link

@HyperExtendedReality HyperExtendedReality commented Jan 15, 2026

  • Add automatic detection and default to bfloat16 (or fp16 fallback) when no explicit dtype is provided, based on device capabilities
  • Respect provided dtype_llama/dtype consistently across Gemma model, projection layer, and connectors
  • Remove forced out.float() in encode_token_weights to prevent downgrading to fp32 after projection
  • This allows SageAttention's optimized kernel to run instead of falling back to PyTorch attention

Fixes the warning:
"Error running sage attention: Input tensors must be in dtype of torch.float16 or torch.bfloat16, using pytorch attention instead."

…ibility

- Add automatic detection and default to bfloat16 (or fp16 fallback) when no explicit dtype is provided, based on device capabilities
- Respect provided dtype_llama/dtype consistently across Gemma model, projection layer, and connectors
- Remove forced `out.float()` in encode_token_weights to prevent downgrading to fp32 after projection
- This allows SageAttention's optimized kernel to run instead of falling back to PyTorch attention

Fixes the warning:
"Error running sage attention: Input tensors must be in dtype of torch.float16 or torch.bfloat16, using pytorch attention instead."
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant