Skip to content

Commit

Permalink
History clean up (All-Hands-AI#2849)
Browse files Browse the repository at this point in the history
* clean up add_history

* refactor last agent message
  • Loading branch information
enyst authored Jul 8, 2024
1 parent c6aa507 commit 2df1d67
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 49 deletions.
12 changes: 2 additions & 10 deletions evaluation/EDA/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from opendevin.core.logger import get_console_handler
from opendevin.core.logger import opendevin_logger as logger
from opendevin.core.main import run_agent_controller
from opendevin.events.action import MessageAction
from opendevin.llm.llm import LLM

game = None
Expand All @@ -44,10 +43,7 @@ def codeact_user_response_eda(state: State) -> str:

# retrieve the latest model message from history
if state.history:
for event in state.history.get_events(reverse=True):
if isinstance(event, MessageAction) and event.source == 'agent':
model_guess = event.content
break
model_guess = state.history.get_last_agent_message()

assert game is not None, 'Game is not initialized.'
msg = game.generate_user_response(model_guess)
Expand Down Expand Up @@ -150,11 +146,7 @@ def process_instance(
if state is None:
raise ValueError('State should not be None.')

final_message = ''
for event in state.history.get_events(reverse=True):
if isinstance(event, MessageAction) and event.source == 'agent':
final_message = event.content
break
final_message = state.history.get_last_agent_message()

logger.info(f'Final message: {final_message} | Ground truth: {instance["text"]}')
test_result = game.reward()
Expand Down
6 changes: 1 addition & 5 deletions evaluation/gorilla/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,8 @@ def process_instance(agent, question_id, question, metadata, reset_logger: bool
if state is None:
raise ValueError('State should not be None.')

model_answer_raw = ''

# retrieve the last message from the agent
for event in state.history.get_events(reverse=True):
if isinstance(event, MessageAction) and event.source == 'agent':
model_answer_raw = event
model_answer_raw = state.history.get_last_agent_message()

# attempt to parse model_answer
_, _, ast_eval = get_data(metadata['hub'])
Expand Down
12 changes: 2 additions & 10 deletions evaluation/gpqa/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from opendevin.core.logger import get_console_handler
from opendevin.core.logger import opendevin_logger as logger
from opendevin.core.main import run_agent_controller
from opendevin.events.action import MessageAction
from opendevin.llm.llm import LLM

AGENT_CLS_TO_FAKE_USER_RESPONSE_FN = {
Expand Down Expand Up @@ -209,15 +208,8 @@ def process_instance(
assert state is not None, 'State should not be None.'

# ======= Attempt to evaluate the agent's edits =======
# get the final message from the state history (default to None if not found)
final_message = next(
(
act.content
for act in state.history.get_events(reverse=True)
if isinstance(act, MessageAction)
),
None,
)
# get the final message from the state history (default to empty if not found)
final_message = state.history.get_last_agent_message()

logger.info(f'Final message generated by the agent: {final_message}')

Expand Down
8 changes: 1 addition & 7 deletions evaluation/toolqa/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from opendevin.core.logger import get_console_handler
from opendevin.core.logger import opendevin_logger as logger
from opendevin.core.main import run_agent_controller
from opendevin.events.action import MessageAction
from opendevin.llm.llm import LLM

from .utils import download_data, download_tools, encode_question, eval_answer, get_data
Expand Down Expand Up @@ -95,13 +94,8 @@ def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool =
if state is None:
raise ValueError('State should not be None.')

model_answer_raw = ''

# retrieve the last message from the agent
for event in state.history.get_events(reverse=True):
if isinstance(event, MessageAction) and event.source == 'agent':
model_answer_raw = event.content
break
model_answer_raw = state.history.get_last_agent_message()

# attempt to parse model_answer
correct = eval_answer(str(model_answer_raw), str(answer))
Expand Down
16 changes: 0 additions & 16 deletions opendevin/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
AgentStateChangedObservation,
CmdOutputObservation,
ErrorObservation,
NullObservation,
Observation,
)

Expand Down Expand Up @@ -128,13 +127,6 @@ async def report_error(self, message: str, exception: Exception | None = None):
self.state.last_error += f': {exception}'
self.event_stream.add_event(ErrorObservation(message), EventSource.AGENT)

async def add_history(self, action: Action, observation: Observation):
if isinstance(action, NullAction) and isinstance(observation, NullObservation):
return
logger.debug(
f'Adding history ({type(action).__name__} with id={action.id}, {type(observation).__name__} with id={observation.id})'
)

async def _start_step_loop(self):
logger.info(f'[Agent Controller {self.id}] Starting step loop...')
while True:
Expand All @@ -160,7 +152,6 @@ async def on_event(self, event: Event):
await self.set_agent_state_to(event.agent_state) # type: ignore
elif isinstance(event, MessageAction):
if event.source == EventSource.USER:
await self.add_history(event, NullObservation(''))
if self.get_agent_state() != AgentState.RUNNING:
await self.set_agent_state_to(AgentState.RUNNING)
elif event.source == EventSource.AGENT and event.wait_for_response:
Expand All @@ -179,18 +170,14 @@ async def on_event(self, event: Event):
await self.set_agent_state_to(AgentState.REJECTED)
elif isinstance(event, Observation):
if self._pending_action and self._pending_action.id == event.cause:
await self.add_history(self._pending_action, event)
self._pending_action = None
logger.info(event, extra={'msg_type': 'OBSERVATION'})
elif isinstance(event, CmdOutputObservation):
await self.add_history(NullAction(), event)
logger.info(event, extra={'msg_type': 'OBSERVATION'})
elif isinstance(event, AgentDelegateObservation):
await self.add_history(NullAction(), event)
self.state.history.on_event(event)
logger.info(event, extra={'msg_type': 'OBSERVATION'})
elif isinstance(event, ErrorObservation):
await self.add_history(NullAction(), event)
logger.info(event, extra={'msg_type': 'OBSERVATION'})

def reset_task(self):
Expand Down Expand Up @@ -359,9 +346,6 @@ async def _step(self):
if not isinstance(action, NullAction):
self.event_stream.add_event(action, EventSource.AGENT)

if not action.runnable:
await self.add_history(action, NullObservation(''))

await self.update_state_after_step()
logger.info(action, extra={'msg_type': 'ACTION'})

Expand Down
19 changes: 18 additions & 1 deletion opendevin/memory/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def get_last_observation(self, end_id: int = -1) -> Observation | None:

def get_last_user_message(self) -> str:
"""
Return the latest user message from the event stream.
Return the content of the last user message from the event stream.
"""

last_user_message = next(
Expand All @@ -141,6 +141,23 @@ def get_last_user_message(self) -> str:

return last_user_message if last_user_message is not None else ''

def get_last_agent_message(self) -> str:
"""
Return the content of the last agent message from the event stream.
"""

last_agent_message = next(
(
event.content
for event in self._event_stream.get_events(reverse=True)
if isinstance(event, MessageAction)
and event.source == EventSource.AGENT
),
None,
)

return last_agent_message if last_agent_message is not None else ''

def get_last_events(self, n: int) -> list[Event]:
"""
Return the last n events from the event stream.
Expand Down

0 comments on commit 2df1d67

Please sign in to comment.