Skip to content

Commit

Permalink
Fix get_data() data naming bug introduced in recent commit
Browse files Browse the repository at this point in the history
  • Loading branch information
bradyneal committed Mar 27, 2021
1 parent 1a3f003 commit a9e9469
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions train_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ def get_data(args):
data_name = args.data.lower()
ate = None
ites = None
if data_name == "lalonde" or data_name == "lalonde_psid":
if data_name == "lalonde" or data_name == "lalonde_psid" or data_name == "lalonde_psid1":
w, t, y = load_lalonde(obs_version="psid", dataroot=args.dataroot)
elif data_name == "lalonde_rct":
w, t, y = load_lalonde(rct=True, dataroot=args.dataroot)
elif data_name == "lalonde_cps":
elif data_name == "lalonde_cps" or data_name == "lalonde_cps1":
w, t, y = load_lalonde(obs_version="cps", dataroot=args.dataroot)
elif data_name.startswith("lbidd"):
# Valid string formats: lbidd_<link>_<n> and lbidd_<link>_<n>_counterfactual
Expand Down

0 comments on commit a9e9469

Please sign in to comment.