Skip to content

Commit

Permalink
changed ndkl(before) threshold to k #80
Browse files Browse the repository at this point in the history
  • Loading branch information
Hamedloghmani committed Aug 23, 2023
1 parent f415724 commit 39d6f9b
Showing 1 changed file with 23 additions and 13 deletions.
36 changes: 23 additions & 13 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def rerank(preds, labels, output, ratios, algorithm: str = 'det_greedy', k_max:
return idx, probs, (finish_time - start_time)

@staticmethod
def eval_fairness(preds, labels, reranked_idx, ratios, output, algorithm, k_max, eq_op: bool = False) -> pandas.DataFrame:
def eval_fairness(preds, labels, reranked_idx, ratios, output, algorithm, k_max, eq_op: bool = False, metrics: list = ['ndkl',]) -> pandas.DataFrame:
"""
Args:
preds: loaded predictions from a .pred file
Expand All @@ -164,18 +164,28 @@ def eval_fairness(preds, labels, reranked_idx, ratios, output, algorithm, k_max,
# because the mapping between popular/nonpopular and protected/nonprotected is reversed
if algorithm == 'fa-ir':
labels = [not value for value in labels]

dic_before = {'ndkl':[]}; dic_after={'ndkl':[]}
for i, team in enumerate(tqdm(preds)):
if eq_op:
r = {True: 1 - ratios[i], False: ratios[i]}
else:
r = ratios
member_popularity_probs = [(m, labels[m], float(team[m])) for m in range(len(team))]
member_popularity_probs.sort(key=lambda x: x[2], reverse=True)
#IMPORTANT: the ratios keys should match the labels!
dic_before['ndkl'].append(reranking.ndkl([label for _, label, _ in member_popularity_probs], r))
dic_after['ndkl'].append(reranking.ndkl([labels[int(m)] for m in reranked_idx[i]], r))
dic_before, dic_after = dict(), dict()
if 'ndkl' in metrics:
dic_before['ndkl'], dic_after['ndkl'] = list(), list()
for i, team in enumerate(tqdm(preds)):
if eq_op:
r = {True: 1 - ratios[i], False: ratios[i]}
else:
r = ratios
member_popularity_probs = [(m, labels[m], float(team[m])) for m in range(k_max)]
member_popularity_probs.sort(key=lambda x: x[2], reverse=True)
#IMPORTANT: the ratios keys should match the labels!
dic_before['ndkl'].append(reranking.ndkl([label for _, label, _ in member_popularity_probs], r))
dic_after['ndkl'].append(reranking.ndkl([labels[int(m)] for m in reranked_idx[i]], r))
if 'skew' in metrics:
dic_before['skew'], dic_after['skew'] = list(), list()
for i, team in enumerate(tqdm(preds)):
if eq_op:
r = {True: 1 - ratios[i], False: ratios[i]}
else:
r = ratios
member_popularity_probs = [(m, labels[m], float(team[m])) for m in range(len(team))]
member_popularity_probs.sort(key=lambda x: x[2], reverse=True)

df_before = pd.DataFrame(dic_before).mean(axis=0).to_frame('mean.before')
df_after = pd.DataFrame(dic_after).mean(axis=0).to_frame('mean.after')
Expand Down

0 comments on commit 39d6f9b

Please sign in to comment.