Skip to content

Commit

Permalink
兼容PYDANTIC_V2
Browse files Browse the repository at this point in the history
  • Loading branch information
glide-the committed Jul 17, 2024
1 parent ce72ee3 commit 5ac6485
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 14 deletions.
2 changes: 1 addition & 1 deletion langchain_glm/agents/output_parsers/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def parse_ai_message_to_tool_action(
actions.append(function_tool_result_stack.popleft())
else:
for too_call in tool_calls:
if "function" == too_call["name"]:
if too_call["name"] not in AdapterAllToolStructType.__members__.values():
actions.append(function_tool_result_stack.popleft())
elif too_call["name"] == AdapterAllToolStructType.CODE_INTERPRETER:
actions.append(code_interpreter_action_result_stack.popleft())
Expand Down
4 changes: 2 additions & 2 deletions langchain_glm/agents/zhipuai_all_tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ class ZhipuAIAllToolsRunnable(RunnableSerializable[Dict, OutputType]):
agent_executor: AgentExecutor
"""ZhipuAI AgentExecutor."""

model_name: str = Field(default="tob-alltools-api-dev")
model_name: str = Field(default="glm-4-alltools")
"""工具模型"""
callback: AgentExecutorAsyncIteratorCallbackHandler
"""ZhipuAI AgentExecutor callback."""
Expand Down Expand Up @@ -193,7 +193,7 @@ def create_agent_executor(
streaming=True,
verbose=True,
callbacks=callbacks,
model_name=model_name,
model=model_name,
temperature=temperature,
**kwargs,
)
Expand Down
19 changes: 8 additions & 11 deletions langchain_glm/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import zhipuai
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import (
BaseModel,
Extra,
Field,
SecretStr,
Expand All @@ -32,8 +33,6 @@
get_from_dict_or_env,
get_pydantic_field_names,
)
from typing_extensions import ClassVar
from zhipuai.core import PYDANTIC_V2, BaseModel, ConfigDict

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -84,16 +83,14 @@ class ZhipuAIEmbeddings(BaseModel, Embeddings):
http_client: Union[Any, None] = None
"""Optional httpx.Client."""

if PYDANTIC_V2:
model_config: ClassVar[ConfigDict] = ConfigDict(
extra="forbid", populate_by_name=True
)
else:

class Config:
allow_population_by_field_name = True
class Config:
"""Configuration for this pydantic object."""

extra = Extra.forbid
allow_population_by_field_name = True

@root_validator(pre=True)
@root_validator(pre=True, allow_reuse=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = get_pydantic_field_names(cls)
Expand All @@ -119,7 +116,7 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
values["model_kwargs"] = extra
return values

@root_validator()
@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
zhipuai_api_key = get_from_dict_or_env(
Expand Down

0 comments on commit 5ac6485

Please sign in to comment.