forked from comfyanonymous/ComfyUI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'master' into vae-fallback-cpu
- Loading branch information
Showing
64 changed files
with
8,469 additions
and
2,873 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
name: Tests CI | ||
|
||
on: [push, pull_request] | ||
|
||
jobs: | ||
test: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- uses: actions/setup-node@v3 | ||
with: | ||
node-version: 18 | ||
- uses: actions/setup-python@v4 | ||
with: | ||
python-version: '3.10' | ||
- name: Install requirements | ||
run: | | ||
python -m pip install --upgrade pip | ||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu | ||
pip install -r requirements.txt | ||
- name: Run Tests | ||
run: | | ||
npm ci | ||
npm run test:generate | ||
npm test | ||
working-directory: ./tests-ui |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
import enum | ||
import torch | ||
import math | ||
import comfy.utils | ||
|
||
|
||
def lcm(a, b): #TODO: eventually replace by math.lcm (added in python3.9) | ||
return abs(a*b) // math.gcd(a, b) | ||
|
||
class CONDRegular: | ||
def __init__(self, cond): | ||
self.cond = cond | ||
|
||
def _copy_with(self, cond): | ||
return self.__class__(cond) | ||
|
||
def process_cond(self, batch_size, device, **kwargs): | ||
return self._copy_with(comfy.utils.repeat_to_batch_size(self.cond, batch_size).to(device)) | ||
|
||
def can_concat(self, other): | ||
if self.cond.shape != other.cond.shape: | ||
return False | ||
return True | ||
|
||
def concat(self, others): | ||
conds = [self.cond] | ||
for x in others: | ||
conds.append(x.cond) | ||
return torch.cat(conds) | ||
|
||
class CONDNoiseShape(CONDRegular): | ||
def process_cond(self, batch_size, device, area, **kwargs): | ||
data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]] | ||
return self._copy_with(comfy.utils.repeat_to_batch_size(data, batch_size).to(device)) | ||
|
||
|
||
class CONDCrossAttn(CONDRegular): | ||
def can_concat(self, other): | ||
s1 = self.cond.shape | ||
s2 = other.cond.shape | ||
if s1 != s2: | ||
if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen | ||
return False | ||
|
||
mult_min = lcm(s1[1], s2[1]) | ||
diff = mult_min // min(s1[1], s2[1]) | ||
if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much | ||
return False | ||
return True | ||
|
||
def concat(self, others): | ||
conds = [self.cond] | ||
crossattn_max_len = self.cond.shape[1] | ||
for x in others: | ||
c = x.cond | ||
crossattn_max_len = lcm(crossattn_max_len, c.shape[1]) | ||
conds.append(c) | ||
|
||
out = [] | ||
for c in conds: | ||
if c.shape[1] < crossattn_max_len: | ||
c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result | ||
out.append(c) | ||
return torch.cat(out) | ||
|
||
class CONDConstant(CONDRegular): | ||
def __init__(self, cond): | ||
self.cond = cond | ||
|
||
def process_cond(self, batch_size, device, **kwargs): | ||
return self._copy_with(self.cond) | ||
|
||
def can_concat(self, other): | ||
if self.cond != other.cond: | ||
return False | ||
return True | ||
|
||
def concat(self, others): | ||
return self.cond |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.