Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Mar 19, 2024
1 parent ab8cf84 commit 8b93eaa
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,11 @@ def run_torch_clip(hf_model_name, hf_auth_token, prompt):
else:
if hf_model_name == "openai/clip-vit-large-patch14":
from transformers import CLIPProcessor

tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
hf_subfolder = "" # CLIPProcessor does not have a subfolder
else:
hf_subfolder="text_encoder"
hf_subfolder = "text_encoder"

tokenizer = CLIPTokenizer.from_pretrained(
hf_model_name,
Expand All @@ -93,6 +94,7 @@ def run_torch_clip(hf_model_name, hf_auth_token, prompt):
)

from transformers import CLIPTextModel

model = CLIPTextModel.from_pretrained(
hf_model_name,
subfolder=hf_subfolder,
Expand Down
20 changes: 10 additions & 10 deletions models/turbine_models/tests/sd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@

class StableDiffusionTest(unittest.TestCase):
def testExportT5Model(self):
arguments["hf_model_name"]="google/t5-v1_1-small",
arguments["hf_model_name"] = "google/t5-v1_1-small"
upload_ir_var = os.environ.get("TURBINE_TANK_ACTION", "not_upload")
safe_prefix = "google_t5-v1_1-small"
with self.assertRaises(SystemExit) as cm:
clip.export_clip_model(
hf_model_name="google/t5-v1_1-small",
hf_model_name=arguments["hf_model_name"],
hf_auth_token=None,
compile_to="vmfb",
external_weights=None,
Expand All @@ -84,7 +84,7 @@ def testExportT5Model(self):
upload_ir=upload_ir_var == "upload",
)
self.assertEqual(cm.exception.code, None)
arguments["vmfb_path"] = safe_prefix+".vmfb"
arguments["vmfb_path"] = safe_prefix + ".vmfb"
turbine = clip_runner.run_clip(
arguments["device"],
arguments["prompt"],
Expand All @@ -98,10 +98,10 @@ def testExportT5Model(self):
)
err = utils.largest_error(torch_output, turbine[0])
assert err < 9e-5
os.remove(safe_prefix+".vmfb")
os.remove(safe_prefix + ".vmfb")

def testExportClipVitLarge14(self):
arguments["hf_model_name"]="openai/clip-vit-large-patch14",
arguments["hf_model_name"] = "openai/clip-vit-large-patch14"
safe_prefix = "openai_clip-vit-large-patch14"
upload_ir_var = os.environ.get("TURBINE_TANK_ACTION", "not_upload")
with self.assertRaises(SystemExit) as cm:
Expand All @@ -110,15 +110,15 @@ def testExportClipVitLarge14(self):
hf_auth_token=None,
compile_to="vmfb",
external_weights="safetensors",
external_weight_path=safe_prefix+".safetensors",
external_weight_path=safe_prefix + ".safetensors",
device="cpu",
target_triple=None,
max_alloc=None,
upload_ir=upload_ir_var == "upload",
)
self.assertEqual(cm.exception.code, None)
arguments["external_weight_path"] = safe_prefix+".safetensors"
arguments["vmfb_path"] = safe_prefix+".vmfb"
arguments["external_weight_path"] = safe_prefix + ".safetensors"
arguments["vmfb_path"] = safe_prefix + ".vmfb"
turbine = clip_runner.run_clip(
arguments["device"],
arguments["prompt"],
Expand All @@ -132,8 +132,8 @@ def testExportClipVitLarge14(self):
)
err = utils.largest_error(torch_output, turbine[0])
assert err < 9e-5
os.remove(safe_prefix+".safetensors")
os.remove(safe_prefix+".vmfb")
os.remove(safe_prefix + ".safetensors")
os.remove(safe_prefix + ".vmfb")

def testExportClipModel(self):
upload_ir_var = os.environ.get("TURBINE_TANK_ACTION", "not_upload")
Expand Down

0 comments on commit 8b93eaa

Please sign in to comment.