Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add UI for concept scoring and improve model selection #423

Merged
merged 8 commits into from
Jul 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion data/concept/lilac/toxicity/concept.json
Original file line number Diff line number Diff line change
Expand Up @@ -5862,7 +5862,27 @@
"label": true,
"text": "is so bad today that some folks don't even know how to turn their headlights on. And yes....those people are stupid and lazy and don't take care of what they own.",
"id": "63d6182d4bde4f62a2ce2eb969e03c2b"
},
"b4339269275343019c56fc7ef4c665ee": {
"label": false,
"text": "\n\nVIOLENT PROTESTS BREAK OUT AT FOXCONN\u2019S \u2018IPHONE CITY\u2019 (2 MINUTE READ)",
"id": "b4339269275343019c56fc7ef4c665ee"
},
"075d0bcdd42c4c309e4145f299a3d994": {
"label": false,
"text": " [\n\n\nBricks is a library of natural language processing modules that can be used in any project. It contains code that can be copied and pasted from an online platform. There are three categories of modules: classifiers, extractors, and generators. The modules can help with sentence complexity estimations, sentiment analysis, and more. \n\nSTOP PAYING FOR CORPORATE-CONTROLLED MONGODB (SPONSO",
"id": "075d0bcdd42c4c309e4145f299a3d994"
},
"da2d7729324d4b09864d8c4fbbad93ab": {
"label": false,
"text": ".\n\n\n\ud83c\udf81 \n\nMISCELLANEOUS\n\nASK HN: PEOPLE WHO WERE LAID OFF OR QUIT RECENTLY, HOW ARE YOU DOING? (HACKER NEWS THREAD",
"id": "da2d7729324d4b09864d8c4fbbad93ab"
},
"1e252a08b4a94259ab26eddbbda567cc": {
"label": false,
"text": "####################################################",
"id": "1e252a08b4a94259ab26eddbbda567cc"
}
},
"version": 148
"version": 152
}
71 changes: 40 additions & 31 deletions src/concepts/concept.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import numpy as np
from joblib import Parallel, delayed
from pydantic import BaseModel, validator
from scipy.interpolate import interp1d
from sklearn.base import BaseEstimator, clone
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import f1_score, precision_score, recall_score, roc_auc_score
from sklearn.metrics import precision_recall_curve, roc_auc_score
from sklearn.model_selection import KFold

from ..db_manager import get_dataset
Expand All @@ -26,7 +27,11 @@
DEFAULT_NUM_NEG_EXAMPLES = 100

# The maximum number of cross-validation models to train.
MAX_NUM_CROSS_VAL_MODELS = 30
MAX_NUM_CROSS_VAL_MODELS = 15
# The β weight to use for the F-beta score: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.fbeta_score.html
# β = 0.5 means we value precision 2x as much as recall.
# β = 2 means we value recall 2x as much as precision.
F_BETA_WEIGHT = 0.5


class ConceptColumnInfo(BaseModel):
Expand Down Expand Up @@ -141,7 +146,8 @@ class ConceptMetrics(BaseModel):
class LogisticEmbeddingModel:
"""A model that uses logistic regression with embeddings."""

version: int = -1
_metrics: Optional[ConceptMetrics] = None
_threshold: float = 0.5

def __post_init__(self) -> None:
# See `notebooks/Toxicity.ipynb` for an example of training a concept model.
Expand All @@ -151,7 +157,10 @@ def __post_init__(self) -> None:
def score_embeddings(self, embeddings: np.ndarray) -> np.ndarray:
"""Get the scores for the provided embeddings."""
try:
return self._model.predict_proba(embeddings)[:, 1]
y_probs = self._model.predict_proba(embeddings)[:, 1]
# Map [0, threshold, 1] to [0, 0.5, 1].
interpolate_fn = interp1d([0, self._threshold, 1], [0, 0.4999, 1])
return interpolate_fn(y_probs)
except NotFittedError:
return np.random.rand(len(embeddings))

Expand Down Expand Up @@ -187,9 +196,11 @@ def fit(self, embeddings: np.ndarray, labels: list[bool],
f'Length of embeddings ({len(embeddings)}) must match length of labels ({len(labels)})')
X_train, y_train, sample_weights = self._setup_training(embeddings, labels, implicit_negatives)
self._model.fit(X_train, y_train, sample_weights)
self._metrics, self._threshold = self._compute_metrics(embeddings, labels, implicit_negatives)

def compute_metrics(self, embeddings: np.ndarray, labels: list[bool],
implicit_negatives: Optional[np.ndarray]) -> ConceptMetrics:
def _compute_metrics(
self, embeddings: np.ndarray, labels: list[bool],
implicit_negatives: Optional[np.ndarray]) -> tuple[Optional[ConceptMetrics], float]:
"""Return the concept metrics."""
labels = np.array(labels)
n_splits = min(len(labels), MAX_NUM_CROSS_VAL_MODELS)
Expand All @@ -198,6 +209,8 @@ def compute_metrics(self, embeddings: np.ndarray, labels: list[bool],
def _fit_and_score(model: BaseEstimator, X_train: np.ndarray, y_train: np.ndarray,
sample_weights: np.ndarray, X_test: np.ndarray,
y_test: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
if len(set(y_train)) < 2:
return np.array([]), np.array([])
model.fit(X_train, y_train, sample_weights)
y_pred = model.predict_proba(X_test)[:, 1]
return y_test, y_pred
Expand All @@ -214,18 +227,25 @@ def _fit_and_score(model: BaseEstimator, X_train: np.ndarray, y_train: np.ndarra

y_test = np.concatenate([y_test for y_test, _ in results], axis=0)
y_pred = np.concatenate([y_pred for _, y_pred in results], axis=0)
y_pred_binary = y_pred >= 0.5
f1_val = f1_score(y_test, y_pred_binary)
precision_val = precision_score(y_test, y_pred_binary)
recall_val = recall_score(y_test, y_pred_binary)
if len(set(y_test)) < 2:
return None, 0.5
roc_auc_val = roc_auc_score(y_test, y_pred)

return ConceptMetrics(
f1=f1_val,
precision=precision_val,
recall=recall_val,
precision, recall, thresholds = precision_recall_curve(y_test, y_pred)
numerator = (1 + F_BETA_WEIGHT**2) * precision * recall
denom = (F_BETA_WEIGHT**2 * precision) + recall
f1_scores = np.divide(numerator, denom, out=np.zeros_like(denom), where=(denom != 0))
max_f1: float = np.max(f1_scores)
max_f1_index = np.argmax(f1_scores)
max_f1_thresh: float = thresholds[max_f1_index]
max_f1_prec: float = precision[max_f1_index]
max_f1_recall: float = recall[max_f1_index]
metrics = ConceptMetrics(
f1=max_f1,
precision=max_f1_prec,
recall=max_f1_recall,
roc_auc=roc_auc_val,
overall=_get_overall_score(f1_val))
overall=_get_overall_score(max_f1))
return metrics, max_f1_thresh


def draft_examples(concept: Concept, draft: DraftId) -> dict[str, Example]:
Expand Down Expand Up @@ -272,6 +292,10 @@ class ConceptModel:
_logistic_models: dict[DraftId, LogisticEmbeddingModel] = dataclasses.field(default_factory=dict)
_negative_vectors: Optional[np.ndarray] = None

def get_metrics(self, concept: Concept) -> Optional[ConceptMetrics]:
"""Return the metrics for this model."""
return self._get_logistic_model(DRAFT_MAIN)._metrics

def __post_init__(self) -> None:
if self.column_info:
self.column_info.path = normalize_path(self.column_info.path)
Expand Down Expand Up @@ -311,18 +335,6 @@ def _get_logistic_model(self, draft: DraftId) -> LogisticEmbeddingModel:
self._logistic_models[draft] = LogisticEmbeddingModel()
return self._logistic_models[draft]

def compute_metrics(self, concept: Concept) -> ConceptMetrics:
"""Compute the metrics for the provided concept using the model."""
examples = draft_examples(concept, DRAFT_MAIN)
embeddings = np.array([self._embeddings[id] for id in examples.keys()])
labels = [example.label for example in examples.values()]
implicit_embeddings: Optional[np.ndarray] = None
implicit_labels: Optional[list[bool]] = None
model = self._get_logistic_model(DRAFT_MAIN)
model_str = f'{self.namespace}/{self.concept_name}/{self.embedding_name}/{self.version}'
with DebugTimer(f'Computing metrics for {model_str}'):
return model.compute_metrics(embeddings, labels, self._negative_vectors)

def sync(self, concept: Concept) -> bool:
"""Update the model with the latest labeled concept data."""
if concept.version == self.version:
Expand All @@ -343,9 +355,6 @@ def sync(self, concept: Concept) -> bool:
with DebugTimer(f'Fitting model for "{concept_path}"'):
model.fit(embeddings, labels, self._negative_vectors)

# Synchronize the model version with the concept version.
model.version = concept.version

# Synchronize the model version with the concept version.
self.version = concept.version

Expand Down
20 changes: 19 additions & 1 deletion src/concepts/db_concept.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ExampleIn,
)

CONCEPTS_DIR = 'concept'
DATASET_CONCEPTS_DIR = '.concepts'
CONCEPT_JSON_FILENAME = 'concept.json'

Expand Down Expand Up @@ -164,6 +165,11 @@ def remove_all(self, namespace: str, concept_name: str) -> None:
"""Remove all the models associated with a concept."""
pass

@abc.abstractmethod
def get_models(self, namespace: str, concept_name: str) -> list[ConceptModel]:
"""List all the models associated with a concept."""
pass

@abc.abstractmethod
def get_column_infos(self, namespace: str, concept_name: str) -> list[ConceptColumnInfo]:
"""Get the dataset columns where this concept was applied to."""
Expand Down Expand Up @@ -244,6 +250,18 @@ def remove_all(self, namespace: str, concept_name: str) -> None:
for dir in dirs:
shutil.rmtree(dir, ignore_errors=True)

@override
def get_models(self, namespace: str, concept_name: str) -> list[ConceptModel]:
"""List all the models associated with a concept."""
model_files = glob.iglob(os.path.join(_concept_output_dir(namespace, concept_name), '*.pkl'))
models: list[ConceptModel] = []
for model_file in model_files:
embedding_name = os.path.basename(model_file)[:-len('.pkl')]
model = self.get(namespace, concept_name, embedding_name)
if model:
models.append(model)
return models

@override
def get_column_infos(self, namespace: str, concept_name: str) -> list[ConceptColumnInfo]:
datasets_path = os.path.join(data_path(), DATASETS_DIR_NAME)
Expand All @@ -264,7 +282,7 @@ def get_column_infos(self, namespace: str, concept_name: str) -> list[ConceptCol

def _concept_output_dir(namespace: str, name: str) -> str:
"""Return the output directory for a given concept."""
return os.path.join(data_path(), 'concept', namespace, name)
return os.path.join(data_path(), CONCEPTS_DIR, namespace, name)


def _concept_json_path(namespace: str, name: str) -> str:
Expand Down
64 changes: 28 additions & 36 deletions src/router_concept.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,19 +123,30 @@ class ConceptModelInfo(BaseModel):
embedding_name: str
version: int
column_info: Optional[ConceptColumnInfo] = None
metrics: Optional[ConceptMetrics] = None


class ConceptModelResponse(BaseModel):
"""Response body for the get_concept_model endpoint."""
model: ConceptModelInfo
model_synced: bool


@router.get('/{namespace}/{concept_name}/{embedding_name}')
def get_concept_model(namespace: str,
concept_name: str,
embedding_name: str,
sync_model: bool = False) -> ConceptModelResponse:
@router.get('/{namespace}/{concept_name}/model')
def get_concept_models(namespace: str, concept_name: str) -> list[ConceptModelInfo]:
"""Get a concept model from a database."""
concept = DISK_CONCEPT_DB.get(namespace, concept_name)
if not concept:
raise HTTPException(
status_code=404, detail=f'Concept "{namespace}/{concept_name}" was not found')
models = DISK_CONCEPT_MODEL_DB.get_models(namespace, concept_name)
return [
ConceptModelInfo(
namespace=m.namespace,
concept_name=m.concept_name,
embedding_name=m.embedding_name,
version=m.version,
column_info=m.column_info,
metrics=m.get_metrics(concept)) for m in models
]


@router.get('/{namespace}/{concept_name}/model/{embedding_name}')
def get_concept_model(namespace: str, concept_name: str, embedding_name: str) -> ConceptModelInfo:
"""Get a concept model from a database."""
concept = DISK_CONCEPT_DB.get(namespace, concept_name)
if not concept:
Expand All @@ -145,43 +156,24 @@ def get_concept_model(namespace: str,
model = DISK_CONCEPT_MODEL_DB.get(namespace, concept_name, embedding_name)
if not model:
model = DISK_CONCEPT_MODEL_DB.create(namespace, concept_name, embedding_name)

if sync_model:
model_synced = DISK_CONCEPT_MODEL_DB.sync(model)
else:
model_synced = DISK_CONCEPT_MODEL_DB.in_sync(model)
model_synced = DISK_CONCEPT_MODEL_DB.sync(model)
model_info = ConceptModelInfo(
namespace=model.namespace,
concept_name=model.concept_name,
embedding_name=model.embedding_name,
version=model.version,
column_info=model.column_info)
return ConceptModelResponse(model=model_info, model_synced=model_synced)
column_info=model.column_info,
metrics=model.get_metrics(concept))
return model_info


class MetricsBody(BaseModel):
"""Request body for the compute_metrics endpoint."""
column_info: Optional[ConceptColumnInfo] = None


@router.post('/{namespace}/{concept_name}/{embedding_name}/compute_metrics')
def compute_metrics(namespace: str, concept_name: str, embedding_name: str,
body: MetricsBody) -> ConceptMetrics:
"""Compute the metrics for the concept model."""
concept = DISK_CONCEPT_DB.get(namespace, concept_name)
if not concept:
raise HTTPException(
status_code=404, detail=f'Concept "{namespace}/{concept_name}" was not found')

column_info = body.column_info
model = DISK_CONCEPT_MODEL_DB.get(namespace, concept_name, embedding_name, column_info)
if model is None:
model = DISK_CONCEPT_MODEL_DB.create(namespace, concept_name, embedding_name, column_info)
model_updated = DISK_CONCEPT_MODEL_DB.sync(model)
return model.compute_metrics(concept)


@router.post('/{namespace}/{concept_name}/{embedding_name}/score', response_model_exclude_none=True)
@router.post(
'/{namespace}/{concept_name}/model/{embedding_name}/score', response_model_exclude_none=True)
def score(namespace: str, concept_name: str, embedding_name: str, body: ScoreBody) -> ScoreResponse:
"""Score examples along the specified concept."""
concept = DISK_CONCEPT_DB.get(namespace, concept_name)
Expand Down
31 changes: 7 additions & 24 deletions src/server_concept_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from .data.dataset_utils import lilac_embedding
from .router_concept import (
ConceptModelInfo,
ConceptModelResponse,
CreateConceptOptions,
MergeConceptDraftOptions,
ScoreBody,
Expand Down Expand Up @@ -340,35 +339,19 @@ def test_concept_model_sync(mocker: MockerFixture) -> None:
assert response.status_code == 200

# Get the concept model.
url = '/api/v1/concepts/concept_namespace/concept/test_embedding?sync_model=False'
url = '/api/v1/concepts/concept_namespace/concept/model/test_embedding'
response = client.get(url)
assert response.status_code == 200
assert ConceptModelResponse.parse_obj(response.json()) == ConceptModelResponse(
model=ConceptModelInfo(
namespace='concept_namespace',
concept_name='concept',
embedding_name='test_embedding',
version=-1),
# The model shouldn't yet be synced because we set sync_model=False.
model_synced=False)

# Sync the concept model.
url = '/api/v1/concepts/concept_namespace/concept/test_embedding?sync_model=True'
response = client.get(url)
assert response.status_code == 200
assert ConceptModelResponse.parse_obj(response.json()) == ConceptModelResponse(
model=ConceptModelInfo(
namespace='concept_namespace',
concept_name='concept',
embedding_name='test_embedding',
version=1),
# The model should be synced because we set sync_model=True.
model_synced=True)
assert ConceptModelInfo.parse_obj(response.json()) == ConceptModelInfo(
namespace='concept_namespace',
concept_name='concept',
embedding_name='test_embedding',
version=1)

# Score an example.
mock_score_emb = mocker.patch.object(LogisticEmbeddingModel, 'score_embeddings', autospec=True)
mock_score_emb.return_value = np.array([0.9, 1.0])
url = '/api/v1/concepts/concept_namespace/concept/test_embedding/score'
url = '/api/v1/concepts/concept_namespace/concept/model/test_embedding/score'
score_body = ScoreBody(examples=[ScoreExample(text='hello world'), ScoreExample(text='hello')])
response = client.post(url, json=score_body.dict())
assert response.status_code == 200
Expand Down
30 changes: 30 additions & 0 deletions web/blueprint/src/lib/components/concepts/ConceptHoverPill.svelte
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
<script lang="ts">
import {formatValue, type ConceptMetrics} from '$lilac';
import {scoreToColor, scoreToText} from './colors';
export let metrics: ConceptMetrics;
</script>

<table>
<tr>
<td>Overall score</td><td class={scoreToColor[metrics.overall]}>
{scoreToText[metrics.overall]}
</td>
</tr>
<tr><td>F1</td><td>{formatValue(metrics.f1)}</td></tr>
<tr><td>Recall</td><td>{formatValue(metrics.recall)}</td></tr>
<tr><td>Precision</td><td>{formatValue(metrics.precision)}</td></tr>
<tr><td>Area under ROC</td><td>{formatValue(metrics.roc_auc)}</td></tr>
</table>

<style lang="postcss">
:global(.concept-score-pill .bx--tooltip__label) {
@apply mr-1 inline-block h-full truncate;
max-width: 5rem;
}
:global(.concept-score-pill .bx--tooltip__content) {
@apply flex flex-col items-center;
}
table td {
@apply px-2 py-1;
}
</style>
Loading