Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix errors when the concept is empty #1158

Merged
merged 1 commit into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 27 additions & 6 deletions lilac/concepts/concept.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from joblib import Parallel, delayed
from pydantic import BaseModel, field_validator
from sklearn.base import clone
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import precision_recall_curve, roc_auc_score
from sklearn.model_selection import KFold
from sklearn.utils.validation import check_is_fitted

from ..embeddings.embedding import get_embed_fn
from ..signal import TextEmbeddingSignal, get_signal_cls
Expand Down Expand Up @@ -140,6 +142,15 @@ class ConceptMetrics(BaseModel):
overall: OverallScore


def _is_fitted(model: LogisticRegression) -> bool:
"""Check if the model is fitted."""
try:
check_is_fitted(model)
return True
except NotFittedError:
return False


@dataclasses.dataclass
class LogisticEmbeddingModel:
"""A model that uses logistic regression with embeddings."""
Expand All @@ -155,7 +166,10 @@ def __post_init__(self) -> None:

def score_embeddings(self, embeddings: np.ndarray) -> np.ndarray:
"""Get the scores for the provided embeddings."""
y_probs = self._model.predict_proba(embeddings)[:, 1]
if _is_fitted(self._model):
y_probs = self._model.predict_proba(embeddings)[:, 1]
else:
y_probs = np.ones(len(embeddings)) * 0.5
# Map [0, threshold, 1] to [0, 0.5, 1].
power = np.log(self._threshold) / np.log(0.5)
return y_probs**power
Expand All @@ -173,7 +187,9 @@ def _setup_training(
def fit(self, embeddings: np.ndarray, labels: list[bool]) -> None:
"""Fit the model to the provided embeddings and labels."""
label_set = set(labels)
if len(label_set) < 2:
if len(label_set) == 0:
return
elif len(label_set) < 2:
dim = embeddings.shape[1]
random_vector = np.random.randn(dim).astype(np.float32)
random_vector /= np.linalg.norm(random_vector)
Expand Down Expand Up @@ -206,7 +222,10 @@ def _fit_and_score(
if len(set(y_train)) < 2:
return np.array([]), np.array([])
model.fit(X_train, y_train)
y_pred = model.predict_proba(X_test)[:, 1]
if _is_fitted(model):
y_pred = model.predict_proba(X_test)[:, 1]
else:
y_pred = np.ones_like(y_test) * 0.5
return y_test, y_pred

# Compute the metrics for each validation fold in parallel.
Expand Down Expand Up @@ -298,7 +317,11 @@ def score_embeddings(self, draft: DraftId, embeddings: np.ndarray) -> np.ndarray

def coef(self, draft: DraftId = DRAFT_MAIN) -> np.ndarray:
"""Get the coefficients of the underlying ML model."""
return self._get_logistic_model(draft)._model.coef_.reshape(-1)
model = self._get_logistic_model(draft)
if _is_fitted(model._model):
return model._model.coef_.reshape(-1)
else:
return np.zeros(0)

def _get_logistic_model(self, draft: DraftId = DRAFT_MAIN) -> LogisticEmbeddingModel:
"""Get the logistic model for the provided draft."""
Expand Down Expand Up @@ -345,8 +368,6 @@ def _compute_embeddings(self, concept: Concept) -> None:
concept_embeddings: dict[str, np.ndarray] = {}

examples = concept.data.items()
if not examples:
raise ValueError(f'Cannot sync concept "{concept.concept_name}". It has no examples.')

# Compute the embeddings for the examples with cache miss.
texts_of_missing_embeddings: dict[str, str] = {}
Expand Down
14 changes: 14 additions & 0 deletions lilac/concepts/db_concept_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,3 +714,17 @@ def test_embedding_not_found_in_map(

with pytest.raises(ValueError, match='Example "unknown text" not in embedding map'):
model_db.sync(model.namespace, model.concept_name, model.embedding_name)

def test_empty_concept(
self, concept_db_cls: Type[ConceptDB], model_db_cls: Type[ConceptModelDB]
) -> None:
concept_db = concept_db_cls()
model_db = model_db_cls(concept_db)

namespace = 'test'
concept_name = 'test_concept'
concept_db.create(namespace=namespace, name=concept_name, type=ConceptType.TEXT)
model = model_db.create(namespace, concept_name, embedding_name='test_embedding')
model = model_db.sync(model.namespace, model.concept_name, model.embedding_name)
# Make sure the model is in sync.
assert model_db.in_sync(model) is True
Loading