Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Liqun/add context #384

Merged
merged 4 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file not shown.
Binary file added auto_eval/cases/rag/knowledge_base/index.faiss
Binary file not shown.
Binary file added auto_eval/cases/rag/knowledge_base/index.pkl
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ config_var:
code_verification.code_verification_on: false
execution_service.kernel_mode: "local"
verbose: True
task_description: use ARIMA model to forecast QQQ in next 7 days
task_description: |-
use ARIMA model to forecast QQQ in next 7 days
If the agent asks for the data, you can suggest it to download using yfinance library.
scoring_points:
- score_point: "There should be 7 predicted stock prices in the output"
weight: 1
Expand Down
16 changes: 16 additions & 0 deletions auto_eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from langchain.load import dumps
from langchain.schema.messages import AIMessage, HumanMessage, SystemMessage
from langchain_community.chat_models import ChatOpenAI
from langchain_community.chat_models.azureml_endpoint import AzureMLChatOnlineEndpoint, CustomOpenAIChatContentFormatter
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_openai import AzureChatOpenAI

EVALUATOR_PROMPT_FILE_PATH = os.path.join(os.path.dirname(__file__), "evaluator_prompt.yaml")
Expand Down Expand Up @@ -56,6 +58,20 @@ def config_llm(config: Dict[str, str]) -> Union[ChatOpenAI, AzureChatOpenAI]:
temperature=0,
verbose=True,
)
elif api_type == "google_ai":
os.environ["GOOGLE_API_KEY"] = get_config(config, "llm.api_key")
model = ChatGoogleGenerativeAI(
temperature=0,
model=get_config(config, "llm.model"),
verbose=True,
convert_system_message_to_human=True,
)
elif api_type == "azure_ml":
model = AzureMLChatOnlineEndpoint(
endpoint_url=get_config(config, "llm.api_base"),
endpoint_api_key=get_config(config, "llm.api_key"),
content_formatter=CustomOpenAIChatContentFormatter(),
)
else:
raise ValueError("Invalid API type. Please check your config file.")
return model
Expand Down
15 changes: 15 additions & 0 deletions auto_eval/taskweaver_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def batch_auto_evaluate_for_taskweaver(
result_file_path: str,
eval_case_root: str,
flush_result_file: bool = False,
sleep_time: int = 0,
):
if not os.path.exists(result_file_path):
df = pd.DataFrame(columns=["case_file", "score", "normalized_score"])
Expand Down Expand Up @@ -147,6 +148,12 @@ def batch_auto_evaluate_for_taskweaver(

results.to_csv(result_file_path, index=False)

if sleep_time > 0:
print(f"Sleeping for {sleep_time} seconds...")
import time

time.sleep(sleep_time)


if __name__ == "__main__":
import argparse
Expand Down Expand Up @@ -186,6 +193,13 @@ def batch_auto_evaluate_for_taskweaver(
action="store_true",
help="Flush the result file",
)
parser.add_argument(
"-s",
"--sleep",
type=int,
default=0,
help="Sleep time between evaluations",
)

args = parser.parse_args()

Expand All @@ -197,4 +211,5 @@ def batch_auto_evaluate_for_taskweaver(
args.result,
args.path,
flush_result_file=args.fresh,
sleep_time=args.sleep,
)
6 changes: 5 additions & 1 deletion project/planner_examples/example-planner-echo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,8 @@ rounds:
content: |-
1. Ask Echo to echo the user's input, 'Hello World'
- type: current_plan_step
content: 1. Ask Echo to echo the user's input, 'Hello World'
content: 1. Ask Echo to echo the user's input, 'Hello World'
- type: review
content: |-
The user query is successfully answered

29 changes: 17 additions & 12 deletions project/planner_examples/example-planner.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ rounds:
attachment_list:
- type: init_plan
content: |-
1. load the data file
2. count the rows of the loaded data <sequentially depends on 1>
3. report the result to the user <interactively depends on 2>
1. Load the data file
2. Count the rows of the loaded data <sequentially depends on 1>
3. Check the execution result and report the result to the user <interactively depends on 2>
- type: plan
content: |-
1. instruct CodeInterpreter to load the data file and count the rows of the loaded data
2. report the result to the user
1. Instruct CodeInterpreter to load the data file and count the rows of the loaded data
2. Check the execution result and report the result to the user if it is correct
- type: current_plan_step
content: 1. instruct CodeInterpreter to load the data file and count the rows of the loaded data
content: 1. Instruct CodeInterpreter to load the data file and count the rows of the loaded data
- message: Load the data file /home/data.csv successfully and there are 100 rows in the data file
send_from: CodeInterpreter
send_to: Planner
Expand All @@ -32,12 +32,17 @@ rounds:
attachment_list:
- type: init_plan
content: |-
1. load the data file
2. count the rows of the loaded data <sequentially depends on 1>
3. report the result to the user <interactively depends on 2>
1. Load the data file
2. Count the rows of the loaded data <sequentially depends on 1>
3. Check the execution result and report the result to the user <interactively depends on 2>
- type: plan
content: |-
1. instruct CodeInterpreter to load the data file and count the rows of the loaded data
2. report the result to the user
1. Instruct CodeInterpreter to load the data file and count the rows of the loaded data
2. Check the execution result and report the result to the user if it is correct
- type: current_plan_step
content: 2. report the result to the user
content: 2. report the result to the user
- type: review
content: |-
The data file /home/data.csv is loaded and there are 100 rows in the data file
The execution result is correct
The user query is successfully answered
27 changes: 21 additions & 6 deletions taskweaver/code_interpreter/code_interpreter/code_generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import json
import os
from typing import List, Optional
Expand Down Expand Up @@ -93,11 +94,6 @@ def __init__(
self.code_verification_on: bool = False
self.allowed_modules: List[str] = []

self.instruction = self.instruction_template.format(
ROLE_NAME=self.role_name,
RESPONSE_JSON_SCHEMA=json.dumps(self.response_json_schema),
)

self.round_compressor: RoundCompressor = round_compressor
self.compression_template = read_yaml(self.config.compression_prompt_path)["content"]

Expand Down Expand Up @@ -151,6 +147,20 @@ def compose_verification_requirements(
)
return "\n".join(requirements)

def compose_sys_prompt(self, context: str):
return self.instruction_template.format(
ENVIRONMENT_CONTEXT=context,
ROLE_NAME=self.role_name,
RESPONSE_JSON_SCHEMA=json.dumps(self.response_json_schema),
)

def get_env_context(self):
# get date and time
now = datetime.datetime.now()
current_time = now.strftime("%Y-%m-%d %H:%M:%S")

return f"- Current time: {current_time}"

def compose_prompt(
self,
rounds: List[Round],
Expand All @@ -166,7 +176,12 @@ def compose_prompt(
else ""
)

chat_history = [format_chat_message(role="system", message=f"{self.instruction}\n{experiences}")]
chat_history = [
format_chat_message(
role="system",
message=f"{self.compose_sys_prompt(context=self.get_env_context())}" f"\n{experiences}",
),
]

if self.examples is None:
self.examples = self.load_examples()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
version: 0.1
version: 0.2
content: |-
## On current environment context:
{ENVIRONMENT_CONTEXT}

## On conversations:
- Each conversation starts with "==============================\n## Conversation Start"
- Each conversation has multiple rounds, each round starts with "-----------------------------"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def reply(
llm_response["code"] = f"powershell -Command {llm_response['code']}"

post_proxy.update_attachment(llm_response["description"], AttachmentType.thought)
post_proxy.update_attachment(llm_response["code"], AttachmentType.python)
post_proxy.update_attachment("python", AttachmentType.reply_type)
post_proxy.update_attachment(llm_response["code"], AttachmentType.reply_content)

self.tracing.set_span_attribute("code", llm_response["code"])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,10 @@ def __init__(
self.return_index = 0

self.logger.info(f"{self.alias} initialized successfully.")

def update_session_variables(self, session_variables: dict) -> None:
assert False, "Not implemented"

@tracing_decorator
def reply(
self,
Expand All @@ -58,7 +58,7 @@ def reply(
prompt_log_path=prompt_log_path,
)

code = post_proxy.post.get_attachment(type=AttachmentType.python)[0]
code = post_proxy.post.get_attachment(type=AttachmentType.reply_content)[0]
if len(code) == 0:
post_proxy.update_message(post_proxy.post.get_attachment(type=AttachmentType.thought)[0], is_end=True)
return post_proxy.end()
Expand Down
1 change: 1 addition & 0 deletions taskweaver/memory/attachment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class AttachmentType(Enum):
init_plan = "init_plan"
plan = "plan"
current_plan_step = "current_plan_step"
review = "review"

# CodeInterpreter - generate code
thought = "thought"
Expand Down
21 changes: 17 additions & 4 deletions taskweaver/planner/planner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import json
import os
import types
Expand Down Expand Up @@ -92,8 +93,6 @@ def __init__(
self.recipient_alias_set,
) + ["User"]

self.instruction = self.compose_sys_prompt()

self.ask_self_cnt = 0
self.max_self_ask_num = 3

Expand All @@ -111,7 +110,7 @@ def __init__(

self.logger.info("Planner initialized successfully")

def compose_sys_prompt(self):
def compose_sys_prompt(self, context: str):
worker_description = ""
for alias, role in self.workers.items():
worker_description += (
Expand All @@ -122,6 +121,7 @@ def compose_sys_prompt(self):
)

instruction = self.instruction_template.format(
environment_context=context,
response_json_schema=json.dumps(self.response_json_schema),
worker_intro=worker_description,
)
Expand Down Expand Up @@ -207,6 +207,13 @@ def compose_conversation_for_prompt(

return conversation

def get_env_context(self) -> str:
ShilinHe marked this conversation as resolved.
Show resolved Hide resolved
# get the current time
now = datetime.datetime.now()
current_time = now.strftime("%Y-%m-%d %H:%M:%S")

return f"- Current time: {current_time}"

def compose_prompt(
self,
rounds: List[Round],
Expand All @@ -220,7 +227,13 @@ def compose_prompt(
if self.config.use_experience
else ""
)
chat_history = [format_chat_message(role="system", message=f"{self.instruction}\n{experiences}")]

chat_history = [
format_chat_message(
role="system",
message=f"{self.compose_sys_prompt(context=self.get_env_context())}" f"\n{experiences}",
),
]

if self.config.use_example and len(self.examples) != 0:
for conv_example in self.examples:
Expand Down
22 changes: 17 additions & 5 deletions taskweaver/planner/planner_prompt.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
version: 0.2
version: 0.3
instruction_template: |-
You are the Planner who can coordinate Workers to finish the user task.

## About the current environment context
{environment_context}

## About conversation history
- There could be multiple Conversations in the chat history
- Each Conversation starts with the User query "Let's start a new conversation!".
Expand All @@ -27,6 +30,7 @@ instruction_template: |-
- Planner should ignore the permission or file access issues since Workers are powerful and can handle them.
- Planner needs to inform Workers on the User's request and the current step.
- Planner must reject the User's request if it contains potential security risks or illegal activities.
- Planner must check the Worker's response and provide feedback to the Worker if the response is incorrect or incomplete.
ShilinHe marked this conversation as resolved.
Show resolved Hide resolved

## Planner's planning process
You need to make a step-by-step plan to complete the User's task. The planning process includes 2 phases: "init_plan" and "plan".
Expand Down Expand Up @@ -54,29 +58,33 @@ instruction_template: |-
init_plan:
1. Read ./data.csv file
2. Count the rows of the loaded data <sequential depend on 1>
3. Check the execution result and report the result to the user <interactively depends on 2>
plan:
1. Read ./data.csv file and count the rows of the loaded data
2. Check the execution result and report the result to the user

[Example 2]
User: Read a manual file and follow the instructions in it.
init_plan:
1. Read the file content and show its content to the user
2. Follow the instructions based on the file content. <interactively depends on 1>
3. Confirm the completion of the instructions and report the result to the user <interactively depends on 2>
plan:
1. Read the file content and show its content to the user
2. follow the instructions based on the file content.
3. Confirm the completion of the instructions and report the result to the user

[Example 3]
User: detect anomaly on ./data.csv
init_plan:
1. Read the ./data.csv and show me the top 5 rows to understand the data schema
2. Confirm the columns to be detected anomalies <sequentially depends on 1>
3. Detect anomalies on the loaded data <interactively depends on 2>
4. Report the detected anomalies to the user <interactively depends on 3>
4. Check the execution result and report the detected anomalies to the user <interactively depends on 3>
plan:
1. Read the ./data.csv and show me the top 5 rows to understand the data schema and confirm the columns to be detected anomalies
2. Detect anomalies on the loaded data
3. Report the detected anomalies to the user
3. Check the execution result and report the detected anomalies to the user

[Example 4]
User: read a.csv and b.csv and join them together
Expand All @@ -85,11 +93,11 @@ instruction_template: |-
2. Load b.csv as dataframe and show me the top 5 rows to understand the data schema
3. Ask which column to join <sequentially depends on 1, 2>
4. Join the two dataframes <interactively depends on 3>
5. report the result to the user <interactively depends on 4>
5. Check the execution result and report the joined data to the user <interactively depends on 4>
plan:
1. Load a.csv and b.csv as dataframes, show me the top 5 rows to understand the data schema, and ask which column to join
2. Join the two dataframes
3. report the result to the user
3. Check the execution result and report the joined data to the user

## Useful Tips
- When the request involves loading a file, Planner should always set the first subtask to reading the file content to understand the structure or schema of the data.
Expand Down Expand Up @@ -123,6 +131,10 @@ response_json_schema: |-
"type": "string",
"description": "The current step Planner is executing."
},
"review": {
"type": "string",
"description": "The review of the current step. If the Worker's response is incorrect or incomplete, Planner should provide feedback to the Worker."
},
"send_to": {
"type": "string",
"description": "The name of character (User or name of the Worker) that Planner wants to speak to."
Expand Down