Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make structured output #12

Merged
merged 1 commit into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 25 additions & 35 deletions backend/agents/ext_agents.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,33 @@
from typing import List, Optional

from autogen_core.base import MessageContext
from autogen_core.components import (DefaultTopicId, RoutedAgent,
default_subscription, message_handler,
type_subscription)
from autogen_core.components import (
DefaultTopicId,
RoutedAgent,
default_subscription,
message_handler,
type_subscription,
)
from llama_index.core.agent.runner.base import AgentRunner
from llama_index.core.base.llms.types import ChatMessage, MessageRole
from llama_index.core.chat_engine.types import AgentChatResponse
from llama_index.core.memory.types import BaseMemory
from pydantic import BaseModel

from ..data_types import AgentResponse, EndUserMessage
from ..data_types import AgentStructuredResponse, EndUserMessage, Resource
from ..otlp_tracing import logger


class Resource(BaseModel):
"""
Represents a resource node retrieved during chat interactions.

Attributes:
content (str): The textual content of the resource.
node_id (str): The identifier of the node.
score (Optional[float]): Score representing the relevance of the resource.
"""

content: str
node_id: str
score: Optional[float] = None
# class Message(BaseModel):
# """
# Represents a message exchanged during the chat.

# Attributes:
# content (str): The textual content of the message.
# sources (Optional[List[Resource]]): List of resources associated with the message.
# """

class Message(BaseModel):
"""
Represents a message exchanged during the chat.

Attributes:
content (str): The textual content of the message.
sources (Optional[List[Resource]]): List of resources associated with the message.
"""

content: str
sources: Optional[List[Resource]] = None
# content: str
# sources: Optional[List[Resource]] = None


@default_subscription
Expand Down Expand Up @@ -116,17 +104,19 @@ async def handle_user_message(
resources.extend(tools)
logger.info(response.response)
await self.publish_message(
AgentResponse(
source="LlamaIndexAgent",
content=f"\n{response.response}\n",
AgentStructuredResponse(
agent_type="default_agent",
data=None,
message=f"\n{response.response}\n",
),
DefaultTopicId(type="user_proxy", source=self._session_id),
)
else:
await self.publish_message(
AgentResponse(
source="LlamaIndexAgent",
content="I'm sorry, I don't have an answer for you.",
AgentStructuredResponse(
agent_type="default_agent",
data=None,
message="I'm sorry, I don't have an answer for you.",
),
DefaultTopicId(type="user_proxy", source=self._session_id),
)
8 changes: 5 additions & 3 deletions backend/agents/travel_activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
GroupChatMessage,
HandoffMessage,
TravelRequest,
AgentStructuredResponse,
)
from ..otlp_tracing import logger

Expand Down Expand Up @@ -171,9 +172,10 @@ async def handle_message(

# Publish the response to the group chat manager
await self.publish_message(
AgentResponse(
source=self.id.type,
content=activities_structured.model_dump_json(),
AgentStructuredResponse(
agent_type=self.id.type,
data=activities_structured,
message=f"Activities processed successfully for query - {message.content}",
),
DefaultTopicId(type="user_proxy", source=ctx.topic_id.source),
)
Expand Down
131 changes: 55 additions & 76 deletions backend/agents/travel_car.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import datetime
import random
from typing import Dict, List

from typing import List
from autogen_core.components.tools import FunctionTool, Tool
from autogen_core.base import MessageContext
from autogen_core.components import (
DefaultTopicId,
Expand All @@ -12,13 +12,13 @@
)
from autogen_core.components.models import LLMMessage, SystemMessage
from typing_extensions import Annotated

from ..data_types import (
AgentResponse,
AgentStructuredResponse,
EndUserMessage,
GroupChatMessage,
HandoffMessage,
TravelRequest,
CarRental,
)
from ..otlp_tracing import logger

Expand All @@ -31,33 +31,7 @@ async def simulate_car_rental_booking(
rental_end_date: Annotated[
str, "The end date of the car rental in the format 'YYYY-MM-DD'."
],
) -> Dict[str, str | int]:
"""
Simulate a car rental booking process.

This function simulates the process of booking a car rental by randomly selecting a car option,
calculating the rental duration and total price, and generating a booking reference number.

Args:
rental_city (str): The city where the car rental will take place.
rental_start_date (str): The start date of the car rental in the format 'YYYY-MM-DD'.
rental_end_date (str): The end date of the car rental in the format 'YYYY-MM-DD'.

Returns:
Dict[str, str | int]: A dictionary containing the car rental details, including the rental city,
start and end dates, car type, company, total price, and booking reference.
Example:
{
"rental_city": "New York",
"rental_start_date": "2023-10-01",
"rental_end_date": "2023-10-07",
"car_type": "SUV",
"company": "Hertz",
"total_price": 560,
"booking_reference": "CR-1234-NYC"
}
"""
# Simulate available car options
) -> CarRental:
car_options = [
{"car_type": "Sedan", "company": "Avis", "price_per_day": 50},
{"car_type": "SUV", "company": "Hertz", "price_per_day": 80},
Expand All @@ -76,36 +50,37 @@ async def simulate_car_rental_booking(
{"car_type": "Hatchback", "company": "City Rentals", "price_per_day": 45},
]

# Randomly select a car option
selected_car = random.choice(car_options)

# Calculate rental duration
start_date = datetime.datetime.strptime(rental_start_date, "%Y-%m-%d")
end_date = datetime.datetime.strptime(rental_end_date, "%Y-%m-%d")
rental_days = (end_date - start_date).days

# Calculate total price
total_price = rental_days * selected_car["price_per_day"]

# Create a booking reference number
booking_reference = f"CR-{random.randint(1000, 9999)}-{rental_city[:3].upper()}"

# Simulate car rental details
car_rental_details = {
"rental_city": rental_city,
"rental_start_date": rental_start_date,
"rental_end_date": rental_end_date,
"car_type": selected_car["car_type"],
"company": selected_car["company"],
"total_price": total_price,
"booking_reference": booking_reference,
}
# Induce an artificial delay to simulate network latency
await asyncio.sleep(3)
car_rental_details = CarRental(
rental_city=rental_city,
rental_start_date=rental_start_date,
rental_end_date=rental_end_date,
car_type=selected_car["car_type"],
company=selected_car["company"],
total_price=total_price,
booking_reference=booking_reference,
)

await asyncio.sleep(2)
return car_rental_details


# Car Rental Agent
def get_car_rental_tool() -> List[Tool]:
return [
FunctionTool(
name="simulate_car_rental_booking",
func=simulate_car_rental_booking,
description="Simulates a car rental booking process based on user preferences.",
)
]


@type_subscription("car_rental")
class CarRentalAgent(RoutedAgent):
def __init__(self) -> None:
Expand All @@ -117,32 +92,35 @@ def __init__(self) -> None:
)
]

async def _process_request(self, requirements: dict) -> dict:
# Simulate car rental booking logic
return await simulate_car_rental_booking(
requirements.get("rental_city", "Unknown"),
requirements.get("rental_start_date", "Unknown"),
requirements.get("rental_end_date", "Unknown"),
)

@message_handler
async def handle_message(
self, message: EndUserMessage, ctx: MessageContext
) -> None:
logger.info(f"CarRentalAgent received message: {message.content}")
if "travel plan" in message.content.lower():
# Cannot handle complex travel plans, hand off back to router
await self.publish_message(
HandoffMessage(content=message.content, source=self.id.type),
DefaultTopicId(type="router", source=ctx.topic_id.source),
)
return
# Extract requirements and process the car rental request
requirements = self.extract_requirements(message.content)
response = await self._process_request(requirements)

# You would typically call a LLM to extract the requirement or have a function call here
requirements = {
"rental_city": (
"New York" if "new york" in message.content.lower() else "Unknown"
),
"rental_start_date": "2023-12-21",
"rental_end_date": "2023-12-26",
}
response = await simulate_car_rental_booking(
requirements["rental_city"],
requirements["rental_start_date"],
requirements["rental_end_date"],
)
await self.publish_message(
AgentResponse(
source=self.id.type,
AgentStructuredResponse(
agent_type=self.id.type,
data=response,
content=f"Car rented: {response}",
),
DefaultTopicId(type="user_proxy", source=ctx.topic_id.source),
Expand All @@ -155,18 +133,19 @@ async def handle_travel_request(
logger.info(
f"CarRentalAgent received travel request: TravelRequest - {message.content}"
)
requirements = self.extract_requirements(message.content)
response = await self._process_request(requirements)
requirements = {
"rental_city": (
"New York" if "new york" in message.content.lower() else "Unknown"
),
"rental_start_date": "2023-12-21",
"rental_end_date": "2023-12-26",
}
response = await simulate_car_rental_booking(
requirements["rental_city"],
requirements["rental_start_date"],
requirements["rental_end_date"],
)
return GroupChatMessage(
source=self.id.type,
content=f"Car rented: {response}",
)

def extract_requirements(self, user_input: str) -> dict:
# Simple keyword-based extraction
requirements = {}
if "new york" in user_input.lower():
requirements["rental_city"] = "New York"
requirements["rental_start_date"] = "2023-12-21"
requirements["rental_end_date"] = "2023-12-26"
return requirements
29 changes: 19 additions & 10 deletions backend/agents/travel_destination.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,22 @@
from typing import List

from autogen_core.base import MessageContext
from autogen_core.components import (DefaultTopicId, RoutedAgent,
message_handler, type_subscription)
from autogen_core.components.models import (LLMMessage, SystemMessage,
UserMessage)
from autogen_core.components import (
DefaultTopicId,
RoutedAgent,
message_handler,
type_subscription,
)
from autogen_core.components.models import LLMMessage, SystemMessage, UserMessage
from autogen_ext.models import AzureOpenAIChatCompletionClient

from ..data_types import (AgentResponse, DestinationInfo, EndUserMessage,
GroupChatMessage, TravelRequest)
from ..data_types import (
AgentStructuredResponse,
DestinationInfo,
EndUserMessage,
GroupChatMessage,
TravelRequest,
)
from ..otlp_tracing import logger


Expand All @@ -33,7 +41,7 @@ async def handle_message(
self, message: EndUserMessage, ctx: MessageContext
) -> None:
logger.info(
f"DestinationAgent received travel request: EndUserMessage{message.content}"
f"DestinationAgent received travel request: EndUserMessage {message.content}"
)
# Provide destination information
try:
Expand All @@ -55,9 +63,10 @@ async def handle_message(
pass

await self.publish_message(
AgentResponse(
source=self.id.type,
content=destination_info_structured.model_dump_json(),
AgentStructuredResponse(
agent_type=self.id.type,
data=destination_info_structured,
message=message.content,
),
DefaultTopicId(type="user_proxy", source=ctx.topic_id.source),
)
Expand Down
Loading
Loading