diff --git a/src/main.py b/src/main.py index 824c228..317b068 100644 --- a/src/main.py +++ b/src/main.py @@ -149,7 +149,7 @@ def rerank(preds, labels, output, ratios, algorithm: str = 'det_greedy', k_max: return idx, probs, (finish_time - start_time) @staticmethod - def eval_fairness(preds, labels, reranked_idx, ratios, output, algorithm, k_max, eq_op: bool = False) -> pandas.DataFrame: + def eval_fairness(preds, labels, reranked_idx, ratios, output, algorithm, k_max, eq_op: bool = False, metrics: list = ['ndkl',]) -> pandas.DataFrame: """ Args: preds: loaded predictions from a .pred file @@ -164,18 +164,28 @@ def eval_fairness(preds, labels, reranked_idx, ratios, output, algorithm, k_max, # because the mapping between popular/nonpopular and protected/nonprotected is reversed if algorithm == 'fa-ir': labels = [not value for value in labels] - - dic_before = {'ndkl':[]}; dic_after={'ndkl':[]} - for i, team in enumerate(tqdm(preds)): - if eq_op: - r = {True: 1 - ratios[i], False: ratios[i]} - else: - r = ratios - member_popularity_probs = [(m, labels[m], float(team[m])) for m in range(len(team))] - member_popularity_probs.sort(key=lambda x: x[2], reverse=True) - #IMPORTANT: the ratios keys should match the labels! - dic_before['ndkl'].append(reranking.ndkl([label for _, label, _ in member_popularity_probs], r)) - dic_after['ndkl'].append(reranking.ndkl([labels[int(m)] for m in reranked_idx[i]], r)) + dic_before, dic_after = dict(), dict() + if 'ndkl' in metrics: + dic_before['ndkl'], dic_after['ndkl'] = list(), list() + for i, team in enumerate(tqdm(preds)): + if eq_op: + r = {True: 1 - ratios[i], False: ratios[i]} + else: + r = ratios + member_popularity_probs = [(m, labels[m], float(team[m])) for m in range(k_max)] + member_popularity_probs.sort(key=lambda x: x[2], reverse=True) + #IMPORTANT: the ratios keys should match the labels! + dic_before['ndkl'].append(reranking.ndkl([label for _, label, _ in member_popularity_probs], r)) + dic_after['ndkl'].append(reranking.ndkl([labels[int(m)] for m in reranked_idx[i]], r)) + if 'skew' in metrics: + dic_before['skew'], dic_after['skew'] = list(), list() + for i, team in enumerate(tqdm(preds)): + if eq_op: + r = {True: 1 - ratios[i], False: ratios[i]} + else: + r = ratios + member_popularity_probs = [(m, labels[m], float(team[m])) for m in range(len(team))] + member_popularity_probs.sort(key=lambda x: x[2], reverse=True) df_before = pd.DataFrame(dic_before).mean(axis=0).to_frame('mean.before') df_after = pd.DataFrame(dic_after).mean(axis=0).to_frame('mean.after')