diff --git a/prepare.py b/prepare.py index 0635236..b326d6e 100644 --- a/prepare.py +++ b/prepare.py @@ -12,10 +12,7 @@ def process_prompt_data(index, batch_start, prompt_embed, output_path): - prompt_dict = { - "prompt_embeds": prompt_embed, - } - np.save(output_path / f"{batch_start+index}.npy", **prompt_dict) + np.save(output_path / f"{batch_start+index}.npy", prompt_embed) return index