Skip to content

Commit

Permalink
Merge branch 'master' into beta
Browse files Browse the repository at this point in the history
  • Loading branch information
jn-jairo committed Dec 11, 2023
2 parents 242406c + 5792663 commit c05e5dd
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 29 deletions.
3 changes: 3 additions & 0 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,9 @@ def text_encoder_dtype(device=None):
elif args.fp32_text_enc:
return torch.float32

if is_device_cpu(device):
return torch.float16

if should_use_fp16(device, prioritize_performance=False):
return torch.float16
else:
Expand Down
33 changes: 33 additions & 0 deletions comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,39 @@ def conv_nd(dims, *args, **kwargs):
else:
raise ValueError(f"unsupported dimensions: {dims}")

def cast_bias_weight(s, input):
bias = None
if s.bias is not None:
bias = s.bias.to(device=input.device, dtype=input.dtype)
weight = s.weight.to(device=input.device, dtype=input.dtype)
return weight, bias

class manual_cast:
class Linear(Linear):
def forward(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.linear(input, weight, bias)

class Conv2d(Conv2d):
def forward(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)

class Conv3d(Conv3d):
def forward(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)

class GroupNorm(GroupNorm):
def forward(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps)

class LayerNorm(LayerNorm):
def forward(self, input):
weight, bias = cast_bias_weight(self, input)
return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps)

@contextmanager
def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way
old_torch_nn_linear = torch.nn.Linear
Expand Down
52 changes: 23 additions & 29 deletions comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_le
with open(textmodel_json_config) as f:
config = json.load(f)

self.transformer = model_class(config, dtype, device, comfy.ops)
self.transformer = model_class(config, dtype, device, comfy.ops.manual_cast)
self.num_layers = self.transformer.num_layers

self.max_length = max_length
Expand Down Expand Up @@ -160,37 +160,31 @@ def forward(self, tokens):
tokens = self.set_up_textual_embeddings(tokens, backup_embeds)
tokens = torch.LongTensor(tokens).to(device)

if self.transformer.dtype != torch.float32:
precision_scope = torch.autocast
attention_mask = None
if self.enable_attention_masks:
attention_mask = torch.zeros_like(tokens)
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1
if tokens[x, y] == max_token:
break

outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
self.transformer.set_input_embeddings(backup_embeds)

if self.layer == "last":
z = outputs[0]
else:
precision_scope = lambda a, dtype: contextlib.nullcontext(a)

with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32):
attention_mask = None
if self.enable_attention_masks:
attention_mask = torch.zeros_like(tokens)
max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1
for x in range(attention_mask.shape[0]):
for y in range(attention_mask.shape[1]):
attention_mask[x, y] = 1
if tokens[x, y] == max_token:
break

outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state)
self.transformer.set_input_embeddings(backup_embeds)

if self.layer == "last":
z = outputs[0]
else:
z = outputs[1]
z = outputs[1]

if outputs[2] is not None:
pooled_output = outputs[2].float()
else:
pooled_output = None
if outputs[2] is not None:
pooled_output = outputs[2].float()
else:
pooled_output = None

if self.text_projection is not None and pooled_output is not None:
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
if self.text_projection is not None and pooled_output is not None:
pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float()
return z.float(), pooled_output

def encode(self, tokens):
Expand Down

0 comments on commit c05e5dd

Please sign in to comment.