Skip to content

Commit

Permalink
add set_optimizer for programmatic fine-tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
fschlatt committed Sep 20, 2024
1 parent cc9a322 commit b447142
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
36 changes: 29 additions & 7 deletions lightning_ir/base/module.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Sequence, Tuple
from typing import Any, Dict, List, Sequence, Tuple, Type

import torch
from lightning import LightningModule
Expand Down Expand Up @@ -48,12 +48,15 @@ def __init__(
:raises ValueError: If neither model nor model_name_or_path are provided
"""
super().__init__()
self.save_hyperparameters()
if model is not None and model_name_or_path is not None:
raise ValueError("Only one of model or model_name_or_path must be provided.")
if model is None:
if model_name_or_path is None:
raise ValueError("Either model or model_name_or_path must be provided.")
model = LightningIRModel.from_pretrained(model_name_or_path, config=config)
# NOTE huggingface models are in eval mode by default
model = model.train()

self.model: LightningIRModel = model
self.config = self.model.config
Expand All @@ -66,16 +69,35 @@ def __init__(
else:
self.loss_functions.append(loss_function)
self.evaluation_metrics = evaluation_metrics
self._optimizer: torch.optim.Optimizer | None = None
self.tokenizer = LightningIRTokenizer.from_pretrained(self.config.name_or_path, config=self.config)

def on_fit_start(self) -> None:
"""Called at the very beginning of fit.
def configure_optimizers(self) -> torch.optim.Optimizer:
"""Configures the optizmizer for fine-tuning. This method is ignored when using the CLI. When using Lightning IR
programmatically, the optimizer must be set using :meth:`set_optimizer`.
If on DDP it is called on every process
:raises ValueError: If optimizer is not set
:return: Optimizer
:rtype: torch.optim.Optimizer
"""
# NOTE huggingface models are in eval mode by default
self.train()
return super().on_fit_start()
if self._optimizer is None:
raise ValueError("Optimizer is not set. Call `set_optimizer`.")
return self._optimizer

def set_optimizer(
self, optimizer: Type[torch.optim.Optimizer], **optimizer_kwargs: Dict[str, Any]
) -> "LightningIRModule":
"""Sets the optimizer for the model. Necessary for fine-tuning when not using the CLI.
:param optimizer: Torch optimizer class
:type optimizer: Type[torch.optim.Optimizer]
:param optimizer_kwargs: Arguments to initialize the optimizer
:type optimizer_kwargs: Dict[str, Any]
:return: self
:rtype: LightningIRModule
"""
self._optimizer = optimizer(self.parameters(), **optimizer_kwargs)
return self

def score(self, queries: Sequence[str] | str, docs: Sequence[Sequence[str]] | Sequence[str]) -> LightningIROutput:
"""Computes relevance scores for queries and documents.
Expand Down
7 changes: 7 additions & 0 deletions lightning_ir/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ def subcommands() -> Dict[str, Set[str]]:
"re_rank": {"model", "dataloaders", "datamodule"},
}

def _add_configure_optimizers_method_to_model(self, subcommand: str | None) -> None:
import warnings

with warnings.catch_warnings():
warnings.simplefilter("ignore")
return super()._add_configure_optimizers_method_to_model(subcommand)


def main():
"""
Expand Down

0 comments on commit b447142

Please sign in to comment.