Skip to content

Commit

Permalink
Merge branch 'master' into beta
Browse files Browse the repository at this point in the history
  • Loading branch information
jn-jairo committed Dec 4, 2023
2 parents 77044e0 + 26b1c0a commit 7196dba
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
4 changes: 2 additions & 2 deletions comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True,

def forward(self, input):
if self.up is not None:
return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
return torch.nn.functional.linear(input, self.weight.to(input.dtype).to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
else:
return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias)

Expand Down Expand Up @@ -247,7 +247,7 @@ def __init__(

def forward(self, input):
if self.up is not None:
return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
return torch.nn.functional.conv2d(input, self.weight.to(input.dtype).to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
else:
return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups)

Expand Down
7 changes: 7 additions & 0 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,13 @@ def dtype_size(dtype):
dtype_size = 4
if dtype == torch.float16 or dtype == torch.bfloat16:
dtype_size = 2
elif dtype == torch.float32:
dtype_size = 4
else:
try:
dtype_size = dtype.itemsize
except: #Old pytorch doesn't have .itemsize
pass
return dtype_size

def unet_offload_device():
Expand Down
10 changes: 7 additions & 3 deletions comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,16 @@ def __init__(self, version="openai/clip-vit-large-patch14", device="cpu", max_le

self.inner_name = inner_name
if dtype is not None:
self.transformer.to(dtype)
inner_model = getattr(self.transformer, self.inner_name)
if hasattr(inner_model, "embeddings"):
inner_model.embeddings.to(torch.float32)
embeddings_bak = inner_model.embeddings.to(torch.float32)
inner_model.embeddings = None
self.transformer.to(dtype)
inner_model.embeddings = embeddings_bak
else:
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32))
previous_inputs = self.transformer.get_input_embeddings().to(torch.float32, copy=True)
self.transformer.to(dtype)
self.transformer.set_input_embeddings(previous_inputs)

self.max_length = max_length
if freeze:
Expand Down

0 comments on commit 7196dba

Please sign in to comment.