Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vertexai: Add embeddings_task_type parameter to embed_query and embed_documents #716

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 31 additions & 15 deletions libs/vertexai/langchain_google_vertexai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,18 @@
_MIN_BATCH_SIZE = 5


EmbeddingTaskTypes = Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
"CODE_RETRIEVAL_QUERY",
]


class GoogleEmbeddingModelType(str, Enum):
TEXT = auto()
MULTIMODAL = auto()
Expand All @@ -63,6 +75,7 @@ class GoogleEmbeddingModelVersion(str, Enum):
EMBEDDINGS_NOV_2023 = auto()
EMBEDDINGS_DEC_2023 = auto()
EMBEDDINGS_MAY_2024 = auto()
EMBEDDINGS_NOV_2024 = auto()

@classmethod
def _missing_(cls, value: Any) -> "GoogleEmbeddingModelVersion":
Expand All @@ -82,6 +95,8 @@ def _missing_(cls, value: Any) -> "GoogleEmbeddingModelVersion":
or "text-multilingual-embedding-preview-0409" in value.lower()
):
return GoogleEmbeddingModelVersion.EMBEDDINGS_MAY_2024
if "text-embedding-005" in value.lower():
return GoogleEmbeddingModelVersion.EMBEDDINGS_NOV_2024

return GoogleEmbeddingModelVersion.EMBEDDINGS_JUNE_2023

Expand Down Expand Up @@ -376,17 +391,7 @@ def embed(
self,
texts: List[str],
batch_size: int = 0,
embeddings_task_type: Optional[
Literal[
"RETRIEVAL_QUERY",
"RETRIEVAL_DOCUMENT",
"SEMANTIC_SIMILARITY",
"CLASSIFICATION",
"CLUSTERING",
"QUESTION_ANSWERING",
"FACT_VERIFICATION",
]
] = None,
embeddings_task_type: Optional[EmbeddingTaskTypes] = None,
dimensions: Optional[int] = None,
) -> List[List[float]]:
"""Embed a list of strings.
Expand All @@ -406,6 +411,8 @@ def embed(
for Semantic Textual Similarity (STS).
CLASSIFICATION - Embeddings will be used for classification.
CLUSTERING - Embeddings will be used for clustering.
CODE_RETRIEVAL_QUERY - Embeddings will be used for
code retrieval for Java and Python.
The following are only supported on preview models:
QUESTION_ANSWERING
FACT_VERIFICATION
Expand Down Expand Up @@ -447,7 +454,11 @@ def embed(
return embeddings

def embed_documents(
self, texts: List[str], batch_size: int = 0
self,
texts: List[str],
batch_size: int = 0,
*,
embeddings_task_type: EmbeddingTaskTypes = "RETRIEVAL_DOCUMENT",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please, add a *, so that embeddings_task_type can be provided by name only

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

) -> List[List[float]]:
"""Embed a list of documents.

Expand All @@ -460,9 +471,14 @@ def embed_documents(
Returns:
List of embeddings, one for each text.
"""
return self.embed(texts, batch_size, "RETRIEVAL_DOCUMENT")
return self.embed(texts, batch_size, embeddings_task_type)

def embed_query(self, text: str) -> List[float]:
def embed_query(
self,
text: str,
*,
embeddings_task_type: EmbeddingTaskTypes = "RETRIEVAL_QUERY",
) -> List[float]:
"""Embed a text.

Args:
Expand All @@ -471,7 +487,7 @@ def embed_query(self, text: str) -> List[float]:
Returns:
Embedding for the text.
"""
return self.embed([text], 1, "RETRIEVAL_QUERY")[0]
return self.embed([text], 1, embeddings_task_type)[0]

@deprecated(
since="2.0.1", removal="3.0.0", alternative="VertexAIEmbeddings.embed_images()"
Expand Down
34 changes: 34 additions & 0 deletions libs/vertexai/tests/integration_tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,25 @@ def test_langchain_google_vertexai_embedding_documents(
assert model.model_name == model_name


@pytest.mark.release
@pytest.mark.parametrize(
"model_name, embeddings_dim",
_EMBEDDING_MODELS,
)
def test_langchain_google_vertexai_embedding_documents_with_task_type(
model_name: str,
embeddings_dim: int,
) -> None:
documents = ["foo bar"] * 8
model = VertexAIEmbeddings(model_name)
output = model.embed_documents(documents)
assert len(output) == 8
for embedding in output:
assert len(embedding) == embeddings_dim
assert model.model_name == model.client._model_id
assert model.model_name == model_name


@pytest.mark.release
@pytest.mark.parametrize(
"model_name, embeddings_dim",
Expand All @@ -65,6 +84,21 @@ def test_langchain_google_vertexai_embedding_query(model_name, embeddings_dim) -
assert len(output) == embeddings_dim


@pytest.mark.release
@pytest.mark.parametrize(
"model_name, embeddings_dim",
_EMBEDDING_MODELS,
)
def test_langchain_google_vertexai_embedding_query_with_task_type(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this integration test, or would a unit test be enough (we've tested in a test above that embeddings_task is passed to the Google API and it returns a valid output)

we've got too many integration tests now and the execution time gets longer and longer :), it might be a good idea to keep the total amount of them reasonable

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally makes sense! However, I'm not sure how we would write a clean unit test for the new argument in embed_documents() and embed_query(), as these methods in turn call embed() and return its response. It almost seems like we would be validating the embed() method instead.

If we only want to check if the arguments are passed correctly, one approach could be to mock the embed() method, capture the arguments, and then call the VertexAIEmbedding's embed() method. Happy to remove the tests if you think it's not required.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just by mocking the API:

def test_langchain_google_vertexai_no_dups_dynamic_batch_size() -> None:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the tests into unit tests.

model_name: str,
embeddings_dim: int,
) -> None:
document = "foo bar"
model = VertexAIEmbeddings(model_name)
output = model.embed_query(document)
assert len(output) == embeddings_dim


@pytest.mark.release
@pytest.mark.parametrize(
"dim, expected_dim",
Expand Down
50 changes: 48 additions & 2 deletions libs/vertexai/tests/unit_tests/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from typing import Any, Dict
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import pytest
from pydantic import model_validator
from typing_extensions import Self

from langchain_google_vertexai import VertexAIEmbeddings
from langchain_google_vertexai.embeddings import GoogleEmbeddingModelType
from langchain_google_vertexai.embeddings import (
EmbeddingTaskTypes,
GoogleEmbeddingModelType,
)


def test_langchain_google_vertexai_embed_image_multimodal_only() -> None:
Expand All @@ -29,6 +32,49 @@ def test_langchain_google_vertexai_no_dups_dynamic_batch_size() -> None:
assert len(batches) == 2


@patch.object(VertexAIEmbeddings, "embed")
def test_embed_documents_with_question_answering_task(mock_embed) -> None:
mock_embeddings = MockVertexAIEmbeddings("text-embedding-005")
texts = [f"text {i}" for i in range(5)]

embedding_dimension = 768
embeddings_task_type: EmbeddingTaskTypes = "QUESTION_ANSWERING"

mock_embed.return_value = [[0.001] * embedding_dimension for _ in texts]

embeddings = mock_embeddings.embed_documents(
texts=texts, embeddings_task_type=embeddings_task_type
)

assert isinstance(embeddings, list)
assert len(embeddings) == len(texts)
assert len(embeddings[0]) == embedding_dimension

# Verify embed() was called correctly
mock_embed.assert_called_once_with(texts, 0, embeddings_task_type)


@patch.object(VertexAIEmbeddings, "embed")
def test_embed_query_with_question_answering_task(mock_embed) -> None:
mock_embeddings = MockVertexAIEmbeddings("text-embedding-005")
text = "text 0"

embedding_dimension = 768
embeddings_task_type: EmbeddingTaskTypes = "QUESTION_ANSWERING"

mock_embed.return_value = [[0.001] * embedding_dimension]

embedding = mock_embeddings.embed_query(
text=text, embeddings_task_type=embeddings_task_type
)

assert isinstance(embedding, list)
assert len(embedding) == embedding_dimension

# Verify embed() was called correctly
mock_embed.assert_called_once_with([text], 1, embeddings_task_type)


class MockVertexAIEmbeddings(VertexAIEmbeddings):
"""
A mock class for avoiding instantiating VertexAI and the EmbeddingModel client
Expand Down
Loading