Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pydatalab/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,6 @@ chat = [
"langchain >= 0.2.6, < 0.3",
"langchain-openai ~= 0.1",
"langchain-anthropic ~= 0.1",
"tiktoken ~= 0.7",
"transformers ~= 4.42",
]
deploy = ["gunicorn ~= 23.0"]
all = ["datalab-server[apps,server,chat]"]
Expand Down
21 changes: 3 additions & 18 deletions pydatalab/src/pydatalab/apps/chat/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from langchain_anthropic import ChatAnthropic
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
from langchain_openai import ChatOpenAI

from pydatalab.blocks.base import DataBlock
Expand Down Expand Up @@ -165,9 +166,7 @@ def render(self):

try:
model_name = self.data["model"]

model_dict = self.data["available_models"][model_name]
LOGGER.warning(f"Initializing chatblock with model: {model_name}")
LOGGER.debug(f"Initializing chatblock with model: {model_name}")

if model_name.startswith("claude"):
self.chat_client = ChatAnthropic(
Expand All @@ -183,8 +182,6 @@ def render(self):
LOGGER.debug(
f'submitting request to API for completion with last message role "{self.data["messages"][-1]["role"]}" (message = {self.data["messages"][-1:]}). Temperature = {self.data["temperature"]} (type {type(self.data["temperature"])})'
)
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage

# Convert your messages to the required format
langchain_messages = []
for message in self.data["messages"]:
Expand All @@ -195,24 +192,12 @@ def render(self):
else:
langchain_messages.append(AIMessage(content=message["content"]))

token_count = self.chat_client.get_num_tokens_from_messages(langchain_messages)

self.data["token_count"] = token_count

if token_count >= model_dict["context_window"]:
self.data["error_message"] = (
f"""This conversation has reached its maximum context size and the chatbot won't be able to respond further ({token_count} tokens, max: {model_dict["context_window"]}). Please make a new chat block to start fresh, or use a model with a larger context window"""
)
return

# Call the chat client with the invoke method
response = self.chat_client.invoke(langchain_messages)
self.data["token_count"] = response.usage_metadata["total_tokens"]

langchain_messages.append(response)

token_count = self.chat_client.get_num_tokens_from_messages(langchain_messages)

self.data["token_count"] = token_count
self.data["messages"].append({"role": "assistant", "content": response.content})
self.data["error_message"] = None

Expand Down
Loading