Skip to content

Commit

Permalink
Pick attention implementation based on device in llama code.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Dec 18, 2024
1 parent ca457f7 commit a4f59bc
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion comfy/text_encoders/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import dataclass
from typing import Optional, Any

from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.modules.attention import optimized_attention_for_device
import comfy.model_management
import comfy.ldm.common_dit

Expand Down Expand Up @@ -81,6 +81,7 @@ def forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
):
batch_size, seq_length, _ = hidden_states.shape

Expand Down Expand Up @@ -124,6 +125,7 @@ def forward(
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
optimized_attention=None,
):
# Self Attention
residual = x
Expand All @@ -132,6 +134,7 @@ def forward(
hidden_states=x,
attention_mask=attention_mask,
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
)
x = residual + x

Expand Down Expand Up @@ -180,6 +183,7 @@ def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_
mask += causal_mask
else:
mask = causal_mask
optimized_attention = optimized_attention_for_device(x.device, mask=mask is not None, small_input=True)

intermediate = None
if intermediate_output is not None:
Expand All @@ -191,6 +195,7 @@ def forward(self, x, attention_mask=None, intermediate_output=None, final_layer_
x=x,
attention_mask=mask,
freqs_cis=freqs_cis,
optimized_attention=optimized_attention,
)
if i == intermediate_output:
intermediate = x.clone()
Expand Down

0 comments on commit a4f59bc

Please sign in to comment.