Skip to content

Commit

Permalink
Move load_from_folder() to loading.py
Browse files Browse the repository at this point in the history
  • Loading branch information
bradyneal committed Mar 25, 2021
1 parent a1f2365 commit 3f2e781
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 42 deletions.
44 changes: 44 additions & 0 deletions loading.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import os
import json
import zipfile
from pathlib import Path

import torch
from addict import Dict

from train_generator import get_args, main


Expand All @@ -27,3 +32,42 @@ def load_gen(saveroot='save', dataroot=None):
for state_dict, net in zip(state_dicts, model.networks):
net.load_state_dict(state_dict)
return model, args


def load_from_folder(dataset, checkpoint_dir="./GenModelCkpts"):
checkpoint_dir = Path(checkpoint_dir).resolve()
dataset_roots = os.listdir(checkpoint_dir)
dataset_stem = dataset.split('_')[0]
subdata_stem = dataset.split('_')[-1]

assert dataset_stem in dataset_roots
subdatasets = os.listdir(checkpoint_dir / dataset_stem)
assert subdata_stem in subdatasets

subdata_path = checkpoint_dir / Path(dataset_stem) / Path(subdata_stem)
# Check if unzipping is necessary
if (
len(os.listdir(subdata_path)) == 1
and ".zip" in os.listdir(subdata_path)[0]
):
zip_name = os.listdir(subdata_path)[0]
zip_path = subdata_path / zip_name
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(subdata_path)

subfolders = [f.path for f in os.scandir(subdata_path) if f.is_dir()]
assert len(subfolders) == 1

model_folder = subdata_path / Path(subfolders[0])

with open(model_folder / "args.txt") as f:
args = Dict(json.load(f))

args.saveroot = model_folder
args.dataroot = "./datasets/"
args.comet = False

# Now load model
model, args = load_gen(saveroot=str(args.saveroot), dataroot="./datasets")

return model, args
2 changes: 1 addition & 1 deletion make_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path
import time

from run_metrics import load_from_folder
from loading import load_from_folder

from data.lalonde import load_lalonde
from data.twins import load_twins
Expand Down
40 changes: 0 additions & 40 deletions run_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from loading import load_gen
import numpy as np
from collections import OrderedDict
import time
from tqdm import tqdm


Expand Down Expand Up @@ -65,45 +64,6 @@ def get_multivariate_results(model, include_w, num_tests=100, n=1000):
return summary


def load_from_folder(dataset, checkpoint_dir="./GenModelCkpts"):
checkpoint_dir = Path(checkpoint_dir).resolve()
dataset_roots = os.listdir(checkpoint_dir)
dataset_stem = dataset.split('_')[0]
subdata_stem = dataset.split('_')[-1]

assert dataset_stem in dataset_roots
subdatasets = os.listdir(checkpoint_dir / dataset_stem)
assert subdata_stem in subdatasets

subdata_path = checkpoint_dir / Path(dataset_stem) / Path(subdata_stem)
# Check if unzipping is necessary
if (
len(os.listdir(subdata_path)) == 1
and ".zip" in os.listdir(subdata_path)[0]
):
zip_name = os.listdir(subdata_path)[0]
zip_path = subdata_path / zip_name
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(subdata_path)

subfolders = [f.path for f in os.scandir(subdata_path) if f.is_dir()]
assert len(subfolders) == 1

model_folder = subdata_path / Path(subfolders[0])

with open(model_folder / "args.txt") as f:
args = Dict(json.load(f))

args.saveroot = model_folder
args.dataroot = "./datasets/"
args.comet = False

# Now load model
model, args = load_gen(saveroot=str(args.saveroot), dataroot="./datasets")

return model, args


def evaluate_directory(
checkpoint_dir="./GenModelCkpts",
# checkpoint_dir="./LinearModelCkpts",
Expand Down
2 changes: 1 addition & 1 deletion uai_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from causal_estimators.standardization_estimator import \
StandardizationEstimator, StratifiedStandardizationEstimator
from evaluation import run_model_cv
from run_metrics import load_from_folder
from loading import load_from_folder

from sklearn.linear_model import LogisticRegression, LinearRegression, Lasso, Ridge, ElasticNet, RidgeClassifier
from sklearn.svm import SVR, LinearSVR, SVC, LinearSVC
Expand Down

0 comments on commit 3f2e781

Please sign in to comment.