Skip to content

Commit

Permalink
Merge pull request #52 from JulesBelveze/docs/tutorials
Browse files Browse the repository at this point in the history
docs: add tutorials
  • Loading branch information
JulesBelveze authored Jun 16, 2023
2 parents 6bdf094 + 0a8fc51 commit 5491980
Show file tree
Hide file tree
Showing 184 changed files with 1,679 additions and 34,845 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/documentation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install sphinx furo sphinx-copybutton sphinxext-opengraph myst-parser autodoc_pydantic
pip install sphinx furo sphinx-copybutton sphinxext-opengraph myst-parser autodoc_pydantic nbsphinx
sudo apt-get install pandoc
- name: Build and Commit
uses: sphinx-notes/pages@master
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
*/*/__pycache__/
*/*/*/__pycache__/

# docs
docs/_build

# data
bert-squeeze/data/classification
bert-squeeze/data/unlabeled
Expand Down
19 changes: 12 additions & 7 deletions bert_squeeze/assistants/configs/distil.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ general:
train:
adam_eps: 1e-8
accumulation_steps: 1
alpha: 0.5
auto_lr: false
discriminative_learning: true
dropout: 0.2
Expand All @@ -35,13 +36,15 @@ train:
model:
_target_: bert_squeeze.distillation.distiller.Distiller
teacher:
_target_: transformers.models.auto.AutoModelForSequenceClassification.from_pretrained
pretrained_model_name_or_path: "bert-base-uncased"
_target_: bert_squeeze.models.lt_bert.LtCustomBert
pretrained_model: "bert-base-uncased"
num_labels: ${general.num_labels}
training_config: ${train}
student:
_target_: transformers.models.auto.AutoModelForSequenceClassification.from_pretrained
pretrained_model_name_or_path: "bert-base-cased"
_target_: bert_squeeze.models.lt_bert.LtCustomBert
pretrained_model: "bert-base-cased"
num_labels: ${general.num_labels}
training_config: ${train}
training_config: ${train}


Expand All @@ -55,7 +58,7 @@ data:
split:
text_col: "text"
label_col: "label"
tokenizer_name: ${model.teacher.pretrained_model_name_or_path}
tokenizer_name: ${model.teacher.pretrained_model}
max_length: 256
student_module:
_target_: bert_squeeze.data.modules.transformer_module.TransformerDataModule
Expand All @@ -65,5 +68,7 @@ data:
split: ${data.teacher_module.dataset_config.split}
text_col: ${data.teacher_module.dataset_config.text_col}
label_col: ${data.teacher_module.dataset_config.label_col}
tokenizer_name: ${model.student.pretrained_model_name_or_path}
max_length: 256
tokenizer_name: ${model.student.pretrained_model}
max_length: 256

callbacks:
16 changes: 11 additions & 5 deletions bert_squeeze/assistants/train_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"lstm": "train_lstm.yaml",
"deebert": "train_deebert.yaml",
"fastbert": "train_fastbert.yaml",
"theseus-bert": "train_theseus_bert.yaml",
"theseusbert": "train_theseus_bert.yaml",
}


Expand Down Expand Up @@ -59,11 +59,17 @@ def __init__(
logger_kwargs: Dict[str, Any] = None,
callbacks: List[Callback] = None,
):
conf = OmegaConf.load(
resource_filename(
"bert_squeeze", os.path.join("assistants/configs", CONFIG_MAPPER[name])
try:
conf = OmegaConf.load(
resource_filename(
"bert_squeeze",
os.path.join("assistants/configs", CONFIG_MAPPER[name]),
)
)
except KeyError:
raise ValueError(
f"'{name}' is not a valid configuration name, please use one of the following: {CONFIG_MAPPER.keys()}"
)
)
if (
data_kwargs is not None
and data_kwargs.get("dataset_config", {}).get("path") is not None
Expand Down
46 changes: 28 additions & 18 deletions bert_squeeze/data/modules/distillation_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,25 @@ def __init__(
self.test = None
self.val = None

@staticmethod
def _concat_dataset(
a: Union[datasets.Dataset, datasets.DatasetDict],
b: Union[datasets.Dataset, datasets.DatasetDict],
) -> Union[datasets.Dataset, datasets.DatasetDict]:
""""""
assert type(a) == type(b) and a.keys() == b.keys()

if isinstance(a, datasets.DatasetDict):
concat_dataset = datasets.DatasetDict(
{
key: datasets.concatenate_datasets([a[key], b[key]], axis=1)
for key in a.keys()
}
)
else:
concat_dataset = datasets.concatenate_datasets([a, b], axis=1)
return concat_dataset

def create_hard_dataset(self) -> datasets.Dataset:
""""""
hard_dataset = self.labeler.label_dataset()
Expand Down Expand Up @@ -176,14 +195,7 @@ def featurize(self) -> datasets.DatasetDict:
)

# Merging the student and teacher datasets into a single one
concat_dataset = datasets.DatasetDict(
{
key: datasets.concatenate_datasets(
[teacher_data[key], student_data[key]], axis=1
)
for key in ["train", "test", "validation"]
}
)
concat_dataset = self._concat_dataset(teacher_data, student_data)
concat_dataset = concat_dataset.shuffle()
concat_dataset.set_format(type="torch")
return concat_dataset
Expand All @@ -200,32 +212,30 @@ def setup(self, stage: Optional[str] = None) -> None:
featurized_dataset = self.featurize()

self.train = featurized_dataset["train"]
self.val = featurized_dataset["validation"]
self.val = (
featurized_dataset["validation"]
if "validation" in featurized_dataset.keys()
else featurized_dataset["test"]
)
self.test = featurized_dataset["test"]

def train_dataloader(self) -> DataLoader:
"""
Returns:
DataLoader: Train dataloader
"""
return DataLoader(
self.train, batch_size=self.train_batch_size, drop_last=True, num_workers=0
)
return DataLoader(self.train, batch_size=self.train_batch_size, drop_last=True)

def test_dataloader(self) -> DataLoader:
"""
Returns:
DataLoader: Test dataloader
"""
return DataLoader(
self.test, batch_size=self.eval_batch_size, drop_last=True, num_workers=0
)
return DataLoader(self.test, batch_size=self.eval_batch_size, drop_last=True)

def val_dataloader(self) -> DataLoader:
"""
Returns:
DataLoader: Validation dataloader
"""
return DataLoader(
self.val, batch_size=self.eval_batch_size, drop_last=True, num_workers=0
)
return DataLoader(self.val, batch_size=self.eval_batch_size, drop_last=True)
3 changes: 0 additions & 3 deletions bert_squeeze/data/modules/transformer_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def train_dataloader(self) -> DataLoader:
# collate_fn=self._collate_fn(),
batch_size=self.train_batch_size,
drop_last=True,
num_workers=0,
)

def test_dataloader(self) -> DataLoader:
Expand All @@ -121,7 +120,6 @@ def test_dataloader(self) -> DataLoader:
# collate_fn=self._collate_fn(),
batch_size=self.eval_batch_size,
drop_last=True,
num_workers=0,
)

def val_dataloader(self) -> DataLoader:
Expand All @@ -134,7 +132,6 @@ def val_dataloader(self) -> DataLoader:
# collate_fn=self._collate_fn(),
batch_size=self.eval_batch_size,
drop_last=True,
num_workers=0,
)


Expand Down
90 changes: 60 additions & 30 deletions bert_squeeze/distillation/distiller.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.nn import CrossEntropyLoss
from torch.optim.lr_scheduler import ReduceLROnPlateau
from transformers import AdamW
from transformers.modeling_outputs import SequenceClassifierOutput

from ..utils.losses import LabelSmoothingLoss
from ..utils.losses.distillation_losses import KLDivLoss
Expand Down Expand Up @@ -229,16 +230,10 @@ def configure_optimizers(self) -> Tuple[List, List]:
eps=self.params.adam_eps,
)
elif self.params.optimizer == "bertadam":
num_training_steps = (
len(self.train_dataloader())
* self.params.num_epochs
// self.params.accumulation_steps
)
optimizer = BertAdam(
optimizer_parameters,
lr=self.params.learning_rates[0],
warmup=self.params.warmup_ratio,
t_total=num_training_steps,
)
elif self.params.optimizer == "adam":
optimizer = torch.optim.Adam(
Expand All @@ -249,7 +244,11 @@ def configure_optimizers(self) -> Tuple[List, List]:

if self.params.lr_scheduler:
scheduler = ReduceLROnPlateau(optimizer)
lr_scheduler = {"scheduler": scheduler, "name": "NeptuneLogger"}
lr_scheduler = {
"scheduler": scheduler,
"name": "NeptuneLogger",
"monitor": "loss",
}
return [optimizer], [lr_scheduler]

return [optimizer], []
Expand Down Expand Up @@ -361,11 +360,17 @@ def get_teacher_logits(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
self.teacher.eval()
teacher_inputs = {
key[2:]: val for key, val in batch.items() if key.startswith("t_")
key[2:]: val
for key, val in batch.items()
if key.startswith("t_") and "labels" not in key
}
with torch.no_grad():
logits = self.teacher.forward(**teacher_inputs)
return logits
outputs = self.teacher.forward(**teacher_inputs)

if isinstance(outputs, SequenceClassifierOutput):
return outputs.logits

return outputs

@overrides
def get_student_logits(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
Expand All @@ -380,10 +385,16 @@ def get_student_logits(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
student logits
"""
student_inputs = {
key[2:]: val for key, val in batch.items() if key.startswith("s_")
key[2:]: val
for key, val in batch.items()
if key.startswith("s_") and "labels" not in key
}
logits = self.student.forward(**student_inputs)
return logits
outputs = self.student.forward(**student_inputs)

if isinstance(outputs, SequenceClassifierOutput):
return outputs.logits

return outputs

@overrides
def loss(
Expand Down Expand Up @@ -413,7 +424,7 @@ def loss(
# Ignore soft labeled indices (where label is `ignore_index`)
active_idx = labels != ignore_index
if active_idx.sum().item() > 0:
objective = self.loss_lce(student_logits[active_idx], labels[active_idx])
objective = self.loss_ce(student_logits[active_idx], labels[active_idx])
else:
objective = torch.tensor(0.0).to(labels.device)

Expand All @@ -430,7 +441,7 @@ def training_step(self, batch, _) -> torch.Tensor:
loss = self.loss(t_logits, s_logits, batch["s_labels"])

self.s_scorer.add(s_logits.detach().cpu(), batch["s_labels"].cpu(), loss)
if self.global_step > 0 and self.global_step % self.config.logging_steps == 0:
if self.global_step > 0 and self.global_step % self.params.logging_steps == 0:
logging_loss = {
f"train/{key}": torch.stack(val).mean()
for key, val in self.s_scorer.losses.items()
Expand Down Expand Up @@ -473,11 +484,14 @@ def validation_step(self, batch, _) -> Dict:
@overrides
def on_validation_epoch_end(self) -> None:
""""""
all_logits = torch.cat([pred["logits"] for pred in self.validation_step_outputs])
all_probs = F.softmax(all_logits, dim=-1)
labels_probs = [all_probs[:, i] for i in range(all_probs.shape[-1])]
if not self.trainer.sanity_checking:
all_logits = torch.cat(
[pred["logits"] for pred in self.validation_step_outputs]
)
all_probs = F.softmax(all_logits, dim=-1)
labels_probs = [all_probs[:, i] for i in range(all_probs.shape[-1])]
self.log_eval_report(labels_probs)

self.log_eval_report(labels_probs)
self.s_valid_scorer.reset()

@overrides
Expand Down Expand Up @@ -530,11 +544,16 @@ def get_teacher_logits(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
self.teacher.eval()
teacher_inputs = {
key[2:]: val for key, val in batch.items() if key.startswith("t_")
key[2:]: val
for key, val in batch.items()
if key.startswith("t_") and "labels" not in key
}
with torch.no_grad():
logits = self.teacher.forward(**teacher_inputs)
return logits
outputs = self.teacher.forward(**teacher_inputs)

if isinstance(outputs, SequenceClassifierOutput):
return outputs.logits
return outputs

@overrides
def get_student_logits(
Expand All @@ -552,23 +571,31 @@ def get_student_logits(
the original text and one prediction for the translation.
"""
student_inputs = {
key[2:]: val for key, val in batch.items() if key.startswith("s_")
key[2:]: val
for key, val in batch.items()
if key.startswith("s_") and "labels" not in key
}
original_logits = self.student.forward(
original_outputs = self.student.forward(
**{
key: val
for key, val in student_inputs.items()
if not key.startswith("translation")
}
)
translation_logits = self.student.forward(
if isinstance(original_outputs, SequenceClassifierOutput):
original_outputs = original_outputs.logits

translation_outputs = self.student.forward(
**{
key: val
for key, val in student_inputs.items()
if key.startswith("translation")
}
)
return original_logits, translation_logits
if isinstance(translation_outputs, SequenceClassifierOutput):
translation_outputs = translation_outputs.logits

return original_outputs, translation_outputs

@overrides
def loss(
Expand Down Expand Up @@ -645,11 +672,14 @@ def validation_step(self, batch, _) -> Dict:
@overrides
def on_validation_epoch_end(self) -> None:
""""""
all_logits = torch.cat([pred["logits"] for pred in self.validation_step_outputs])
all_probs = F.softmax(all_logits, dim=-1)
labels_probs = [all_probs[:, i] for i in range(all_probs.shape[-1])]
if not self.trainer.sanity_checking:
all_logits = torch.cat(
[pred["logits"] for pred in self.validation_step_outputs]
)
all_probs = F.softmax(all_logits, dim=-1)
labels_probs = [all_probs[:, i] for i in range(all_probs.shape[-1])]
self.log_eval_report(labels_probs)

self.log_eval_report(labels_probs)
self.s_valid_scorer.reset()

@overrides
Expand Down
Loading

0 comments on commit 5491980

Please sign in to comment.