Skip to content

Commit d64bd7d

Browse files
vertexai: Add embeddings_task_type parameter to embed_query and embed_documents (#716)
1 parent 5788f75 commit d64bd7d

File tree

3 files changed

+113
-17
lines changed

3 files changed

+113
-17
lines changed

libs/vertexai/langchain_google_vertexai/embeddings.py

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,18 @@
4545
_MIN_BATCH_SIZE = 5
4646

4747

48+
EmbeddingTaskTypes = Literal[
49+
"RETRIEVAL_QUERY",
50+
"RETRIEVAL_DOCUMENT",
51+
"SEMANTIC_SIMILARITY",
52+
"CLASSIFICATION",
53+
"CLUSTERING",
54+
"QUESTION_ANSWERING",
55+
"FACT_VERIFICATION",
56+
"CODE_RETRIEVAL_QUERY",
57+
]
58+
59+
4860
class GoogleEmbeddingModelType(str, Enum):
4961
TEXT = auto()
5062
MULTIMODAL = auto()
@@ -63,6 +75,7 @@ class GoogleEmbeddingModelVersion(str, Enum):
6375
EMBEDDINGS_NOV_2023 = auto()
6476
EMBEDDINGS_DEC_2023 = auto()
6577
EMBEDDINGS_MAY_2024 = auto()
78+
EMBEDDINGS_NOV_2024 = auto()
6679

6780
@classmethod
6881
def _missing_(cls, value: Any) -> "GoogleEmbeddingModelVersion":
@@ -82,6 +95,8 @@ def _missing_(cls, value: Any) -> "GoogleEmbeddingModelVersion":
8295
or "text-multilingual-embedding-preview-0409" in value.lower()
8396
):
8497
return GoogleEmbeddingModelVersion.EMBEDDINGS_MAY_2024
98+
if "text-embedding-005" in value.lower():
99+
return GoogleEmbeddingModelVersion.EMBEDDINGS_NOV_2024
85100

86101
return GoogleEmbeddingModelVersion.EMBEDDINGS_JUNE_2023
87102

@@ -376,17 +391,7 @@ def embed(
376391
self,
377392
texts: List[str],
378393
batch_size: int = 0,
379-
embeddings_task_type: Optional[
380-
Literal[
381-
"RETRIEVAL_QUERY",
382-
"RETRIEVAL_DOCUMENT",
383-
"SEMANTIC_SIMILARITY",
384-
"CLASSIFICATION",
385-
"CLUSTERING",
386-
"QUESTION_ANSWERING",
387-
"FACT_VERIFICATION",
388-
]
389-
] = None,
394+
embeddings_task_type: Optional[EmbeddingTaskTypes] = None,
390395
dimensions: Optional[int] = None,
391396
) -> List[List[float]]:
392397
"""Embed a list of strings.
@@ -406,6 +411,8 @@ def embed(
406411
for Semantic Textual Similarity (STS).
407412
CLASSIFICATION - Embeddings will be used for classification.
408413
CLUSTERING - Embeddings will be used for clustering.
414+
CODE_RETRIEVAL_QUERY - Embeddings will be used for
415+
code retrieval for Java and Python.
409416
The following are only supported on preview models:
410417
QUESTION_ANSWERING
411418
FACT_VERIFICATION
@@ -447,7 +454,11 @@ def embed(
447454
return embeddings
448455

449456
def embed_documents(
450-
self, texts: List[str], batch_size: int = 0
457+
self,
458+
texts: List[str],
459+
batch_size: int = 0,
460+
*,
461+
embeddings_task_type: EmbeddingTaskTypes = "RETRIEVAL_DOCUMENT",
451462
) -> List[List[float]]:
452463
"""Embed a list of documents.
453464
@@ -460,9 +471,14 @@ def embed_documents(
460471
Returns:
461472
List of embeddings, one for each text.
462473
"""
463-
return self.embed(texts, batch_size, "RETRIEVAL_DOCUMENT")
474+
return self.embed(texts, batch_size, embeddings_task_type)
464475

465-
def embed_query(self, text: str) -> List[float]:
476+
def embed_query(
477+
self,
478+
text: str,
479+
*,
480+
embeddings_task_type: EmbeddingTaskTypes = "RETRIEVAL_QUERY",
481+
) -> List[float]:
466482
"""Embed a text.
467483
468484
Args:
@@ -471,7 +487,7 @@ def embed_query(self, text: str) -> List[float]:
471487
Returns:
472488
Embedding for the text.
473489
"""
474-
return self.embed([text], 1, "RETRIEVAL_QUERY")[0]
490+
return self.embed([text], 1, embeddings_task_type)[0]
475491

476492
@deprecated(
477493
since="2.0.1", removal="3.0.0", alternative="VertexAIEmbeddings.embed_images()"

libs/vertexai/tests/integration_tests/test_embeddings.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,25 @@ def test_langchain_google_vertexai_embedding_documents(
5353
assert model.model_name == model_name
5454

5555

56+
@pytest.mark.release
57+
@pytest.mark.parametrize(
58+
"model_name, embeddings_dim",
59+
_EMBEDDING_MODELS,
60+
)
61+
def test_langchain_google_vertexai_embedding_documents_with_task_type(
62+
model_name: str,
63+
embeddings_dim: int,
64+
) -> None:
65+
documents = ["foo bar"] * 8
66+
model = VertexAIEmbeddings(model_name)
67+
output = model.embed_documents(documents)
68+
assert len(output) == 8
69+
for embedding in output:
70+
assert len(embedding) == embeddings_dim
71+
assert model.model_name == model.client._model_id
72+
assert model.model_name == model_name
73+
74+
5675
@pytest.mark.release
5776
@pytest.mark.parametrize(
5877
"model_name, embeddings_dim",
@@ -65,6 +84,21 @@ def test_langchain_google_vertexai_embedding_query(model_name, embeddings_dim) -
6584
assert len(output) == embeddings_dim
6685

6786

87+
@pytest.mark.release
88+
@pytest.mark.parametrize(
89+
"model_name, embeddings_dim",
90+
_EMBEDDING_MODELS,
91+
)
92+
def test_langchain_google_vertexai_embedding_query_with_task_type(
93+
model_name: str,
94+
embeddings_dim: int,
95+
) -> None:
96+
document = "foo bar"
97+
model = VertexAIEmbeddings(model_name)
98+
output = model.embed_query(document)
99+
assert len(output) == embeddings_dim
100+
101+
68102
@pytest.mark.release
69103
@pytest.mark.parametrize(
70104
"dim, expected_dim",

libs/vertexai/tests/unit_tests/test_embeddings.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from typing import Any, Dict
2-
from unittest.mock import MagicMock
2+
from unittest.mock import MagicMock, patch
33

44
import pytest
55
from pydantic import model_validator
66
from typing_extensions import Self
77

88
from langchain_google_vertexai import VertexAIEmbeddings
9-
from langchain_google_vertexai.embeddings import GoogleEmbeddingModelType
9+
from langchain_google_vertexai.embeddings import (
10+
EmbeddingTaskTypes,
11+
GoogleEmbeddingModelType,
12+
)
1013

1114

1215
def test_langchain_google_vertexai_embed_image_multimodal_only() -> None:
@@ -29,6 +32,49 @@ def test_langchain_google_vertexai_no_dups_dynamic_batch_size() -> None:
2932
assert len(batches) == 2
3033

3134

35+
@patch.object(VertexAIEmbeddings, "embed")
36+
def test_embed_documents_with_question_answering_task(mock_embed) -> None:
37+
mock_embeddings = MockVertexAIEmbeddings("text-embedding-005")
38+
texts = [f"text {i}" for i in range(5)]
39+
40+
embedding_dimension = 768
41+
embeddings_task_type: EmbeddingTaskTypes = "QUESTION_ANSWERING"
42+
43+
mock_embed.return_value = [[0.001] * embedding_dimension for _ in texts]
44+
45+
embeddings = mock_embeddings.embed_documents(
46+
texts=texts, embeddings_task_type=embeddings_task_type
47+
)
48+
49+
assert isinstance(embeddings, list)
50+
assert len(embeddings) == len(texts)
51+
assert len(embeddings[0]) == embedding_dimension
52+
53+
# Verify embed() was called correctly
54+
mock_embed.assert_called_once_with(texts, 0, embeddings_task_type)
55+
56+
57+
@patch.object(VertexAIEmbeddings, "embed")
58+
def test_embed_query_with_question_answering_task(mock_embed) -> None:
59+
mock_embeddings = MockVertexAIEmbeddings("text-embedding-005")
60+
text = "text 0"
61+
62+
embedding_dimension = 768
63+
embeddings_task_type: EmbeddingTaskTypes = "QUESTION_ANSWERING"
64+
65+
mock_embed.return_value = [[0.001] * embedding_dimension]
66+
67+
embedding = mock_embeddings.embed_query(
68+
text=text, embeddings_task_type=embeddings_task_type
69+
)
70+
71+
assert isinstance(embedding, list)
72+
assert len(embedding) == embedding_dimension
73+
74+
# Verify embed() was called correctly
75+
mock_embed.assert_called_once_with([text], 1, embeddings_task_type)
76+
77+
3278
class MockVertexAIEmbeddings(VertexAIEmbeddings):
3379
"""
3480
A mock class for avoiding instantiating VertexAI and the EmbeddingModel client

0 commit comments

Comments
 (0)