Skip to content

Commit

Permalink
chore: update type hint
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwotherspoon committed Oct 17, 2024
1 parent 7d74821 commit 4a953c6
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/langchain_google_cloud_sql_mysql/vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from __future__ import annotations

import json
from typing import Any, Iterable, Optional, Type, Union
from typing import TYPE_CHECKING, Any, Iterable, Optional, Sequence, Type, Union

import numpy as np
from langchain_core.documents import Document
Expand All @@ -33,6 +33,9 @@
)
from .loader import _parse_doc_from_row

if TYPE_CHECKING:
from sqlalchemy.engine.row import Row, RowMapping

DEFAULT_INDEX_NAME_SUFFIX = "langchainvectorindex"


Expand Down Expand Up @@ -644,7 +647,7 @@ def _query_collection(
filter: Optional[str] = None,
query_options: Optional[QueryOptions] = None,
map_results: Optional[bool] = True,
) -> list[Any]:
) -> Union[Sequence[Row], Sequence[RowMapping]]:
column_names = self.__get_column_names()
# Apply vector_to_string to the embedding_column
for i, v in enumerate(column_names):
Expand Down Expand Up @@ -673,7 +676,6 @@ def _query_collection(
)
stmt = f"SELECT {column_query}, {distance_function}({self.embedding_column}, string_to_vector('{embedding}')) AS distance FROM `{self.table_name}` WHERE NEAREST({self.embedding_column}) TO (string_to_vector('{embedding}'), 'num_neighbors={k}{num_partitions}') {filter} ORDER BY distance;"

# return self.engine._fetch(stmt)
if map_results:
return self.engine._fetch(stmt)
else:
Expand Down

0 comments on commit 4a953c6

Please sign in to comment.