diff --git a/comfy/model_management.py b/comfy/model_management.py index 0371754fd59..42e2ee7fd3f 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -517,6 +517,9 @@ def text_encoder_dtype(device=None): elif args.fp32_text_enc: return torch.float32 + if is_device_cpu(device): + return torch.float16 + if should_use_fp16(device, prioritize_performance=False): return torch.float16 else: diff --git a/comfy/ops.py b/comfy/ops.py index deb849d63c9..e48568409a1 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -29,6 +29,39 @@ def conv_nd(dims, *args, **kwargs): else: raise ValueError(f"unsupported dimensions: {dims}") +def cast_bias_weight(s, input): + bias = None + if s.bias is not None: + bias = s.bias.to(device=input.device, dtype=input.dtype) + weight = s.weight.to(device=input.device, dtype=input.dtype) + return weight, bias + +class manual_cast: + class Linear(Linear): + def forward(self, input): + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.linear(input, weight, bias) + + class Conv2d(Conv2d): + def forward(self, input): + weight, bias = cast_bias_weight(self, input) + return self._conv_forward(input, weight, bias) + + class Conv3d(Conv3d): + def forward(self, input): + weight, bias = cast_bias_weight(self, input) + return self._conv_forward(input, weight, bias) + + class GroupNorm(GroupNorm): + def forward(self, input): + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) + + class LayerNorm(LayerNorm): + def forward(self, input): + weight, bias = cast_bias_weight(self, input) + return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) + @contextmanager def use_comfy_ops(device=None, dtype=None): # Kind of an ugly hack but I can't think of a better way old_torch_nn_linear = torch.nn.Linear diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 4530168ab7a..6ffef515ede 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -78,7 +78,7 @@ def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_le with open(textmodel_json_config) as f: config = json.load(f) - self.transformer = model_class(config, dtype, device, comfy.ops) + self.transformer = model_class(config, dtype, device, comfy.ops.manual_cast) self.num_layers = self.transformer.num_layers self.max_length = max_length @@ -160,37 +160,31 @@ def forward(self, tokens): tokens = self.set_up_textual_embeddings(tokens, backup_embeds) tokens = torch.LongTensor(tokens).to(device) - if self.transformer.dtype != torch.float32: - precision_scope = torch.autocast + attention_mask = None + if self.enable_attention_masks: + attention_mask = torch.zeros_like(tokens) + max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1 + for x in range(attention_mask.shape[0]): + for y in range(attention_mask.shape[1]): + attention_mask[x, y] = 1 + if tokens[x, y] == max_token: + break + + outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state) + self.transformer.set_input_embeddings(backup_embeds) + + if self.layer == "last": + z = outputs[0] else: - precision_scope = lambda a, dtype: contextlib.nullcontext(a) - - with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32): - attention_mask = None - if self.enable_attention_masks: - attention_mask = torch.zeros_like(tokens) - max_token = self.transformer.get_input_embeddings().weight.shape[0] - 1 - for x in range(attention_mask.shape[0]): - for y in range(attention_mask.shape[1]): - attention_mask[x, y] = 1 - if tokens[x, y] == max_token: - break - - outputs = self.transformer(tokens, attention_mask, intermediate_output=self.layer_idx, final_layer_norm_intermediate=self.layer_norm_hidden_state) - self.transformer.set_input_embeddings(backup_embeds) - - if self.layer == "last": - z = outputs[0] - else: - z = outputs[1] + z = outputs[1] - if outputs[2] is not None: - pooled_output = outputs[2].float() - else: - pooled_output = None + if outputs[2] is not None: + pooled_output = outputs[2].float() + else: + pooled_output = None - if self.text_projection is not None and pooled_output is not None: - pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() + if self.text_projection is not None and pooled_output is not None: + pooled_output = pooled_output.float().to(self.text_projection.device) @ self.text_projection.float() return z.float(), pooled_output def encode(self, tokens):