diff --git a/libs/vertexai/Makefile b/libs/vertexai/Makefile index d86dffd2..e68f0a70 100644 --- a/libs/vertexai/Makefile +++ b/libs/vertexai/Makefile @@ -11,6 +11,9 @@ integration_test integration_tests: TEST_FILE = tests/integration_tests/ test tests integration_test integration_tests: poetry run pytest --release $(TEST_FILE) +test_watch: + poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE) + # Run unit tests and generate a coverage report. coverage: poetry run pytest --cov \ @@ -33,7 +36,6 @@ lint_tests: PYTHON_FILES=tests lint_tests: MYPY_CACHE=.mypy_cache_test lint lint_diff lint_package lint_tests: - ./scripts/check_pydantic.sh . ./scripts/lint_imports.sh poetry run ruff check . poetry run ruff format $(PYTHON_FILES) --diff diff --git a/libs/vertexai/langchain_google_vertexai/_anthropic_parsers.py b/libs/vertexai/langchain_google_vertexai/_anthropic_parsers.py index d401f78f..a31759bf 100644 --- a/libs/vertexai/langchain_google_vertexai/_anthropic_parsers.py +++ b/libs/vertexai/langchain_google_vertexai/_anthropic_parsers.py @@ -4,7 +4,7 @@ from langchain_core.messages.tool import tool_call from langchain_core.output_parsers import BaseGenerationOutputParser from langchain_core.outputs import ChatGeneration, Generation -from langchain_core.pydantic_v1 import BaseModel +from pydantic import BaseModel, ConfigDict class ToolsOutputParser(BaseGenerationOutputParser): @@ -12,8 +12,9 @@ class ToolsOutputParser(BaseGenerationOutputParser): args_only: bool = False pydantic_schemas: Optional[List[Type[BaseModel]]] = None - class Config: - extra = "forbid" + model_config = ConfigDict( + extra="forbid", + ) def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: """Parse a list of candidate model Generations into a specific format. diff --git a/libs/vertexai/langchain_google_vertexai/_anthropic_utils.py b/libs/vertexai/langchain_google_vertexai/_anthropic_utils.py index e6205d79..976398fa 100644 --- a/libs/vertexai/langchain_google_vertexai/_anthropic_utils.py +++ b/libs/vertexai/langchain_google_vertexai/_anthropic_utils.py @@ -26,9 +26,9 @@ ToolMessage, ) from langchain_core.messages.ai import UsageMetadata -from langchain_core.pydantic_v1 import BaseModel from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import convert_to_openai_tool +from pydantic import BaseModel if TYPE_CHECKING: from anthropic.types import RawMessageStreamEvent # type: ignore diff --git a/libs/vertexai/langchain_google_vertexai/_base.py b/libs/vertexai/langchain_google_vertexai/_base.py index 9bcd09a1..d5ac6bfe 100644 --- a/libs/vertexai/langchain_google_vertexai/_base.py +++ b/libs/vertexai/langchain_google_vertexai/_base.py @@ -22,7 +22,8 @@ from google.protobuf import json_format from google.protobuf.struct_pb2 import Value from langchain_core.outputs import Generation, LLMResult -from langchain_core.pydantic_v1 import BaseModel, Field, root_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator +from typing_extensions import Self from vertexai.generative_models._generative_models import ( # type: ignore SafetySettingsType, ) @@ -91,14 +92,15 @@ class _VertexAIBase(BaseModel): "when making API calls. If not provided, credentials will be ascertained from " "the environment." - class Config: - """Configuration for this pydantic object.""" + model_config = ConfigDict( + populate_by_name=True, + arbitrary_types_allowed=True, + protected_namespaces=(), + ) - allow_population_by_field_name = True - arbitrary_types_allowed = True - - @root_validator(pre=True) - def validate_params_base(cls, values: dict) -> dict: + @model_validator(mode="before") + @classmethod + def validate_params_base(cls, values: dict) -> Any: if "model" in values and "model_name" not in values: values["model_name"] = values.pop("model") if values.get("project") is None: @@ -108,7 +110,7 @@ def validate_params_base(cls, values: dict) -> dict: if values.get("api_endpoint"): api_endpoint = values["api_endpoint"] else: - location = values.get("location", cls.__fields__["location"].default) + location = values.get("location", cls.model_fields["location"].default) api_endpoint = f"{location}-{constants.PREDICTION_API_BASE_PATH}" client_options = ClientOptions(api_endpoint=api_endpoint) if values.get("client_cert_source"): @@ -311,26 +313,26 @@ class _BaseVertexAIModelGarden(_VertexAIBase): single_example_per_request: bool = True "LLM endpoint currently serves only the first example in the request" - @root_validator(pre=False, skip_on_failure=True) - def validate_environment(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_environment(self) -> Self: """Validate that the python package exists in environment.""" - if not values["project"]: + if not self.project: raise ValueError( "A GCP project should be provided to run inference on Model Garden!" ) client_options = ClientOptions( - api_endpoint=f"{values['location']}-aiplatform.googleapis.com" + api_endpoint=f"{self.location}-aiplatform.googleapis.com" ) client_info = get_client_info(module="vertex-ai-model-garden") - values["client"] = PredictionServiceClient( + self.client = PredictionServiceClient( client_options=client_options, client_info=client_info ) - values["async_client"] = PredictionServiceAsyncClient( + self.async_client = PredictionServiceAsyncClient( client_options=client_options, client_info=client_info ) - return values + return self @property def endpoint_path(self) -> str: diff --git a/libs/vertexai/langchain_google_vertexai/chains.py b/libs/vertexai/langchain_google_vertexai/chains.py index a6f63455..6f3c2680 100644 --- a/libs/vertexai/langchain_google_vertexai/chains.py +++ b/libs/vertexai/langchain_google_vertexai/chains.py @@ -13,8 +13,8 @@ StrOutputParser, ) from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate -from langchain_core.pydantic_v1 import BaseModel from langchain_core.runnables import Runnable +from pydantic import BaseModel from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser @@ -51,7 +51,12 @@ def _create_structured_runnable_extra_step( *, prompt: Optional[BasePromptTemplate] = None, ) -> Runnable: - names = [schema.schema()["title"] for schema in functions] + names = [ + schema.model_json_schema()["title"] + if hasattr(schema, "model_json_schema") + else schema.schema()["title"] + for schema in functions + ] if hasattr(llm, "is_gemini_advanced") and llm._is_gemini_advanced: # type: ignore llm_with_functions = llm.bind( functions=functions, @@ -111,7 +116,7 @@ def create_structured_runnable( from langchain_google_vertexai import ChatVertexAI, create_structured_runnable from langchain_core.prompts import ChatPromptTemplate - from langchain_core.pydantic_v1 import BaseModel, Field + from pydantic import BaseModel, Field class RecordPerson(BaseModel): diff --git a/libs/vertexai/langchain_google_vertexai/chat_models.py b/libs/vertexai/langchain_google_vertexai/chat_models.py index ad60efe2..fba782ce 100644 --- a/libs/vertexai/langchain_google_vertexai/chat_models.py +++ b/libs/vertexai/langchain_google_vertexai/chat_models.py @@ -62,7 +62,7 @@ ) from langchain_core.output_parsers.openai_tools import parse_tool_calls from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from langchain_core.pydantic_v1 import BaseModel, root_validator, Field +from pydantic import BaseModel, Field, model_validator from langchain_core.runnables import Runnable, RunnablePassthrough, RunnableGenerator from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.pydantic import is_basemodel_subclass @@ -124,6 +124,9 @@ _format_to_gapic_tool, _ToolType, ) +from pydantic import ConfigDict +from typing_extensions import Self + logger = logging.getLogger(__name__) @@ -762,7 +765,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): Tool calling: .. code-block:: python - from langchain_core.pydantic_v1 import BaseModel, Field + from pydantic import BaseModel, Field class GetWeather(BaseModel): '''Get the current weather in a given location''' @@ -800,7 +803,7 @@ class GetPopulation(BaseModel): from typing import Optional - from langchain_core.pydantic_v1 import BaseModel, Field + from pydantic import BaseModel, Field class Joke(BaseModel): '''Joke to tell user.''' @@ -1024,11 +1027,10 @@ def __init__(self, *, model_name: Optional[str] = None, **kwargs: Any) -> None: kwargs["model_name"] = model_name super().__init__(**kwargs) - class Config: - """Configuration for this pydantic object.""" - - allow_population_by_field_name = True - arbitrary_types_allowed = True + model_config = ConfigDict( + populate_by_name=True, + arbitrary_types_allowed=True, + ) @classmethod def is_lc_serializable(self) -> bool: @@ -1039,57 +1041,65 @@ def get_lc_namespace(cls) -> List[str]: """Get the namespace of the langchain object.""" return ["langchain", "chat_models", "vertexai"] - @root_validator(pre=False, skip_on_failure=True) - def validate_environment(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_environment(self) -> Self: """Validate that the python package exists in environment.""" - safety_settings = values.get("safety_settings") - tuned_model_name = values.get("tuned_model_name") - values["model_family"] = GoogleModelFamily(values["model_name"]) + safety_settings = self.safety_settings + tuned_model_name = self.tuned_model_name + self.model_family = GoogleModelFamily(self.model_name) - if values["model_name"] == "chat-bison-default": + if self.model_name == "chat-bison-default": logger.warning( "Model_name will become a required arg for VertexAIEmbeddings " "starting from Sep-01-2024. Currently the default is set to " "chat-bison" ) - values["model_name"] = "chat-bison" + self.model_name = "chat-bison" - if values.get("full_model_name") is not None: + if self.full_model_name is not None: pass - elif values.get("tuned_model_name") is not None: - values["full_model_name"] = _format_model_name( - values["tuned_model_name"], - location=values["location"], - project=values["project"], + elif self.tuned_model_name is not None: + self.full_model_name = _format_model_name( + self.tuned_model_name, + location=self.location, + project=cast(str, self.project), ) else: - values["full_model_name"] = _format_model_name( - values["model_name"], - location=values["location"], - project=values["project"], + self.full_model_name = _format_model_name( + self.model_name, + location=self.location, + project=cast(str, self.project), ) - if safety_settings and not is_gemini_model(values["model_family"]): + if safety_settings and not is_gemini_model(self.model_family): raise ValueError("Safety settings are only supported for Gemini models") if tuned_model_name: - generative_model_name = values["tuned_model_name"] + generative_model_name = self.tuned_model_name else: - generative_model_name = values["model_name"] - - if not is_gemini_model(values["model_family"]): - cls._init_vertexai(values) - if values["model_family"] == GoogleModelFamily.CODEY: + generative_model_name = self.model_name + + if not is_gemini_model(self.model_family): + values = { + "project": self.project, + "location": self.location, + "credentials": self.credentials, + "api_transport": self.api_transport, + "api_endpoint": self.api_endpoint, + "default_metadata": self.default_metadata, + } + self._init_vertexai(values) + if self.model_family == GoogleModelFamily.CODEY: model_cls = CodeChatModel model_cls_preview = PreviewCodeChatModel else: model_cls = ChatModel model_cls_preview = PreviewChatModel - values["client"] = model_cls.from_pretrained(generative_model_name) - values["client_preview"] = model_cls_preview.from_pretrained( + self.client = model_cls.from_pretrained(generative_model_name) + self.client_preview = model_cls_preview.from_pretrained( generative_model_name ) - return values + return self @property def _is_gemini_advanced(self) -> bool: @@ -1647,7 +1657,7 @@ def with_structured_output( Example: Pydantic schema, exclude raw: .. code-block:: python - from langchain_core.pydantic_v1 import BaseModel + from pydantic import BaseModel from langchain_google_vertexai import ChatVertexAI class AnswerWithJustification(BaseModel): @@ -1666,7 +1676,7 @@ class AnswerWithJustification(BaseModel): Example: Pydantic schema, include raw: .. code-block:: python - from langchain_core.pydantic_v1 import BaseModel + from pydantic import BaseModel from langchain_google_vertexai import ChatVertexAI class AnswerWithJustification(BaseModel): @@ -1687,7 +1697,7 @@ class AnswerWithJustification(BaseModel): Example: Dict schema, exclude raw: .. code-block:: python - from langchain_core.pydantic_v1 import BaseModel + from pydantic import BaseModel from langchain_core.utils.function_calling import convert_to_openai_function from langchain_google_vertexai import ChatVertexAI diff --git a/libs/vertexai/langchain_google_vertexai/embeddings.py b/libs/vertexai/langchain_google_vertexai/embeddings.py index ce3f03df..e2fb6e3e 100644 --- a/libs/vertexai/langchain_google_vertexai/embeddings.py +++ b/libs/vertexai/langchain_google_vertexai/embeddings.py @@ -17,7 +17,8 @@ from google.cloud.aiplatform import telemetry from langchain_core.embeddings import Embeddings from langchain_core.language_models.llms import create_base_retry_decorator -from langchain_core.pydantic_v1 import root_validator +from pydantic import ConfigDict, model_validator +from typing_extensions import Self from vertexai.language_models import ( # type: ignore TextEmbeddingInput, TextEmbeddingModel, @@ -100,24 +101,33 @@ class VertexAIEmbeddings(_VertexAICommon, Embeddings): # Instance context instance: Dict[str, Any] = {} #: :meta private: - @root_validator(pre=False, skip_on_failure=True) - def validate_environment(cls, values: Dict) -> Dict: + model_config = ConfigDict( + extra="forbid", + protected_namespaces=(), + ) + + @model_validator(mode="after") + def validate_environment(self) -> Self: """Validates that the python package exists in environment.""" - cls._init_vertexai(values) - _, user_agent = get_user_agent(f"{cls.__name__}_{values['model_name']}") # type: ignore + values = { + "project": self.project, + "location": self.location, + "credentials": self.credentials, + "api_transport": self.api_transport, + "api_endpoint": self.api_endpoint, + "default_metadata": self.default_metadata, + } + self._init_vertexai(values) + _, user_agent = get_user_agent(f"{self.__class__.__name__}_{self.model_name}") with telemetry.tool_context_manager(user_agent): if ( - GoogleEmbeddingModelType(values["model_name"]) + GoogleEmbeddingModelType(self.model_name) == GoogleEmbeddingModelType.MULTIMODAL ): - values["client"] = MultiModalEmbeddingModel.from_pretrained( - values["model_name"] - ) + self.client = MultiModalEmbeddingModel.from_pretrained(self.model_name) else: - values["client"] = TextEmbeddingModel.from_pretrained( - values["model_name"] - ) - return values + self.client = TextEmbeddingModel.from_pretrained(self.model_name) + return self def __init__( self, diff --git a/libs/vertexai/langchain_google_vertexai/functions_utils.py b/libs/vertexai/langchain_google_vertexai/functions_utils.py index 455d298c..f4499e8b 100644 --- a/libs/vertexai/langchain_google_vertexai/functions_utils.py +++ b/libs/vertexai/langchain_google_vertexai/functions_utils.py @@ -21,7 +21,6 @@ from langchain_core.exceptions import OutputParserException from langchain_core.output_parsers import BaseOutputParser from langchain_core.outputs import ChatGeneration, Generation -from langchain_core.pydantic_v1 import BaseModel from langchain_core.tools import BaseTool from langchain_core.tools import tool as callable_as_lc_tool from langchain_core.utils.function_calling import ( @@ -29,6 +28,7 @@ convert_to_openai_tool, ) from langchain_core.utils.json_schema import dereference_refs +from pydantic import BaseModel logger = logging.getLogger(__name__) @@ -69,18 +69,19 @@ class _ToolDictLike(TypedDict): _ALLOWED_SCHEMA_FIELDS_SET = set(_ALLOWED_SCHEMA_FIELDS) -def _format_json_schema_to_gapic(schema: Dict[str, Any]) -> Dict[str, Any]: +def _format_json_schema_to_gapic_v1(schema: Dict[str, Any]) -> Dict[str, Any]: + """Format a JSON schema from a Pydantic V1 BaseModel to gapic.""" converted_schema: Dict[str, Any] = {} for key, value in schema.items(): if key == "definitions": continue elif key == "items": - converted_schema["items"] = _format_json_schema_to_gapic(value) + converted_schema["items"] = _format_json_schema_to_gapic_v1(value) elif key == "properties": if "properties" not in converted_schema: converted_schema["properties"] = {} for pkey, pvalue in value.items(): - converted_schema["properties"][pkey] = _format_json_schema_to_gapic( + converted_schema["properties"][pkey] = _format_json_schema_to_gapic_v1( pvalue ) continue @@ -92,7 +93,57 @@ def _format_json_schema_to_gapic(schema: Dict[str, Any]) -> Dict[str, Any]: "Only first value for 'allOf' key is supported. " f"Got {len(value)}, ignoring other than first value!" ) - return _format_json_schema_to_gapic(value[0]) + return _format_json_schema_to_gapic_v1(value[0]) + elif key not in _ALLOWED_SCHEMA_FIELDS_SET: + logger.warning(f"Key '{key}' is not supported in schema, ignoring") + else: + converted_schema[key] = value + return converted_schema + + +def _format_json_schema_to_gapic( + schema: Dict[str, Any], + parent_key: Optional[str] = None, + required_fields: Optional[list] = None, +) -> Dict[str, Any]: + """Format a JSON schema from a Pydantic V2 BaseModel to gapic.""" + converted_schema: Dict[str, Any] = {} + for key, value in schema.items(): + if key == "definitions": + continue + elif key == "items": + converted_schema["items"] = _format_json_schema_to_gapic( + value, parent_key, required_fields + ) + elif key == "properties": + if "properties" not in converted_schema: + converted_schema["properties"] = {} + for pkey, pvalue in value.items(): + converted_schema["properties"][pkey] = _format_json_schema_to_gapic( + pvalue, pkey, schema.get("required", []) + ) + continue + elif key in ["type", "_type"]: + converted_schema["type"] = str(value).upper() + elif key == "allOf": + if len(value) > 1: + logger.warning( + "Only first value for 'allOf' key is supported. " + f"Got {len(value)}, ignoring other than first value!" + ) + return _format_json_schema_to_gapic(value[0], parent_key, required_fields) + elif key == "anyOf": + if len(value) == 2 and any(v.get("type") == "null" for v in value): + non_null_type = next(v for v in value if v.get("type") != "null") + converted_schema.update( + _format_json_schema_to_gapic( + non_null_type, parent_key, required_fields + ) + ) + # Remove the field from required if it exists + if required_fields and parent_key in required_fields: + required_fields.remove(parent_key) + continue elif key not in _ALLOWED_SCHEMA_FIELDS_SET: logger.warning(f"Key '{key}' is not supported in schema, ignoring") else: @@ -100,9 +151,14 @@ def _format_json_schema_to_gapic(schema: Dict[str, Any]) -> Dict[str, Any]: return converted_schema -def _dict_to_gapic_schema(schema: Dict[str, Any]) -> gapic.Schema: +def _dict_to_gapic_schema( + schema: Dict[str, Any], pydantic_version: str = "v1" +) -> gapic.Schema: dereferenced_schema = dereference_refs(schema) - formatted_schema = _format_json_schema_to_gapic(dereferenced_schema) + if pydantic_version == "v1": + formatted_schema = _format_json_schema_to_gapic_v1(dereferenced_schema) + else: + formatted_schema = _format_json_schema_to_gapic(dereferenced_schema) json_schema = json.dumps(formatted_schema) return gapic.Schema.from_json(json_schema) @@ -124,8 +180,14 @@ def _format_base_tool_to_function_declaration( ), ) - schema = tool.args_schema.schema() - parameters = _dict_to_gapic_schema(schema) + if hasattr(tool.args_schema, "model_json_schema"): + schema = tool.args_schema.model_json_schema() + pydantic_version = "v2" + else: + schema = tool.args_schema.schema() + pydantic_version = "v1" + + parameters = _dict_to_gapic_schema(schema, pydantic_version=pydantic_version) return gapic.FunctionDeclaration( name=tool.name or schema.get("title"), @@ -137,12 +199,17 @@ def _format_base_tool_to_function_declaration( def _format_pydantic_to_function_declaration( pydantic_model: Type[BaseModel], ) -> gapic.FunctionDeclaration: - schema = pydantic_model.schema() + if hasattr(pydantic_model, "model_json_schema"): + schema = pydantic_model.model_json_schema() + pydantic_version = "v2" + else: + schema = pydantic_model.schema() + pydantic_version = "v1" return gapic.FunctionDeclaration( name=schema["title"], description=schema.get("description", ""), - parameters=_dict_to_gapic_schema(schema), + parameters=_dict_to_gapic_schema(schema, pydantic_version=pydantic_version), ) diff --git a/libs/vertexai/langchain_google_vertexai/gemma.py b/libs/vertexai/langchain_google_vertexai/gemma.py index 55e945e8..a854af05 100644 --- a/libs/vertexai/langchain_google_vertexai/gemma.py +++ b/libs/vertexai/langchain_google_vertexai/gemma.py @@ -21,7 +21,8 @@ Generation, LLMResult, ) -from langchain_core.pydantic_v1 import BaseModel, Field, root_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator +from typing_extensions import Self from langchain_google_vertexai._base import _BaseVertexAIModelGarden from langchain_google_vertexai._utils import enforce_stop_tokens @@ -76,6 +77,8 @@ class _GemmaBase(BaseModel): top_k: Optional[int] = None """The top-k value to use for sampling.""" + model_config = ConfigDict(protected_namespaces=()) + @property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling gemma.""" @@ -127,10 +130,10 @@ def __init__(self, *, model_name: Optional[str] = None, **kwargs: Any) -> None: kwargs["model_name"] = model_name super().__init__(**kwargs) - class Config: - """Configuration for this pydantic object.""" - - allow_population_by_field_name = True + model_config = ConfigDict( + populate_by_name=True, + protected_namespaces=(), + ) @property def _llm_type(self) -> str: @@ -200,10 +203,9 @@ class _GemmaLocalKaggleBase(_GemmaBase): model_name: str = Field(default="gemma_2b_en", alias="model") """Gemma model name.""" - class Config: - """Configuration for this pydantic object.""" - - allow_population_by_field_name = True + model_config = ConfigDict( + populate_by_name=True, + ) def __init__(self, *, model_name: Optional[str] = None, **kwargs: Any) -> None: """Needed for mypy typing to recognize model_name as a valid arg.""" @@ -211,11 +213,11 @@ def __init__(self, *, model_name: Optional[str] = None, **kwargs: Any) -> None: kwargs["model_name"] = model_name super().__init__(**kwargs) - @root_validator(pre=False, skip_on_failure=True) - def validate_environment(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_environment(self) -> Self: """Validate that llama-cpp-python library is installed.""" try: - os.environ["KERAS_BACKEND"] = values["keras_backend"] + os.environ["KERAS_BACKEND"] = self.keras_backend from keras_nlp.models import GemmaCausalLM # type: ignore except ImportError: raise ImportError( @@ -224,8 +226,8 @@ def validate_environment(cls, values: Dict) -> Dict: "use this model: pip install keras-nlp keras>=3 kaggle" ) - values["client"] = GemmaCausalLM.from_preset(values["model_name"]) - return values + self.client = GemmaCausalLM.from_preset(self.model_name) + return self @property def _default_params(self) -> Dict[str, Any]: @@ -239,7 +241,7 @@ def _get_params(self, **kwargs) -> Dict[str, Any]: return {**self._default_params, **params} -class GemmaLocalKaggle(_GemmaLocalKaggleBase, BaseLLM): # type: ignore +class GemmaLocalKaggle(_GemmaLocalKaggleBase, BaseLLM): """Local gemma chat model loaded from Kaggle.""" def __init__(self, *, model_name: Optional[str] = None, **kwargs: Any) -> None: @@ -269,7 +271,7 @@ def _llm_type(self) -> str: return "gemma_local_kaggle" -class GemmaChatLocalKaggle(_GemmaLocalKaggleBase, BaseChatModel): # type: ignore +class GemmaChatLocalKaggle(_GemmaLocalKaggleBase, BaseChatModel): parse_response: bool = False """Whether to post-process the chat response and clean repeations """ """or multi-turn statements.""" @@ -313,13 +315,12 @@ class _GemmaLocalHFBase(_GemmaBase): model_name: str = Field(default="google/gemma-2b", alias="model") """Gemma model name.""" - class Config: - """Configuration for this pydantic object.""" - - allow_population_by_field_name = True + model_config = ConfigDict( + populate_by_name=True, + ) - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_environment(self) -> Self: """Validate that llama-cpp-python library is installed.""" try: from transformers import AutoTokenizer, GemmaForCausalLM # type: ignore @@ -330,15 +331,15 @@ def validate_environment(cls, values: Dict) -> Dict: "use this model: pip install transformers>=4.38.1" ) - values["tokenizer"] = AutoTokenizer.from_pretrained( - values["model_name"], token=values["hf_access_token"] + self.tokenizer = AutoTokenizer.from_pretrained( + self.model_name, token=self.hf_access_token ) - values["client"] = GemmaForCausalLM.from_pretrained( - values["model_name"], - token=values["hf_access_token"], - cache_dir=values["cache_dir"], + self.client = GemmaForCausalLM.from_pretrained( + self.model_name, + token=self.hf_access_token, + cache_dir=self.cache_dir, ) - return values + return self @property def _default_params(self) -> Dict[str, Any]: @@ -360,7 +361,7 @@ def _run(self, prompt: str, **kwargs: Any) -> str: )[0] -class GemmaLocalHF(_GemmaLocalHFBase, BaseLLM): # type: ignore +class GemmaLocalHF(_GemmaLocalHFBase, BaseLLM): """Local gemma model loaded from HuggingFace.""" def _generate( @@ -382,7 +383,7 @@ def _llm_type(self) -> str: return "gemma_local_hf" -class GemmaChatLocalHF(_GemmaLocalHFBase, BaseChatModel): # type: ignore +class GemmaChatLocalHF(_GemmaLocalHFBase, BaseChatModel): parse_response: bool = False """Whether to post-process the chat response and clean repeations """ """or multi-turn statements.""" diff --git a/libs/vertexai/langchain_google_vertexai/llms.py b/libs/vertexai/langchain_google_vertexai/llms.py index b9dc76db..303169fb 100644 --- a/libs/vertexai/langchain_google_vertexai/llms.py +++ b/libs/vertexai/langchain_google_vertexai/llms.py @@ -9,7 +9,8 @@ ) from langchain_core.language_models.llms import BaseLLM, LangSmithParams from langchain_core.outputs import Generation, GenerationChunk, LLMResult -from langchain_core.pydantic_v1 import Field, root_validator +from pydantic import ConfigDict, Field, model_validator +from typing_extensions import Self from vertexai.generative_models import ( # type: ignore[import-untyped] Candidate, GenerativeModel, @@ -120,10 +121,9 @@ def __init__(self, *, model_name: Optional[str] = None, **kwargs: Any) -> None: kwargs["model_name"] = model_name super().__init__(**kwargs) - class Config: - """Configuration for this pydantic object.""" - - allow_population_by_field_name = True + model_config = ConfigDict( + populate_by_name=True, + ) @classmethod def is_lc_serializable(self) -> bool: @@ -134,19 +134,27 @@ def get_lc_namespace(cls) -> List[str]: """Get the namespace of the langchain object.""" return ["langchain", "llms", "vertexai"] - @root_validator(pre=False, skip_on_failure=True) - def validate_environment(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_environment(self) -> Self: """Validate that the python package exists in environment.""" - tuned_model_name = values.get("tuned_model_name") - safety_settings = values["safety_settings"] - values["model_family"] = GoogleModelFamily(values["model_name"]) - is_gemini = is_gemini_model(values["model_family"]) - cls._init_vertexai(values) + tuned_model_name = self.tuned_model_name or None + safety_settings = self.safety_settings + self.model_family = GoogleModelFamily(self.model_name) + is_gemini = is_gemini_model(self.model_family) + values = { + "project": self.project, + "location": self.location, + "credentials": self.credentials, + "api_transport": self.api_transport, + "api_endpoint": self.api_endpoint, + "default_metadata": self.default_metadata, + } + self._init_vertexai(values) if safety_settings and (not is_gemini or tuned_model_name): raise ValueError("Safety settings are only supported for Gemini models") - if values["model_family"] == GoogleModelFamily.CODEY: + if self.model_family == GoogleModelFamily.CODEY: model_cls = CodeGenerationModel preview_model_cls = PreviewCodeGenerationModel elif is_gemini: @@ -157,31 +165,31 @@ def validate_environment(cls, values: Dict) -> Dict: preview_model_cls = PreviewTextGenerationModel if tuned_model_name: - generative_model_name = values["tuned_model_name"] + generative_model_name = self.tuned_model_name else: - generative_model_name = values["model_name"] + generative_model_name = self.model_name if is_gemini: - values["client"] = model_cls( + self.client = model_cls( model_name=generative_model_name, safety_settings=safety_settings ) - values["client_preview"] = preview_model_cls( + self.client_preview = preview_model_cls( model_name=generative_model_name, safety_settings=safety_settings ) else: if tuned_model_name: - values["client"] = model_cls.get_tuned_model(generative_model_name) - values["client_preview"] = preview_model_cls.get_tuned_model( + self.client = model_cls.get_tuned_model(generative_model_name) + self.client_preview = preview_model_cls.get_tuned_model( generative_model_name ) else: - values["client"] = model_cls.from_pretrained(generative_model_name) - values["client_preview"] = preview_model_cls.from_pretrained( + self.client = model_cls.from_pretrained(generative_model_name) + self.client_preview = preview_model_cls.from_pretrained( generative_model_name ) - if values["streaming"] and values["n"] > 1: + if self.streaming and self.n > 1: raise ValueError("Only one candidate can be generated with streaming!") - return values + return self def _get_ls_params( self, stop: Optional[List[str]] = None, **kwargs: Any diff --git a/libs/vertexai/langchain_google_vertexai/model_garden.py b/libs/vertexai/langchain_google_vertexai/model_garden.py index 86d6f00e..4fee0562 100644 --- a/libs/vertexai/langchain_google_vertexai/model_garden.py +++ b/libs/vertexai/langchain_google_vertexai/model_garden.py @@ -39,13 +39,14 @@ Generation, LLMResult, ) -from langchain_core.pydantic_v1 import BaseModel, Field, root_validator from langchain_core.runnables import ( Runnable, RunnableMap, RunnablePassthrough, ) from langchain_core.tools import BaseTool +from pydantic import BaseModel, ConfigDict, Field, model_validator +from typing_extensions import Self from langchain_google_vertexai._anthropic_parsers import ( ToolsOutputParser, @@ -63,10 +64,10 @@ class VertexAIModelGarden(_BaseVertexAIModelGarden, BaseLLM): """Large language models served from Vertex AI Model Garden.""" - class Config: - """Configuration for this pydantic object.""" - - allow_population_by_field_name = True + model_config = ConfigDict( + populate_by_name=True, + protected_namespaces=(), + ) # Needed so that mypy doesn't flag missing aliased init args. def __init__(self, **kwargs: Any) -> None: @@ -137,37 +138,36 @@ class ChatAnthropicVertex(_VertexAICommon, BaseChatModel): stream_usage: bool = True # Whether to include usage metadata in streaming output credentials: Optional[Credentials] = None - class Config: - """Configuration for this pydantic object.""" - - allow_population_by_field_name = True + model_config = ConfigDict( + populate_by_name=True, + ) # Needed so that mypy doesn't flag missing aliased init args. def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) - @root_validator(pre=False, skip_on_failure=True) - def validate_environment(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_environment(self) -> Self: from anthropic import ( # type: ignore AnthropicVertex, AsyncAnthropicVertex, ) - values["client"] = AnthropicVertex( - project_id=values["project"], - region=values["location"], - max_retries=values["max_retries"], - access_token=values["access_token"], - credentials=values["credentials"], + self.client = AnthropicVertex( + project_id=self.project, + region=self.location, + max_retries=self.max_retries, + access_token=self.access_token, + credentials=self.credentials, ) - values["async_client"] = AsyncAnthropicVertex( - project_id=values["project"], - region=values["location"], - max_retries=values["max_retries"], - access_token=values["access_token"], - credentials=values["credentials"], + self.async_client = AsyncAnthropicVertex( + project_id=self.project, + region=self.location, + max_retries=self.max_retries, + access_token=self.access_token, + credentials=self.credentials, ) - return values + return self @property def _default_params(self): diff --git a/libs/vertexai/langchain_google_vertexai/model_garden_maas/_base.py b/libs/vertexai/langchain_google_vertexai/model_garden_maas/_base.py index 030b1c69..9129f2af 100644 --- a/libs/vertexai/langchain_google_vertexai/model_garden_maas/_base.py +++ b/libs/vertexai/langchain_google_vertexai/model_garden_maas/_base.py @@ -24,7 +24,8 @@ CallbackManagerForLLMRun, ) from langchain_core.language_models.llms import create_base_retry_decorator -from langchain_core.pydantic_v1 import root_validator +from pydantic import ConfigDict, model_validator +from typing_extensions import Self from langchain_google_vertexai._base import _VertexAIBase @@ -101,11 +102,10 @@ class _BaseVertexMaasModelGarden(_VertexAIBase): model_family: Optional[VertexMaaSModelFamily] = None timeout: int = 120 - class Config: - """Configuration for this pydantic object.""" - - allow_population_by_field_name = True - arbitrary_types_allowed = True + model_config = ConfigDict( + populate_by_name=True, + arbitrary_types_allowed=True, + ) def __init__(self, **kwargs): super().__init__(**kwargs) @@ -129,16 +129,16 @@ def __init__(self, **kwargs): timeout=self.timeout, ) - @root_validator(pre=True) - def validate_environment_model_garden(cls, values: Dict) -> Dict: + @model_validator(mode="after") + def validate_environment_model_garden(self) -> Self: """Validate that the python package exists in environment.""" - family = VertexMaaSModelFamily(values["model_name"]) - values["model_family"] = family + family = VertexMaaSModelFamily(self.model_name) + self.model_family = family if family == VertexMaaSModelFamily.MISTRAL: - model = values["model_name"].split("@")[0] - values["full_model_name"] = values["model_name"] - values["model_name"] = model - return values + model = self.model_name.split("@")[0] if self.model_name else None + self.full_model_name = self.model_name + self.model_name = model + return self def _enrich_params(self, params: Dict[str, Any]) -> Dict[str, Any]: """Fix params to be compliant with Vertex AI.""" diff --git a/libs/vertexai/langchain_google_vertexai/model_garden_maas/llama.py b/libs/vertexai/langchain_google_vertexai/model_garden_maas/llama.py index d80c4dd7..ebfdeaaa 100644 --- a/libs/vertexai/langchain_google_vertexai/model_garden_maas/llama.py +++ b/libs/vertexai/langchain_google_vertexai/model_garden_maas/llama.py @@ -114,7 +114,7 @@ def _parse_response_candidate_llama( ) -class VertexModelGardenLlama(_BaseVertexMaasModelGarden, BaseChatModel): # type: ignore[misc] +class VertexModelGardenLlama(_BaseVertexMaasModelGarden, BaseChatModel): """Integration for Llama 3.1 on Google Cloud Vertex AI Model-as-a-Service. For more information, see: diff --git a/libs/vertexai/langchain_google_vertexai/vision_models.py b/libs/vertexai/langchain_google_vertexai/vision_models.py index 499f5e58..8d5d0e8f 100644 --- a/libs/vertexai/langchain_google_vertexai/vision_models.py +++ b/libs/vertexai/langchain_google_vertexai/vision_models.py @@ -9,7 +9,7 @@ from langchain_core.outputs import ChatResult, LLMResult from langchain_core.outputs.chat_generation import ChatGeneration from langchain_core.outputs.generation import Generation -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from vertexai.preview.vision_models import ( # type: ignore[import-untyped] GeneratedImage, ImageGenerationModel, @@ -39,6 +39,8 @@ class _BaseImageTextModel(BaseModel): project: Union[str, None] = Field(default=None) """Google cloud project""" + model_config = ConfigDict(protected_namespaces=()) + @property def client(self) -> ImageTextModel: if self.cached_client is None: @@ -334,6 +336,8 @@ class _BaseVertexAIImageGenerator(BaseModel): project: Union[str, None] = Field(default=None) """Google cloud project id""" + model_config = ConfigDict(protected_namespaces=()) + @property def client(self) -> ImageGenerationModel: if not self.cached_client: diff --git a/libs/vertexai/poetry.lock b/libs/vertexai/poetry.lock index 5a63e1ec..677b9e55 100644 --- a/libs/vertexai/poetry.lock +++ b/libs/vertexai/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -133,9 +133,6 @@ files = [ {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, ] -[package.dependencies] -typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""} - [[package]] name = "anthropic" version = "0.34.0" @@ -1313,24 +1310,24 @@ files = [ [[package]] name = "langchain" -version = "0.2.14" +version = "0.3.0.dev1" description = "Building applications with LLMs through composability" optional = false -python-versions = ">=3.8.1,<4.0" +python-versions = ">=3.9,<4.0" files = [] develop = false [package.dependencies] aiohttp = "^3.8.3" async-timeout = {version = "^4.0.0", markers = "python_version < \"3.11\""} -langchain-core = "^0.2.32" -langchain-text-splitters = "^0.2.0" +langchain-core = "^0.3.0.dev2" +langchain-text-splitters = "^0.3.0.dev1" langsmith = "^0.1.17" numpy = [ {version = ">=1,<2", markers = "python_version < \"3.12\""}, {version = ">=1.26.0,<2.0.0", markers = "python_version >= \"3.12\""}, ] -pydantic = ">=1,<3" +pydantic = "^2.7.4" PyYAML = ">=5.3" requests = "^2" SQLAlchemy = ">=1.4,<3" @@ -1339,27 +1336,24 @@ tenacity = "^8.1.0,!=8.4.0" [package.source] type = "git" url = "https://github.com/langchain-ai/langchain.git" -reference = "HEAD" -resolved_reference = "bda3becbe77a22ce49d77c59035727f3b2ed64f1" +reference = "v0.3rc" +resolved_reference = "6c8d626d701b9a9e365270d06a80e3a82de19963" subdirectory = "libs/langchain" [[package]] name = "langchain-core" -version = "0.2.33" +version = "0.3.0.dev4" description = "Building applications with LLMs through composability" optional = false -python-versions = ">=3.8.1,<4.0" +python-versions = ">=3.9,<4.0" files = [] develop = false [package.dependencies] jsonpatch = "^1.33" -langsmith = "^0.1.75" +langsmith = "^0.1.112" packaging = ">=23.2,<25" -pydantic = [ - {version = ">=1,<3", markers = "python_full_version < \"3.12.4\""}, - {version = ">=2.7.4,<3.0.0", markers = "python_full_version >= \"3.12.4\""}, -] +pydantic = "^2.7.4" PyYAML = ">=5.3" tenacity = "^8.1.0,!=8.4.0" typing-extensions = ">=4.7" @@ -1367,74 +1361,81 @@ typing-extensions = ">=4.7" [package.source] type = "git" url = "https://github.com/langchain-ai/langchain.git" -reference = "HEAD" -resolved_reference = "bda3becbe77a22ce49d77c59035727f3b2ed64f1" +reference = "v0.3rc" +resolved_reference = "6c8d626d701b9a9e365270d06a80e3a82de19963" subdirectory = "libs/core" [[package]] name = "langchain-mistralai" -version = "0.1.12" +version = "0.2.0.dev1" description = "An integration package connecting Mistral and LangChain" optional = false -python-versions = "<4.0,>=3.8.1" -files = [ - {file = "langchain_mistralai-0.1.12-py3-none-any.whl", hash = "sha256:5dc5f3a63a646f848eb5007e410745a11667dcd2bc42939049a84f59f85a9737"}, - {file = "langchain_mistralai-0.1.12.tar.gz", hash = "sha256:d13a55aa84d7defd7a547919643188fd8c18d5a15ac139a1deebbe7a0889047b"}, -] +python-versions = ">=3.9,<4.0" +files = [] +develop = false [package.dependencies] httpx = ">=0.25.2,<1" httpx-sse = ">=0.3.1,<1" -langchain-core = ">=0.2.26,<0.3.0" +langchain-core = "^0.3.0.dev4" tokenizers = ">=0.15.1,<1" +[package.source] +type = "git" +url = "https://github.com/langchain-ai/langchain.git" +reference = "v0.3rc" +resolved_reference = "2070d659a06d1919093df56b6a63180f56ef8afa" +subdirectory = "libs/partners/mistralai" + [[package]] name = "langchain-standard-tests" version = "0.1.1" description = "Standard tests for LangChain implementations" optional = false -python-versions = ">=3.8.1,<4.0" +python-versions = ">=3.9,<4.0" files = [] develop = false [package.dependencies] httpx = "^0.27.0" -langchain-core = ">=0.1.40,<0.3" +langchain-core = ">=0.3.0.dev1" pytest = ">=7,<9" +syrupy = "^4" [package.source] type = "git" url = "https://github.com/langchain-ai/langchain.git" -reference = "HEAD" -resolved_reference = "bda3becbe77a22ce49d77c59035727f3b2ed64f1" +reference = "v0.3rc" +resolved_reference = "6c8d626d701b9a9e365270d06a80e3a82de19963" subdirectory = "libs/standard-tests" [[package]] name = "langchain-text-splitters" -version = "0.2.2" +version = "0.3.0.dev1" description = "LangChain text splitting utilities" optional = false -python-versions = "<4.0,>=3.8.1" +python-versions = "<4.0,>=3.9" files = [ - {file = "langchain_text_splitters-0.2.2-py3-none-any.whl", hash = "sha256:1c80d4b11b55e2995f02d2a326c0323ee1eeff24507329bb22924e420c782dff"}, - {file = "langchain_text_splitters-0.2.2.tar.gz", hash = "sha256:a1e45de10919fa6fb080ef0525deab56557e9552083600455cb9fa4238076140"}, + {file = "langchain_text_splitters-0.3.0.dev1-py3-none-any.whl", hash = "sha256:85abe6ab1aa95e8cc3bf985cd9e53848de616c21d3590a25ac13a694d409f133"}, + {file = "langchain_text_splitters-0.3.0.dev1.tar.gz", hash = "sha256:5e13ca0f27719406c6c5575f48cdfb89755f02dd0f1c8af5d1f8a1a9f391f3a2"}, ] [package.dependencies] -langchain-core = ">=0.2.10,<0.3.0" +langchain-core = ">=0.3.0.dev1,<0.4.0" [[package]] name = "langsmith" -version = "0.1.99" +version = "0.1.117" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = "<4.0,>=3.8.1" files = [ - {file = "langsmith-0.1.99-py3-none-any.whl", hash = "sha256:ef8d1d74a2674c514aa429b0171a9fbb661207dc3835142cca0e8f1bf97b26b0"}, - {file = "langsmith-0.1.99.tar.gz", hash = "sha256:b5c6a1f158abda61600a4a445081ee848b4a28b758d91f2793dc02aeffafcaf1"}, + {file = "langsmith-0.1.117-py3-none-any.whl", hash = "sha256:e936ee9bcf8293b0496df7ba462a3702179fbe51f9dc28744b0fbec0dbf206ae"}, + {file = "langsmith-0.1.117.tar.gz", hash = "sha256:a1b532f49968b9339bcaff9118d141846d52ed3d803f342902e7448edf1d662b"}, ] [package.dependencies] +httpx = ">=0.23.0,<1" orjson = ">=3.9.14,<4.0.0" pydantic = [ {version = ">=1,<3", markers = "python_full_version < \"3.12.4\""}, @@ -1599,48 +1600,6 @@ files = [ {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, ] -[[package]] -name = "numexpr" -version = "2.8.6" -description = "Fast numerical expression evaluator for NumPy" -optional = false -python-versions = ">=3.7" -files = [ - {file = "numexpr-2.8.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:80acbfefb68bd92e708e09f0a02b29e04d388b9ae72f9fcd57988aca172a7833"}, - {file = "numexpr-2.8.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6e884687da8af5955dc9beb6a12d469675c90b8fb38b6c93668c989cfc2cd982"}, - {file = "numexpr-2.8.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9ef7e8aaa84fce3aba2e65f243d14a9f8cc92aafd5d90d67283815febfe43eeb"}, - {file = "numexpr-2.8.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dee04d72307c09599f786b9231acffb10df7d7a74b2ce3681d74a574880d13ce"}, - {file = "numexpr-2.8.6-cp310-cp310-win32.whl", hash = "sha256:211804ec25a9f6d188eadf4198dd1a92b2f61d7d20993c6c7706139bc4199c5b"}, - {file = "numexpr-2.8.6-cp310-cp310-win_amd64.whl", hash = "sha256:18b1804923cfa3be7bbb45187d01c0540c8f6df4928c22a0f786e15568e9ebc5"}, - {file = "numexpr-2.8.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:95b9da613761e4fc79748535b2a1f58cada22500e22713ae7d9571fa88d1c2e2"}, - {file = "numexpr-2.8.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:47b45da5aa25600081a649f5e8b2aa640e35db3703f4631f34bb1f2f86d1b5b4"}, - {file = "numexpr-2.8.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:84979bf14143351c2db8d9dd7fef8aca027c66ad9df9cb5e75c93bf5f7b5a338"}, - {file = "numexpr-2.8.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d36528a33aa9c23743b3ea686e57526a4f71e7128a1be66210e1511b09c4e4e9"}, - {file = "numexpr-2.8.6-cp311-cp311-win32.whl", hash = "sha256:681812e2e71ff1ba9145fac42d03f51ddf6ba911259aa83041323f68e7458002"}, - {file = "numexpr-2.8.6-cp311-cp311-win_amd64.whl", hash = "sha256:27782177a0081bd0aab229be5d37674e7f0ab4264ef576697323dd047432a4cd"}, - {file = "numexpr-2.8.6-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:ef6e8896457a60a539cb6ba27da78315a9bb31edb246829b25b5b0304bfcee91"}, - {file = "numexpr-2.8.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e640bc0eaf1b59f3dde52bc02bbfda98e62f9950202b0584deba28baf9f36bbb"}, - {file = "numexpr-2.8.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d126938c2c3784673c9c58d94e00b1570aa65517d9c33662234d442fc9fb5795"}, - {file = "numexpr-2.8.6-cp37-cp37m-win32.whl", hash = "sha256:e93d64cd20940b726477c3cb64926e683d31b778a1e18f9079a5088fd0d8e7c8"}, - {file = "numexpr-2.8.6-cp37-cp37m-win_amd64.whl", hash = "sha256:31cf610c952eec57081171f0b4427f9bed2395ec70ec432bbf45d260c5c0cdeb"}, - {file = "numexpr-2.8.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b5f96c89aa0b1f13685ec32fa3d71028db0b5981bfd99a0bbc271035949136b3"}, - {file = "numexpr-2.8.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c8f37f7a6af3bdd61f2efd1cafcc083a9525ab0aaf5dc641e7ec8fc0ae2d3aa1"}, - {file = "numexpr-2.8.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:38b8b90967026bbc36c7aa6e8ca3b8906e1990914fd21f446e2a043f4ee3bc06"}, - {file = "numexpr-2.8.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1967c16f61c27df1cdc43ba3c0ba30346157048dd420b4259832276144d0f64e"}, - {file = "numexpr-2.8.6-cp38-cp38-win32.whl", hash = "sha256:15469dc722b5ceb92324ec8635411355ebc702303db901ae8cc87f47c5e3a124"}, - {file = "numexpr-2.8.6-cp38-cp38-win_amd64.whl", hash = "sha256:95c09e814b0d6549de98b5ded7cdf7d954d934bb6b505432ff82e83a6d330bda"}, - {file = "numexpr-2.8.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:aa0f661f5f4872fd7350cc9895f5d2594794b2a7e7f1961649a351724c64acc9"}, - {file = "numexpr-2.8.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:8e3e6f1588d6c03877cb3b3dcc3096482da9d330013b886b29cb9586af5af3eb"}, - {file = "numexpr-2.8.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8564186aad5a2c88d597ebc79b8171b52fd33e9b085013e1ff2208f7e4b387e3"}, - {file = "numexpr-2.8.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d6a88d71c166e86b98d34701285d23e3e89d548d9f5ae3f4b60919ac7151949f"}, - {file = "numexpr-2.8.6-cp39-cp39-win32.whl", hash = "sha256:c48221b6a85494a7be5a022899764e58259af585dff031cecab337277278cc93"}, - {file = "numexpr-2.8.6-cp39-cp39-win_amd64.whl", hash = "sha256:6d7003497d82ef19458dce380b36a99343b96a3bd5773465c2d898bf8f5a38f9"}, - {file = "numexpr-2.8.6.tar.gz", hash = "sha256:6336f8dba3f456e41a4ffc3c97eb63d89c73589ff6e1707141224b930263260d"}, -] - -[package.dependencies] -numpy = ">=1.13.3" - [[package]] name = "numexpr" version = "2.10.1" @@ -1682,43 +1641,6 @@ files = [ [package.dependencies] numpy = ">=1.23.0" -[[package]] -name = "numpy" -version = "1.24.4" -description = "Fundamental package for array computing in Python" -optional = false -python-versions = ">=3.8" -files = [ - {file = "numpy-1.24.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0bfb52d2169d58c1cdb8cc1f16989101639b34c7d3ce60ed70b19c63eba0b64"}, - {file = "numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed094d4f0c177b1b8e7aa9cba7d6ceed51c0e569a5318ac0ca9a090680a6a1b1"}, - {file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79fc682a374c4a8ed08b331bef9c5f582585d1048fa6d80bc6c35bc384eee9b4"}, - {file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ffe43c74893dbf38c2b0a1f5428760a1a9c98285553c89e12d70a96a7f3a4d6"}, - {file = "numpy-1.24.4-cp310-cp310-win32.whl", hash = "sha256:4c21decb6ea94057331e111a5bed9a79d335658c27ce2adb580fb4d54f2ad9bc"}, - {file = "numpy-1.24.4-cp310-cp310-win_amd64.whl", hash = "sha256:b4bea75e47d9586d31e892a7401f76e909712a0fd510f58f5337bea9572c571e"}, - {file = "numpy-1.24.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f136bab9c2cfd8da131132c2cf6cc27331dd6fae65f95f69dcd4ae3c3639c810"}, - {file = "numpy-1.24.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e2926dac25b313635e4d6cf4dc4e51c8c0ebfed60b801c799ffc4c32bf3d1254"}, - {file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:222e40d0e2548690405b0b3c7b21d1169117391c2e82c378467ef9ab4c8f0da7"}, - {file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7215847ce88a85ce39baf9e89070cb860c98fdddacbaa6c0da3ffb31b3350bd5"}, - {file = "numpy-1.24.4-cp311-cp311-win32.whl", hash = "sha256:4979217d7de511a8d57f4b4b5b2b965f707768440c17cb70fbf254c4b225238d"}, - {file = "numpy-1.24.4-cp311-cp311-win_amd64.whl", hash = "sha256:b7b1fc9864d7d39e28f41d089bfd6353cb5f27ecd9905348c24187a768c79694"}, - {file = "numpy-1.24.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1452241c290f3e2a312c137a9999cdbf63f78864d63c79039bda65ee86943f61"}, - {file = "numpy-1.24.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:04640dab83f7c6c85abf9cd729c5b65f1ebd0ccf9de90b270cd61935eef0197f"}, - {file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5425b114831d1e77e4b5d812b69d11d962e104095a5b9c3b641a218abcc050e"}, - {file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd80e219fd4c71fc3699fc1dadac5dcf4fd882bfc6f7ec53d30fa197b8ee22dc"}, - {file = "numpy-1.24.4-cp38-cp38-win32.whl", hash = "sha256:4602244f345453db537be5314d3983dbf5834a9701b7723ec28923e2889e0bb2"}, - {file = "numpy-1.24.4-cp38-cp38-win_amd64.whl", hash = "sha256:692f2e0f55794943c5bfff12b3f56f99af76f902fc47487bdfe97856de51a706"}, - {file = "numpy-1.24.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2541312fbf09977f3b3ad449c4e5f4bb55d0dbf79226d7724211acc905049400"}, - {file = "numpy-1.24.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9667575fb6d13c95f1b36aca12c5ee3356bf001b714fc354eb5465ce1609e62f"}, - {file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3a86ed21e4f87050382c7bc96571755193c4c1392490744ac73d660e8f564a9"}, - {file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d11efb4dbecbdf22508d55e48d9c8384db795e1b7b51ea735289ff96613ff74d"}, - {file = "numpy-1.24.4-cp39-cp39-win32.whl", hash = "sha256:6620c0acd41dbcb368610bb2f4d83145674040025e5536954782467100aa8835"}, - {file = "numpy-1.24.4-cp39-cp39-win_amd64.whl", hash = "sha256:befe2bf740fd8373cf56149a5c23a0f601e82869598d41f8e188a0e9869926f8"}, - {file = "numpy-1.24.4-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:31f13e25b4e304632a4619d0e0777662c2ffea99fcae2029556b17d8ff958aef"}, - {file = "numpy-1.24.4-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95f7ac6540e95bc440ad77f56e520da5bf877f87dca58bd095288dce8940532a"}, - {file = "numpy-1.24.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e98f220aa76ca2a977fe435f5b04d7b3470c0a2e6312907b37ba6068f26787f2"}, - {file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"}, -] - [[package]] name = "numpy" version = "1.26.4" @@ -2833,5 +2755,5 @@ mistral = ["langchain-mistralai"] [metadata] lock-version = "2.0" -python-versions = ">=3.8.1,<4.0" -content-hash = "8240d23b8fc671345e140d2720298d549a40e05da220a901ee8b72a1b51bcb25" +python-versions = ">=3.9,<4.0" +content-hash = "00bdba3e2d330934ca0bd2384a01fdaaa42f3f4eefe33eafede493f946f5f604" diff --git a/libs/vertexai/pyproject.toml b/libs/vertexai/pyproject.toml index 00f6b259..380a97d1 100644 --- a/libs/vertexai/pyproject.toml +++ b/libs/vertexai/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langchain-google-vertexai" -version = "1.0.10" +version = "2.0.0.dev1" description = "An integration package connecting Google VertexAI and LangChain" authors = [] readme = "README.md" @@ -11,15 +11,16 @@ license = "MIT" "Source Code" = "https://github.com/langchain-ai/langchain-google/tree/main/libs/vertexai" [tool.poetry.dependencies] -python = ">=3.8.1,<4.0" -langchain-core = ">=0.2.33,<0.3" +python = ">=3.9,<4.0" +langchain-core = { version = "^0.3.0.dev4", allow-prereleases = true } google-cloud-aiplatform = "^1.56.0" google-cloud-storage = "^2.17.0" # optional dependencies anthropic = { extras = ["vertexai"], version = ">=0.30.0,<1", optional = true } -langchain-mistralai = { version = ">=0.1.12,<1", optional = true } +langchain-mistralai = { version = "^0.2.0.dev1", allow-prereleases = true } httpx = "^0.27.0" httpx-sse = "^0.4.0" +pydantic = ">=2,<3" [tool.poetry.group.test] optional = true @@ -39,9 +40,9 @@ numpy = [ { version = "^1.26.0", python = ">=3.12" }, ] google-api-python-client = "^2.117.0" -langchain = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/langchain" } -langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core" } -langchain-standard-tests = {git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/standard-tests"} +langchain = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/langchain", branch = "v0.3rc" } +langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core", branch = "v0.3rc" } +langchain-standard-tests = {git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/standard-tests", branch = "v0.3rc"} [tool.codespell] @@ -68,8 +69,8 @@ optional = true numexpr = { version = "^2.8.8", python = ">=3.9,<4.0" } google-api-python-client = "^2.114.0" google-cloud-datastore = "^2.19.0" -langchain = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/langchain" } -langchain-mistralai = "^0.1.12" +langchain = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/langchain", branch = "v0.3rc" } +langchain-mistralai = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/partners/mistralai", branch = "v0.3rc"} [tool.poetry.group.lint] optional = true @@ -83,12 +84,12 @@ mypy = "^1" types-google-cloud-ndb = "^2.2.0.20240106" types-protobuf = "^4.24.0.4" types-requests = "^2.31.0" -langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core" } +langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core", branch = "v0.3rc" } [tool.poetry.group.dev.dependencies] -langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core" } +langchain-core = { git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core", branch = "v0.3rc" } [tool.ruff.lint] select = [ diff --git a/libs/vertexai/scripts/check_pydantic.sh b/libs/vertexai/scripts/check_pydantic.sh deleted file mode 100755 index 06b5bb81..00000000 --- a/libs/vertexai/scripts/check_pydantic.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash -# -# This script searches for lines starting with "import pydantic" or "from pydantic" -# in tracked files within a Git repository. -# -# Usage: ./scripts/check_pydantic.sh /path/to/repository - -# Check if a path argument is provided -if [ $# -ne 1 ]; then - echo "Usage: $0 /path/to/repository" - exit 1 -fi - -repository_path="$1" - -# Search for lines matching the pattern within the specified repository -result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic') - -# Check if any matching lines were found -if [ -n "$result" ]; then - echo "ERROR: The following lines need to be updated:" - echo "$result" - echo "Please replace the code with an import from langchain_core.pydantic_v1." - echo "For example, replace 'from pydantic import BaseModel'" - echo "with 'from langchain_core.pydantic_v1 import BaseModel'" - exit 1 -fi diff --git a/libs/vertexai/tests/integration_tests/test_chains.py b/libs/vertexai/tests/integration_tests/test_chains.py index e2e3560f..295219f2 100644 --- a/libs/vertexai/tests/integration_tests/test_chains.py +++ b/libs/vertexai/tests/integration_tests/test_chains.py @@ -5,7 +5,7 @@ AIMessage, ) from langchain_core.prompts import ChatPromptTemplate -from langchain_core.pydantic_v1 import BaseModel, Field +from pydantic import BaseModel, Field from langchain_google_vertexai import ChatVertexAI, create_structured_runnable from tests.integration_tests.conftest import _DEFAULT_MODEL_NAME diff --git a/libs/vertexai/tests/integration_tests/test_chat_models.py b/libs/vertexai/tests/integration_tests/test_chat_models.py index e58a3223..7965a5b7 100644 --- a/libs/vertexai/tests/integration_tests/test_chat_models.py +++ b/libs/vertexai/tests/integration_tests/test_chat_models.py @@ -22,9 +22,9 @@ ) from langchain_core.outputs import ChatGeneration, LLMResult from langchain_core.prompts import ChatPromptTemplate -from langchain_core.pydantic_v1 import BaseModel from langchain_core.rate_limiters import InMemoryRateLimiter from langchain_core.tools import tool +from pydantic import BaseModel from langchain_google_vertexai import ( ChatVertexAI, @@ -302,7 +302,7 @@ def get_climate_info(query: str): prompt=prompt_template, ) agent_executor = agents.AgentExecutor( # type: ignore[call-arg] - agent=agent, # type: ignore[arg-type] + agent=agent, tools=tools, # type: ignore[arg-type] verbose=False, stream_runnable=False, @@ -616,7 +616,11 @@ class MyModel(BaseModel): assert response == MyModel(name="Erick", age=27) model = llm.with_structured_output( - {"name": "MyModel", "description": "MyModel", "parameters": MyModel.schema()} + { + "name": "MyModel", + "description": "MyModel", + "parameters": MyModel.model_json_schema(), + } ) response = model.invoke([message]) assert response == { diff --git a/libs/vertexai/tests/integration_tests/test_model_garden.py b/libs/vertexai/tests/integration_tests/test_model_garden.py index d0a2a781..a0946a6c 100644 --- a/libs/vertexai/tests/integration_tests/test_model_garden.py +++ b/libs/vertexai/tests/integration_tests/test_model_garden.py @@ -11,8 +11,8 @@ SystemMessage, ) from langchain_core.outputs import LLMResult -from langchain_core.pydantic_v1 import BaseModel from langchain_core.tools import tool +from pydantic import BaseModel from langchain_google_vertexai.model_garden import ( ChatAnthropicVertex, diff --git a/libs/vertexai/tests/unit_tests/__snapshots__/test_standard.ambr b/libs/vertexai/tests/unit_tests/__snapshots__/test_standard.ambr new file mode 100644 index 00000000..65c7176a --- /dev/null +++ b/libs/vertexai/tests/unit_tests/__snapshots__/test_standard.ambr @@ -0,0 +1,133 @@ +# serializer version: 1 +# name: TestGeminiAIStandard.test_serdes[serialized] + dict({ + 'graph': dict({ + 'edges': list([ + dict({ + 'source': 0, + 'target': 1, + }), + dict({ + 'source': 1, + 'target': 2, + }), + ]), + 'nodes': list([ + dict({ + 'data': 'ChatVertexAIInput', + 'id': 0, + 'type': 'schema', + }), + dict({ + 'data': dict({ + 'id': list([ + 'langchain', + 'chat_models', + 'vertexai', + 'ChatVertexAI', + ]), + 'name': 'ChatVertexAI', + }), + 'id': 1, + 'type': 'runnable', + }), + dict({ + 'data': 'ChatVertexAIOutput', + 'id': 2, + 'type': 'schema', + }), + ]), + }), + 'id': list([ + 'langchain', + 'chat_models', + 'vertexai', + 'ChatVertexAI', + ]), + 'kwargs': dict({ + 'default_metadata': list([ + ]), + 'full_model_name': 'projects/test-project/locations/us-central1/publishers/google/models/gemini-1.0-pro-001', + 'location': 'us-central1', + 'max_output_tokens': 100, + 'max_retries': 2, + 'model_family': '1', + 'model_name': 'gemini-1.0-pro-001', + 'n': 1, + 'project': 'test-project', + 'request_parallelism': 5, + 'stop': list([ + ]), + 'temperature': 0.0, + }), + 'lc': 1, + 'name': 'ChatVertexAI', + 'type': 'constructor', + }) +# --- +# name: TestGemini_15_AIStandard.test_serdes[serialized] + dict({ + 'graph': dict({ + 'edges': list([ + dict({ + 'source': 0, + 'target': 1, + }), + dict({ + 'source': 1, + 'target': 2, + }), + ]), + 'nodes': list([ + dict({ + 'data': 'ChatVertexAIInput', + 'id': 0, + 'type': 'schema', + }), + dict({ + 'data': dict({ + 'id': list([ + 'langchain', + 'chat_models', + 'vertexai', + 'ChatVertexAI', + ]), + 'name': 'ChatVertexAI', + }), + 'id': 1, + 'type': 'runnable', + }), + dict({ + 'data': 'ChatVertexAIOutput', + 'id': 2, + 'type': 'schema', + }), + ]), + }), + 'id': list([ + 'langchain', + 'chat_models', + 'vertexai', + 'ChatVertexAI', + ]), + 'kwargs': dict({ + 'default_metadata': list([ + ]), + 'full_model_name': 'projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-pro-001', + 'location': 'us-central1', + 'max_output_tokens': 100, + 'max_retries': 2, + 'model_family': '2', + 'model_name': 'gemini-1.5-pro-001', + 'n': 1, + 'project': 'test-project', + 'request_parallelism': 5, + 'stop': list([ + ]), + 'temperature': 0.0, + }), + 'lc': 1, + 'name': 'ChatVertexAI', + 'type': 'constructor', + }) +# --- diff --git a/libs/vertexai/tests/unit_tests/test_chat_models.py b/libs/vertexai/tests/unit_tests/test_chat_models.py index 27727acb..5df10ad5 100644 --- a/libs/vertexai/tests/unit_tests/test_chat_models.py +++ b/libs/vertexai/tests/unit_tests/test_chat_models.py @@ -28,7 +28,7 @@ from langchain_core.output_parsers.openai_tools import ( PydanticToolsParser, ) -from langchain_core.pydantic_v1 import BaseModel +from pydantic import BaseModel from vertexai.language_models import ( # type: ignore ChatMessage, InputOutputTextPair, diff --git a/libs/vertexai/tests/unit_tests/test_embeddings.py b/libs/vertexai/tests/unit_tests/test_embeddings.py index b47d8b5b..e1b3ceb8 100644 --- a/libs/vertexai/tests/unit_tests/test_embeddings.py +++ b/libs/vertexai/tests/unit_tests/test_embeddings.py @@ -2,7 +2,8 @@ from unittest.mock import MagicMock import pytest -from langchain_core.pydantic_v1 import root_validator +from pydantic import model_validator +from typing_extensions import Self from langchain_google_vertexai import VertexAIEmbeddings from langchain_google_vertexai.embeddings import GoogleEmbeddingModelType @@ -41,7 +42,7 @@ def __init__(self, model_name, **kwargs: Any) -> None: def _init_vertexai(cls, values: Dict) -> None: pass - @root_validator(pre=False, skip_on_failure=True) - def validate_environment(cls, values: Dict) -> Dict: - values["client"] = MagicMock() - return values + @model_validator(mode="after") + def validate_environment(self) -> Self: + self.client = MagicMock() + return self diff --git a/libs/vertexai/tests/unit_tests/test_function_utils.py b/libs/vertexai/tests/unit_tests/test_function_utils.py index d9cab148..96a04860 100644 --- a/libs/vertexai/tests/unit_tests/test_function_utils.py +++ b/libs/vertexai/tests/unit_tests/test_function_utils.py @@ -1,19 +1,26 @@ import json from enum import Enum -from typing import Any, Dict, List, Optional, Sequence, Tuple, cast +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast from unittest.mock import Mock, patch import google.cloud.aiplatform_v1beta1.types as gapic import pytest import vertexai.generative_models as vertexai # type: ignore -from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.tools import BaseTool, tool from langchain_core.utils.json_schema import dereference_refs +from pydantic import BaseModel, Field +from pydantic.v1 import ( + BaseModel as BaseModelV1, +) +from pydantic.v1 import ( + Field as FieldV1, +) from langchain_google_vertexai.functions_utils import ( _format_base_tool_to_function_declaration, _format_dict_to_function_declaration, _format_json_schema_to_gapic, + _format_json_schema_to_gapic_v1, _format_pydantic_to_function_declaration, _format_to_gapic_function_declaration, _format_to_gapic_tool, @@ -27,6 +34,28 @@ def test_format_json_schema_to_gapic(): + # Simple case + class RecordPerson(BaseModel): + """Record some identifying information about a person.""" + + name: str + age: Optional[int] + + schema = RecordPerson.model_json_schema() + result = _format_json_schema_to_gapic(schema) + expected = { + "title": "RecordPerson", + "type": "OBJECT", + "description": "Record some identifying information about a person.", + "properties": { + "name": {"title": "Name", "type": "STRING"}, + "age": {"type": "INTEGER", "title": "Age"}, + }, + "required": ["name"], + } + assert result == expected + + # Nested case class StringEnum(str, Enum): pear = "pear" banana = "banana" @@ -39,14 +68,120 @@ class A(BaseModel): class B(BaseModel): object_field: Optional[A] = Field(description="Class A") array_field: Sequence[A] - int_field: int = Field(description="int field", minimum=1, maximum=10) + int_field: int = Field(description="int field", ge=1, le=10) str_field: str = Field( + min_length=1, + max_length=10, + pattern="^[A-Z]{1,10}$", + json_schema_extra={"example": "ABCD"}, + ) + str_enum_field: StringEnum + + schema = B.model_json_schema() + result = _format_json_schema_to_gapic(dereference_refs(schema)) + + expected = { + "properties": { + "object_field": { + "description": "Class A", + "properties": {"int_field": {"type": "INTEGER", "title": "Int Field"}}, + "required": [], + "title": "A", + "type": "OBJECT", + }, + "array_field": { + "items": { + "description": "Class A", + "properties": { + "int_field": {"type": "INTEGER", "title": "Int Field"} + }, + "required": [], + "title": "A", + "type": "OBJECT", + }, + "type": "ARRAY", + "title": "Array Field", + }, + "int_field": { + "description": "int field", + "maximum": 10, + "minimum": 1, + "title": "Int Field", + "type": "INTEGER", + }, + "str_field": { + "example": "ABCD", + "maxLength": 10, + "minLength": 1, + "pattern": "^[A-Z]{1,10}$", + "title": "Str Field", + "type": "STRING", + }, + "str_enum_field": { + "enum": ["pear", "banana"], + "title": "StringEnum", + "type": "STRING", + }, + }, + "type": "OBJECT", + "title": "B", + "required": ["array_field", "int_field", "str_field", "str_enum_field"], + } + assert result == expected + + gapic_schema = cast(gapic.Schema, gapic.Schema.from_json(json.dumps(result))) + assert gapic_schema.type_ == gapic.Type.OBJECT + assert gapic_schema.title == expected["title"] + assert gapic_schema.required == expected["required"] + assert ( + gapic_schema.properties["str_field"].example + == expected["properties"]["str_field"]["example"] # type: ignore + ) + + +def test_format_json_schema_to_gapic_v1(): + # Simple case + class RecordPerson(BaseModelV1): + """Record some identifying information about a person.""" + + name: str + age: Optional[int] + + schema = RecordPerson.schema() + result = _format_json_schema_to_gapic_v1(schema) + expected = { + "title": "RecordPerson", + "type": "OBJECT", + "description": "Record some identifying information about a person.", + "properties": { + "name": {"title": "Name", "type": "STRING"}, + "age": {"type": "INTEGER", "title": "Age"}, + }, + "required": ["name"], + } + assert result == expected + + # Nested case + class StringEnum(str, Enum): + pear = "pear" + banana = "banana" + + class A(BaseModelV1): + """Class A""" + + int_field: Optional[int] + + class B(BaseModelV1): + object_field: Optional[A] = FieldV1(description="Class A") + array_field: Sequence[A] + int_field: int = FieldV1(description="int field", minimum=1, maximum=10) + str_field: str = FieldV1( min_length=1, max_length=10, pattern="^[A-Z]{1,10}$", example="ABCD" ) str_enum_field: StringEnum schema = B.schema() - result = _format_json_schema_to_gapic(dereference_refs(schema)) + result = _format_json_schema_to_gapic_v1(dereference_refs(schema)) expected = { "properties": { @@ -106,6 +241,27 @@ class B(BaseModel): ) +def test_format_json_schema_to_gapic_union_types() -> None: + """Test that union types are consistent between v1 and v2.""" + + class RecordPerson_v1(BaseModelV1): + name: str + age: Union[int, str] + + class RecordPerson(BaseModel): + name: str + age: Union[int, str] + + schema_v1 = RecordPerson_v1.schema() + schema_v2 = RecordPerson.model_json_schema() + + result_v1 = _format_json_schema_to_gapic_v1(schema_v1) + result_v2 = _format_json_schema_to_gapic(schema_v2) + result_v1["title"] = "RecordPerson" + + assert result_v1 == result_v2 + + # reusable test inputs def search(question: str) -> str: """Search tool""" @@ -164,7 +320,7 @@ class SearchModel(BaseModel): question: str -search_model_schema = SearchModel.schema() +search_model_schema = SearchModel.model_json_schema() search_model_dict = { "name": search_model_schema["title"], "description": search_model_schema["description"], diff --git a/libs/vertexai/tests/unit_tests/test_llm.py b/libs/vertexai/tests/unit_tests/test_llm.py index e9271b41..ac6afdba 100644 --- a/libs/vertexai/tests/unit_tests/test_llm.py +++ b/libs/vertexai/tests/unit_tests/test_llm.py @@ -2,7 +2,8 @@ from unittest import TestCase from unittest.mock import MagicMock, patch -from langchain_core.pydantic_v1 import root_validator +from pydantic import model_validator +from typing_extensions import Self from langchain_google_vertexai._base import _BaseVertexAIModelGarden from langchain_google_vertexai.llms import VertexAI @@ -80,9 +81,9 @@ def test_vertexai_args_passed() -> None: def test_extract_response() -> None: class FakeModelGarden(_BaseVertexAIModelGarden): - @root_validator(pre=False, skip_on_failure=True) - def validate_environment(cls, values: Dict) -> Dict: - return values + @model_validator(mode="after") + def validate_environment(self) -> Self: + return self prompts_results = [ ("a prediction", "a prediction"),