Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ sh = uv run --no-sync --frozen
.PHONY: install
install:
rm -rf uv.lock
uv sync --all-groups
uv sync --all-groups --extra catboost --extra peft --extra sentence-transformers --extra transformers

.PHONY: test
test:
Expand Down
13 changes: 7 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ classifiers=[
]
requires-python = ">=3.10,<3.13"
dependencies = [
"sentence-transformers (>=3,<4)",
"torch (>=2.0.0,<3.0.0)",
"scikit-learn (>=1.5,<2.0)",
"scikit-multilearn (==0.2.0)",
"iterative-stratification (>=0.1.9)",
"appdirs (>=1.4,<2.0)",
"optuna (>=4.0.0,<5.0.0)",
"pathlib (>=1.0.1,<2.0.0)",
Expand All @@ -43,15 +43,16 @@ dependencies = [
"datasets (>=3.2.0,<4.0.0)",
"xxhash (>=3.5.0,<4.0.0)",
"python-dotenv (>=1.0.1,<2.0.0)",
"transformers[torch] (>=4.49.0,<5.0.0)",
"peft (>= 0.10.0, !=0.15.0, !=0.15.1, <1.0.0)",
"catboost (>=1.2.8,<2.0.0)",
"aiometer (>=1.0.0,<2.0.0)",
"aiofiles (>=24.1.0,<25.0.0)",
"threadpoolctl (>=3.0.0,<4.0.0)",
]

[project.optional-dependencies]
catboost = ["catboost (>=1.2.8,<2.0.0)"]
peft = ["peft (>= 0.10.0, !=0.15.0, !=0.15.1, <1.0.0)"]
transformers = ["transformers (>=4.49.0,<5.0.0)"]
sentence-transformers = ["sentence-transformers (>=3,<4)"]
dspy = [
"dspy (>=2.6.5,<3.0.0)",
]
Expand Down Expand Up @@ -252,7 +253,7 @@ module = [
"xeger",
"appdirs",
"sre_yield",
"skmultilearn.model_selection",
"iterstrat.ml_stratifiers",
"hydra",
"hydra.*",
"transformers",
Expand Down
82 changes: 50 additions & 32 deletions src/autointent/_dump_tools/unit_dumpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,27 @@
import json
import logging
from pathlib import Path
from typing import Any, TypeVar
from typing import TYPE_CHECKING, Any, TypeVar

import aiofiles
import joblib
import numpy as np
import numpy.typing as npt
from catboost import CatBoostClassifier
from peft import PeftModel
from pydantic import BaseModel
from sklearn.base import BaseEstimator
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
)

from autointent import Embedder, Ranker, VectorIndex
from autointent._utils import require
from autointent._wrappers import BaseTorchModule
from autointent.schemas import TagsList

from .base import BaseObjectDumper, ModuleSimpleAttributes

if TYPE_CHECKING:
from catboost import CatBoostClassifier
from peft import PeftModel
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast

T = TypeVar("T")
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -204,11 +201,11 @@ def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
return isinstance(obj, BaseModel)


class PeftModelDumper(BaseObjectDumper[PeftModel]):
class PeftModelDumper(BaseObjectDumper["PeftModel"]):
dir_or_file_name = "peft_models"

@staticmethod
def dump(obj: PeftModel, path: Path, exists_ok: bool) -> None:
def dump(obj: "PeftModel", path: Path, exists_ok: bool) -> None:
path.mkdir(parents=True, exist_ok=exists_ok)
if obj._is_prompt_learning: # noqa: SLF001
# strategy to save prompt learning models: save prompt encoder and bert classifier separately
Expand All @@ -224,56 +221,72 @@ def dump(obj: PeftModel, path: Path, exists_ok: bool) -> None:
merged_model.save_pretrained(lora_path)

@staticmethod
def load(path: Path, **kwargs: Any) -> PeftModel: # noqa: ANN401, ARG004
def load(path: Path, **kwargs: Any) -> "PeftModel": # noqa: ANN401, ARG004
peft = require("peft", extra="peft")
transformers = require("transformers", extra="transformers")
if (path / "ptuning").exists():
# prompt learning model
ptuning_path = path / "ptuning"
model = AutoModelForSequenceClassification.from_pretrained(ptuning_path / "base_model")
return PeftModel.from_pretrained(model, ptuning_path / "peft")
model = transformers.AutoModelForSequenceClassification.from_pretrained(ptuning_path / "base_model")
return peft.PeftModel.from_pretrained(model, ptuning_path / "peft") # type: ignore[no-any-return]
if (path / "lora").exists():
# merged lora model
lora_path = path / "lora"
return AutoModelForSequenceClassification.from_pretrained(lora_path) # type: ignore[no-any-return]
return transformers.AutoModelForSequenceClassification.from_pretrained(lora_path) # type: ignore[no-any-return]
msg = f"Invalid PeftModel directory structure at {path}. Expected 'ptuning' or 'lora' subdirectory."
raise ValueError(msg)

@classmethod
def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
return isinstance(obj, PeftModel)
try:
peft = require("peft", extra="peft")
return isinstance(obj, peft.PeftModel)
except ImportError:
return False


class HFModelDumper(BaseObjectDumper[PreTrainedModel]):
class HFModelDumper(BaseObjectDumper["PreTrainedModel"]):
dir_or_file_name = "hf_models"

@staticmethod
def dump(obj: PreTrainedModel, path: Path, exists_ok: bool) -> None:
def dump(obj: "PreTrainedModel", path: Path, exists_ok: bool) -> None:
path.mkdir(parents=True, exist_ok=exists_ok)
obj.save_pretrained(path)

@staticmethod
def load(path: Path, **kwargs: Any) -> PreTrainedModel: # noqa: ANN401, ARG004
return AutoModelForSequenceClassification.from_pretrained(path) # type: ignore[no-any-return]
def load(path: Path, **kwargs: Any) -> "PreTrainedModel": # noqa: ANN401, ARG004
transformers = require("transformers", extra="transformers")
return transformers.AutoModelForSequenceClassification.from_pretrained(path) # type: ignore[no-any-return]

@classmethod
def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
return isinstance(obj, PreTrainedModel)
try:
transformers = require("transformers", extra="transformers")
return isinstance(obj, transformers.PreTrainedModel)
except ImportError:
return False


class HFTokenizerDumper(BaseObjectDumper[PreTrainedTokenizer | PreTrainedTokenizerFast]):
class HFTokenizerDumper(BaseObjectDumper["PreTrainedTokenizer | PreTrainedTokenizerFast"]):
dir_or_file_name = "hf_tokenizers"

@staticmethod
def dump(obj: PreTrainedTokenizer | PreTrainedTokenizerFast, path: Path, exists_ok: bool) -> None:
def dump(obj: "PreTrainedTokenizer | PreTrainedTokenizerFast", path: Path, exists_ok: bool) -> None:
path.mkdir(parents=True, exist_ok=exists_ok)
obj.save_pretrained(path)

@staticmethod
def load(path: Path, **kwargs: Any) -> PreTrainedTokenizer | PreTrainedTokenizerFast: # noqa: ANN401, ARG004
return AutoTokenizer.from_pretrained(path) # type: ignore[no-any-return,no-untyped-call]
def load(path: Path, **kwargs: Any) -> "PreTrainedTokenizer | PreTrainedTokenizerFast": # noqa: ANN401, ARG004
transformers = require("transformers", extra="transformers")
return transformers.AutoTokenizer.from_pretrained(path) # type: ignore[no-any-return]

@classmethod
def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
return isinstance(obj, PreTrainedTokenizer | PreTrainedTokenizerFast)
try:
transformers = require("transformers", extra="transformers")
return isinstance(obj, transformers.PreTrainedTokenizer | transformers.PreTrainedTokenizerFast)
except ImportError:
return False


class TorchModelDumper(BaseObjectDumper[BaseTorchModule]):
Expand Down Expand Up @@ -303,20 +316,25 @@ def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
return isinstance(obj, BaseTorchModule)


class CatBoostDumper(BaseObjectDumper[CatBoostClassifier]):
class CatBoostDumper(BaseObjectDumper["CatBoostClassifier"]):
dir_or_file_name = "catboost_models"

@staticmethod
def dump(obj: CatBoostClassifier, path: Path, exists_ok: bool) -> None: # noqa: ARG004
def dump(obj: "CatBoostClassifier", path: Path, exists_ok: bool) -> None: # noqa: ARG004
path.parent.mkdir(parents=True, exist_ok=True)
obj.save_model(str(path), format="cbm")

@staticmethod
def load(path: Path, **kwargs: Any) -> CatBoostClassifier: # noqa: ANN401, ARG004
model = CatBoostClassifier()
def load(path: Path, **kwargs: Any) -> "CatBoostClassifier": # noqa: ANN401, ARG004
catboost = require("catboost", extra="catboost")
model = catboost.CatBoostClassifier()
model.load_model(str(path))
return model

@classmethod
def check_isinstance(cls, obj: Any) -> bool: # noqa: ANN401
return isinstance(obj, CatBoostClassifier)
try:
catboost = require("catboost", extra="catboost")
return isinstance(obj, catboost.CatBoostClassifier)
except ImportError:
return False
24 changes: 23 additions & 1 deletion src/autointent/_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utils."""

from typing import TypeVar
import importlib
from typing import Any, TypeVar

import torch

Expand All @@ -25,3 +26,24 @@ def detect_device() -> str:
if torch.mps.is_available():
return "mps"
return "cpu"


def require(dependency: str, extra: str | None = None) -> Any: # noqa: ANN401
"""Try to import dependency, raise informative ImportError if missing.

Args:
dependency: The name of the module to import
extra: Optional extra package name for pip install instructions

Returns:
The imported module

Raises:
ImportError: If the dependency is not installed
"""
try:
return importlib.import_module(dependency)
except ImportError as e:
extra_info = f" Install with `pip install autointent[{extra}]`." if extra else ""
msg = f"Missing dependency '{dependency}' required for this feature.{extra_info}"
raise ImportError(msg) from e
34 changes: 21 additions & 13 deletions src/autointent/_wrappers/embedder/sentence_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,29 @@
import tempfile
from functools import lru_cache
from pathlib import Path
from typing import Literal, cast, overload
from typing import TYPE_CHECKING, Literal, cast, overload
from uuid import uuid4

import huggingface_hub
import numpy as np
import numpy.typing as npt
import torch
from datasets import Dataset
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments
from sentence_transformers.losses import BatchAllTripletLoss
from sentence_transformers.training_args import BatchSamplers
from sklearn.model_selection import train_test_split
from transformers import EarlyStoppingCallback, TrainerCallback

from autointent._hash import Hasher
from autointent._utils import require
from autointent.configs import EmbedderFineTuningConfig, TaskTypeEnum
from autointent.configs._embedder import SentenceTransformerEmbeddingConfig
from autointent.custom_types import ListOfLabels

from .base import BaseEmbeddingBackend
from .utils import get_embeddings_path

if TYPE_CHECKING:
from sentence_transformers import SentenceTransformer
from transformers import TrainerCallback

logger = logging.getLogger(__name__)


Expand All @@ -48,6 +49,7 @@ class SentenceTransformerEmbeddingBackend(BaseEmbeddingBackend):
"""SentenceTransformer-based embedding backend implementation."""

supports_training: bool = True
_model: "SentenceTransformer | None"

def __init__(self, config: SentenceTransformerEmbeddingConfig) -> None:
"""Initialize the SentenceTransformer backend.
Expand All @@ -56,7 +58,7 @@ def __init__(self, config: SentenceTransformerEmbeddingConfig) -> None:
config: Configuration for SentenceTransformer embeddings.
"""
self.config = config
self._model: SentenceTransformer | None = None
self._model = None
self._trained: bool = False

def clear_ram(self) -> None:
Expand All @@ -68,10 +70,12 @@ def clear_ram(self) -> None:
self._model = None
torch.cuda.empty_cache()

def _load_model(self) -> SentenceTransformer:
def _load_model(self) -> "SentenceTransformer":
"""Load sentence transformers model to device."""
if self._model is None:
res = SentenceTransformer(
# Lazy import sentence-transformers
st = require("sentence_transformers", extra="sentence-transformers")
res = st.SentenceTransformer(
self.config.model_name,
device=self.config.device,
prompts=self.config.get_prompt_config(),
Expand Down Expand Up @@ -228,13 +232,17 @@ def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFin

model = self._load_model()

# Lazy import sentence-transformers training components (only needed for fine-tuning)
st = require("sentence_transformers", extra="sentence-transformers")
transformers = require("transformers", extra="transformers")

x_train, x_val, y_train, y_val = train_test_split(utterances, labels, test_size=config.val_fraction)
tr_ds = Dataset.from_dict({"text": x_train, "label": y_train})
val_ds = Dataset.from_dict({"text": x_val, "label": y_val})

loss = BatchAllTripletLoss(model=model, margin=config.margin)
loss = st.losses.BatchAllTripletLoss(model=model, margin=config.margin)
with tempfile.TemporaryDirectory() as tmp_dir:
args = SentenceTransformerTrainingArguments(
args = st.SentenceTransformerTrainingArguments(
save_strategy="epoch",
save_total_limit=1,
output_dir=tmp_dir,
Expand All @@ -245,19 +253,19 @@ def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFin
warmup_ratio=config.warmup_ratio,
fp16=config.fp16,
bf16=config.bf16,
batch_sampler=BatchSamplers.NO_DUPLICATES,
batch_sampler=st.training_args.BatchSamplers.NO_DUPLICATES,
metric_for_best_model="eval_loss",
load_best_model_at_end=True,
eval_strategy="epoch",
greater_is_better=False,
)
callbacks: list[TrainerCallback] = [
EarlyStoppingCallback(
transformers.EarlyStoppingCallback(
early_stopping_patience=config.early_stopping_patience,
early_stopping_threshold=config.early_stopping_threshold,
)
]
trainer = SentenceTransformerTrainer(
trainer = st.SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=tr_ds,
Expand Down
Loading
Loading