Skip to content

Commit

Permalink
Merge pull request #12 from lordlinus:structured_output
Browse files Browse the repository at this point in the history
make structured output
  • Loading branch information
lordlinus authored Nov 10, 2024
2 parents 9912c69 + 09828bc commit 1bc335a
Show file tree
Hide file tree
Showing 14 changed files with 865 additions and 344 deletions.
60 changes: 25 additions & 35 deletions backend/agents/ext_agents.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,33 @@
from typing import List, Optional

from autogen_core.base import MessageContext
from autogen_core.components import (DefaultTopicId, RoutedAgent,
default_subscription, message_handler,
type_subscription)
from autogen_core.components import (
DefaultTopicId,
RoutedAgent,
default_subscription,
message_handler,
type_subscription,
)
from llama_index.core.agent.runner.base import AgentRunner
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.chat_engine.types import AgentChatResponse
from llama_index.core.memory.types import BaseMemory
from pydantic import BaseModel

from ..data_types import AgentResponse, EndUserMessage
from ..data_types import AgentStructuredResponse, EndUserMessage, Resource
from ..otlp_tracing import logger


class Resource(BaseModel):
"""
Represents a resource node retrieved during chat interactions.
Attributes:
content (str): The textual content of the resource.
node_id (str): The identifier of the node.
score (Optional[float]): Score representing the relevance of the resource.
"""

content: str
node_id: str
score: Optional[float] = None
# class Message(BaseModel):
# """
# Represents a message exchanged during the chat.

# Attributes:
# content (str): The textual content of the message.
# sources (Optional[List[Resource]]): List of resources associated with the message.
# """

class Message(BaseModel):
"""
Represents a message exchanged during the chat.
Attributes:
content (str): The textual content of the message.
sources (Optional[List[Resource]]): List of resources associated with the message.
"""

content: str
sources: Optional[List[Resource]] = None
# content: str
# sources: Optional[List[Resource]] = None


@default_subscription
Expand Down Expand Up @@ -116,17 +104,19 @@ async def handle_user_message(
resources.extend(tools)
logger.info(response.response)
await self.publish_message(
AgentResponse(
source="LlamaIndexAgent",
content=f"\n{response.response}\n",
AgentStructuredResponse(
agent_type="default_agent",
data=None,
message=f"\n{response.response}\n",
),
DefaultTopicId(type="user_proxy", source=self._session_id),
)
else:
await self.publish_message(
AgentResponse(
source="LlamaIndexAgent",
content="I'm sorry, I don't have an answer for you.",
AgentStructuredResponse(
agent_type="default_agent",
data=None,
message="I'm sorry, I don't have an answer for you.",
),
DefaultTopicId(type="user_proxy", source=self._session_id),
)
8 changes: 5 additions & 3 deletions backend/agents/travel_activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
GroupChatMessage,
HandoffMessage,
TravelRequest,
AgentStructuredResponse,
)
from ..otlp_tracing import logger

Expand Down Expand Up @@ -171,9 +172,10 @@ async def handle_message(

# Publish the response to the group chat manager
await self.publish_message(
AgentResponse(
source=self.id.type,
content=activities_structured.model_dump_json(),
AgentStructuredResponse(
agent_type=self.id.type,
data=activities_structured,
message=f"Activities processed successfully for query - {message.content}",
),
DefaultTopicId(type="user_proxy", source=ctx.topic_id.source),
)
Expand Down
131 changes: 55 additions & 76 deletions backend/agents/travel_car.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import datetime
import random
from typing import Dict, List

from typing import List
from autogen_core.components.tools import FunctionTool, Tool
from autogen_core.base import MessageContext
from autogen_core.components import (
DefaultTopicId,
Expand All @@ -12,13 +12,13 @@
)
from autogen_core.components.models import LLMMessage, SystemMessage
from typing_extensions import Annotated

from ..data_types import (
AgentResponse,
AgentStructuredResponse,
EndUserMessage,
GroupChatMessage,
HandoffMessage,
TravelRequest,
CarRental,
)
from ..otlp_tracing import logger

Expand All @@ -31,33 +31,7 @@ async def simulate_car_rental_booking(
rental_end_date: Annotated[
str, "The end date of the car rental in the format 'YYYY-MM-DD'."
],
) -> Dict[str, str | int]:
"""
Simulate a car rental booking process.
This function simulates the process of booking a car rental by randomly selecting a car option,
calculating the rental duration and total price, and generating a booking reference number.
Args:
rental_city (str): The city where the car rental will take place.
rental_start_date (str): The start date of the car rental in the format 'YYYY-MM-DD'.
rental_end_date (str): The end date of the car rental in the format 'YYYY-MM-DD'.
Returns:
Dict[str, str | int]: A dictionary containing the car rental details, including the rental city,
start and end dates, car type, company, total price, and booking reference.
Example:
{
"rental_city": "New York",
"rental_start_date": "2023-10-01",
"rental_end_date": "2023-10-07",
"car_type": "SUV",
"company": "Hertz",
"total_price": 560,
"booking_reference": "CR-1234-NYC"
}
"""
# Simulate available car options
) -> CarRental:
car_options = [
{"car_type": "Sedan", "company": "Avis", "price_per_day": 50},
{"car_type": "SUV", "company": "Hertz", "price_per_day": 80},
Expand All @@ -76,36 +50,37 @@ async def simulate_car_rental_booking(
{"car_type": "Hatchback", "company": "City Rentals", "price_per_day": 45},
]

# Randomly select a car option
selected_car = random.choice(car_options)

# Calculate rental duration
start_date = datetime.datetime.strptime(rental_start_date, "%Y-%m-%d")
end_date = datetime.datetime.strptime(rental_end_date, "%Y-%m-%d")
rental_days = (end_date - start_date).days

# Calculate total price
total_price = rental_days * selected_car["price_per_day"]

# Create a booking reference number
booking_reference = f"CR-{random.randint(1000, 9999)}-{rental_city[:3].upper()}"

# Simulate car rental details
car_rental_details = {
"rental_city": rental_city,
"rental_start_date": rental_start_date,
"rental_end_date": rental_end_date,
"car_type": selected_car["car_type"],
"company": selected_car["company"],
"total_price": total_price,
"booking_reference": booking_reference,
}
# Induce an artificial delay to simulate network latency
await asyncio.sleep(3)
car_rental_details = CarRental(
rental_city=rental_city,
rental_start_date=rental_start_date,
rental_end_date=rental_end_date,
car_type=selected_car["car_type"],
company=selected_car["company"],
total_price=total_price,
booking_reference=booking_reference,
)

await asyncio.sleep(2)
return car_rental_details


# Car Rental Agent
def get_car_rental_tool() -> List[Tool]:
return [
FunctionTool(
name="simulate_car_rental_booking",
func=simulate_car_rental_booking,
description="Simulates a car rental booking process based on user preferences.",
)
]


@type_subscription("car_rental")
class CarRentalAgent(RoutedAgent):
def __init__(self) -> None:
Expand All @@ -117,32 +92,35 @@ def __init__(self) -> None:
)
]

async def _process_request(self, requirements: dict) -> dict:
# Simulate car rental booking logic
return await simulate_car_rental_booking(
requirements.get("rental_city", "Unknown"),
requirements.get("rental_start_date", "Unknown"),
requirements.get("rental_end_date", "Unknown"),
)

@message_handler
async def handle_message(
self, message: EndUserMessage, ctx: MessageContext
) -> None:
logger.info(f"CarRentalAgent received message: {message.content}")
if "travel plan" in message.content.lower():
# Cannot handle complex travel plans, hand off back to router
await self.publish_message(
HandoffMessage(content=message.content, source=self.id.type),
DefaultTopicId(type="router", source=ctx.topic_id.source),
)
return
# Extract requirements and process the car rental request
requirements = self.extract_requirements(message.content)
response = await self._process_request(requirements)

# You would typically call a LLM to extract the requirement or have a function call here
requirements = {
"rental_city": (
"New York" if "new york" in message.content.lower() else "Unknown"
),
"rental_start_date": "2023-12-21",
"rental_end_date": "2023-12-26",
}
response = await simulate_car_rental_booking(
requirements["rental_city"],
requirements["rental_start_date"],
requirements["rental_end_date"],
)
await self.publish_message(
AgentResponse(
source=self.id.type,
AgentStructuredResponse(
agent_type=self.id.type,
data=response,
content=f"Car rented: {response}",
),
DefaultTopicId(type="user_proxy", source=ctx.topic_id.source),
Expand All @@ -155,18 +133,19 @@ async def handle_travel_request(
logger.info(
f"CarRentalAgent received travel request: TravelRequest - {message.content}"
)
requirements = self.extract_requirements(message.content)
response = await self._process_request(requirements)
requirements = {
"rental_city": (
"New York" if "new york" in message.content.lower() else "Unknown"
),
"rental_start_date": "2023-12-21",
"rental_end_date": "2023-12-26",
}
response = await simulate_car_rental_booking(
requirements["rental_city"],
requirements["rental_start_date"],
requirements["rental_end_date"],
)
return GroupChatMessage(
source=self.id.type,
content=f"Car rented: {response}",
)

def extract_requirements(self, user_input: str) -> dict:
# Simple keyword-based extraction
requirements = {}
if "new york" in user_input.lower():
requirements["rental_city"] = "New York"
requirements["rental_start_date"] = "2023-12-21"
requirements["rental_end_date"] = "2023-12-26"
return requirements
29 changes: 19 additions & 10 deletions backend/agents/travel_destination.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,22 @@
from typing import List

from autogen_core.base import MessageContext
from autogen_core.components import (DefaultTopicId, RoutedAgent,
message_handler, type_subscription)
from autogen_core.components.models import (LLMMessage, SystemMessage,
UserMessage)
from autogen_core.components import (
DefaultTopicId,
RoutedAgent,
message_handler,
type_subscription,
)
from autogen_core.components.models import LLMMessage, SystemMessage, UserMessage
from autogen_ext.models import AzureOpenAIChatCompletionClient

from ..data_types import (AgentResponse, DestinationInfo, EndUserMessage,
GroupChatMessage, TravelRequest)
from ..data_types import (
AgentStructuredResponse,
DestinationInfo,
EndUserMessage,
GroupChatMessage,
TravelRequest,
)
from ..otlp_tracing import logger


Expand All @@ -33,7 +41,7 @@ async def handle_message(
self, message: EndUserMessage, ctx: MessageContext
) -> None:
logger.info(
f"DestinationAgent received travel request: EndUserMessage{message.content}"
f"DestinationAgent received travel request: EndUserMessage {message.content}"
)
# Provide destination information
try:
Expand All @@ -55,9 +63,10 @@ async def handle_message(
pass

await self.publish_message(
AgentResponse(
source=self.id.type,
content=destination_info_structured.model_dump_json(),
AgentStructuredResponse(
agent_type=self.id.type,
data=destination_info_structured,
message=message.content,
),
DefaultTopicId(type="user_proxy", source=ctx.topic_id.source),
)
Expand Down
Loading

0 comments on commit 1bc335a

Please sign in to comment.