Skip to content

Commit

Permalink
Control to Treatment Matching (#797)
Browse files Browse the repository at this point in the history
* Many to one matching without replacement.
* feat: first draft allowing control to treatment.
* fix: writing tests, fixing code.
* fix: formatting.
  • Loading branch information
spohngellert-o authored Oct 7, 2024
1 parent c4d7c60 commit 8e4a5bf
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 26 deletions.
58 changes: 32 additions & 26 deletions causalml/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ class NearestNeighborMatch:
ratio (int): ratio of control / treatment to be matched.
shuffle (bool): whether to shuffle the treatment group data before
matching
treatment_to_control (bool): whether to match treatment to control
or control to treatment
random_state (numpy.random.RandomState or int): RandomState or an int
seed
n_jobs (int): The number of parallel jobs to run for neighbors search.
Expand All @@ -103,6 +105,7 @@ def __init__(
replace=False,
ratio=1,
shuffle=True,
treatment_to_control=True,
random_state=None,
n_jobs=-1,
):
Expand All @@ -123,6 +126,7 @@ def __init__(
self.replace = replace
self.ratio = ratio
self.shuffle = shuffle
self.treatment_to_control = treatment_to_control
self.random_state = check_random_state(random_state)
self.n_jobs = n_jobs

Expand All @@ -144,16 +148,19 @@ def match(self, data, treatment_col, score_cols):
treatment = data.loc[data[treatment_col] == 1, score_cols]
control = data.loc[data[treatment_col] == 0, score_cols]

# Picks whether to use treatment or control for matching direction
match_from = treatment if self.treatment_to_control else control
match_to = control if self.treatment_to_control else treatment
sdcal = self.caliper * np.std(data[score_cols].values)

if self.replace:
scaler = StandardScaler()
scaler.fit(data[score_cols])
treatment_scaled = pd.DataFrame(
scaler.transform(treatment), index=treatment.index
match_from_scaled = pd.DataFrame(
scaler.transform(match_from), index=match_from.index
)
control_scaled = pd.DataFrame(
scaler.transform(control), index=control.index
match_to_scaled = pd.DataFrame(
scaler.transform(match_to), index=match_to.index
)

# SD is the same as caliper because we use a StandardScaler above
Expand All @@ -162,21 +169,20 @@ def match(self, data, treatment_col, score_cols):
matching_model = NearestNeighbors(
n_neighbors=self.ratio, n_jobs=self.n_jobs
)
matching_model.fit(control_scaled)
distances, indices = matching_model.kneighbors(treatment_scaled)

matching_model.fit(match_to_scaled)
distances, indices = matching_model.kneighbors(match_from_scaled)
# distances and indices are (n_obs, self.ratio) matrices.
# To index easily, reshape distances, indices and treatment into
# the (n_obs * self.ratio, 1) matrices and data frame.
distances = distances.T.flatten()
indices = indices.T.flatten()
treatment_scaled = pd.concat([treatment_scaled] * self.ratio, axis=0)
match_from_scaled = pd.concat([match_from_scaled] * self.ratio, axis=0)

cond = (distances / np.sqrt(len(score_cols))) < sdcal
# Deduplicate the indices of the treatment group
t_idx_matched = np.unique(treatment_scaled.loc[cond].index)
from_idx_matched = np.unique(match_from_scaled.loc[cond].index)
# XXX: Should we deduplicate the indices of the control group too?
c_idx_matched = np.array(control_scaled.iloc[indices[cond]].index)
to_idx_matched = np.array(match_to_scaled.iloc[indices[cond]].index)
else:
assert len(score_cols) == 1, (
"Matching on multiple columns is only supported using the "
Expand All @@ -187,31 +193,31 @@ def match(self, data, treatment_col, score_cols):
score_col = score_cols[0]

if self.shuffle:
t_indices = self.random_state.permutation(treatment.index)
from_indices = self.random_state.permutation(match_from.index)
else:
t_indices = treatment.index
from_indices = match_from.index

t_idx_matched = []
c_idx_matched = []
control["unmatched"] = True
from_idx_matched = []
to_idx_matched = []
match_to["unmatched"] = True

for t_idx in t_indices:
for from_idx in from_indices:
dist = np.abs(
control.loc[control.unmatched, score_col]
- treatment.loc[t_idx, score_col]
match_to.loc[match_to.unmatched, score_col]
- match_from.loc[from_idx, score_col]
)
# Gets self.ratio lowest dists
c_np_idx_list = np.argpartition(dist, self.ratio)[: self.ratio]
c_idx_list = dist.index[c_np_idx_list]
for i, c_idx in enumerate(c_idx_list):
if dist[c_idx] <= sdcal:
to_np_idx_list = np.argpartition(dist, self.ratio)[: self.ratio]
to_idx_list = dist.index[to_np_idx_list]
for i, to_idx in enumerate(to_idx_list):
if dist[to_idx] <= sdcal:
if i == 0:
t_idx_matched.append(t_idx)
c_idx_matched.append(c_idx)
control.loc[c_idx, "unmatched"] = False
from_idx_matched.append(from_idx)
to_idx_matched.append(to_idx)
match_to.loc[to_idx, "unmatched"] = False

return data.loc[
np.concatenate([np.array(t_idx_matched), np.array(c_idx_matched)])
np.concatenate([np.array(from_idx_matched), np.array(to_idx_matched)])
]

def match_by_group(self, data, treatment_col, score_cols, groupby_col):
Expand Down
21 changes: 21 additions & 0 deletions tests/test_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,27 @@ def test_nearest_neighbor_match_by_group(generate_unmatched_data):
assert sum(matched[TREATMENT_COL] == 0) == sum(matched[TREATMENT_COL] != 0)


def test_nearest_neighbor_match_control_to_treatment(generate_unmatched_data):
"""
Tests whether control to treatment matching is working. Does so
by using:
replace=True
treatment_to_control=False
ratio=2
And testing if we get 2x the number of control matches than treatment
"""
df, features = generate_unmatched_data()

psm = NearestNeighborMatch(
replace=True, ratio=2, treatment_to_control=False, random_state=RANDOM_SEED
)
matched = psm.match(data=df, treatment_col=TREATMENT_COL, score_cols=[SCORE_COL])
assert 2 * sum(matched[TREATMENT_COL] == 0) == sum(matched[TREATMENT_COL] != 0)


def test_match_optimizer(generate_unmatched_data):
df, features = generate_unmatched_data()

Expand Down

0 comments on commit 8e4a5bf

Please sign in to comment.