Skip to content

Commit

Permalink
Migrate to the stable version of discoveryengine
Browse files Browse the repository at this point in the history
  • Loading branch information
lgesuellip committed Dec 4, 2024
1 parent 019d578 commit d6773aa
Show file tree
Hide file tree
Showing 5 changed files with 1,345 additions and 1,278 deletions.
6 changes: 3 additions & 3 deletions libs/community/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ lint_tests: MYPY_CACHE=.mypy_cache_test

lint lint_diff lint_package lint_tests:
./scripts/lint_imports.sh
poetry run ruff .
poetry run ruff check .
poetry run ruff format $(PYTHON_FILES) --diff
poetry run ruff --select I $(PYTHON_FILES)
poetry run ruff check --select I $(PYTHON_FILES)
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)

format format_diff:
poetry run ruff format $(PYTHON_FILES)
poetry run ruff --select I --fix $(PYTHON_FILES)
poetry run ruff check --select I --fix $(PYTHON_FILES)

spell_check:
poetry run codespell --toml pyproject.toml
Expand Down
76 changes: 8 additions & 68 deletions libs/community/langchain_google_community/vertex_ai_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from google.protobuf.json_format import MessageToDict
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.load import Serializable, load
from langchain_core.retrievers import BaseRetriever
from langchain_core.tools import BaseTool
Expand All @@ -25,7 +24,7 @@
from langchain_google_community._utils import get_client_info

if TYPE_CHECKING:
from google.cloud.discoveryengine_v1beta import ( # type: ignore[import, attr-defined]
from google.cloud.discoveryengine_v1 import ( # type: ignore[import, attr-defined]
ConversationalSearchServiceClient,
SearchRequest,
SearchResult,
Expand Down Expand Up @@ -69,7 +68,7 @@ def __reduce__(self) -> Any:
def validate_environment(cls, values: Dict) -> Any:
"""Validates the environment."""
try:
from google.cloud import discoveryengine_v1beta # noqa: F401
from google.cloud import discoveryengine_v1 # noqa: F401
except ImportError as exc:
raise ImportError(
"Could not import google-cloud-discoveryengine python package. "
Expand Down Expand Up @@ -279,23 +278,6 @@ class VertexAISearchRetriever(BaseRetriever, _BaseVertexAISearchRetriever):
https://cloud.google.com/generative-ai-app-builder/docs/boost-search-results
https://cloud.google.com/generative-ai-app-builder/docs/reference/rest/v1beta/BoostSpec
"""
custom_embedding: Optional[Embeddings] = None
"""Custom embedding model for the retriever. (Bring your own embedding)
It needs to match the embedding model that was used to embed docs in the datastore.
It needs to be a langchain embedding VertexAIEmbeddings(project="{PROJECT}")
If you provide an embedding model, you also need to provide a ranking_expression and
a custom_embedding_field_path.
https://cloud.google.com/generative-ai-app-builder/docs/bring-embeddings
"""
custom_embedding_field_path: Optional[str] = None
""" The field path for the custom embedding used in the Vertex AI datastore schema.
"""
custom_embedding_ratio: Optional[float] = 0.0
"""Controls the ranking of results. Value should be between 0 and 1.
It will generate the ranking_expression in the following manner:
"{custom_embedding_ratio} * dotProduct({custom_embedding_field_path}) +
{1 - custom_embedding_ratio} * relevance_score"
"""

_client: SearchServiceClient = PrivateAttr()
_serving_config: str = PrivateAttr()
Expand All @@ -308,7 +290,7 @@ class VertexAISearchRetriever(BaseRetriever, _BaseVertexAISearchRetriever):
def __init__(self, **kwargs: Any) -> None:
"""Initializes private fields."""
try:
from google.cloud.discoveryengine_v1beta import SearchServiceClient
from google.cloud.discoveryengine_v1 import SearchServiceClient
except ImportError as exc:
raise ImportError(
"Could not import google-cloud-discoveryengine python package. "
Expand Down Expand Up @@ -340,7 +322,7 @@ def __init__(self, **kwargs: Any) -> None:
def _get_content_spec_kwargs(self) -> Optional[Dict[str, Any]]:
"""Prepares a ContentSpec object."""

from google.cloud.discoveryengine_v1beta import SearchRequest
from google.cloud.discoveryengine_v1 import SearchRequest

if self.engine_data_type == 0:
if self.get_extractive_answers:
Expand Down Expand Up @@ -382,7 +364,7 @@ def _get_content_spec_kwargs(self) -> Optional[Dict[str, Any]]:

def _create_search_request(self, query: str) -> SearchRequest:
"""Prepares a SearchRequest object."""
from google.cloud.discoveryengine_v1beta import SearchRequest
from google.cloud.discoveryengine_v1 import SearchRequest

query_expansion_spec = SearchRequest.QueryExpansionSpec(
condition=self.query_expansion_condition,
Expand All @@ -401,46 +383,6 @@ def _create_search_request(self, query: str) -> SearchRequest:
else:
content_search_spec = None

if (
self.custom_embedding is not None
or self.custom_embedding_field_path is not None
):
if self.custom_embedding is None:
raise ValueError(
"Please provide a custom embedding model if you provide a "
"custom_embedding_field_path."
)
if self.custom_embedding_field_path is None:
raise ValueError(
"Please provide a custom_embedding_field_path if you provide a "
"custom embedding model."
)
if self.custom_embedding_ratio is None:
raise ValueError(
"Please provide a custom_embedding_ratio if you provide a "
"custom embedding model or a custom_embedding_field_path."
)
if not 0 <= self.custom_embedding_ratio <= 1:
raise ValueError(
"Custom embedding ratio must be between 0 and 1 "
f"when using custom embeddings. Got {self.custom_embedding_ratio}"
)
embedding_vector = SearchRequest.EmbeddingSpec.EmbeddingVector(
field_path=self.custom_embedding_field_path,
vector=self.custom_embedding.embed_query(query),
)
embedding_spec = SearchRequest.EmbeddingSpec(
embedding_vectors=[embedding_vector]
)
ranking_expression = (
f"{self.custom_embedding_ratio} * "
f"dotProduct({self.custom_embedding_field_path}) + "
f"{1 - self.custom_embedding_ratio} * relevance_score"
)
else:
embedding_spec = None
ranking_expression = None

return SearchRequest(
query=query,
filter=self.filter,
Expand All @@ -454,8 +396,6 @@ def _create_search_request(self, query: str) -> SearchRequest:
boost_spec=SearchRequest.BoostSpec(**self.boost_spec)
if self.boost_spec
else None,
embedding_spec=embedding_spec,
ranking_expression=ranking_expression,
)

def _get_relevant_documents(
Expand Down Expand Up @@ -517,7 +457,7 @@ class VertexAIMultiTurnSearchRetriever(BaseRetriever, _BaseVertexAISearchRetriev

def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
from google.cloud.discoveryengine_v1beta import (
from google.cloud.discoveryengine_v1 import (
ConversationalSearchServiceClient,
)

Expand Down Expand Up @@ -545,7 +485,7 @@ def _get_relevant_documents(
self, query: str, *, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Get documents relevant for a query."""
from google.cloud.discoveryengine_v1beta import (
from google.cloud.discoveryengine_v1 import (
ConverseConversationRequest,
TextInput,
)
Expand Down Expand Up @@ -599,7 +539,7 @@ def _get_content_spec_kwargs(self) -> Optional[Dict[str, Any]]:
Returns:
kwargs for the specification of the content.
"""
from google.cloud.discoveryengine_v1beta import SearchRequest
from google.cloud.discoveryengine_v1 import SearchRequest

kwargs = super()._get_content_spec_kwargs() or {}

Expand Down
Loading

0 comments on commit d6773aa

Please sign in to comment.