Skip to content

Commit

Permalink
Add progress bars for JINA embedding for local clustering (#1138)
Browse files Browse the repository at this point in the history
  • Loading branch information
brilee authored Jan 26, 2024
1 parent 589fa9d commit 48ec94f
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions lilac/data/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
BaseModel,
)
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential
from tqdm import tqdm

from ..batch_utils import compress_docs, flatten_path_iter, group_by_sorted_key_iter
from ..dataset_format import DatasetFormatInputSelector
Expand Down Expand Up @@ -341,6 +342,8 @@ def extract_text(item: Item) -> Item:

dataset.map(extract_text, output_path=cluster_output_path, overwrite=True)

total_len = dataset.stats(temp_text_path).total_count

cluster_ids_exists = schema.has_field((*cluster_output_path, CLUSTER_ID))
if not cluster_ids_exists or overwrite:
if task_info:
Expand All @@ -352,7 +355,10 @@ def compute_clusters(items: Iterator[Item]) -> Iterator[Item]:
items, items2 = itertools.tee(items)
docs: Iterator[Optional[str]] = (item.get(TEXT_COLUMN) for item in items)
cluster_items = sparse_to_dense_compute(
docs, lambda x: _hdbscan_cluster(x, min_cluster_size, use_garden)
docs,
lambda x: _hdbscan_cluster(
x, min_cluster_size, use_garden, num_docs=total_len, task_info=task_info
),
)
for item, cluster_item in zip(items2, cluster_items):
yield {**item, **(cluster_item or {})}
Expand All @@ -365,7 +371,6 @@ def compute_clusters(items: Iterator[Item]) -> Iterator[Item]:
overwrite=True,
)

total_len = dataset.stats(temp_text_path).total_count
cluster_titles_exist = schema.has_field((*cluster_output_path, CLUSTER_TITLE))
if not cluster_titles_exist or overwrite or recompute_titles:
if task_info:
Expand Down Expand Up @@ -491,6 +496,8 @@ def _hdbscan_cluster(
docs: Iterator[str],
min_cluster_size: int = MIN_CLUSTER_SIZE,
use_garden: bool = False,
num_docs: Optional[int] = None,
task_info: Optional[TaskInfo] = None,
) -> Iterator[Item]:
"""Cluster docs with HDBSCAN."""
if use_garden:
Expand All @@ -500,14 +507,23 @@ def _hdbscan_cluster(
response = remote_fn({'gzipped_docs': gzipped_docs})
yield from response['clusters']

if task_info:
task_info.message = 'Computing embeddings'
task_info.total_progress = 0
task_info.total_len = num_docs
with DebugTimer('Computing embeddings'):
jina = JinaV2Small()
jina.setup()
response = jina.compute(list(docs))
response = []
for doc in tqdm(docs, position=0, desc='Computing embeddings', total=num_docs):
response.extend(jina.compute([doc]))
if task_info and task_info.total_progress is not None:
task_info.total_progress += 1
jina.teardown()

del docs, jina
all_vectors = np.array([r[0][EMBEDDING_KEY] for r in response], dtype=np.float32)
del response, docs
del response
gc.collect()

# Use UMAP to reduce the dimensionality before hdbscan to speed up clustering.
Expand Down

0 comments on commit 48ec94f

Please sign in to comment.