This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Contextualized bias mitigation (#5176)
* added linear and hard debiasers * worked on documentation * committing changes before branch switch * committing changes before switching branch * finished bias direction, linear and hard debiasers, need to write tests * finished bias direction test * Commiting changes before switching branch * finished hard and linear debiasers * finished OSCaR * bias mitigators tests and bias metrics remaining * added bias mitigator tests * added bias mitigator tests * finished tests for bias mitigation methods * fixed gpu issues * fixed gpu issues * fixed gpu issues * resolve issue with count_nonzero not being differentiable * added more references * fairness during finetuning * finished bias mitigator wrapper * added reference * updated CHANGELOG and fixed minor docs issues * move id tensors to embedding device * fixed to use predetermined bias direction * fixed minor doc errors * snli reader registration issue * fixed _pretrained from params issue * fixed device issues * evaluate bias mitigation initial commit * finished evaluate bias mitigation * handles multiline prediction files * fixed minor bugs * fixed minor bugs * improved prediction diff JSON format * forgot to resolve a conflict * Refactored evaluate bias mitigation to use NLI metric * Added SNLIPredictionsDiff class * ensured dataloader is same for bias mitigated and baseline models * finished evaluate bias mitigation * Update CHANGELOG.md * Replaced local data files with github raw content links * Update allennlp/fairness/bias_mitigator_applicator.py Co-authored-by: Pete <petew@allenai.org> * deleted evaluate_bias_mitigation from git tracking * removed evaluate-bias-mitigation instances from rest of repo * addressed Akshita's comments * moved bias mitigator applicator test to allennlp-models * removed unnecessary files Co-authored-by: Arjun Subramonian <arjuns@Arjuns-MacBook-Pro.local> Co-authored-by: Arjun Subramonian <arjuns@ip-192-168-0-106.us-west-2.compute.internal> Co-authored-by: Arjun Subramonian <arjuns@ip-192-168-0-108.us-west-2.compute.internal> Co-authored-by: Arjun Subramonian <arjuns@ip-192-168-1-108.us-west-2.compute.internal> Co-authored-by: Akshita Bhagia <akshita23bhagia@gmail.com> Co-authored-by: Pete <petew@allenai.org>
- Loading branch information
1 parent
aa52a9a
commit b92fd9a
Showing
12 changed files
with
2,557 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,269 @@ | ||
import torch | ||
from typing import Union, Optional | ||
from os import PathLike | ||
|
||
from allennlp.fairness.bias_direction import ( | ||
BiasDirection, | ||
PCABiasDirection, | ||
PairedPCABiasDirection, | ||
TwoMeansBiasDirection, | ||
ClassificationNormalBiasDirection, | ||
) | ||
from allennlp.fairness.bias_utils import load_word_pairs, load_words | ||
|
||
from allennlp.common import Registrable | ||
from allennlp.data.tokenizers.tokenizer import Tokenizer | ||
from allennlp.data import Vocabulary | ||
|
||
|
||
class BiasDirectionWrapper(Registrable): | ||
""" | ||
Parent class for bias direction wrappers. | ||
""" | ||
|
||
def __init__(self): | ||
self.direction: BiasDirection = None | ||
self.noise: float = None | ||
|
||
def __call__(self, module): | ||
raise NotImplementedError | ||
|
||
def train(self, mode: bool = True): | ||
""" | ||
# Parameters | ||
mode : `bool`, optional (default=`True`) | ||
Sets `requires_grad` to value of `mode` for bias direction. | ||
""" | ||
self.direction.requires_grad = mode | ||
|
||
def add_noise(self, t: torch.Tensor): | ||
""" | ||
# Parameters | ||
t : `torch.Tensor` | ||
Tensor to which to add small amount of Gaussian noise. | ||
""" | ||
return t + self.noise * torch.randn(t.size(), device=t.device) | ||
|
||
|
||
@BiasDirectionWrapper.register("pca") | ||
class PCABiasDirectionWrapper(BiasDirectionWrapper): | ||
""" | ||
# Parameters | ||
seed_words_file : `Union[PathLike, str]` | ||
Path of file containing seed words. | ||
tokenizer : `Tokenizer` | ||
Tokenizer used to tokenize seed words. | ||
direction_vocab : `Vocabulary`, optional (default=`None`) | ||
Vocabulary of tokenizer. If `None`, assumes tokenizer is of | ||
type `PreTrainedTokenizer` and uses tokenizer's `vocab` attribute. | ||
namespace : `str`, optional (default=`"tokens"`) | ||
Namespace of direction_vocab to use when tokenizing. | ||
Disregarded when direction_vocab is `None`. | ||
requires_grad : `bool`, optional (default=`False`) | ||
Option to enable gradient calculation for bias direction. | ||
noise : `float`, optional (default=`1e-10`) | ||
To avoid numerical instability if embeddings are initialized uniformly. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
seed_words_file: Union[PathLike, str], | ||
tokenizer: Tokenizer, | ||
direction_vocab: Optional[Vocabulary] = None, | ||
namespace: str = "tokens", | ||
requires_grad: bool = False, | ||
noise: float = 1e-10, | ||
): | ||
self.ids = load_words(seed_words_file, tokenizer, direction_vocab, namespace) | ||
self.direction = PCABiasDirection(requires_grad=requires_grad) | ||
self.noise = noise | ||
|
||
def __call__(self, module): | ||
# embed subword token IDs and mean pool to get | ||
# embedding of original word | ||
ids_embeddings = [] | ||
for i in self.ids: | ||
i = i.to(module.weight.device) | ||
ids_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True)) | ||
ids_embeddings = torch.cat(ids_embeddings) | ||
|
||
# adding trivial amount of noise | ||
# to eliminate linear dependence amongst all embeddings | ||
# when training first starts | ||
ids_embeddings = self.add_noise(ids_embeddings) | ||
|
||
return self.direction(ids_embeddings) | ||
|
||
|
||
@BiasDirectionWrapper.register("paired_pca") | ||
class PairedPCABiasDirectionWrapper(BiasDirectionWrapper): | ||
""" | ||
# Parameters | ||
seed_word_pairs_file : `Union[PathLike, str]` | ||
Path of file containing seed word pairs. | ||
tokenizer : `Tokenizer` | ||
Tokenizer used to tokenize seed words. | ||
direction_vocab : `Vocabulary`, optional (default=`None`) | ||
Vocabulary of tokenizer. If `None`, assumes tokenizer is of | ||
type `PreTrainedTokenizer` and uses tokenizer's `vocab` attribute. | ||
namespace : `str`, optional (default=`"tokens"`) | ||
Namespace of direction_vocab to use when tokenizing. | ||
Disregarded when direction_vocab is `None`. | ||
requires_grad : `bool`, optional (default=`False`) | ||
Option to enable gradient calculation for bias direction. | ||
noise : `float`, optional (default=`1e-10`) | ||
To avoid numerical instability if embeddings are initialized uniformly. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
seed_word_pairs_file: Union[PathLike, str], | ||
tokenizer: Tokenizer, | ||
direction_vocab: Optional[Vocabulary] = None, | ||
namespace: str = "tokens", | ||
requires_grad: bool = False, | ||
noise: float = 1e-10, | ||
): | ||
self.ids1, self.ids2 = load_word_pairs( | ||
seed_word_pairs_file, tokenizer, direction_vocab, namespace | ||
) | ||
self.direction = PairedPCABiasDirection(requires_grad=requires_grad) | ||
self.noise = noise | ||
|
||
def __call__(self, module): | ||
# embed subword token IDs and mean pool to get | ||
# embedding of original word | ||
ids1_embeddings = [] | ||
for i in self.ids1: | ||
i = i.to(module.weight.device) | ||
ids1_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True)) | ||
ids2_embeddings = [] | ||
for i in self.ids2: | ||
i = i.to(module.weight.device) | ||
ids2_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True)) | ||
ids1_embeddings = torch.cat(ids1_embeddings) | ||
ids2_embeddings = torch.cat(ids2_embeddings) | ||
|
||
ids1_embeddings = self.add_noise(ids1_embeddings) | ||
ids2_embeddings = self.add_noise(ids2_embeddings) | ||
|
||
return self.direction(ids1_embeddings, ids2_embeddings) | ||
|
||
|
||
@BiasDirectionWrapper.register("two_means") | ||
class TwoMeansBiasDirectionWrapper(BiasDirectionWrapper): | ||
""" | ||
# Parameters | ||
seed_word_pairs_file : `Union[PathLike, str]` | ||
Path of file containing seed word pairs. | ||
tokenizer : `Tokenizer` | ||
Tokenizer used to tokenize seed words. | ||
direction_vocab : `Vocabulary`, optional (default=`None`) | ||
Vocabulary of tokenizer. If `None`, assumes tokenizer is of | ||
type `PreTrainedTokenizer` and uses tokenizer's `vocab` attribute. | ||
namespace : `str`, optional (default=`"tokens"`) | ||
Namespace of direction_vocab to use when tokenizing. | ||
Disregarded when direction_vocab is `None`. | ||
requires_grad : `bool`, optional (default=`False`) | ||
Option to enable gradient calculation for bias direction. | ||
noise : `float`, optional (default=`1e-10`) | ||
To avoid numerical instability if embeddings are initialized uniformly. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
seed_word_pairs_file: Union[PathLike, str], | ||
tokenizer: Tokenizer, | ||
direction_vocab: Optional[Vocabulary] = None, | ||
namespace: str = "tokens", | ||
requires_grad: bool = False, | ||
noise: float = 1e-10, | ||
): | ||
self.ids1, self.ids2 = load_word_pairs( | ||
seed_word_pairs_file, tokenizer, direction_vocab, namespace | ||
) | ||
self.direction = TwoMeansBiasDirection(requires_grad=requires_grad) | ||
self.noise = noise | ||
|
||
def __call__(self, module): | ||
# embed subword token IDs and mean pool to get | ||
# embedding of original word | ||
ids1_embeddings = [] | ||
for i in self.ids1: | ||
i = i.to(module.weight.device) | ||
ids1_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True)) | ||
ids2_embeddings = [] | ||
for i in self.ids2: | ||
i = i.to(module.weight.device) | ||
ids2_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True)) | ||
ids1_embeddings = torch.cat(ids1_embeddings) | ||
ids2_embeddings = torch.cat(ids2_embeddings) | ||
|
||
ids1_embeddings = self.add_noise(ids1_embeddings) | ||
ids2_embeddings = self.add_noise(ids2_embeddings) | ||
|
||
return self.direction(ids1_embeddings, ids2_embeddings) | ||
|
||
|
||
@BiasDirectionWrapper.register("classification_normal") | ||
class ClassificationNormalBiasDirectionWrapper(BiasDirectionWrapper): | ||
""" | ||
# Parameters | ||
seed_word_pairs_file : `Union[PathLike, str]` | ||
Path of file containing seed word pairs. | ||
tokenizer : `Tokenizer` | ||
Tokenizer used to tokenize seed words. | ||
direction_vocab : `Vocabulary`, optional (default=`None`) | ||
Vocabulary of tokenizer. If `None`, assumes tokenizer is of | ||
type `PreTrainedTokenizer` and uses tokenizer's `vocab` attribute. | ||
namespace : `str`, optional (default=`"tokens"`) | ||
Namespace of direction_vocab to use when tokenizing. | ||
Disregarded when direction_vocab is `None`. | ||
noise : `float`, optional (default=`1e-10`) | ||
To avoid numerical instability if embeddings are initialized uniformly. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
seed_word_pairs_file: Union[PathLike, str], | ||
tokenizer: Tokenizer, | ||
direction_vocab: Optional[Vocabulary] = None, | ||
namespace: str = "tokens", | ||
noise: float = 1e-10, | ||
): | ||
self.ids1, self.ids2 = load_word_pairs( | ||
seed_word_pairs_file, tokenizer, direction_vocab, namespace | ||
) | ||
self.direction = ClassificationNormalBiasDirection() | ||
self.noise = noise | ||
|
||
def __call__(self, module): | ||
# embed subword token IDs and mean pool to get | ||
# embedding of original word | ||
ids1_embeddings = [] | ||
for i in self.ids1: | ||
i = i.to(module.weight.device) | ||
ids1_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True)) | ||
ids2_embeddings = [] | ||
for i in self.ids2: | ||
i = i.to(module.weight.device) | ||
ids2_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True)) | ||
ids1_embeddings = torch.cat(ids1_embeddings) | ||
ids2_embeddings = torch.cat(ids2_embeddings) | ||
|
||
ids1_embeddings = self.add_noise(ids1_embeddings) | ||
ids2_embeddings = self.add_noise(ids2_embeddings) | ||
|
||
return self.direction(ids1_embeddings, ids2_embeddings) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.