diff --git a/utils/clusters_utils.py b/utils/clusters_utils.py index 67bc5da..b69cb31 100644 --- a/utils/clusters_utils.py +++ b/utils/clusters_utils.py @@ -19,10 +19,15 @@ def compute_clusters(embeddings: List[List[float]]) -> List[int]: def compute_reduced_embeddings(embeddings: List[List[float]]) -> List[List[float]]: tsne_model = TSNE(n_components=2, random_state=42) + new_embeddings = [] for emb in embeddings: - print(len(emb)) - embeddings = np.asarray(embeddings, dtype=object) + if len(emb) != 384: + new_embeddings.append([0]*384) + else: + new_embeddings.append(emb) + + embeddings = np.asarray(new_embeddings, dtype=object) print(embeddings.shape) - tsne_data = tsne_model.fit_transform(embeddings) + tsne_data = tsne_model.fit_transform(new_embeddings) return (tsne_data - tsne_data.min()).tolist()