Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: add and update type hints #101

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import List

import sqlalchemy
from langchain_core.chat_history import BaseChatMessageHistory
Expand Down Expand Up @@ -78,7 +77,7 @@ def _verify_schema(self) -> None:
)

@property
def messages(self) -> List[BaseMessage]: # type: ignore
def messages(self) -> list[BaseMessage]: # type: ignore
"""Retrieve the messages from Cloud SQL"""
query = f"SELECT data, type FROM `{self.table_name}` WHERE session_id = :session_id ORDER BY id;"
with self.engine.connect() as conn:
Expand Down
17 changes: 9 additions & 8 deletions src/langchain_google_cloud_sql_mysql/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
# TODO: Remove below import when minimum supported Python version is 3.10
from __future__ import annotations

from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Optional, Sequence

import google.auth
import google.auth.transport.requests
import requests
import sqlalchemy
from google.cloud.sql.connector import Connector, RefreshStrategy
from sqlalchemy.engine.row import Row, RowMapping

from .version import __version__

Expand Down Expand Up @@ -75,7 +76,7 @@ def _get_iam_principal_email(
url = f"https://oauth2.googleapis.com/tokeninfo?access_token={credentials.token}"
response = requests.get(url)
response.raise_for_status()
response_json: Dict = response.json()
response_json: dict = response.json()
email = response_json.get("email")
if email is None:
raise ValueError(
Expand Down Expand Up @@ -235,15 +236,15 @@ def _execute_outside_tx(self, query: str, params: Optional[dict] = None) -> None
conn = conn.execution_options(isolation_level="AUTOCOMMIT")
conn.execute(sqlalchemy.text(query), params)

def _fetch(self, query: str, params: Optional[dict] = None):
def _fetch(self, query: str, params: Optional[dict] = None) -> Sequence[RowMapping]:
"""Fetch results from a SQL query."""
with self.engine.connect() as conn:
result = conn.execute(sqlalchemy.text(query), params)
result_map = result.mappings()
result_fetch = result_map.fetchall()
return result_fetch

def _fetch_rows(self, query: str, params: Optional[dict] = None):
def _fetch_rows(self, query: str, params: Optional[dict] = None) -> Sequence[Row]:
"""Fetch results from a SQL query as rows."""
with self.engine.connect() as conn:
result = conn.execute(sqlalchemy.text(query), params)
Expand Down Expand Up @@ -283,7 +284,7 @@ def init_chat_history_table(self, table_name: str) -> None:
def init_document_table(
self,
table_name: str,
metadata_columns: List[sqlalchemy.Column] = [],
metadata_columns: list[sqlalchemy.Column] = [],
content_column: str = "page_content",
metadata_json_column: Optional[str] = "langchain_metadata",
overwrite_existing: bool = False,
Expand All @@ -293,7 +294,7 @@ def init_document_table(

Args:
table_name (str): The MySQL database table name.
metadata_columns (List[sqlalchemy.Column]): A list of SQLAlchemy Columns
metadata_columns (list[sqlalchemy.Column]): A list of SQLAlchemy Columns
to create for custom metadata. Optional.
content_column (str): The column to store document content.
Deafult: `page_content`.
Expand Down Expand Up @@ -347,7 +348,7 @@ def init_vectorstore_table(
vector_size: int,
content_column: str = "content",
embedding_column: str = "embedding",
metadata_columns: List[Column] = [],
metadata_columns: list[Column] = [],
metadata_json_column: str = "langchain_metadata",
id_column: str = "langchain_id",
overwrite_existing: bool = False,
Expand All @@ -363,7 +364,7 @@ def init_vectorstore_table(
Default: `page_content`.
embedding_column (str) : Name of the column to store vector embeddings.
Default: `embedding`.
metadata_columns (List[Column]): A list of Columns to create for custom
metadata_columns (list[Column]): A list of Columns to create for custom
metadata. Default: []. Optional.
metadata_json_column (str): The column to store extra metadata in JSON format.
Default: `langchain_metadata`. Optional.
Expand Down
30 changes: 15 additions & 15 deletions src/langchain_google_cloud_sql_mysql/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Any, Dict, Iterable, Iterator, List, Optional, cast
from typing import Any, Iterable, Iterator, Optional, cast

import pymysql
import sqlalchemy
Expand All @@ -28,13 +28,13 @@
def _parse_doc_from_row(
content_columns: Iterable[str],
metadata_columns: Iterable[str],
row: Dict,
row: dict,
metadata_json_column: Optional[str] = DEFAULT_METADATA_COL,
) -> Document:
page_content = " ".join(
str(row[column]) for column in content_columns if column in row
)
metadata: Dict[str, Any] = {}
metadata: dict[str, Any] = {}
# unnest metadata from langchain_metadata column
if row.get(metadata_json_column):
for k, v in row[metadata_json_column].items():
Expand All @@ -51,9 +51,9 @@ def _parse_row_from_doc(
doc: Document,
content_column: str = DEFAULT_CONTENT_COL,
metadata_json_column: str = DEFAULT_METADATA_COL,
) -> Dict:
) -> dict:
doc_metadata = doc.metadata.copy()
row: Dict[str, Any] = {content_column: doc.page_content}
row: dict[str, Any] = {content_column: doc.page_content}
for entry in doc.metadata:
if entry in column_names:
row[entry] = doc_metadata[entry]
Expand All @@ -72,8 +72,8 @@ def __init__(
engine: MySQLEngine,
table_name: str = "",
query: str = "",
content_columns: Optional[List[str]] = None,
metadata_columns: Optional[List[str]] = None,
content_columns: Optional[list[str]] = None,
metadata_columns: Optional[list[str]] = None,
metadata_json_column: Optional[str] = None,
):
"""
Expand All @@ -89,9 +89,9 @@ def __init__(
engine (MySQLEngine): MySQLEngine object to connect to the MySQL database.
table_name (str): The MySQL database table name. (OneOf: table_name, query).
query (str): The query to execute in MySQL format. (OneOf: table_name, query).
content_columns (List[str]): The columns to write into the `page_content`
content_columns (list[str]): The columns to write into the `page_content`
of the document. Optional.
metadata_columns (List[str]): The columns to write into the `metadata` of the document.
metadata_columns (list[str]): The columns to write into the `metadata` of the document.
Optional.
metadata_json_column (str): The name of the JSON column to use as the metadata’s base
dictionary. Default: `langchain_metadata`. Optional.
Expand All @@ -110,12 +110,12 @@ def __init__(
"entire table or 'query' to load a specific query."
)

def load(self) -> List[Document]:
def load(self) -> list[Document]:
"""
Load langchain documents from a Cloud SQL MySQL database.

Returns:
(List[langchain_core.documents.Document]): a list of Documents with metadata from
(list[langchain_core.documents.Document]): a list of Documents with metadata from
specific columns.
"""
return list(self.lazy_load())
Expand Down Expand Up @@ -231,13 +231,13 @@ def __init__(
)
self.metadata_json_column = metadata_json_column or DEFAULT_METADATA_COL

def add_documents(self, docs: List[Document]) -> None:
def add_documents(self, docs: list[Document]) -> None:
"""
Save documents in the DocumentSaver table. Document’s metadata is added to columns if found or
stored in langchain_metadata JSON column.

Args:
docs (List[langchain_core.documents.Document]): a list of documents to be saved.
docs (list[langchain_core.documents.Document]): a list of documents to be saved.
"""
with self.engine.connect() as conn:
for doc in docs:
Expand All @@ -250,13 +250,13 @@ def add_documents(self, docs: List[Document]) -> None:
conn.execute(sqlalchemy.insert(self._table).values(row))
conn.commit()

def delete(self, docs: List[Document]) -> None:
def delete(self, docs: list[Document]) -> None:
"""
Delete all instances of a document from the DocumentSaver table by matching the entire Document
object.

Args:
docs (List[langchain_core.documents.Document]): a list of documents to be deleted.
docs (list[langchain_core.documents.Document]): a list of documents to be deleted.
"""
with self.engine.connect() as conn:
for doc in docs:
Expand Down
Loading
Loading