Skip to content

Commit

Permalink
名称修改
Browse files Browse the repository at this point in the history
  • Loading branch information
glide-the committed Jun 25, 2024
1 parent e44290e commit b88b77b
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 42 deletions.
2 changes: 1 addition & 1 deletion langchain_zhipuai/agents/all_tools_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _perform_agent_action(
# We then call the tool on the tool input to get an observation
# TODO: platform adapter tool for all tools,
# view tools binding langchain_zhipuai/agents/zhipuai_all_tools/base.py:188
if "code_interpreter" in agent_action.tool:
if agent_action.tool in AdapterAllToolStructType.__members__.values():
observation = tool.run(
{
"agent_action": agent_action,
Expand Down
10 changes: 8 additions & 2 deletions langchain_zhipuai/agents/format_scratchpad/all_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,19 @@ def format_to_zhipuai_all_tool_messages(

elif isinstance(agent_action, DrawingToolAgentAction):
if isinstance(observation, DrawingToolOutput):
messages.append(AIMessage(content=str(observation)))
new_messages = list(agent_action.message_log) + [
_create_tool_message(agent_action, observation)
]
messages.extend([new for new in new_messages if new not in messages])
else:
raise ValueError(f"Unknown observation type: {type(observation)}")

elif isinstance(agent_action, WebBrowserAgentAction):
if isinstance(observation, WebBrowserToolOutput):
messages.append(AIMessage(content=str(observation)))
new_messages = list(agent_action.message_log) + [
_create_tool_message(agent_action, observation)
]
messages.extend([new for new in new_messages if new not in messages])
else:
raise ValueError(f"Unknown observation type: {type(observation)}")

Expand Down
2 changes: 1 addition & 1 deletion langchain_zhipuai/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
logger = logging.getLogger(__name__)


class ZhipuAIAIEmbeddings(BaseModel, Embeddings):
class ZhipuAIEmbeddings(BaseModel, Embeddings):
"""ZhipuAI embedding models.
To use, you should have the
Expand Down
20 changes: 1 addition & 19 deletions tests/integration_tests/all_tools/test_alltools.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ async def test_all_tools_code_interpreter(logging_conf):

agent_executor = ZhipuAIAllToolsRunnable.create_agent_executor(
model_name="glm-4-alltools",
tools=[{"type": "code_interpreter"}, shell],
tools=[shell],
)
chat_iterator = agent_executor.invoke(
chat_input="看下本地文件有哪些,告诉我你用的是什么文件,查看当前目录"
Expand All @@ -56,24 +56,6 @@ async def test_all_tools_code_interpreter(logging_conf):
if item.status == AgentStatus.llm_end:
print("llm_end:" + item.text)

chat_iterator = agent_executor.invoke(chat_input="打印下test_alltools.py")
async for item in chat_iterator:
if isinstance(item, AllToolsAction):
print("AllToolsAction:" + str(item.to_json()))

elif isinstance(item, AllToolsFinish):
print("AllToolsFinish:" + str(item.to_json()))

elif isinstance(item, AllToolsActionToolStart):
print("AllToolsActionToolStart:" + str(item.to_json()))

elif isinstance(item, AllToolsActionToolEnd):
print("AllToolsActionToolEnd:" + str(item.to_json()))
elif isinstance(item, AllToolsLLMStatus):
if item.status == AgentStatus.llm_end:
print("llm_end:" + item.text)


@pytest.mark.asyncio
async def test_all_tools_code_interpreter_sandbox_none(logging_conf):
logging.config.dictConfig(logging_conf) # type: ignore
Expand Down
38 changes: 19 additions & 19 deletions tests/integration_tests/embeddings/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,25 @@
"""Test openai embeddings."""
"""Test zhipuai embeddings."""
import numpy as np
import pytest

from langchain_zhipuai.embeddings.base import ZhipuAIAIEmbeddings
from langchain_zhipuai.embeddings.base import ZhipuAIEmbeddings


@pytest.mark.scheduled
def test_openai_embedding_documents() -> None:
"""Test openai embeddings."""
def test_zhipuai_embedding_documents() -> None:
"""Test zhipuai embeddings."""
documents = ["foo bar"]
embedding = ZhipuAIAIEmbeddings()
embedding = ZhipuAIEmbeddings()
output = embedding.embed_documents(documents)
assert len(output) == 1
assert len(output[0]) == 1024


@pytest.mark.scheduled
def test_openai_embedding_documents_multiple() -> None:
"""Test openai embeddings."""
def test_zhipuai_embedding_documents_multiple() -> None:
"""Test zhipuai embeddings."""
documents = ["foo bar", "bar foo", "foo"]
embedding = ZhipuAIAIEmbeddings(chunk_size=2)
embedding = ZhipuAIEmbeddings(chunk_size=2)
embedding.embedding_ctx_length = 8191
output = embedding.embed_documents(documents)
assert len(output) == 3
Expand All @@ -29,10 +29,10 @@ def test_openai_embedding_documents_multiple() -> None:


@pytest.mark.scheduled
async def test_openai_embedding_documents_async_multiple() -> None:
"""Test openai embeddings."""
async def test_zhipuai_embedding_documents_async_multiple() -> None:
"""Test zhipuai embeddings."""
documents = ["foo bar", "bar foo", "foo"]
embedding = ZhipuAIAIEmbeddings(chunk_size=2)
embedding = ZhipuAIEmbeddings(chunk_size=2)
embedding.embedding_ctx_length = 8191
output = await embedding.aembed_documents(documents)
assert len(output) == 3
Expand All @@ -42,30 +42,30 @@ async def test_openai_embedding_documents_async_multiple() -> None:


@pytest.mark.scheduled
def test_openai_embedding_query() -> None:
"""Test openai embeddings."""
def test_zhipuai_embedding_query() -> None:
"""Test zhipuai embeddings."""
document = "foo bar"
embedding = ZhipuAIAIEmbeddings()
embedding = ZhipuAIEmbeddings()
output = embedding.embed_query(document)
assert len(output) == 1024


@pytest.mark.scheduled
async def test_openai_embedding_async_query() -> None:
"""Test openai embeddings."""
async def test_zhipuai_embedding_async_query() -> None:
"""Test zhipuai embeddings."""
document = "foo bar"
embedding = ZhipuAIAIEmbeddings()
embedding = ZhipuAIEmbeddings()
output = await embedding.aembed_query(document)
assert len(output) == 1024


@pytest.mark.scheduled
def test_embed_documents_normalized() -> None:
output = ZhipuAIAIEmbeddings().embed_documents(["foo walked to the market"])
output = ZhipuAIEmbeddings().embed_documents(["foo walked to the market"])
assert np.isclose(np.linalg.norm(output[0]), 1.0)


@pytest.mark.scheduled
def test_embed_query_normalized() -> None:
output = ZhipuAIAIEmbeddings().embed_query("foo walked to the market")
output = ZhipuAIEmbeddings().embed_query("foo walked to the market")
assert np.isclose(np.linalg.norm(output), 1.0)

0 comments on commit b88b77b

Please sign in to comment.