Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vertex[major]: upgrade pydantic #475

Merged
merged 37 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
3fd9569
update deps
ccurme Sep 5, 2024
36e7f9f
delete check_pydantic script
ccurme Sep 5, 2024
b6615cc
to_pydantic_2
ccurme Sep 5, 2024
c9fc32d
model_before_rewrite
ccurme Sep 5, 2024
669a9cd
model_after_rewrite
ccurme Sep 5, 2024
ca5a4de
Self
ccurme Sep 5, 2024
0d49776
format
ccurme Sep 5, 2024
9f7d7b8
clean up
ccurme Sep 5, 2024
52c1f61
model_before_rewrite
ccurme Sep 5, 2024
8a1cd84
change VertexAI.validate_environment to pre
ccurme Sep 5, 2024
69c279e
lint
ccurme Sep 5, 2024
a1105a3
update chat and embeddings validation to pre
ccurme Sep 5, 2024
574a51b
update some features to pydantic 2
ccurme Sep 5, 2024
6b8171d
remove unused type ignores
ccurme Sep 5, 2024
bfed90d
fix validate_environment in llm and chat models
ccurme Sep 5, 2024
f282bfe
more validation updates
ccurme Sep 5, 2024
0188eef
change maas model garden validation to post
ccurme Sep 5, 2024
a0d9d0b
add protected namespaces to embeddings
ccurme Sep 5, 2024
ddf04c3
fix embeddings init
ccurme Sep 5, 2024
973a304
update docstrings
ccurme Sep 5, 2024
2db3c58
Merge branch 'main' into cc/upgrade_pydantic
lkuligin Sep 6, 2024
f0b8513
Merge branch 'main' into cc/upgrade_pydantic
ccurme Sep 6, 2024
9a0a6b7
update _format_json_schema_to_gapic
ccurme Sep 8, 2024
03a72e7
support v1 function
ccurme Sep 8, 2024
20b395f
add test for union types
ccurme Sep 8, 2024
0063ff6
vertex: fix gapic conversion (#482)
ccurme Sep 8, 2024
82a78e3
Merge branch 'v0.3rc' into cc/upgrade_pydantic
ccurme Sep 8, 2024
d3f6b28
increment version to 2.0.0.dev1
ccurme Sep 9, 2024
1c1894f
Merge branch 'cc/upgrade_pydantic' of github.com:langchain-ai/langcha…
ccurme Sep 9, 2024
29b7a2a
add snapshots for serialization standard test
ccurme Sep 9, 2024
dd5d8fe
bump core dep
ccurme Sep 9, 2024
36e9ec4
json_schema_extra in test
ccurme Sep 9, 2024
ec4bc72
protected namespaces
ccurme Sep 9, 2024
1dc420c
fix warnings
ccurme Sep 9, 2024
e598e27
fix warnings
ccurme Sep 9, 2024
0c4c5b2
fix mistral dependency
ccurme Sep 9, 2024
97b32d8
fix mistral dep and lock
ccurme Sep 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion libs/vertexai/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ integration_test integration_tests: TEST_FILE = tests/integration_tests/
test tests integration_test integration_tests:
poetry run pytest --release $(TEST_FILE)

test_watch:
poetry run ptw --snapshot-update --now . -- -vv $(TEST_FILE)

# Run unit tests and generate a coverage report.
coverage:
poetry run pytest --cov \
Expand All @@ -33,7 +36,6 @@ lint_tests: PYTHON_FILES=tests
lint_tests: MYPY_CACHE=.mypy_cache_test

lint lint_diff lint_package lint_tests:
./scripts/check_pydantic.sh .
./scripts/lint_imports.sh
poetry run ruff check .
poetry run ruff format $(PYTHON_FILES) --diff
Expand Down
7 changes: 4 additions & 3 deletions libs/vertexai/langchain_google_vertexai/_anthropic_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
from langchain_core.messages.tool import tool_call
from langchain_core.output_parsers import BaseGenerationOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel, ConfigDict


class ToolsOutputParser(BaseGenerationOutputParser):
first_tool_only: bool = False
args_only: bool = False
pydantic_schemas: Optional[List[Type[BaseModel]]] = None

class Config:
extra = "forbid"
model_config = ConfigDict(
extra="forbid",
)

def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
"""Parse a list of candidate model Generations into a specific format.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
ToolMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.tools import BaseTool
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import BaseModel

if TYPE_CHECKING:
from anthropic.types import RawMessageStreamEvent # type: ignore
Expand Down
34 changes: 18 additions & 16 deletions libs/vertexai/langchain_google_vertexai/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from google.protobuf import json_format
from google.protobuf.struct_pb2 import Value
from langchain_core.outputs import Generation, LLMResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator
from typing_extensions import Self
from vertexai.generative_models._generative_models import ( # type: ignore
SafetySettingsType,
)
Expand Down Expand Up @@ -91,14 +92,15 @@ class _VertexAIBase(BaseModel):
"when making API calls. If not provided, credentials will be ascertained from "
"the environment."

class Config:
"""Configuration for this pydantic object."""
model_config = ConfigDict(
populate_by_name=True,
arbitrary_types_allowed=True,
protected_namespaces=(),
)

allow_population_by_field_name = True
arbitrary_types_allowed = True

@root_validator(pre=True)
def validate_params_base(cls, values: dict) -> dict:
@model_validator(mode="before")
@classmethod
def validate_params_base(cls, values: dict) -> Any:
if "model" in values and "model_name" not in values:
values["model_name"] = values.pop("model")
if values.get("project") is None:
Expand All @@ -108,7 +110,7 @@ def validate_params_base(cls, values: dict) -> dict:
if values.get("api_endpoint"):
api_endpoint = values["api_endpoint"]
else:
location = values.get("location", cls.__fields__["location"].default)
location = values.get("location", cls.model_fields["location"].default)
api_endpoint = f"{location}-{constants.PREDICTION_API_BASE_PATH}"
client_options = ClientOptions(api_endpoint=api_endpoint)
if values.get("client_cert_source"):
Expand Down Expand Up @@ -311,26 +313,26 @@ class _BaseVertexAIModelGarden(_VertexAIBase):
single_example_per_request: bool = True
"LLM endpoint currently serves only the first example in the request"

@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that the python package exists in environment."""

if not values["project"]:
if not self.project:
raise ValueError(
"A GCP project should be provided to run inference on Model Garden!"
)

client_options = ClientOptions(
api_endpoint=f"{values['location']}-aiplatform.googleapis.com"
api_endpoint=f"{self.location}-aiplatform.googleapis.com"
)
client_info = get_client_info(module="vertex-ai-model-garden")
values["client"] = PredictionServiceClient(
self.client = PredictionServiceClient(
client_options=client_options, client_info=client_info
)
values["async_client"] = PredictionServiceAsyncClient(
self.async_client = PredictionServiceAsyncClient(
client_options=client_options, client_info=client_info
)
return values
return self

@property
def endpoint_path(self) -> str:
Expand Down
11 changes: 8 additions & 3 deletions libs/vertexai/langchain_google_vertexai/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
StrOutputParser,
)
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable
from pydantic import BaseModel

from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser

Expand Down Expand Up @@ -51,7 +51,12 @@ def _create_structured_runnable_extra_step(
*,
prompt: Optional[BasePromptTemplate] = None,
) -> Runnable:
names = [schema.schema()["title"] for schema in functions]
names = [
schema.model_json_schema()["title"]
if hasattr(schema, "model_json_schema")
else schema.schema()["title"]
for schema in functions
]
if hasattr(llm, "is_gemini_advanced") and llm._is_gemini_advanced: # type: ignore
llm_with_functions = llm.bind(
functions=functions,
Expand Down Expand Up @@ -111,7 +116,7 @@ def create_structured_runnable(

from langchain_google_vertexai import ChatVertexAI, create_structured_runnable
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field


class RecordPerson(BaseModel):
Expand Down
86 changes: 48 additions & 38 deletions libs/vertexai/langchain_google_vertexai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
)
from langchain_core.output_parsers.openai_tools import parse_tool_calls
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, root_validator, Field
from pydantic import BaseModel, Field, model_validator
from langchain_core.runnables import Runnable, RunnablePassthrough, RunnableGenerator
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.pydantic import is_basemodel_subclass
Expand Down Expand Up @@ -124,6 +124,9 @@
_format_to_gapic_tool,
_ToolType,
)
from pydantic import ConfigDict
from typing_extensions import Self


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -762,7 +765,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
Tool calling:
.. code-block:: python

from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field

class GetWeather(BaseModel):
'''Get the current weather in a given location'''
Expand Down Expand Up @@ -800,7 +803,7 @@ class GetPopulation(BaseModel):

from typing import Optional

from langchain_core.pydantic_v1 import BaseModel, Field
from pydantic import BaseModel, Field

class Joke(BaseModel):
'''Joke to tell user.'''
Expand Down Expand Up @@ -1024,11 +1027,10 @@ def __init__(self, *, model_name: Optional[str] = None, **kwargs: Any) -> None:
kwargs["model_name"] = model_name
super().__init__(**kwargs)

class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True
arbitrary_types_allowed = True
model_config = ConfigDict(
populate_by_name=True,
arbitrary_types_allowed=True,
)

@classmethod
def is_lc_serializable(self) -> bool:
Expand All @@ -1039,57 +1041,65 @@ def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "chat_models", "vertexai"]

@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validate that the python package exists in environment."""
safety_settings = values.get("safety_settings")
tuned_model_name = values.get("tuned_model_name")
values["model_family"] = GoogleModelFamily(values["model_name"])
safety_settings = self.safety_settings
tuned_model_name = self.tuned_model_name
self.model_family = GoogleModelFamily(self.model_name)

if values["model_name"] == "chat-bison-default":
if self.model_name == "chat-bison-default":
logger.warning(
"Model_name will become a required arg for VertexAIEmbeddings "
"starting from Sep-01-2024. Currently the default is set to "
"chat-bison"
)
values["model_name"] = "chat-bison"
self.model_name = "chat-bison"

if values.get("full_model_name") is not None:
if self.full_model_name is not None:
pass
elif values.get("tuned_model_name") is not None:
values["full_model_name"] = _format_model_name(
values["tuned_model_name"],
location=values["location"],
project=values["project"],
elif self.tuned_model_name is not None:
self.full_model_name = _format_model_name(
self.tuned_model_name,
location=self.location,
project=cast(str, self.project),
)
else:
values["full_model_name"] = _format_model_name(
values["model_name"],
location=values["location"],
project=values["project"],
self.full_model_name = _format_model_name(
self.model_name,
location=self.location,
project=cast(str, self.project),
)

if safety_settings and not is_gemini_model(values["model_family"]):
if safety_settings and not is_gemini_model(self.model_family):
raise ValueError("Safety settings are only supported for Gemini models")

if tuned_model_name:
generative_model_name = values["tuned_model_name"]
generative_model_name = self.tuned_model_name
else:
generative_model_name = values["model_name"]

if not is_gemini_model(values["model_family"]):
cls._init_vertexai(values)
if values["model_family"] == GoogleModelFamily.CODEY:
generative_model_name = self.model_name

if not is_gemini_model(self.model_family):
values = {
"project": self.project,
"location": self.location,
"credentials": self.credentials,
"api_transport": self.api_transport,
"api_endpoint": self.api_endpoint,
"default_metadata": self.default_metadata,
}
self._init_vertexai(values)
if self.model_family == GoogleModelFamily.CODEY:
model_cls = CodeChatModel
model_cls_preview = PreviewCodeChatModel
else:
model_cls = ChatModel
model_cls_preview = PreviewChatModel
values["client"] = model_cls.from_pretrained(generative_model_name)
values["client_preview"] = model_cls_preview.from_pretrained(
self.client = model_cls.from_pretrained(generative_model_name)
self.client_preview = model_cls_preview.from_pretrained(
generative_model_name
)
return values
return self

@property
def _is_gemini_advanced(self) -> bool:
Expand Down Expand Up @@ -1647,7 +1657,7 @@ def with_structured_output(
Example: Pydantic schema, exclude raw:
.. code-block:: python

from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
from langchain_google_vertexai import ChatVertexAI

class AnswerWithJustification(BaseModel):
Expand All @@ -1666,7 +1676,7 @@ class AnswerWithJustification(BaseModel):
Example: Pydantic schema, include raw:
.. code-block:: python

from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
from langchain_google_vertexai import ChatVertexAI

class AnswerWithJustification(BaseModel):
Expand All @@ -1687,7 +1697,7 @@ class AnswerWithJustification(BaseModel):
Example: Dict schema, exclude raw:
.. code-block:: python

from langchain_core.pydantic_v1 import BaseModel
from pydantic import BaseModel
from langchain_core.utils.function_calling import convert_to_openai_function
from langchain_google_vertexai import ChatVertexAI

Expand Down
36 changes: 23 additions & 13 deletions libs/vertexai/langchain_google_vertexai/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from google.cloud.aiplatform import telemetry
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.llms import create_base_retry_decorator
from langchain_core.pydantic_v1 import root_validator
from pydantic import ConfigDict, model_validator
from typing_extensions import Self
from vertexai.language_models import ( # type: ignore
TextEmbeddingInput,
TextEmbeddingModel,
Expand Down Expand Up @@ -100,24 +101,33 @@ class VertexAIEmbeddings(_VertexAICommon, Embeddings):
# Instance context
instance: Dict[str, Any] = {} #: :meta private:

@root_validator(pre=False, skip_on_failure=True)
def validate_environment(cls, values: Dict) -> Dict:
model_config = ConfigDict(
extra="forbid",
protected_namespaces=(),
)

@model_validator(mode="after")
def validate_environment(self) -> Self:
"""Validates that the python package exists in environment."""
cls._init_vertexai(values)
_, user_agent = get_user_agent(f"{cls.__name__}_{values['model_name']}") # type: ignore
values = {
"project": self.project,
"location": self.location,
"credentials": self.credentials,
"api_transport": self.api_transport,
"api_endpoint": self.api_endpoint,
"default_metadata": self.default_metadata,
}
self._init_vertexai(values)
_, user_agent = get_user_agent(f"{self.__class__.__name__}_{self.model_name}")
with telemetry.tool_context_manager(user_agent):
if (
GoogleEmbeddingModelType(values["model_name"])
GoogleEmbeddingModelType(self.model_name)
== GoogleEmbeddingModelType.MULTIMODAL
):
values["client"] = MultiModalEmbeddingModel.from_pretrained(
values["model_name"]
)
self.client = MultiModalEmbeddingModel.from_pretrained(self.model_name)
else:
values["client"] = TextEmbeddingModel.from_pretrained(
values["model_name"]
)
return values
self.client = TextEmbeddingModel.from_pretrained(self.model_name)
return self

def __init__(
self,
Expand Down
Loading
Loading