diff --git a/comfy/controlnet.py b/comfy/controlnet.py index 5921e6b1d19..6d37aa74f69 100644 --- a/comfy/controlnet.py +++ b/comfy/controlnet.py @@ -208,7 +208,7 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True, def forward(self, input): if self.up is not None: - return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias) + return torch.nn.functional.linear(input, self.weight.to(input.dtype).to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias) else: return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias) @@ -247,7 +247,7 @@ def __init__( def forward(self, input): if self.up is not None: - return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups) + return torch.nn.functional.conv2d(input, self.weight.to(input.dtype).to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups) else: return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups) diff --git a/comfy/model_management.py b/comfy/model_management.py index 8e5b7df853f..9e06b40497c 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -444,6 +444,13 @@ def dtype_size(dtype): dtype_size = 4 if dtype == torch.float16 or dtype == torch.bfloat16: dtype_size = 2 + elif dtype == torch.float32: + dtype_size = 4 + else: + try: + dtype_size = dtype.itemsize + except: #Old pytorch doesn't have .itemsize + pass return dtype_size def unet_offload_device(): diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 58acb97fce7..4e9f6bffe01 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -84,12 +84,16 @@ def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_le self.inner_name = inner_name if dtype is not None: - self.transformer.to(dtype) inner_model = getattr(self.transformer, self.inner_name) if hasattr(inner_model, "embeddings"): - inner_model.embeddings.to(torch.float32) + embeddings_bak = inner_model.embeddings.to(torch.float32) + inner_model.embeddings = None + self.transformer.to(dtype) + inner_model.embeddings = embeddings_bak else: - self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32)) + previous_inputs = self.transformer.get_input_embeddings().to(torch.float32, copy=True) + self.transformer.to(dtype) + self.transformer.set_input_embeddings(previous_inputs) self.max_length = max_length if freeze: