diff --git a/prepare.py b/prepare.py index 9236549..0635236 100644 --- a/prepare.py +++ b/prepare.py @@ -15,7 +15,7 @@ def process_prompt_data(index, batch_start, prompt_embed, output_path): prompt_dict = { "prompt_embeds": prompt_embed, } - np.savez(output_path / f"{batch_start+index}.npz", **prompt_dict) + np.save(output_path / f"{batch_start+index}.npy", **prompt_dict) return index