Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic loading examples #420

Merged
merged 6 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 5 additions & 36 deletions taskweaver/code_interpreter/code_interpreter/code_generator.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import datetime
import json
import os
from typing import List, Optional, Tuple
from typing import List, Optional

from injector import inject

from taskweaver.code_interpreter.plugin_selection import PluginSelector, SelectedPluginPool
from taskweaver.llm import LLMApi
from taskweaver.llm.util import ChatMessageType, format_chat_message
from taskweaver.logging import TelemetryLogger
from taskweaver.memory import Attachment, Conversation, Memory, Post, Round, RoundCompressor
from taskweaver.memory import Attachment, Memory, Post, Round, RoundCompressor
from taskweaver.memory.attachment import AttachmentType
from taskweaver.memory.experience import Experience, ExperienceGenerator
from taskweaver.memory.experience import ExperienceGenerator
from taskweaver.memory.plugin import PluginEntry, PluginRegistry
from taskweaver.misc.example import load_examples
from taskweaver.module.event_emitter import PostEventProxy, SessionEventEmitter
from taskweaver.module.tracing import Tracing, tracing_decorator
from taskweaver.role import PostTranslator, Role
Expand All @@ -26,21 +25,13 @@ def _configure(self) -> None:
self._set_name("code_generator")
self.role_name = self._get_str("role_name", "ProgramApe")
self.load_plugin = self._get_bool("load_plugin", True)
self.load_example = self._get_bool("load_example", True)
self.prompt_file_path = self._get_path(
"prompt_file_path",
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"code_generator_prompt.yaml",
),
)
self.example_base_path = self._get_path(
"example_base_path",
os.path.join(
self.src.app_base_path,
"codeinterpreter_examples",
),
)
self.prompt_compression = self._get_bool("prompt_compression", False)
self.compression_prompt_path = self._get_path(
"compression_prompt_path",
Expand Down Expand Up @@ -89,7 +80,6 @@ def __init__(
self.query_requirements_template = self.prompt_data["requirements"]
self.response_json_schema = json.loads(self.prompt_data["response_json_schema"])

self.examples = None
self.code_verification_on: bool = False
self.allowed_modules: List[str] = []

Expand Down Expand Up @@ -157,12 +147,10 @@ def compose_prompt(
self,
rounds: List[Round],
plugins: List[PluginEntry],
selected_experiences: Optional[List[Tuple[Experience, float]]] = None,
planning_enrichments: Optional[List[str]] = None,
) -> List[ChatMessageType]:
experiences = self.format_experience(
template=self.prompt_data["experience_instruction"],
experiences=selected_experiences,
)

chat_history = [
Expand All @@ -172,8 +160,6 @@ def compose_prompt(
),
]

if self.examples is None:
self.examples = self.load_examples()
for i, example in enumerate(self.examples):
chat_history.extend(
self.compose_conversation(example.rounds, example.plugins, add_requirements=False),
Expand Down Expand Up @@ -358,21 +344,14 @@ def reply(
if self.config.enable_auto_plugin_selection:
self.plugin_pool = self.select_plugins_for_prompt(query)

exp_sub_paths = memory.get_shared_memory_entries(entry_type="experience_sub_path")

if exp_sub_paths:
self.tracing.set_span_attribute("experience_sub_path", str(exp_sub_paths))
exp_sub_path = exp_sub_paths[0].content
else:
exp_sub_path = ""
selected_experiences = self.load_experience(query=query, sub_path=exp_sub_path)
self.role_load_experience(query=query, memory=memory)
self.role_load_example(memory=memory, role_set={self.alias, "Planner"})

planning_enrichments = memory.get_shared_memory_entries(entry_type="plan")

prompt = self.compose_prompt(
rounds,
self.plugin_pool,
selected_experiences,
planning_enrichments=[pe.content for pe in planning_enrichments],
)

Expand Down Expand Up @@ -440,16 +419,6 @@ def format_plugins(
)
return ""

def load_examples(
self,
) -> List[Conversation]:
if self.config.load_example:
return load_examples(
folder=self.config.example_base_path,
role_set={self.alias, "Planner"},
)
return []

def get_plugin_pool(self) -> List[PluginEntry]:
return self.plugin_pool

Expand Down
20 changes: 8 additions & 12 deletions taskweaver/memory/experience.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def load_experience(self):
exp_ids = [os.path.splitext(os.path.basename(exp_file))[0].split("_")[2] for exp_file in original_exp_files]
if len(exp_ids) == 0:
self.logger.warning(
"No experience found."
"No experience found.",
)
return

Expand All @@ -253,18 +253,14 @@ def load_experience(self):

exp_file = f"exp_{exp_id}.yaml"
exp_file_path = os.path.join(exp_dir, exp_file)
assert os.path.exists(exp_file_path), (
f"Experience {exp_file} not found. "
)
assert os.path.exists(exp_file_path), f"Experience {exp_file} not found. "

experience = read_yaml(exp_file_path)

assert len(experience["embedding"]) > 0, (
f"Experience {exp_file} has no embedding."
)
assert experience["embedding_model"] == self.llm_api.embedding_service.config.embedding_model, (
f"Experience {exp_file} has different embedding model."
)
assert len(experience["embedding"]) > 0, f"Experience {exp_file} has no embedding."
assert (
experience["embedding_model"] == self.llm_api.embedding_service.config.embedding_model
), f"Experience {exp_file} has different embedding model."

self.experience_list.append(Experience(**experience))

Expand Down Expand Up @@ -326,13 +322,13 @@ def delete_handcrafted_experience(self, exp_id: str):
@staticmethod
def format_experience_in_prompt(
prompt_template: str,
selected_experiences: Optional[List[Tuple[Experience, float]]] = None,
selected_experiences: Optional[List[Experience,]] = None,
):
if selected_experiences is not None and len(selected_experiences) > 0:
return prompt_template.format(
experiences="===================\n"
+ "\n===================\n".join(
[exp.experience_text for exp, _ in selected_experiences],
[exp.experience_text for exp in selected_experiences],
),
)
else:
Expand Down
2 changes: 1 addition & 1 deletion taskweaver/memory/type_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@

RoleName = str
RoundState = Literal["finished", "failed", "created"]
SharedMemoryEntryType = Literal["plan", "experience_sub_path"]
SharedMemoryEntryType = Literal["plan", "experience_sub_path", "example_sub_path"]
SharedMemoryEntryScope = Literal["round", "conversation"]
11 changes: 11 additions & 0 deletions taskweaver/misc/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,26 @@

def load_examples(
folder: str,
sub_path: Optional[str] = None,
role_set: Optional[Set[str]] = None,
) -> List[Conversation]:
"""
Load all the examples from a folder.

Args:
folder: the folder path.
sub_path: the sub-folder path.
role_set: the roles should be included in the examples.
"""
if sub_path:
folder = path.join(folder, sub_path)
if not path.exists(folder):
raise FileNotFoundError(
f"Folder {folder} does not exist. The default example folder of CodeInterpreter "
f"has been changed from `codeinterpreter_examples` to `code_generator_examples`. "
f"If this is the cause, please either rename the folder or change the config.",
)

example_file_list: List[str] = glob.glob(path.join(folder, "*.yaml"))
example_conv_pool: List[Conversation] = []
for yaml_path in example_file_list:
Expand Down
51 changes: 11 additions & 40 deletions taskweaver/planner/planner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@
import os
import types
from json import JSONDecodeError
from typing import Dict, Iterable, List, Optional, Tuple
from typing import Dict, Iterable, List, Optional

from injector import inject

from taskweaver.llm import LLMApi
from taskweaver.llm.util import ChatMessageType, format_chat_message
from taskweaver.logging import TelemetryLogger
from taskweaver.memory import Conversation, Memory, Post, Round, RoundCompressor
from taskweaver.memory import Memory, Post, Round, RoundCompressor
from taskweaver.memory.attachment import AttachmentType
from taskweaver.memory.experience import Experience, ExperienceGenerator
from taskweaver.memory.experience import ExperienceGenerator
from taskweaver.memory.memory import SharedMemoryEntry
from taskweaver.misc.example import load_examples
from taskweaver.module.event_emitter import SessionEventEmitter
from taskweaver.module.tracing import Tracing, tracing_decorator
from taskweaver.role import PostTranslator, Role
Expand All @@ -25,22 +24,13 @@
class PlannerConfig(RoleConfig):
def _configure(self) -> None:
self._set_name("planner")
app_dir = self.src.app_base_path
self.use_example = self._get_bool("use_example", True)
self.prompt_file_path = self._get_path(
"prompt_file_path",
os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"planner_prompt.yaml",
),
)
self.example_base_path = self._get_path(
"example_base_path",
os.path.join(
app_dir,
"planner_examples",
),
)
self.prompt_compression = self._get_bool("prompt_compression", False)
self.compression_prompt_path = self._get_path(
"compression_prompt_path",
Expand Down Expand Up @@ -82,9 +72,6 @@ def __init__(

self.prompt_data = read_yaml(self.config.prompt_file_path)

if self.config.use_example:
self.examples = self.get_examples()

self.instruction_template = self.prompt_data["instruction_template"]

self.response_json_schema = json.loads(self.prompt_data["response_json_schema"])
Expand Down Expand Up @@ -211,11 +198,9 @@ def get_env_context(self) -> str:
def compose_prompt(
self,
rounds: List[Round],
selected_experiences: Optional[List[Tuple[Experience, float]]] = None,
) -> List[ChatMessageType]:
experiences = self.format_experience(
template=self.prompt_data["experience_instruction"],
experiences=selected_experiences,
)

chat_history = [
Expand All @@ -225,12 +210,11 @@ def compose_prompt(
),
]

if self.config.use_example and len(self.examples) != 0:
for conv_example in self.examples:
conv_example_in_prompt = self.compose_conversation_for_prompt(
conv_example.rounds,
)
chat_history += conv_example_in_prompt
for conv_example in self.examples:
conv_example_in_prompt = self.compose_conversation_for_prompt(
conv_example.rounds,
)
chat_history += conv_example_in_prompt

summary = None
if self.config.prompt_compression and self.round_compressor is not None:
Expand Down Expand Up @@ -266,19 +250,13 @@ def reply(
self.tracing.set_span_attribute("user_query", user_query)
self.tracing.set_span_attribute("use_experience", self.config.use_experience)

exp_sub_paths = memory.get_shared_memory_entries(entry_type="experience_sub_path")

if exp_sub_paths:
self.tracing.set_span_attribute("experience_sub_path", str(exp_sub_paths))
exp_sub_path = exp_sub_paths[0].content
else:
exp_sub_path = ""
selected_experiences = self.load_experience(query=user_query, sub_path=exp_sub_path)
self.role_load_experience(query=user_query, memory=memory)
self.role_load_example(role_set=set(self.recipient_alias_set) | {self.alias, "User"}, memory=memory)

post_proxy = self.event_emitter.create_post_proxy(self.alias)

post_proxy.update_status("composing prompt")
chat_history = self.compose_prompt(rounds, selected_experiences)
chat_history = self.compose_prompt(rounds)

def check_post_validity(post: Post):
missing_elements: List[str] = []
Expand Down Expand Up @@ -392,10 +370,3 @@ def stream_filter(s: Iterable[ChatMessageType]):
self.tracing.set_span_attribute("out.attachments", str(reply_post.attachment_list))

return reply_post

def get_examples(self) -> List[Conversation]:
example_conv_list = load_examples(
self.config.example_base_path,
role_set=set(self.recipient_alias_set) | {self.alias, "User"},
)
return example_conv_list
Loading