diff --git a/src/langchain_google_cloud_sql_mysql/chat_message_history.py b/src/langchain_google_cloud_sql_mysql/chat_message_history.py index 56584b6..93de88e 100644 --- a/src/langchain_google_cloud_sql_mysql/chat_message_history.py +++ b/src/langchain_google_cloud_sql_mysql/chat_message_history.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -from typing import List import sqlalchemy from langchain_core.chat_history import BaseChatMessageHistory @@ -78,7 +77,7 @@ def _verify_schema(self) -> None: ) @property - def messages(self) -> List[BaseMessage]: # type: ignore + def messages(self) -> list[BaseMessage]: # type: ignore """Retrieve the messages from Cloud SQL""" query = f"SELECT data, type FROM `{self.table_name}` WHERE session_id = :session_id ORDER BY id;" with self.engine.connect() as conn: diff --git a/src/langchain_google_cloud_sql_mysql/engine.py b/src/langchain_google_cloud_sql_mysql/engine.py index 7cdc8d6..301b21b 100644 --- a/src/langchain_google_cloud_sql_mysql/engine.py +++ b/src/langchain_google_cloud_sql_mysql/engine.py @@ -15,13 +15,14 @@ # TODO: Remove below import when minimum supported Python version is 3.10 from __future__ import annotations -from typing import TYPE_CHECKING, Dict, List, Optional +from typing import TYPE_CHECKING, Optional, Sequence import google.auth import google.auth.transport.requests import requests import sqlalchemy from google.cloud.sql.connector import Connector, RefreshStrategy +from sqlalchemy.engine.row import Row, RowMapping from .version import __version__ @@ -75,7 +76,7 @@ def _get_iam_principal_email( url = f"https://oauth2.googleapis.com/tokeninfo?access_token={credentials.token}" response = requests.get(url) response.raise_for_status() - response_json: Dict = response.json() + response_json: dict = response.json() email = response_json.get("email") if email is None: raise ValueError( @@ -235,7 +236,7 @@ def _execute_outside_tx(self, query: str, params: Optional[dict] = None) -> None conn = conn.execution_options(isolation_level="AUTOCOMMIT") conn.execute(sqlalchemy.text(query), params) - def _fetch(self, query: str, params: Optional[dict] = None): + def _fetch(self, query: str, params: Optional[dict] = None) -> Sequence[RowMapping]: """Fetch results from a SQL query.""" with self.engine.connect() as conn: result = conn.execute(sqlalchemy.text(query), params) @@ -243,7 +244,7 @@ def _fetch(self, query: str, params: Optional[dict] = None): result_fetch = result_map.fetchall() return result_fetch - def _fetch_rows(self, query: str, params: Optional[dict] = None): + def _fetch_rows(self, query: str, params: Optional[dict] = None) -> Sequence[Row]: """Fetch results from a SQL query as rows.""" with self.engine.connect() as conn: result = conn.execute(sqlalchemy.text(query), params) @@ -283,7 +284,7 @@ def init_chat_history_table(self, table_name: str) -> None: def init_document_table( self, table_name: str, - metadata_columns: List[sqlalchemy.Column] = [], + metadata_columns: list[sqlalchemy.Column] = [], content_column: str = "page_content", metadata_json_column: Optional[str] = "langchain_metadata", overwrite_existing: bool = False, @@ -293,7 +294,7 @@ def init_document_table( Args: table_name (str): The MySQL database table name. - metadata_columns (List[sqlalchemy.Column]): A list of SQLAlchemy Columns + metadata_columns (list[sqlalchemy.Column]): A list of SQLAlchemy Columns to create for custom metadata. Optional. content_column (str): The column to store document content. Deafult: `page_content`. @@ -347,7 +348,7 @@ def init_vectorstore_table( vector_size: int, content_column: str = "content", embedding_column: str = "embedding", - metadata_columns: List[Column] = [], + metadata_columns: list[Column] = [], metadata_json_column: str = "langchain_metadata", id_column: str = "langchain_id", overwrite_existing: bool = False, @@ -363,7 +364,7 @@ def init_vectorstore_table( Default: `page_content`. embedding_column (str) : Name of the column to store vector embeddings. Default: `embedding`. - metadata_columns (List[Column]): A list of Columns to create for custom + metadata_columns (list[Column]): A list of Columns to create for custom metadata. Default: []. Optional. metadata_json_column (str): The column to store extra metadata in JSON format. Default: `langchain_metadata`. Optional. diff --git a/src/langchain_google_cloud_sql_mysql/loader.py b/src/langchain_google_cloud_sql_mysql/loader.py index facbfe8..67aaecc 100644 --- a/src/langchain_google_cloud_sql_mysql/loader.py +++ b/src/langchain_google_cloud_sql_mysql/loader.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -from typing import Any, Dict, Iterable, Iterator, List, Optional, cast +from typing import Any, Iterable, Iterator, Optional, cast import pymysql import sqlalchemy @@ -28,13 +28,13 @@ def _parse_doc_from_row( content_columns: Iterable[str], metadata_columns: Iterable[str], - row: Dict, + row: dict, metadata_json_column: Optional[str] = DEFAULT_METADATA_COL, ) -> Document: page_content = " ".join( str(row[column]) for column in content_columns if column in row ) - metadata: Dict[str, Any] = {} + metadata: dict[str, Any] = {} # unnest metadata from langchain_metadata column if row.get(metadata_json_column): for k, v in row[metadata_json_column].items(): @@ -51,9 +51,9 @@ def _parse_row_from_doc( doc: Document, content_column: str = DEFAULT_CONTENT_COL, metadata_json_column: str = DEFAULT_METADATA_COL, -) -> Dict: +) -> dict: doc_metadata = doc.metadata.copy() - row: Dict[str, Any] = {content_column: doc.page_content} + row: dict[str, Any] = {content_column: doc.page_content} for entry in doc.metadata: if entry in column_names: row[entry] = doc_metadata[entry] @@ -72,8 +72,8 @@ def __init__( engine: MySQLEngine, table_name: str = "", query: str = "", - content_columns: Optional[List[str]] = None, - metadata_columns: Optional[List[str]] = None, + content_columns: Optional[list[str]] = None, + metadata_columns: Optional[list[str]] = None, metadata_json_column: Optional[str] = None, ): """ @@ -89,9 +89,9 @@ def __init__( engine (MySQLEngine): MySQLEngine object to connect to the MySQL database. table_name (str): The MySQL database table name. (OneOf: table_name, query). query (str): The query to execute in MySQL format. (OneOf: table_name, query). - content_columns (List[str]): The columns to write into the `page_content` + content_columns (list[str]): The columns to write into the `page_content` of the document. Optional. - metadata_columns (List[str]): The columns to write into the `metadata` of the document. + metadata_columns (list[str]): The columns to write into the `metadata` of the document. Optional. metadata_json_column (str): The name of the JSON column to use as the metadata’s base dictionary. Default: `langchain_metadata`. Optional. @@ -110,12 +110,12 @@ def __init__( "entire table or 'query' to load a specific query." ) - def load(self) -> List[Document]: + def load(self) -> list[Document]: """ Load langchain documents from a Cloud SQL MySQL database. Returns: - (List[langchain_core.documents.Document]): a list of Documents with metadata from + (list[langchain_core.documents.Document]): a list of Documents with metadata from specific columns. """ return list(self.lazy_load()) @@ -231,13 +231,13 @@ def __init__( ) self.metadata_json_column = metadata_json_column or DEFAULT_METADATA_COL - def add_documents(self, docs: List[Document]) -> None: + def add_documents(self, docs: list[Document]) -> None: """ Save documents in the DocumentSaver table. Document’s metadata is added to columns if found or stored in langchain_metadata JSON column. Args: - docs (List[langchain_core.documents.Document]): a list of documents to be saved. + docs (list[langchain_core.documents.Document]): a list of documents to be saved. """ with self.engine.connect() as conn: for doc in docs: @@ -250,13 +250,13 @@ def add_documents(self, docs: List[Document]) -> None: conn.execute(sqlalchemy.insert(self._table).values(row)) conn.commit() - def delete(self, docs: List[Document]) -> None: + def delete(self, docs: list[Document]) -> None: """ Delete all instances of a document from the DocumentSaver table by matching the entire Document object. Args: - docs (List[langchain_core.documents.Document]): a list of documents to be deleted. + docs (list[langchain_core.documents.Document]): a list of documents to be deleted. """ with self.engine.connect() as conn: for doc in docs: diff --git a/src/langchain_google_cloud_sql_mysql/vectorstore.py b/src/langchain_google_cloud_sql_mysql/vectorstore.py index 9028d01..7974a2a 100644 --- a/src/langchain_google_cloud_sql_mysql/vectorstore.py +++ b/src/langchain_google_cloud_sql_mysql/vectorstore.py @@ -16,7 +16,7 @@ from __future__ import annotations import json -from typing import Any, Iterable, List, Optional, Tuple, Type, Union +from typing import TYPE_CHECKING, Any, Iterable, Optional, Sequence, Type, Union import numpy as np from langchain_core.documents import Document @@ -33,6 +33,9 @@ ) from .loader import _parse_doc_from_row +if TYPE_CHECKING: + from sqlalchemy.engine.row import Row, RowMapping + DEFAULT_INDEX_NAME_SUFFIX = "langchainvectorindex" @@ -44,8 +47,8 @@ def __init__( table_name: str, content_column: str = "content", embedding_column: str = "embedding", - metadata_columns: List[str] = [], - ignore_metadata_columns: Optional[List[str]] = None, + metadata_columns: list[str] = [], + ignore_metadata_columns: Optional[list[str]] = None, id_column: str = "langchain_id", metadata_json_column: Optional[str] = "langchain_metadata", k: int = 4, @@ -60,8 +63,8 @@ def __init__( table_name (str): Name of an existing table or table to be created. content_column (str): Column that represent a Document's page_content. Defaults to "content". embedding_column (str): Column for embedding vectors. The embedding is generated from the document value. Defaults to "embedding". - metadata_columns (List[str]): Column(s) that represent a document's metadata. - ignore_metadata_columns (List[str]): Column(s) to ignore in pre-existing tables for a document's metadata. Can not be used with metadata_columns. Defaults to None. + metadata_columns (list[str]): Column(s) that represent a document's metadata. + ignore_metadata_columns (list[str]): Column(s) to ignore in pre-existing tables for a document's metadata. Can not be used with metadata_columns. Defaults to None. id_column (str): Column that represents the Document's id. Defaults to "langchain_id". metadata_json_column (str): Column to store metadata as JSON. Defaults to "langchain_metadata". k (int): The number of documents to return as the final result of a similarity search. Defaults to 4. @@ -141,7 +144,7 @@ def __get_db_name(self) -> str: result = self.engine._fetch("SELECT DATABASE();") return result[0]["DATABASE()"] - def __get_column_names(self) -> List[str]: + def __get_column_names(self) -> list[str]: results = self.engine._fetch( f"SELECT COLUMN_NAME FROM `INFORMATION_SCHEMA`.`COLUMNS` WHERE `TABLE_SCHEMA` = '{self.db_name}' AND `TABLE_NAME` = '{self.table_name}'" ) @@ -150,11 +153,11 @@ def __get_column_names(self) -> List[str]: def _add_embeddings( self, texts: Iterable[str], - embeddings: List[List[float]], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, + embeddings: list[list[float]], + metadatas: Optional[list[dict]] = None, + ids: Optional[list[str]] = None, **kwargs: Any, - ) -> List[str]: + ) -> list[str]: if not ids: ids = ["NULL" for _ in texts] if not metadatas: @@ -198,10 +201,10 @@ def _add_embeddings( def add_texts( self, texts: Iterable[str], - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, + metadatas: Optional[list[dict]] = None, + ids: Optional[list[str]] = None, **kwargs: Any, - ) -> List[str]: + ) -> list[str]: embeddings = self.embedding_service.embed_documents(list(texts)) ids = self._add_embeddings( texts, embeddings, metadatas=metadatas, ids=ids, **kwargs @@ -210,10 +213,10 @@ def add_texts( def add_documents( self, - documents: List[Document], - ids: Optional[List[str]] = None, + documents: list[Document], + ids: Optional[list[str]] = None, **kwargs: Any, - ) -> List[str]: + ) -> list[str]: texts = [doc.page_content for doc in documents] metadatas = [doc.metadata for doc in documents] ids = self.add_texts(texts, metadatas=metadatas, ids=ids, **kwargs) @@ -221,7 +224,7 @@ def add_documents( def delete( self, - ids: Optional[List[str]] = None, + ids: Optional[list[str]] = None, **kwargs: Any, ) -> bool: if not ids: @@ -234,7 +237,7 @@ def delete( self.engine._execute(query) return True - def apply_vector_index(self, vector_index: VectorIndex): + def apply_vector_index(self, vector_index: VectorIndex) -> None: # Construct the default index name if not vector_index.name: vector_index.name = f"{self.table_name}_{DEFAULT_INDEX_NAME_SUFFIX}" @@ -243,7 +246,7 @@ def apply_vector_index(self, vector_index: VectorIndex): # After applying an index to the table, set the query option search type to be ANN self.query_options.search_type = SearchType.ANN - def alter_vector_index(self, vector_index: VectorIndex): + def alter_vector_index(self, vector_index: VectorIndex) -> None: existing_index_name = self._get_vector_index_name() if not existing_index_name: raise ValueError("No existing vector index found.") @@ -258,7 +261,9 @@ def alter_vector_index(self, vector_index: VectorIndex): ) self.__exec_apply_vector_index(query_template, vector_index) - def __exec_apply_vector_index(self, query_template: str, vector_index: VectorIndex): + def __exec_apply_vector_index( + self, query_template: str, vector_index: VectorIndex + ) -> None: index_options = [] if vector_index.index_type: index_options.append(f"index_type={vector_index.index_type.value}") @@ -275,7 +280,7 @@ def __exec_apply_vector_index(self, query_template: str, vector_index: VectorInd stmt = query_template.format(index_options_query) self.engine._execute_outside_tx(stmt) - def _get_vector_index_name(self): + def _get_vector_index_name(self) -> Optional[str]: query = f"SELECT index_name FROM mysql.vector_indexes WHERE table_name='{self.db_name}.{self.table_name}';" result = self.engine._fetch(query) if result: @@ -283,7 +288,7 @@ def _get_vector_index_name(self): else: return None - def drop_vector_index(self): + def drop_vector_index(self) -> Optional[str]: existing_index_name = self._get_vector_index_name() if existing_index_name: self.engine._execute_outside_tx( @@ -295,16 +300,16 @@ def drop_vector_index(self): @classmethod def from_texts( # type: ignore[override] cls: Type[MySQLVectorStore], - texts: List[str], + texts: list[str], embedding: Embeddings, engine: MySQLEngine, table_name: str, - metadatas: Optional[List[dict]] = None, - ids: Optional[List[str]] = None, + metadatas: Optional[list[dict]] = None, + ids: Optional[list[str]] = None, content_column: str = "content", embedding_column: str = "embedding", - metadata_columns: List[str] = [], - ignore_metadata_columns: Optional[List[str]] = None, + metadata_columns: list[str] = [], + ignore_metadata_columns: Optional[list[str]] = None, id_column: str = "langchain_id", metadata_json_column: str = "langchain_metadata", query_options: QueryOptions = DEFAULT_QUERY_OPTIONS, @@ -328,15 +333,15 @@ def from_texts( # type: ignore[override] @classmethod def from_documents( # type: ignore[override] cls: Type[MySQLVectorStore], - documents: List[Document], + documents: list[Document], embedding: Embeddings, engine: MySQLEngine, table_name: str, - ids: Optional[List[str]] = None, + ids: Optional[list[str]] = None, content_column: str = "content", embedding_column: str = "embedding", - metadata_columns: List[str] = [], - ignore_metadata_columns: Optional[List[str]] = None, + metadata_columns: list[str] = [], + ignore_metadata_columns: Optional[list[str]] = None, id_column: str = "langchain_id", metadata_json_column: str = "langchain_metadata", query_options: QueryOptions = DEFAULT_QUERY_OPTIONS, @@ -365,7 +370,7 @@ def similarity_search( k: Optional[int] = None, filter: Optional[str] = None, **kwargs: Any, - ) -> List[Document]: + ) -> list[Document]: """Searches for similar documents based on a text query. Args: @@ -385,12 +390,12 @@ def similarity_search( def similarity_search_by_vector( self, - embedding: List[float], + embedding: list[float], k: Optional[int] = None, filter: Optional[str] = None, query_options: Optional[QueryOptions] = None, **kwargs: Any, - ) -> List[Document]: + ) -> list[Document]: """Searches for similar documents based on a vector embedding. Args: @@ -420,7 +425,7 @@ def similarity_search_with_score( filter: Optional[str] = None, query_options: Optional[QueryOptions] = None, **kwargs: Any, - ) -> List[Tuple[Document, float]]: + ) -> list[tuple[Document, float]]: """Searches for similar documents based on a text query and returns their scores. Args: @@ -445,12 +450,12 @@ def similarity_search_with_score( def similarity_search_with_score_by_vector( self, - embedding: List[float], + embedding: list[float], k: Optional[int] = None, filter: Optional[str] = None, query_options: Optional[QueryOptions] = None, **kwargs: Any, - ) -> List[Tuple[Document, float]]: + ) -> list[tuple[Document, float]]: """Searches for similar documents based on a vector embedding and returns their scores. Args: @@ -505,7 +510,7 @@ def max_marginal_relevance_search( filter: Optional[str] = None, query_options: Optional[QueryOptions] = None, **kwargs: Any, - ) -> List[Document]: + ) -> list[Document]: """Performs Maximal Marginal Relevance (MMR) search based on a text query. Args: @@ -534,14 +539,14 @@ def max_marginal_relevance_search( def max_marginal_relevance_search_by_vector( self, - embedding: List[float], + embedding: list[float], k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, filter: Optional[str] = None, query_options: Optional[QueryOptions] = None, **kwargs: Any, - ) -> List[Document]: + ) -> list[Document]: """Performs Maximal Marginal Relevance (MMR) search based on a vector embedding. Args: @@ -570,14 +575,14 @@ def max_marginal_relevance_search_by_vector( def max_marginal_relevance_search_with_score_by_vector( self, - embedding: List[float], + embedding: list[float], k: Optional[int] = None, fetch_k: Optional[int] = None, lambda_mult: Optional[float] = None, filter: Optional[str] = None, query_options: Optional[QueryOptions] = None, **kwargs: Any, - ) -> List[Tuple[Document, float]]: + ) -> list[tuple[Document, float]]: """Performs Maximal Marginal Relevance (MMR) search based on a vector embedding and returns documents with scores. Args: @@ -637,12 +642,12 @@ def max_marginal_relevance_search_with_score_by_vector( def _query_collection( self, - embedding: List[float], + embedding: list[float], k: Optional[int] = None, filter: Optional[str] = None, query_options: Optional[QueryOptions] = None, map_results: Optional[bool] = True, - ) -> List[Any]: + ) -> Union[Sequence[Row], Sequence[RowMapping]]: column_names = self.__get_column_names() # Apply vector_to_string to the embedding_column for i, v in enumerate(column_names): @@ -671,7 +676,6 @@ def _query_collection( ) stmt = f"SELECT {column_query}, {distance_function}({self.embedding_column}, string_to_vector('{embedding}')) AS distance FROM `{self.table_name}` WHERE NEAREST({self.embedding_column}) TO (string_to_vector('{embedding}'), 'num_neighbors={k}{num_partitions}') {filter} ORDER BY distance;" - # return self.engine._fetch(stmt) if map_results: return self.engine._fetch(stmt) else: @@ -680,7 +684,7 @@ def _query_collection( ### The following is copied from langchain-community until it's moved into core -Matrix = Union[List[List[float]], List[np.ndarray], np.ndarray] +Matrix = Union[list[list[float]], list[np.ndarray], np.ndarray] def maximal_marginal_relevance( @@ -688,7 +692,7 @@ def maximal_marginal_relevance( embedding_list: list, lambda_mult: float = 0.5, k: int = 4, -) -> List[int]: +) -> list[int]: """Calculate maximal marginal relevance.""" if min(k, len(embedding_list)) <= 0: return []