Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Respect the self._split param when computing embeddings for a text #1193

Merged
merged 3 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion lilac/concepts/concept.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,10 @@ class ExampleIn(BaseModel):

@field_validator('text')
@classmethod
def parse_text(cls, text: str) -> str:
def parse_text(cls, text: Optional[str]) -> Optional[str]:
"""Fixes surrogate errors in text: https://github.com/ijl/orjson/blob/master/README.md#str ."""
if not text:
return None
return text.encode('utf-8', 'replace').decode('utf-8')


Expand Down
2 changes: 2 additions & 0 deletions lilac/concepts/db_concept.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,8 @@ def _validate_examples(
self, examples: List[Union[ExampleIn, Example]], type: ConceptType
) -> None:
for example in examples:
if not example.text and not example.img:
raise ValueError('The example must have a text or image associated with it.')
inferred_type = 'text' if example.text else 'unknown'
if inferred_type != type:
raise ValueError(f'Example type "{inferred_type}" does not match concept type "{type}".')
Expand Down
11 changes: 8 additions & 3 deletions lilac/embeddings/bge.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Gegeral Text Embeddings (GTE) model. Open-source model, designed to run on device."""
import gc
from typing import TYPE_CHECKING, ClassVar, Optional
from typing import TYPE_CHECKING, Callable, ClassVar, Optional, cast

from typing_extensions import override

from ..splitters.chunk_splitter import TextChunk
from ..utils import log

if TYPE_CHECKING:
Expand All @@ -16,7 +17,7 @@
from ..signal import TextEmbeddingSignal
from ..splitters.spacy_splitter import clustering_spacy_chunker
from ..tasks import TaskExecutionType
from .embedding import chunked_compute_embedding
from .embedding import chunked_compute_embedding, identity_chunker
from .transformer_utils import SENTENCE_TRANSFORMER_BATCH_SIZE

# See https://huggingface.co/spaces/mteb/leaderboard for leaderboard of models.
Expand Down Expand Up @@ -69,11 +70,15 @@ def compute(self, docs: list[str]) -> list[Optional[Item]]:
# While we get docs in batches of 1024, the chunker expands that by a factor of 3-10.
# The sentence transformer API actually does batching internally, so we pass
# local_batch_size * 16 to allow the library to see all the chunks at once.
chunker = cast(
Callable[[str], list[TextChunk]],
clustering_spacy_chunker if self._split else identity_chunker,
)
return chunked_compute_embedding(
lambda docs: self._model.encode(docs)['dense_vecs'],
docs,
self.local_batch_size * 16,
chunker=clustering_spacy_chunker,
chunker=chunker,
)

@override
Expand Down
11 changes: 7 additions & 4 deletions lilac/embeddings/cohere.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
"""Cohere embeddings."""
from typing import TYPE_CHECKING, ClassVar, Optional
from typing import TYPE_CHECKING, Callable, ClassVar, Optional, cast

import numpy as np
from typing_extensions import override

from ..env import env
from ..schema import Item
from ..signal import TextEmbeddingSignal
from ..splitters.chunk_splitter import TextChunk
from ..splitters.spacy_splitter import clustering_spacy_chunker
from ..tasks import TaskExecutionType
from .embedding import chunked_compute_embedding
from .embedding import chunked_compute_embedding, identity_chunker

if TYPE_CHECKING:
from cohere import Client
Expand Down Expand Up @@ -65,6 +66,8 @@ def _embed_fn(docs: list[str]) -> list[np.ndarray]:
).embeddings
]

return chunked_compute_embedding(
_embed_fn, docs, self.local_batch_size, chunker=clustering_spacy_chunker
chunker = cast(
Callable[[str], list[TextChunk]],
clustering_spacy_chunker if self._split else identity_chunker,
)
return chunked_compute_embedding(_embed_fn, docs, self.local_batch_size, chunker=chunker)
16 changes: 12 additions & 4 deletions lilac/embeddings/gte.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Gegeral Text Embeddings (GTE) model. Open-source model, designed to run on device."""
import gc
import itertools
from typing import TYPE_CHECKING, ClassVar, Iterator, Optional
from typing import TYPE_CHECKING, Callable, ClassVar, Iterator, Optional, cast

import modal
from typing_extensions import override
Expand All @@ -19,7 +19,7 @@
from ..signal import TextEmbeddingSignal
from ..splitters.spacy_splitter import clustering_spacy_chunker
from ..tasks import TaskExecutionType
from .embedding import chunked_compute_embedding
from .embedding import chunked_compute_embedding, identity_chunker
from .transformer_utils import SENTENCE_TRANSFORMER_BATCH_SIZE, setup_model_device

# See https://huggingface.co/spaces/mteb/leaderboard for leaderboard of models.
Expand Down Expand Up @@ -69,17 +69,25 @@ def compute(self, docs: list[str]) -> list[Optional[Item]]:
# While we get docs in batches of 1024, the chunker expands that by a factor of 3-10.
# The sentence transformer API actually does batching internally, so we pass
# local_batch_size * 16 to allow the library to see all the chunks at once.
chunker = cast(
Callable[[str], list[TextChunk]],
clustering_spacy_chunker if self._split else identity_chunker,
)
return chunked_compute_embedding(
self._model.encode, docs, self.local_batch_size * 16, chunker=clustering_spacy_chunker
self._model.encode, docs, self.local_batch_size * 16, chunker=chunker
)

@override
def compute_garden(self, docs: Iterator[str]) -> Iterator[Item]:
# Trim the docs to the max context size.

trimmed_docs = (doc[:GTE_CONTEXT_SIZE] for doc in docs)
chunker = cast(
Callable[[str], list[TextChunk]],
clustering_spacy_chunker if self._split else identity_chunker,
)
text_chunks: Iterator[tuple[int, TextChunk]] = (
(i, chunk) for i, doc in enumerate(trimmed_docs) for chunk in clustering_spacy_chunker(doc)
(i, chunk) for i, doc in enumerate(trimmed_docs) for chunk in chunker(doc)
)
text_chunks, text_chunks_2 = itertools.tee(text_chunks)
chunk_texts = (chunk[0] for _, chunk in text_chunks)
Expand Down
12 changes: 8 additions & 4 deletions lilac/embeddings/nomic_embed.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
"""Gegeral Text Embeddings (GTE) model. Open-source model, designed to run on device."""
import gc
from typing import TYPE_CHECKING, ClassVar, Optional
from typing import TYPE_CHECKING, Callable, ClassVar, Optional, cast

import numpy as np
from typing_extensions import override

from ..splitters.chunk_splitter import TextChunk

if TYPE_CHECKING:
from sentence_transformers import SentenceTransformer

Expand All @@ -14,7 +16,7 @@
from ..signal import TextEmbeddingSignal
from ..splitters.spacy_splitter import clustering_spacy_chunker
from ..tasks import TaskExecutionType
from .embedding import chunked_compute_embedding
from .embedding import chunked_compute_embedding, identity_chunker
from .transformer_utils import SENTENCE_TRANSFORMER_BATCH_SIZE, setup_model_device

# See https://huggingface.co/spaces/mteb/leaderboard for leaderboard of models.
Expand Down Expand Up @@ -76,9 +78,11 @@ def _encode(doc: list[str]) -> list[np.ndarray]:
# While we get docs in batches of 1024, the chunker expands that by a factor of 3-10.
# The sentence transformer API actually does batching internally, so we pass
# local_batch_size * 16 to allow the library to see all the chunks at once.
return chunked_compute_embedding(
_encode, docs, self.local_batch_size * 16, chunker=clustering_spacy_chunker
chunker = cast(
Callable[[str], list[TextChunk]],
clustering_spacy_chunker if self._split else identity_chunker,
)
return chunked_compute_embedding(_encode, docs, self.local_batch_size * 16, chunker=chunker)

@override
def teardown(self) -> None:
Expand Down
11 changes: 7 additions & 4 deletions lilac/embeddings/openai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""OpenAI embeddings."""
from typing import ClassVar, Optional
from typing import Callable, ClassVar, Optional, cast

import numpy as np
from tenacity import retry, stop_after_attempt, wait_random_exponential
Expand All @@ -8,9 +8,10 @@
from ..env import env
from ..schema import Item
from ..signal import TextEmbeddingSignal
from ..splitters.chunk_splitter import TextChunk
from ..splitters.spacy_splitter import clustering_spacy_chunker
from ..tasks import TaskExecutionType
from .embedding import chunked_compute_embedding
from .embedding import chunked_compute_embedding, identity_chunker

API_NUM_PARALLEL_REQUESTS = 10
API_OPENAI_BATCH_SIZE = 128
Expand Down Expand Up @@ -92,6 +93,8 @@ def embed_fn(texts: list[str]) -> list[np.ndarray]:
)
return [np.array(embedding.embedding, dtype=np.float32) for embedding in response.data]

return chunked_compute_embedding(
embed_fn, docs, self.local_batch_size, chunker=clustering_spacy_chunker
chunker = cast(
Callable[[str], list[TextChunk]],
clustering_spacy_chunker if self._split else identity_chunker,
)
return chunked_compute_embedding(embed_fn, docs, self.local_batch_size, chunker=chunker)
11 changes: 8 additions & 3 deletions lilac/embeddings/sbert.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Sentence-BERT embeddings. Open-source models, designed to run on device."""
from typing import TYPE_CHECKING, ClassVar, Optional
from typing import TYPE_CHECKING, Callable, ClassVar, Optional, cast

from typing_extensions import override

from ..splitters.chunk_splitter import TextChunk
from ..tasks import TaskExecutionType

if TYPE_CHECKING:
Expand All @@ -12,7 +13,7 @@
from ..schema import Item
from ..signal import TextEmbeddingSignal
from ..splitters.spacy_splitter import clustering_spacy_chunker
from .embedding import chunked_compute_embedding
from .embedding import chunked_compute_embedding, identity_chunker
from .transformer_utils import SENTENCE_TRANSFORMER_BATCH_SIZE, setup_model_device

# The `all-mpnet-base-v2` model provides the best quality, while `all-MiniLM-L6-v2`` is 5 times
Expand Down Expand Up @@ -47,8 +48,12 @@ def compute(self, docs: list[str]) -> list[Optional[Item]]:
# While we get docs in batches of 1024, the chunker expands that by a factor of 3-10.
# The sentence transformer API actually does batching internally, so we pass
# local_batch_size * 16 to allow the library to see all the chunks at once.
chunker = cast(
Callable[[str], list[TextChunk]],
clustering_spacy_chunker if self._split else identity_chunker,
)
return chunked_compute_embedding(
self._model.encode, docs, self.local_batch_size * 16, chunker=clustering_spacy_chunker
self._model.encode, docs, self.local_batch_size * 16, chunker=chunker
)

@override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@
};
$: {
pathToSpans = {};
spanPaths.forEach(sp => {
(spanPaths || []).forEach(sp => {
if (row == null) return;
let valueNodes = getValueNodes(row, sp);
const isSpanNestedUnder = pathMatchesPrefix(sp, path);
Expand All @@ -97,7 +97,7 @@
let spanPathToValueInfos: Record<string, SpanValueInfo[]> = {};
$: {
spanPathToValueInfos = {};
for (const spanValueInfo of spanValueInfos) {
for (const spanValueInfo of spanValueInfos || []) {
const spanPathStr = serializePath(spanValueInfo.spanPath);
if (spanPathToValueInfos[spanPathStr] == null) {
spanPathToValueInfos[spanPathStr] = [];
Expand Down Expand Up @@ -206,7 +206,7 @@
$: {
if (model != null && editor != null) {
let minPosition: Monaco.Position | null = null;
for (const renderSpan of monacoSpans) {
for (const renderSpan of monacoSpans || []) {
const span = L.span(renderSpan.span)!;
const position = model.getPositionAt(span.start);

Expand Down Expand Up @@ -381,7 +381,7 @@

const conceptQuery = queryConcepts();
$: concepts = $conceptQuery.data;
let conceptsInMenu: Set<string>;
let conceptsInMenu: Set<string> = new Set();
let addToConceptItems: DropdownItem[] = [];

$: {
Expand Down
Loading