Skip to content

Commit

Permalink
adding load gen and minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
CW-Huang committed Sep 9, 2020
1 parent aca97ab commit a984588
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 15 deletions.
7 changes: 4 additions & 3 deletions helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ def create(*args):


class Logging:
def __init__(self, saveroot, filename='log.txt'):
def __init__(self, saveroot, filename='log.txt', log_=True):
self.log_path = os.path.join(saveroot, filename)
self.log_ = log_

def info(self, s, print_=True, log_=True):
def info(self, s, print_=True):
if print_:
print(f'{datetime.now()} / {s}')
if log_:
if self.log_:
with open(self.log_path, 'a+') as f_log:
f_log.write(f'{datetime.now()} / {s} \n')
19 changes: 19 additions & 0 deletions load_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os
import json
import torch
from train_generator import get_args, main


def load_gen(saveroot='save'):
args = get_args()
args_path = os.path.join(saveroot, 'args.txt')
args.__dict__.update(json.load(open(args_path, 'r')))
print(args)

args.train = False
args.eval = False
model = main(args, False, False)
state_dicts = torch.load(model.savepath)
for state_dict, net in zip(state_dicts, model.networks):
net.load_state_dict(state_dict)
return model, args
8 changes: 5 additions & 3 deletions models/distributions/flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def sigmoid_flow_integral(x, ndim=4, params=None):
return x_pre


def sigmoid_flow_inverse(y, ndim=4, params=None, logit_end=True, x=None, tol=1e-2, max_iter=100, lr=0.1):
def sigmoid_flow_inverse(y, ndim=4, params=None, logit_end=True, x=None, tol=1e-2, max_iter=100, lr=0.1, verbose=False):
if logit_end:
y = torch.sigmoid(y)
if x is None:
Expand All @@ -72,12 +72,14 @@ def closure():
optimizer.step(closure)

error_new = (sigmoid_flow(x, 0, ndim=ndim, params=params, logit_end=False)[0] - y).abs().max().item()
print('inversion error', error_new)
if verbose:
print('inversion error', error_new)
torch.cuda.empty_cache()
gc.collect()

if error_new > error_old:
print('learning rate too large for inversion')
if verbose:
print('learning rate too large for inversion')
return sigmoid_flow_inverse(y, ndim=ndim, params=params, logit_end=False, x=x)
else:
return x
Expand Down
1 change: 1 addition & 0 deletions models/nonlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def __init__(self, w, t, y, seed=1,
batch_size=training_params.batch_size,
shuffle=True)

self.best_val_loss = float('inf')

def _matricize(self, data):
return [np.reshape(d, [d.shape[0], -1]) for d in data]
Expand Down
21 changes: 12 additions & 9 deletions train_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,15 @@ def evaluate(args, model):
return summary, all_runs


def main(args):
def main(args, save_args=True, log_=True):
helpers.create(*args.saveroot.split('/'))
logger = helpers.Logging(args.saveroot, 'log.txt')
logger = helpers.Logging(args.saveroot, 'log.txt', log_)
logger.info(args)

# save args
with open(os.path.join(args.saveroot, 'args.txt'), 'w') as file:
file.write(json.dumps(args.__dict__, indent=4))
if save_args:
with open(os.path.join(args.saveroot, 'args.txt'), 'w') as file:
file.write(json.dumps(args.__dict__, indent=4))

# dataset
logger.info(f'getting data: {args.data}')
Expand Down Expand Up @@ -168,9 +169,8 @@ def main(args):
return model


if __name__ == "__main__":

parser = argparse.ArgumentParser()
def get_args():
parser = argparse.ArgumentParser(description='causal-gen')

# dataset
parser.add_argument('--data', type=str, default='lalonde') # TODO: fix choices
Expand Down Expand Up @@ -212,5 +212,8 @@ def main(args):
# evaluation
parser.add_argument('--num_univariate_tests', type=int, default=100)

arguments = parser.parse_args()
main(arguments)
return parser.parse_args()


if __name__ == "__main__":
main(get_args())

0 comments on commit a984588

Please sign in to comment.