Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
dan-garvey committed Mar 19, 2024
1 parent 7bcb42a commit 9593f69
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
21 changes: 14 additions & 7 deletions models/turbine_models/custom_models/sd_inference/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,15 @@ def export_clip_model(
):
if "google/t5" in hf_model_name:
from transformers import T5Tokenizer, T5Model

tokenizer = T5Tokenizer.from_pretrained(hf_model_name)
text_encoder_model = T5Model.from_pretrained(hf_model_name)

else:
#TODO: Add better filtering mechanism for things that require CLIPProcessor
# TODO: Add better filtering mechanism for things that require CLIPProcessor
if hf_model_name == "openai/clip-vit-large-patch14":
tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
hf_subfolder = "" # CLIPProcessor does not have a subfolder
hf_subfolder = "" # CLIPProcessor does not have a subfolder
else:
# Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = CLIPTokenizer.from_pretrained(
Expand All @@ -80,20 +81,19 @@ def export_clip_model(
)
hf_subfolder = "text_encoder"


text_encoder_model = CLIPTextModel.from_pretrained(
hf_model_name,
subfolder=hf_subfolder,
token=hf_auth_token,
)


mapper = {}
utils.save_external_weights(
mapper, text_encoder_model, external_weights, external_weight_path
)

if "google/t5" in hf_model_name:

class CompiledClip(CompiledModule):
if external_weights:
params = export_parameters(
Expand All @@ -105,10 +105,17 @@ class CompiledClip(CompiledModule):
else:
params = export_parameters(text_encoder_model)

def main(self, inp=AbstractTensor(1, 77, dtype=torch.int64),
decoder_input_ids=AbstractTensor(1, 77, dtype=torch.int64)):
return jittable(text_encoder_model.forward)(input_ids=inp, decoder_input_ids=decoder_input_ids)
def main(
self,
inp=AbstractTensor(1, 77, dtype=torch.int64),
decoder_input_ids=AbstractTensor(1, 77, dtype=torch.int64),
):
return jittable(text_encoder_model.forward)(
input_ids=inp, decoder_input_ids=decoder_input_ids
)

else:

class CompiledClip(CompiledModule):
if external_weights:
params = export_parameters(
Expand Down
1 change: 0 additions & 1 deletion models/turbine_models/tests/sd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def testExportClipVitLarge14(self):
)
self.assertEqual(cm.exception.code, None)


def testExportClipModel(self):
upload_ir_var = os.environ.get("TURBINE_TANK_ACTION", "not_upload")
with self.assertRaises(SystemExit) as cm:
Expand Down

0 comments on commit 9593f69

Please sign in to comment.