Skip to content

Commit

Permalink
Add RealCause pre-computed dataset loading function
Browse files Browse the repository at this point in the history
  • Loading branch information
bradyneal committed Mar 25, 2021
1 parent 3f2e781 commit 103b950
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
3 changes: 3 additions & 0 deletions consts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
REALCAUSE_DATASETS_FOLDER = 'realcause_datasets'
N_SAMPLE_SEEDS = 100
N_AGG_SEEDS = 100
18 changes: 18 additions & 0 deletions loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,29 @@
import json
import zipfile
from pathlib import Path
import pandas as pd

import torch
from addict import Dict

from train_generator import get_args, main
from consts import REALCAUSE_DATASETS_FOLDER, N_SAMPLE_SEEDS


def load_realcause_dataset(dataset, sample=0):
valid_datasets = {'lalonde_cps', 'lalonde_psid', 'twins'}
dataset = dataset.lower()
if dataset not in valid_datasets:
raise ValueError('Invalid dataset "{}" ... Valid datasets: {}'
.format(dataset, valid_datasets))
if not isinstance(sample, int):
raise ValueError('sample must be an integer')
if 0 < sample >= N_SAMPLE_SEEDS:
raise ValueError('sample must be between 0 and {}'
.format(N_SAMPLE_SEEDS - 1))

dataset_file = Path(REALCAUSE_DATASETS_FOLDER) / '{}_sample{}.csv'.format(dataset, sample)
return pd.read_csv(dataset_file)


def load_gen(saveroot='save', dataroot=None):
Expand Down
5 changes: 2 additions & 3 deletions make_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

from data.lalonde import load_lalonde
from data.twins import load_twins
from consts import REALCAUSE_DATASETS_FOLDER, N_SAMPLE_SEEDS, N_AGG_SEEDS

FOLDER = Path('realcause_datasets')
FOLDER = Path(REALCAUSE_DATASETS_FOLDER)
FOLDER.mkdir(parents=True, exist_ok=True)

psid_gen_model, args = load_from_folder(dataset='lalonde_psid1')
Expand All @@ -25,8 +26,6 @@
w_dfs = [psid_w, cps_w, twins_w]
names = ['lalonde_psid', 'lalonde_cps', 'twins']

N_SAMPLE_SEEDS = 100
N_AGG_SEEDS = 100
dfs = []
print('N samples:', N_SAMPLE_SEEDS)
print('N seeds per sample:', N_AGG_SEEDS)
Expand Down

0 comments on commit 103b950

Please sign in to comment.