Skip to content

Commit

Permalink
Adding model checkpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
51N84D committed Mar 29, 2021
1 parent b21ad84 commit 6e3d1c9
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
Binary file added GenModelCkpts.zip
Binary file not shown.
21 changes: 12 additions & 9 deletions loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,25 @@ def load_gen(saveroot='save', dataroot=None):
return model, args


def load_from_folder(dataset, checkpoint_dir="./GenModelCkpts"):
checkpoint_dir = Path(checkpoint_dir).resolve()
def load_from_folder(dataset, checkpoint_path="./GenModelCkpts.zip"):
checkpoint_path = Path(checkpoint_path).resolve()
root_dir = checkpoint_path.parent
checkpoint_dir = root_dir / checkpoint_path.stem
if not checkpoint_dir.is_dir():
with zipfile.ZipFile(checkpoint_path, "r") as zip_ref:
zip_ref.extractall()

dataset_roots = os.listdir(checkpoint_dir)
dataset_stem = dataset.split('_')[0]
subdata_stem = dataset.split('_')[-1]
dataset_stem = dataset.split("_")[0]
subdata_stem = dataset.split("_")[-1]

assert dataset_stem in dataset_roots
subdatasets = os.listdir(checkpoint_dir / dataset_stem)
assert subdata_stem in subdatasets

subdata_path = checkpoint_dir / Path(dataset_stem) / Path(subdata_stem)
# Check if unzipping is necessary
if (
len(os.listdir(subdata_path)) == 1
and ".zip" in os.listdir(subdata_path)[0]
):
if len(os.listdir(subdata_path)) == 1 and ".zip" in os.listdir(subdata_path)[0]:
zip_name = os.listdir(subdata_path)[0]
zip_path = subdata_path / zip_name
with zipfile.ZipFile(zip_path, "r") as zip_ref:
Expand All @@ -88,4 +91,4 @@ def load_from_folder(dataset, checkpoint_dir="./GenModelCkpts"):
# Now load model
model, args = load_gen(saveroot=str(args.saveroot), dataroot="./datasets")

return model, args
return model, args

0 comments on commit 6e3d1c9

Please sign in to comment.