From 48ec94f4011cf7f6c5507cbf2bab3af7ac2df8a3 Mon Sep 17 00:00:00 2001 From: Brian Lee Date: Fri, 26 Jan 2024 15:43:05 -0500 Subject: [PATCH] Add progress bars for JINA embedding for local clustering (#1138) --- lilac/data/clustering.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/lilac/data/clustering.py b/lilac/data/clustering.py index f08441cf..7b57a73d 100644 --- a/lilac/data/clustering.py +++ b/lilac/data/clustering.py @@ -14,6 +14,7 @@ BaseModel, ) from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_random_exponential +from tqdm import tqdm from ..batch_utils import compress_docs, flatten_path_iter, group_by_sorted_key_iter from ..dataset_format import DatasetFormatInputSelector @@ -341,6 +342,8 @@ def extract_text(item: Item) -> Item: dataset.map(extract_text, output_path=cluster_output_path, overwrite=True) + total_len = dataset.stats(temp_text_path).total_count + cluster_ids_exists = schema.has_field((*cluster_output_path, CLUSTER_ID)) if not cluster_ids_exists or overwrite: if task_info: @@ -352,7 +355,10 @@ def compute_clusters(items: Iterator[Item]) -> Iterator[Item]: items, items2 = itertools.tee(items) docs: Iterator[Optional[str]] = (item.get(TEXT_COLUMN) for item in items) cluster_items = sparse_to_dense_compute( - docs, lambda x: _hdbscan_cluster(x, min_cluster_size, use_garden) + docs, + lambda x: _hdbscan_cluster( + x, min_cluster_size, use_garden, num_docs=total_len, task_info=task_info + ), ) for item, cluster_item in zip(items2, cluster_items): yield {**item, **(cluster_item or {})} @@ -365,7 +371,6 @@ def compute_clusters(items: Iterator[Item]) -> Iterator[Item]: overwrite=True, ) - total_len = dataset.stats(temp_text_path).total_count cluster_titles_exist = schema.has_field((*cluster_output_path, CLUSTER_TITLE)) if not cluster_titles_exist or overwrite or recompute_titles: if task_info: @@ -491,6 +496,8 @@ def _hdbscan_cluster( docs: Iterator[str], min_cluster_size: int = MIN_CLUSTER_SIZE, use_garden: bool = False, + num_docs: Optional[int] = None, + task_info: Optional[TaskInfo] = None, ) -> Iterator[Item]: """Cluster docs with HDBSCAN.""" if use_garden: @@ -500,14 +507,23 @@ def _hdbscan_cluster( response = remote_fn({'gzipped_docs': gzipped_docs}) yield from response['clusters'] + if task_info: + task_info.message = 'Computing embeddings' + task_info.total_progress = 0 + task_info.total_len = num_docs with DebugTimer('Computing embeddings'): jina = JinaV2Small() jina.setup() - response = jina.compute(list(docs)) + response = [] + for doc in tqdm(docs, position=0, desc='Computing embeddings', total=num_docs): + response.extend(jina.compute([doc])) + if task_info and task_info.total_progress is not None: + task_info.total_progress += 1 jina.teardown() + del docs, jina all_vectors = np.array([r[0][EMBEDDING_KEY] for r in response], dtype=np.float32) - del response, docs + del response gc.collect() # Use UMAP to reduce the dimensionality before hdbscan to speed up clustering.