From c021a542b5f0cb6a4c8e89c39884b3662cf84d4b Mon Sep 17 00:00:00 2001 From: Andreas Happe Date: Sat, 19 Oct 2024 17:46:50 +0200 Subject: [PATCH] convert more examples to the log-helper --- src/helper/log.py | 48 +++++++++++++++++++++++++++++------------ src/plan_and_execute.py | 19 +++++++--------- src/switch-to-react.py | 26 +++++++++++----------- 3 files changed, 54 insertions(+), 39 deletions(-) diff --git a/src/helper/log.py b/src/helper/log.py index 49ce9b6..8b24994 100644 --- a/src/helper/log.py +++ b/src/helper/log.py @@ -25,10 +25,22 @@ def __init__(self): self.console = Console() # todo: create log file path - def capture_event(self, event): - self.events.append(event) - # todo: write data to logfile for long-term tracing + def process_single_message(self, message): + if isinstance(message, ToolMessage): + self.console.print(Panel(message.content, title=f"{message.name} answers")) + elif isinstance(message, AIMessage): + for call in message.tool_calls: + self.console.print(Panel(Pretty(call['args']), title=f"Outgoing Tool to {call['name']}")) + elif isinstance(message, HumanMessage): + self.console.print(Panel(message.content, title="Initial (Human?) Query")) + else: + self.console.print(Panel(Pretty(message), title="Unknown Message Type!")) + + def process_messages(self, messages): + for message in messages: + self.process_single_message(message) + def process_debug_event(self, event): if event['type'] == 'task': task = Task(event['timestamp'], event['step'], event['payload']['id'], event['payload']['name'], event['payload']['input']) self.open_tasks[task.payload_id] = task @@ -48,12 +60,10 @@ def capture_event(self, event): self.console.log(f"finshed task {task.name}") if task.name == 'tools': for (type, messages) in event['payload']['result']: - assert(type == 'messages') in_there = False for message in messages: in_there = True - if isinstance(message, ToolMessage): - self.console.print(Panel(message.content, title=f"{message.name} answers")) + self.process_single_message(message) if not in_there: self.console.log(Pretty(messages)) elif 'messages' in event['payload']['result']: @@ -62,19 +72,29 @@ def capture_event(self, event): else: in_there = False for (type, messages) in event['payload']['result']: - in_there = True - assert(type == 'messages') - for message in messages: - if isinstance(message, AIMessage): - for call in message.tool_calls: - self.console.print(Panel(Pretty(call['args']), title=f"Outgoing Tool to {call['name']}")) - else: - self.console.log(Pretty(message)) + if type == 'plan': + in_there = True + self.process_single_message(messages) + else: + for message in messages: + in_there = True + self.process_single_message(message) if not in_there: self.console.log(task.result) else: self.console.print(Pretty(event)) + def capture_event(self, event): + self.events.append(event) + # todo: write data to logfile for long-term tracing + + if 'type' in event: + self.process_debug_event(event) + elif 'messages' in event: + self.process_messages(event['messages']) + else: + self.console.print(Panel(Pretty(event), title="Unknown Event Type!")) + def print_message(self, message): if isinstance(message, AIMessage): if len(message.tool_calls) > 0 and len(message.content) == 0: diff --git a/src/plan_and_execute.py b/src/plan_and_execute.py index 6cf2321..a968585 100644 --- a/src/plan_and_execute.py +++ b/src/plan_and_execute.py @@ -1,17 +1,14 @@ -import time - from dotenv import load_dotenv from graphs.plan_and_execute import PlanExecute -from rich.console import Console from langchain_core.prompts import PromptTemplate from langchain_openai import ChatOpenAI from helper.common import get_or_fail -from helper.ui import print_event -from tools.ssh import get_ssh_connection_from_env, SshTestCredentialsTool, SshExecuteTool +from helper.log import RichLogger from graphs.initial_version import create_chat_tool_agent_graph from graphs.plan_and_execute import create_plan_and_execute_graph +from tools.ssh import get_ssh_connection_from_env, SshTestCredentialsTool, SshExecuteTool # setup configuration from environment variables load_dotenv() @@ -19,8 +16,8 @@ conn = get_ssh_connection_from_env() conn.connect() -# prepare console -console = Console() +# prepare logging +logger = RichLogger() # initialize the ChatOpenAI model and register the tool (ssh connection) llm = ChatOpenAI(model="gpt-4o", temperature=0) @@ -48,12 +45,12 @@ def execute_step(state: PlanExecute): events = graph.stream( {"messages": [("user", template)]}, - stream_mode='values' + stream_mode='debug' ) agent_response = None for event in events: - print_event(console, event) + logger.capture_event(event) agent_response = event return { @@ -76,9 +73,9 @@ def execute_step(state: PlanExecute): events = app.stream( input = {"input": template }, config = {"recursion_limit": 50}, - stream_mode = "values" + stream_mode = "debug" ) # output all occurring logs for event in events: - print_event(console, event) \ No newline at end of file + logger.capture_event(event) \ No newline at end of file diff --git a/src/switch-to-react.py b/src/switch-to-react.py index ca9cbf8..bda665d 100644 --- a/src/switch-to-react.py +++ b/src/switch-to-react.py @@ -1,12 +1,11 @@ from dotenv import load_dotenv -from rich.console import Console from langchain_core.prompts import PromptTemplate from langchain_openai import ChatOpenAI from langgraph.prebuilt import create_react_agent from helper.common import get_or_fail -from helper.ui import print_event_stream +from helper.log import RichLogger from tools.ssh import SshExecuteTool, SshTestCredentialsTool,get_ssh_connection_from_env # setup configuration from environment variables @@ -31,18 +30,17 @@ Do not repeat already tried escalation attacks. You should focus upon enumeration and privilege escalation. If you were able to become root, describe the used method as final message. """).format(username=conn.username, password=conn.password) -if __name__ == '__main__': - console = Console() +logger = RichLogger() - events = agent_executor.stream( - { - "messages": [ - ("user", template), - ] - }, - stream_mode="values", - ) +events = agent_executor.stream( + { + "messages": [ + ("user", template), + ] + }, + stream_mode="debug", +) - # output all the events that we're getting from the agent - print_event_stream(console, events) \ No newline at end of file +for event in events: + logger.capture_event(event) \ No newline at end of file