Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into nik-md
Browse files Browse the repository at this point in the history
  • Loading branch information
nsthorat committed Dec 19, 2023
2 parents 28f5045 + 649c756 commit 91a1b28
Show file tree
Hide file tree
Showing 9 changed files with 881 additions and 82 deletions.
7 changes: 5 additions & 2 deletions lilac/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,11 @@ def __init__(self, config: Config) -> None:
super().__init__(config)

def run() -> None:
loop = asyncio.get_event_loop()
loop.run_until_complete(self.serve())
try:
loop = asyncio.get_event_loop()
loop.run_until_complete(self.serve())
except RuntimeError:
self.run()

self.thread = Thread(target=run)

Expand Down
9 changes: 8 additions & 1 deletion lilac/server_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test our public REST API."""
import os
from time import sleep

from fastapi.testclient import TestClient
from pytest_mock import MockerFixture
Expand All @@ -12,7 +13,7 @@
UserInfo,
get_session_user,
)
from .server import app
from .server import app, start_server, stop_server

client = TestClient(app)

Expand Down Expand Up @@ -161,3 +162,9 @@ def user() -> UserInfo:
),
auth_enabled=True,
)


def test_start_and_stop_server() -> None:
start_server()
sleep(1)
stop_server()
37 changes: 24 additions & 13 deletions lilac/signals/cluster_hdbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import numpy as np
from pydantic import Field as PyField
from sklearn.cluster import HDBSCAN
from typing_extensions import override

from ..embeddings.embedding import get_embed_fn
Expand All @@ -13,6 +12,7 @@
from ..utils import DebugTimer

CLUSTER_ID = 'cluster_id'
MEMBERSHIP_PROB = 'membership_prob'
MIN_CLUSTER_SIZE = 5
UMAP_N_COMPONENTS = 10

Expand Down Expand Up @@ -46,7 +46,12 @@ class ClusterHDBScan(VectorSignal):
@override
def fields(self) -> Field:
return field(
fields=[field(dtype='string_span', fields={CLUSTER_ID: field('int32', categorical=True)})]
fields=[
field(
dtype='string_span',
fields={CLUSTER_ID: field('int32', categorical=True), MEMBERSHIP_PROB: field('float32')},
)
]
)

@override
Expand Down Expand Up @@ -83,13 +88,17 @@ def _cluster_span_vectors(
f'UMAP: Reducing dimensionality of {len(all_vectors)} vectors '
f'of dimensionality {all_vectors[0].size} to {self.umap_n_components}'
):
reducer = umap.UMAP(
n_components=self.umap_n_components,
n_neighbors=30,
min_dist=0.0,
random_state=self.umap_random_state,
)
all_vectors = reducer.fit_transform(all_vectors)
dim = all_vectors[0].size
if self.umap_n_components < dim:
reducer = umap.UMAP(
n_components=self.umap_n_components,
n_neighbors=30,
min_dist=0.0,
random_state=self.umap_random_state,
)
all_vectors = reducer.fit_transform(all_vectors)

from sklearn.cluster import HDBSCAN

with DebugTimer('HDBSCAN: Clustering'):
hdbscan = HDBSCAN(min_cluster_size=self.min_cluster_size, n_jobs=-1)
Expand All @@ -99,11 +108,13 @@ def _cluster_span_vectors(
for spans in all_spans:
span_clusters: list[Item] = []
for text_span in spans:
cluster_id: Optional[int] = int(hdbscan.labels_[span_index])
cluster_id = int(hdbscan.labels_[span_index])
membership_prob = float(hdbscan.probabilities_[span_index])
start, end = text_span
if cluster_id == -1:
cluster_id = None
span_clusters.append(span(start, end, {CLUSTER_ID: cluster_id}))
metadata = {CLUSTER_ID: cluster_id, MEMBERSHIP_PROB: membership_prob}
if cluster_id < 0:
metadata = {CLUSTER_ID: -1}
span_clusters.append(span(start, end, metadata))
span_index += 1

yield span_clusters
45 changes: 35 additions & 10 deletions lilac/signals/cluster_hdbscan_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from pytest_mock import MockerFixture
from typing_extensions import override

from ..data.dataset import SortOrder
from ..data.dataset_test_utils import TestDataMaker, enriched_item
from ..schema import Item, RichData, lilac_embedding, span
from ..schema import ROWID, Item, RichData, lilac_embedding, span
from ..signal import TextEmbeddingSignal, clear_signal_registry, register_signal
from .cluster_hdbscan import ClusterHDBScan

Expand All @@ -19,8 +20,9 @@
EMBEDDINGS: dict[str, list[float]] = {
'a': [1.0, 0.0, 0.0],
'b': [0.0, 1.0, 0.0],
'c': [1.0, 0.1, 0.0],
'd': [0.0, 0.9, 0.0],
'c': [0.999, 0.0, 0.0],
'd': [0.0, 0.999, 0.0],
'outlier': [0.0, 0.0, 1.0],
}


Expand Down Expand Up @@ -54,19 +56,42 @@ def compute(self, data: Iterable[RichData]) -> Iterator[Item]:


def test_simple_data(make_test_data: TestDataMaker, mocker: MockerFixture) -> None:
dataset = make_test_data([{'text': 'a'}, {'text': 'b'}, {'text': 'c'}, {'text': 'd'}])
dataset = make_test_data(
[{'text': 'a'}, {'text': 'b'}, {'text': 'c'}, {'text': 'd'}, {'text': 'outlier'}]
)
dataset.compute_embedding('test_embedding', 'text')

signal = ClusterHDBScan(
embedding='test_embedding', min_cluster_size=2, umap_n_components=2, umap_random_state=1337
embedding='test_embedding', min_cluster_size=2, umap_n_components=3, umap_random_state=1337
)
dataset.compute_signal(signal, 'text')
signal_key = signal.key(is_computed_signal=True)
result = dataset.select_rows(combine_columns=True)
result = dataset.select_rows(combine_columns=True, sort_by=[ROWID], sort_order=SortOrder.ASC)
expected_result = [
{'text': enriched_item('a', {signal_key: [span(0, 1, {'cluster_id': 0})]})},
{'text': enriched_item('b', {signal_key: [span(0, 1, {'cluster_id': 1})]})},
{'text': enriched_item('c', {signal_key: [span(0, 1, {'cluster_id': 0})]})},
{'text': enriched_item('d', {signal_key: [span(0, 1, {'cluster_id': 1})]})},
{
'text': enriched_item(
'a', {signal_key: [span(0, 1, {'cluster_id': 0, 'membership_prob': 1.0})]}
)
},
{
'text': enriched_item(
'b', {signal_key: [span(0, 1, {'cluster_id': 1, 'membership_prob': 1.0})]}
)
},
{
'text': enriched_item(
'c', {signal_key: [span(0, 1, {'cluster_id': 0, 'membership_prob': 1.0})]}
)
},
{
'text': enriched_item(
'd', {signal_key: [span(0, 1, {'cluster_id': 1, 'membership_prob': 1.0})]}
)
},
{
'text': enriched_item(
'outlier', {signal_key: [span(0, 7, {'cluster_id': -1, 'membership_prob': None})]}
)
},
]
assert list(result) == expected_result
Loading

0 comments on commit 91a1b28

Please sign in to comment.