diff --git a/README.md b/README.md index 4444457..748278d 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ Built on top of LangChain's [`SQLDatabase`](https://docs.langchain.com/oss/pytho - **Data Visualization**: Generate charts and graphs from query results using natural language (e.g., "show me a bar chart") - **Configurable Agents**: YAML-based configuration for adding new data sources - **A2A Protocol**: Agent-to-Agent interoperability for integration with other A2A-compliant systems +- **MCP Protocol**: Model Context Protocol support for Claude Desktop, VS Code, and other MCP clients ## Architecture @@ -51,7 +52,9 @@ Generates, validates, and executes SQL queries with retry logic. - [Database Setup](docs/DATABASE_SETUP.md) - [Configuration](docs/CONFIGURATION.md) - [Data Visualization](docs/VISUALIZATION.md) +- [Prompts & Dialects](docs/PROMPTS.md) - [A2A Protocol](docs/A2A.md) +- [MCP Protocol](docs/MCP.md) ## Quick Start diff --git a/docs/CONFIGURATION.md b/docs/CONFIGURATION.md index f2f2c27..cf79cd2 100644 --- a/docs/CONFIGURATION.md +++ b/docs/CONFIGURATION.md @@ -1,19 +1,6 @@ # Configuration -Data agents are configured via YAML files. See `src/data_agent/config/contoso.yaml` for a complete example. - -## Intent Detection - -```yaml -intent_detection_agent: - llm: - model: gpt-4o - provider: azure_openai - temperature: 0.0 - system_prompt: | - You are an intent detection assistant... - {agent_descriptions} -``` +Data agents are configured via YAML files. See `src/data_agent/agents/contoso.yaml` for a complete example. ## Data Agent Definition diff --git a/docs/MCP.md b/docs/MCP.md new file mode 100644 index 0000000..3ffb2da --- /dev/null +++ b/docs/MCP.md @@ -0,0 +1,147 @@ +# MCP Protocol Support + +The Data Agent supports the [Model Context Protocol (MCP)](https://modelcontextprotocol.io/), enabling integration with Claude Desktop, VS Code, Cursor, and other MCP-compatible clients. + +## Quick Start + +```bash +# Start MCP server with SSE transport (default) +uv run data-agent mcp + +# Start with a specific config +uv run data-agent mcp --config contoso + +# Start with stdio transport (for Claude Desktop) +uv run data-agent mcp --transport stdio + +# Start on a custom port +uv run data-agent mcp --port 9000 +``` + +## Server Options + +| Option | Default | Description | +|--------|---------|-------------| +| `--config, -c` | all | Configuration name (e.g., `contoso`). Loads all configs if not specified. | +| `--transport, -t` | sse | Transport: `sse` for HTTP clients (VS Code, Cursor), `stdio` for Claude Desktop | +| `--port, -p` | 8002 | Port for SSE transport | +| `--log-level` | warning | Logging level | + +## Available Tools + +The MCP server exposes the following tools: + +| Tool | Description | +|------|-------------| +| `query` | Execute natural language queries against datasources | +| `list_datasources` | List all configured datasources with descriptions | +| `list_tables` | List tables for a specific datasource | +| `get_schema` | Get database schema for a specific datasource | +| `validate_sql` | Validate SQL syntax without executing | + +## Available Resources + +| Resource URI | Description | +|--------------|-------------| +| `datasources://list` | List of available datasources | +| `schema://{datasource}` | Database schema for a datasource | +| `tables://{datasource}` | List of tables for a datasource | + +## Client Configuration + +### Claude Desktop + +Add to `~/Library/Application Support/Claude/claude_desktop_config.json` (macOS) or `%APPDATA%\Claude\claude_desktop_config.json` (Windows): + +```json +{ + "mcpServers": { + "data-agent": { + "command": "uv", + "args": ["run", "data-agent-mcp"], + "cwd": "/path/to/langchain_data_agent" + } + } +} +``` + +### VS Code + +Add to `.vscode/mcp.json` in your workspace: + +```json +{ + "servers": { + "data-agent": { + "type": "sse", + "url": "http://127.0.0.1:8002/sse" + } + } +} +``` + +> **Note:** Start the MCP server first with `uv run data-agent mcp` before connecting. + +Or for stdio transport (runs server automatically): + +```json +{ + "servers": { + "data-agent": { + "type": "stdio", + "command": "uv", + "args": ["run", "data-agent-mcp", "--transport", "stdio"] + } + } +} +``` + +### Cursor / Windsurf + +Similar configuration to VS Code. Check your IDE's MCP documentation. + +## Example Usage + +Once configured, you can interact with the Data Agent directly from your AI client: + +``` +User: What datasources are available? +AI: [calls list_datasources] → Shows contoso, adventure_works, amex + +User: What's the schema for the contoso database? +AI: [calls get_schema("contoso")] → Shows tables, columns, types + +User: Show me the top 5 products by sales in Q4 2024 +AI: [calls query("top 5 products by sales Q4 2024")] → Returns results +``` + +## Programmatic Client Example + +```python +from mcp import ClientSession, StdioServerParameters +from mcp.client.stdio import stdio_client + +async def main(): + server_params = StdioServerParameters( + command="uv", + args=["run", "data-agent-mcp"] + ) + + async with stdio_client(server_params) as (read, write): + async with ClientSession(read, write) as session: + await session.initialize() + + # List available tools + tools = await session.list_tools() + print("Available tools:", [t.name for t in tools.tools]) + + # Execute a query + result = await session.call_tool( + "query", + arguments={"question": "What are the top selling products?"} + ) + print(result.content) + +import asyncio +asyncio.run(main()) +``` diff --git a/docs/PROMPTS.md b/docs/PROMPTS.md new file mode 100644 index 0000000..eafa749 --- /dev/null +++ b/docs/PROMPTS.md @@ -0,0 +1,173 @@ +# Prompts Module + +This module manages system prompts for the Data Agent's LLM interactions. It provides a modular, extensible prompt architecture that supports multiple database dialects and customization through configuration. + +## Architecture + +``` +src/data_agent/prompts/ +├── __init__.py # Public exports +├── builder.py # Prompt assembly logic +├── defaults.py # Default prompt templates +└── dialects.py # Database-specific SQL guidelines +``` + +## Components + +### defaults.py - Core Prompt Templates + +Contains the default system prompts used across the agent: + +| Prompt | Purpose | +|--------|---------| +| `DEFAULT_INTENT_DETECTION_PROMPT` | Routes user questions to the appropriate data agent | +| `DEFAULT_GENERAL_CHAT_PROMPT` | Handles greetings and capability questions | +| `DEFAULT_SQL_PROMPT` | Guides SQL generation with schema context | +| `DEFAULT_RESPONSE_PROMPT` | Formats query results into natural language | +| `VISUALIZATION_SYSTEM_PROMPT` | Generates matplotlib visualization code | +| `COSMOS_PROMPT_ADDENDUM` | Cosmos DB-specific constraints and best practices | + +### dialects.py - Database-Specific Guidelines + +Provides SQL dialect guidelines that are automatically appended based on datasource type: + +| Dialect | Datasource Types | +|---------|------------------| +| BigQuery | `bigquery` | +| PostgreSQL | `postgres`, `postgresql` | +| Azure SQL / SQL Server | `azure_sql`, `mssql`, `sqlserver` | +| Azure Synapse | `synapse` | +| Databricks | `databricks` | +| Cosmos DB | `cosmos`, `cosmosdb` | + +Each dialect includes: +- Syntax conventions (date functions, data types, quoting) +- Aggregation function usage +- String manipulation functions +- Performance best practices + +### builder.py - Prompt Assembly + +The `build_prompt()` function assembles the final system prompt: + +``` +┌─────────────────────────────────────┐ +│ Date Context (current date) │ +├─────────────────────────────────────┤ +│ Base Prompt (custom or default) │ +│ - Schema context │ +│ - Few-shot examples │ +├─────────────────────────────────────┤ +│ Dialect Guidelines │ +│ (based on datasource type) │ +├─────────────────────────────────────┤ +│ Cosmos Addendum (if applicable) │ +│ - Partition key constraints │ +└─────────────────────────────────────┘ +``` + +## Usage + +### Basic Prompt Building + +```python +from data_agent.prompts import build_prompt + +# Build a prompt for PostgreSQL +prompt = build_prompt( + datasource_type="postgres", + schema_context="Tables: customers (id, name, email), orders (id, customer_id, total)", + few_shot_examples="Q: How many customers?\nA: SELECT COUNT(*) FROM customers", +) +``` + +### Custom Prompts via Configuration + +Teams can override default prompts in their agent YAML configuration using `system_prompt` and `response_prompt`: + +```yaml +data_agents: + - name: my_agent + description: E-commerce sales database + datasource: + type: postgres + # ... + system_prompt: | + You are a SQL expert for our e-commerce database. + Focus on sales metrics and customer behavior. + + {schema_context} + + {few_shot_examples} + response_prompt: | + Provide insights focused on business impact. + Always mention revenue implications. + table_schemas: + # ... +``` + +### Getting Dialect Guidelines + +```python +from data_agent.prompts import get_dialect_guidelines + +# Get BigQuery-specific SQL guidelines +guidelines = get_dialect_guidelines("bigquery") +``` + +## Prompt Template Variables + +The following variables are automatically substituted: + +| Variable | Description | Used In | +|----------|-------------|---------| +| `{schema_context}` | Database schema information | SQL prompt | +| `{few_shot_examples}` | Example Q&A pairs | SQL prompt | +| `{agent_descriptions}` | Available data agents | Intent detection, general chat | +| `{partition_key}` | Cosmos DB partition key | Cosmos addendum | + +## Extending + +### Adding a New Dialect + +1. Add guidelines constant to `dialects.py`: + +```python +MY_DATABASE_GUIDELINES = """## My Database SQL Guidelines + +1. **Syntax conventions:** + - Use MY_DATE_FUNC() for date operations + - ... +""" +``` + +2. Register in `DIALECT_GUIDELINES_MAP`: + +```python +DIALECT_GUIDELINES_MAP: dict[str, str] = { + # ... existing entries + "mydatabase": MY_DATABASE_GUIDELINES, +} +``` + +### Adding a New Prompt Type + +1. Add the template to `defaults.py`: + +```python +MY_NEW_PROMPT = """You are a specialized assistant for... + +{custom_variable} +""" +``` + +2. Export in `__init__.py`: + +```python +from data_agent.prompts.defaults import MY_NEW_PROMPT + +__all__ = [ + # ... existing exports + "MY_NEW_PROMPT", +] +``` diff --git a/docs/langchain.png b/docs/langchain.png new file mode 100644 index 0000000..b959e54 Binary files /dev/null and b/docs/langchain.png differ diff --git a/pyproject.toml b/pyproject.toml index 6e957c9..e842d06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,12 +47,14 @@ dependencies = [ "langchain-azure-dynamic-sessions>=0.2.0", "matplotlib>=3.10.8", "tabulate>=0.9.0", + "mcp>=1.25.0", ] [project.scripts] data-agent = "data_agent.cli:main" data-agent-ui = "data_agent.ui:main" data-agent-a2a = "data_agent.a2a.server:main" +data-agent-mcp = "data_agent.mcp.server:main" [project.optional-dependencies] dev = [ diff --git a/src/data_agent/agent.py b/src/data_agent/agent.py index 0b9fd02..f0db36b 100644 --- a/src/data_agent/agent.py +++ b/src/data_agent/agent.py @@ -11,6 +11,7 @@ from uuid import uuid4 from langchain_community.utilities.sql_database import SQLDatabase +from langchain_core.language_models import BaseChatModel from langchain_core.messages import AIMessage, HumanMessage, SystemMessage from langgraph.checkpoint.memory import InMemorySaver from langgraph.graph import END, START, StateGraph @@ -32,6 +33,11 @@ from data_agent.graph import create_data_agent from data_agent.llm import get_llm from data_agent.models.state import AgentState, InputState, OutputState +from data_agent.prompts.defaults import ( + DEFAULT_GENERAL_CHAT_PROMPT, + DEFAULT_INTENT_DETECTION_PROMPT, + DEFAULT_QUERY_REWRITE_PROMPT, +) from data_agent.utils.callbacks import AgentCallback from data_agent.utils.message_utils import get_recent_history @@ -83,28 +89,16 @@ def __init__( self._agent_descriptions: dict[str, str] = {} self._shared_db = shared_db - self._intent_llm = get_llm( - provider=self.config.intent_detection.llm_config.provider, + self._workflow_llm: BaseChatModel = get_llm( + provider="azure_openai", azure_endpoint=azure_endpoint, api_key=api_key, - deployment_name=self.config.intent_detection.llm_config.model - or deployment_name, - api_version=self.config.intent_detection.llm_config.api_version - or api_version, - temperature=self.config.intent_detection.llm_config.temperature, + deployment_name=deployment_name, + api_version=api_version, ) - self._default_llm_settings = { - "azure_endpoint": azure_endpoint, - "api_key": api_key, - "deployment_name": deployment_name, - "api_version": api_version, - } - self._callback = AgentCallback(agent_name="data_agent_flow") - self._initialize_agents() - self._graph = self._build_workflow() def _initialize_agents(self) -> None: @@ -210,18 +204,26 @@ def _create_datasource(self, name: str, ds: Any) -> Datasource | None: return None def _create_agent_graph(self, name: str, agent_config: DataAgentConfig) -> None: - """Create the LangGraph agent for a data agent.""" + """Create the LangGraph agent for a data agent. + + Uses the workflow LLM by default. If the agent has custom LLM settings + defined in its YAML, creates a new LLM instance. + + Args: + name: Name of the agent. + agent_config: Data agent configuration from YAML. + """ llm_cfg = agent_config.llm_config - agent_llm = get_llm( - provider=llm_cfg.provider or "azure_openai", - azure_endpoint=self._default_llm_settings["azure_endpoint"], - api_key=self._default_llm_settings["api_key"], - deployment_name=llm_cfg.model - or self._default_llm_settings["deployment_name"], - api_version=llm_cfg.api_version - or self._default_llm_settings["api_version"], - temperature=llm_cfg.temperature, - ) + if llm_cfg: + agent_llm = get_llm( + provider=llm_cfg.provider or "azure_openai", + deployment_name=llm_cfg.model, + api_version=llm_cfg.api_version, + temperature=llm_cfg.temperature, + ) + else: + agent_llm = self._workflow_llm + self.data_agents[name] = create_data_agent( llm=agent_llm, datasource=self.datasources[name], @@ -262,8 +264,9 @@ def _build_workflow(self) -> CompiledStateGraph: """ agent_names = list(self.data_agents.keys()) agent_descriptions = self._agent_descriptions - intent_llm = self._intent_llm - intent_system_prompt = self.config.intent_detection.system_prompt + + # Valid intents include all agent names plus "general_chat" + valid_intents = agent_names + ["general_chat"] def intent_detection_node(state: AgentState) -> dict[str, Any]: """Detect user intent and select the appropriate data agent.""" @@ -273,7 +276,9 @@ def intent_detection_node(state: AgentState) -> dict[str, Any]: [f"- {name}: {desc}" for name, desc in agent_descriptions.items()] ) - system_content = intent_system_prompt.format(agent_descriptions=agent_list) + system_content = DEFAULT_INTENT_DETECTION_PROMPT.format( + agent_descriptions=agent_list + ) history = get_recent_history(state.get("messages", []), max_messages=4) @@ -283,7 +288,7 @@ def intent_detection_node(state: AgentState) -> dict[str, Any]: HumanMessage(content=question), ] - response = intent_llm.invoke(messages) + response = self._workflow_llm.invoke(messages) content = response.content selected_agent = ( content.strip() @@ -291,7 +296,7 @@ def intent_detection_node(state: AgentState) -> dict[str, Any]: else str(content[0]).strip() if content else "" ) - if selected_agent not in agent_names: + if selected_agent not in valid_intents: clarification = interrupt( { "type": "clarification_needed", @@ -310,14 +315,14 @@ def intent_detection_node(state: AgentState) -> dict[str, Any]: SystemMessage(content=system_content), HumanMessage(content=clarified_question), ] - response = intent_llm.invoke(messages) + response = self._workflow_llm.invoke(messages) content = response.content selected_agent = ( content.strip() if isinstance(content, str) else str(content[0]).strip() if content else "" ) - if selected_agent in agent_names: + if selected_agent in valid_intents: return { "question": clarified_question, "datasource_name": selected_agent, @@ -365,30 +370,14 @@ def query_rewrite_node(state: AgentState) -> dict[str, Any]: content = getattr(msg, "content", str(msg))[:500] conversation_context += f"- {msg_type}: {content}\n" - rewrite_prompt = f"""You are a query rewriter. Your job is to rewrite user questions to be more specific and clear for a database query system. - -## Target Agent -{agent_desc} - -## Conversation Context -{conversation_context} - -## Instructions -1. Keep the original intent of the question -2. If this is a follow-up question (e.g., "what's the average?", "show me the same for X", "filter those by Y"), use the conversation history to expand the question with the relevant context -3. For follow-up questions, make the implicit references explicit (e.g., "What's the average?" → "What is the average transaction amount?" if previous query was about transactions) -4. Make the question more specific if needed -5. If the question is already clear and specific, return it unchanged -6. Do NOT add information that wasn't implied by the original question or conversation - -## Original Question -{question} - -## Rewritten Question -Respond with ONLY the rewritten question, nothing else.""" + rewrite_prompt = DEFAULT_QUERY_REWRITE_PROMPT.format( + agent_description=agent_desc, + conversation_context=conversation_context, + question=question, + ) messages = [HumanMessage(content=rewrite_prompt)] - response = intent_llm.invoke(messages) + response = self._workflow_llm.invoke(messages) content = response.content rewritten = ( content.strip() @@ -414,6 +403,40 @@ def route_to_agent(state: AgentState) -> str: return "out_of_scope" return datasource + def general_chat_node(state: AgentState) -> dict[str, Any]: + """Handle general conversation like greetings and capability questions.""" + question = state["question"] + + agent_list = "\n".join( + [f"- **{name}**: {desc}" for name, desc in agent_descriptions.items()] + ) + + system_content = DEFAULT_GENERAL_CHAT_PROMPT.format( + agent_descriptions=agent_list + ) + + messages = [ + SystemMessage(content=system_content), + HumanMessage(content=question), + ] + + response = self._workflow_llm.invoke(messages) + content = response.content + response_text = ( + content.strip() + if isinstance(content, str) + else str(content[0]).strip() if content else "" + ) + + return { + "final_response": response_text, + "error": None, + "messages": [ + HumanMessage(content=question, name="user"), + AIMessage(content=response_text, name="general_chat"), + ], + } + def out_of_scope_node(state: AgentState) -> dict[str, Any]: """Handle out-of-scope requests.""" agent_list = "\n".join( @@ -440,6 +463,7 @@ def out_of_scope_node(state: AgentState) -> dict[str, Any]: workflow.add_node("intent_detection", intent_detection_node) workflow.add_node("query_rewrite", query_rewrite_node) + workflow.add_node("general_chat", general_chat_node) workflow.add_node("out_of_scope", out_of_scope_node) for name, agent in self.data_agents.items(): @@ -448,16 +472,18 @@ def out_of_scope_node(state: AgentState) -> dict[str, Any]: workflow.add_edge(START, "intent_detection") def route_after_intent(state: AgentState) -> str: - """Route to query_rewrite or out_of_scope after intent detection.""" + """Route to query_rewrite, general_chat, or out_of_scope after intent detection.""" datasource = state.get("datasource_name", "") if not datasource or state.get("error") == "out_of_scope": return "out_of_scope" + if datasource == "general_chat": + return "general_chat" return "query_rewrite" workflow.add_conditional_edges( "intent_detection", route_after_intent, - path_map=["query_rewrite", "out_of_scope"], + path_map=["query_rewrite", "general_chat", "out_of_scope"], ) routing_map: dict[str, str] = {name: name for name in agent_names} @@ -468,6 +494,7 @@ def route_after_intent(state: AgentState) -> str: path_map=list(routing_map.keys()), ) + workflow.add_edge("general_chat", END) workflow.add_edge("out_of_scope", END) for name in agent_names: workflow.add_edge(name, END) diff --git a/src/data_agent/config/adventure_works.yaml b/src/data_agent/agents/adventure_works.yaml similarity index 96% rename from src/data_agent/config/adventure_works.yaml rename to src/data_agent/agents/adventure_works.yaml index 78b6dbb..9dda68f 100644 --- a/src/data_agent/config/adventure_works.yaml +++ b/src/data_agent/agents/adventure_works.yaml @@ -1,28 +1,3 @@ -intent_detection_agent: - llm: - model: gpt-5-mini - provider: azure_openai - api_version: 2024-12-01-preview - temperature: 1.0 - max_tokens: 500 - system_prompt: | - You are an intent detection assistant responsible for routing user questions to the appropriate data agent. - - ## Available Data Agents - - {agent_descriptions} - - ## Instructions - - 1. Analyze the user's question to understand what data they are asking about. - 2. Match the question to the most relevant data agent based on the domain and data types. - 3. If the question is ambiguous, choose the agent most likely to have the relevant data. - 4. If no agent is a clear match, respond with "unknown". - - ## Response Format - - Respond with ONLY the agent name (e.g., "contoso_hr" or "hotel_analytics"). Do not include any explanation. - data_agents: - name: contoso_hr datasource: diff --git a/src/data_agent/config/amex.yaml b/src/data_agent/agents/amex.yaml similarity index 57% rename from src/data_agent/config/amex.yaml rename to src/data_agent/agents/amex.yaml index c8ec50a..666fbe3 100644 --- a/src/data_agent/config/amex.yaml +++ b/src/data_agent/agents/amex.yaml @@ -1,28 +1,3 @@ -intent_detection_agent: - llm: - model: gpt-5-mini - provider: azure_openai - api_version: 2024-12-01-preview - temperature: 1.0 - max_tokens: 500 - system_prompt: | - You are an intent detection assistant responsible for routing user questions to the appropriate data agent. - - ## Available Data Agents - - {agent_descriptions} - - ## Instructions - - 1. Analyze the user's question to understand what data they are asking about. - 2. Match the question to the most relevant data agent based on the domain and data types. - 3. If the question is ambiguous, choose the agent most likely to have the relevant data. - 4. If no agent is a clear match, respond with "unknown". - - ## Response Format - - Respond with ONLY the agent name (e.g., "financial_transactions"). Do not include any explanation. - data_agents: - name: financial_transactions description: Financial transactions, accounts, customers, and fraud alerts @@ -57,24 +32,6 @@ data_agents: {few_shot_examples} - ## BigQuery SQL Generation Guidelines - - 1. **Use only tables and columns defined in the schema above.** Never reference tables or columns not listed. - 2. **Use BigQuery SQL syntax.** This includes: - - DATE_TRUNC, DATE_ADD, DATE_SUB for date operations - - CURRENT_DATE(), CURRENT_TIMESTAMP() for current time - - EXTRACT(YEAR FROM date), EXTRACT(MONTH FROM date) for date parts - - STRING, INT64, FLOAT64, NUMERIC, BOOL data types - - TIMESTAMP for datetime values - 3. **Always qualify column names** with table aliases to avoid ambiguity. - 4. **Use appropriate JOINs** when combining data from multiple tables. - 5. **Include WHERE clauses** to filter data when the question implies filtering. - 6. **Use GROUP BY** for aggregations and ORDER BY for sorting results. - 7. **Use LIMIT** to restrict results (e.g., LIMIT 100) unless the user specifies otherwise. - 8. **Handle NULL values** appropriately with IFNULL, COALESCE, or IS NOT NULL checks. - 9. **Use fully qualified table names** in format `project.dataset.table` or just `dataset.table`. - 10. **DISTINCT with ORDER BY limitation**: When using DISTINCT inside aggregate functions (e.g., ARRAY_AGG, STRING_AGG), ORDER BY can only reference the column being aggregated, not other columns. - ## Response Format Provide your response as JSON with these fields: @@ -82,27 +39,8 @@ data_agents: - "sql_query": The generated BigQuery SQL query - "explanation": Brief explanation of what the query does - "visualization_requested": Set to true if the user is asking for a chart, graph, plot, bar chart, pie chart, line chart, histogram, or any visual representation of the data. Set to false for plain data queries. - response_prompt: | - You are a helpful financial analyst for banking and transaction data. - - Given the user's question, the SQL query that was executed, and the results, - provide a clear and concise natural language response. - - Be conversational but precise. Include relevant numbers, percentages, and insights. - Format currency values with $ and commas. Format large numbers for readability. - When discussing fraud or risk, be clear about severity levels and recommended actions. - If the results are empty, explain what that means in context. - - ## Data Presentation - - When the query returns tabular data (multiple rows/columns), ALWAYS include a formatted markdown table showing the results. - - Use proper markdown table syntax with headers - - Align numeric columns to the right - - Format currency with $ and commas (e.g., $1,234.56) - - Format dates in readable format (e.g., Dec 16, 2025) - - Limit tables to 20 rows max; if more rows exist, show first 20 and note "... and X more rows" - - After the table, provide a brief summary or insight about the data - # table_schemas omitted - will use dynamic schema discovery from BigQuery + # response_prompt: omitted - will use default prompt + # table_schemas: omitted - will use dynamic schema discovery from BigQuery few_shot_examples: - question: What are the total deposits by customer segment this month? answer: This query shows total deposits grouped by customer segment (VIP, Premium, Standard) for the current month. diff --git a/src/data_agent/config/contoso.yaml b/src/data_agent/agents/contoso.yaml similarity index 96% rename from src/data_agent/config/contoso.yaml rename to src/data_agent/agents/contoso.yaml index 05b6b37..ad38ab0 100644 --- a/src/data_agent/config/contoso.yaml +++ b/src/data_agent/agents/contoso.yaml @@ -1,28 +1,3 @@ -intent_detection_agent: - llm: - model: gpt-5-mini - provider: azure_openai - api_version: 2024-12-01-preview - temperature: 1.0 - max_tokens: 500 - system_prompt: | - You are an intent detection assistant responsible for routing user questions to the appropriate data agent. - - ## Available Data Agents - - {agent_descriptions} - - ## Instructions - - 1. Analyze the user's question to understand what data they are asking about. - 2. Match the question to the most relevant data agent based on the domain and data types. - 3. If the question is ambiguous, choose the agent most likely to have the relevant data. - 4. If no agent is a clear match, respond with "unknown". - - ## Response Format - - Respond with ONLY the agent name (e.g., "contoso_sales", "contoso_products" or "contoso_inventory"). Do not include any explanation. - data_agents: - name: contoso_sales description: Sales transactions, customers, and revenue analytics diff --git a/src/data_agent/config/schema/agent_config.schema.json b/src/data_agent/agents/schema/agent_config.schema.json similarity index 96% rename from src/data_agent/config/schema/agent_config.schema.json rename to src/data_agent/agents/schema/agent_config.schema.json index 98e601f..674d54f 100644 --- a/src/data_agent/config/schema/agent_config.schema.json +++ b/src/data_agent/agents/schema/agent_config.schema.json @@ -6,9 +6,6 @@ "type": "object", "additionalProperties": false, "properties": { - "intent_detection_agent": { - "$ref": "#/$defs/intent_detection_config" - }, "data_agents": { "type": "array", "description": "List of data agent configurations", @@ -72,20 +69,6 @@ "model" ] }, - "intent_detection_config": { - "type": "object", - "description": "Configuration for intent detection agent", - "additionalProperties": false, - "properties": { - "llm": { - "$ref": "#/$defs/llm_config" - }, - "system_prompt": { - "type": "string", - "description": "System prompt for intent detection. Use {agent_descriptions} placeholder." - } - } - }, "column_schema": { "type": "object", "description": "Schema definition for a table column", diff --git a/src/data_agent/agents/wmt_retail.yaml b/src/data_agent/agents/wmt_retail.yaml new file mode 100644 index 0000000..a5532cf --- /dev/null +++ b/src/data_agent/agents/wmt_retail.yaml @@ -0,0 +1,1009 @@ +data_agents: + - name: wmt_retail_sales + description: Walmart US retail point-of-sale transactions, items, stores, channels, and financial data + datasource: + type: bigquery + project_id: wmt-edw-prod + dataset: US_FIN_SALES_DL_RPT_VM + location: US + llm: + model: gpt-5-mini + provider: azure_openai + api_version: 2024-12-01-preview + temperature: 1.0 + max_tokens: 2000 + validation: + max_rows: 5000 + blocked_functions: + - session_user + - external_query + system_prompt: | + You are an expert SQL assistant for a Walmart US retail point-of-sale database running on Google BigQuery. + + ## Your Role + + Generate accurate, efficient BigQuery SQL queries based on natural language questions. You have access to comprehensive retail transaction data including stores, items, channels, fulfillment, and financial hierarchies. + + ## Database Context + + {schema_context} + + ## Examples + + {few_shot_examples} + + ## Key Business Concepts + + ### Transaction Identifiers + - **VISIT_NBR**: 29-character unique identifier for B&M register transactions (format: void_flag + country + store + register + transaction + datetime) + - **ORDER_NBR**: Walmart.com order number for online-initiated sales + + ### Sales Channels + - **CHNL_TYPE_ID**: High-level channel (1=Store Sale, 30=Layaway, 86=Online Initiated) + - **FULFMT_TYPE_ID**: Fulfillment method (6=Scheduled Pickup, 7=Scheduled Delivery, etc.) + - **CHNL_CONS_HIER_LVL_***: Channel consumption hierarchy for reporting + + ### Item Hierarchy (Walmart) + Department → Category Group → Category → Subcategory → Fineline + + ### Financial Hierarchy + FIN_GROUP → FIN_SEG → FIN_SUBGROUP → FIN_PORTF + + ### Store Classifications + - **COMP_CD**: Comparable store indicator (C=Comparable, L/T=New, S/E=Expansion, M/R=Relocation) + - **BANNER_CD**: Store banner (A1=Supercenter, etc.) + - **STORE_TYPE_CD**: Physical building type + + ### Fiscal Calendar + - Walmart fiscal year starts in February + - **WM_YR_WK_NBR**: YYYYWW format fiscal year-week + - **FISCAL_MTH_NBR**: 1=February through 12=January + + ## Data Asset Information + + ### Timeframe & Timeliness + - **Historical Data**: VISIT_LOCAL_DT from 2017-01-01 to current + - **Data Refresh**: Updates daily with yesterday's data by 6:00 AM CDT + - **Table Grain**: Store/Day/Register/Transaction/Item level + + ### Data Quality + - Quality variance: Less than +/- 1 basis point to SAP per month at Walmart segment level + - Completeness validations ensure data equals source and all Data Lake layers are consistent + - Primary source: CTH (Customer Transaction History) + + ### Data Metrics Available + - Retail Sales & Returns (SALES_AMT) + - Unit Quantity (UOM_QTY, OPS_UNIT_QTY) + - Associate Discount (ASSOC_DISC_AMT) + + ### SAP Alignment Rules (CRITICAL) + - **OTHER_INCOME_IND = 0**: Aligns to GL accounts 4101010 + 4101030 + 4102001 (Parent IDs: FN105 + BRUS975) + - **OTHER_INCOME_IND = 1**: Revenue from items posting to other GL accounts via FN105 + - **For top-line merchandise sales only**: Use `OTHER_INCOME_IND = 0 AND VENDOR_NBR != 481890` + - Note: Bottle deposit revenue (state-imposed) is flagged OII=0 because it posts to 4101010, exclude via VENDOR_NBR filter if needed + + ## Response Format + + Provide your response as JSON with these fields: + - "thinking": Step-by-step reasoning about how to construct the query + - "sql_query": The generated BigQuery SQL query + - "explanation": Brief explanation of what the query does + - "visualization_requested": Set to true if the user is asking for a chart, graph, plot, or any visual representation + + response_prompt: | + You are a helpful retail analyst for Walmart US sales data. + + Given the user's question, the SQL query that was executed, and the results, + provide a clear and concise natural language response. + + Be conversational but precise. Include relevant numbers, percentages, and insights. + Format currency values with $ and commas. Format large numbers for readability. + When discussing sales performance, provide context about comparable stores, channels, and time periods. + If the results are empty, explain what that means in context. + + ## Data Presentation + + When the query returns tabular data (multiple rows/columns), ALWAYS include a formatted markdown table showing the results. + - Use proper markdown table syntax with headers + - Align numeric columns to the right + - Format currency with $ and commas (e.g., $1,234.56) + - Format dates in readable format (e.g., Jun 21, 2025) + - Limit tables to 20 rows max; if more rows exist, show first 20 and note "... and X more rows" + - After the table, provide a brief summary or insight about the data + + table_schemas: + - table_name: WMT_STORE_SALES_DTL_D + table_description: Point-of-sale item-level transaction data for Walmart US stores including online-initiated store-fulfilled orders. + columns: + # Transaction Identifiers + - column_name: STORE_NBR + data_type: INT64 + description: "Store Number - Unique numeric identifier for retail locations, DCs, Home Office departments" + - column_name: VISIT_NBR + data_type: STRING + description: "Visit Number - 29-character unique identifier for B&M register transactions. Format: void(1) + country(2) + store(5) + register(3) + transaction(4) + datetime(14)" + - column_name: SCAN_ID + data_type: INT64 + description: "Scan ID - Company-assigned number for each item sold. Interpretation varies by SCAN_TYPE" + - column_name: SCAN_TYPE + data_type: INT64 + description: "Scan Type - Item type indicator: 0=Corporate Item, 1=Dept Hand-Keyed, 2=Local Store Item, 3=Rx, 4=Old Number, -999=Unknown" + - column_name: SCAN_TYPE_DESC + data_type: STRING + description: "Scan Type Description" + - column_name: SEQ_LINE_NBR + data_type: INT64 + description: "Sequence Line Number - Position of item in register transaction or eComm order" + + # Channel & Fulfillment + - column_name: CHNL_TYPE_ID + data_type: INT64 + description: "Channel Type ID - 1=Store Sale, 30=Layaway, 86=Online Initiated Sale" + - column_name: CHNL_TYPE_DESC + data_type: STRING + description: "Channel Type Description" + - column_name: CHNL_SUBTYPE_ID + data_type: INT64 + description: "Channel Subtype ID - Granular channel code (VISIT_SUBTYPE_CD for B&M, SVC_ID for eComm)" + - column_name: CHNL_SUBTYPE_DESC + data_type: STRING + description: "Channel Subtype Description" + - column_name: FULFMT_TYPE_ID + data_type: INT64 + description: "Fulfillment Type ID - 1=Site to Home, 6=Scheduled Pickup, 7=Scheduled Delivery, 8=Express Pickup, 9=Express Delivery, 16=Store" + - column_name: FULFMT_TYPE_DESC + data_type: STRING + description: "Fulfillment Type Description" + - column_name: ORDER_NBR + data_type: STRING + description: "Walmart.com order number for online-initiated store-fulfilled sales" + - column_name: CHNL_ACTV_GRP_NBR + data_type: INT64 + description: "Channel Activity Group Number for business channel identification" + - column_name: CHNL_ACTV_GRP_DESC + data_type: STRING + description: "Channel Activity Group Description" + + # Channel Consumption Hierarchy + - column_name: CHNL_CONS_HIER_LVL_0_DESC + data_type: STRING + description: "Channel Hierarchy Level 0 - Total Channel Consumption" + - column_name: CHNL_CONS_HIER_LVL_1_DESC + data_type: STRING + description: "Channel Hierarchy Level 1" + - column_name: CHNL_CONS_HIER_LVL_2_DESC + data_type: STRING + description: "Channel Hierarchy Level 2" + - column_name: CHNL_CONS_HIER_LVL_3_DESC + data_type: STRING + description: "Channel Hierarchy Level 3" + - column_name: CHNL_CONS_HIER_LVL_4_DESC + data_type: STRING + description: "Channel Hierarchy Level 4" + - column_name: CHNL_BASE_DESC + data_type: STRING + description: "Channel Base Description" + - column_name: SRC_OF_ORIGN_CD + data_type: STRING + description: "Source of Origin Code" + + # Sales Metrics + - column_name: SALES_AMT + data_type: NUMERIC + description: "Sales Amount - Retail amount charged or refunded" + - column_name: UNIT_COST_AMT + data_type: NUMERIC + description: "Unit Cost Amount - Synthetic cost metric (multiply by UOM_QTY for total cost)" + - column_name: UOM_QTY + data_type: NUMERIC + description: "Unit of Measure Quantity - Packages sold (or pounds/yards for variable-priced items)" + - column_name: SCAN_CNT + data_type: INT64 + description: "Scan Count - Number of items physically scanned" + - column_name: OPS_UNIT_QTY + data_type: NUMERIC + description: "Operations Unit Quantity - Business-defined package count metric for all items" + - column_name: ASSOC_DISC_AMT + data_type: NUMERIC + description: "Associate Discount Amount" + - column_name: SVC_VAL_AMT + data_type: NUMERIC + description: "Services Value Amount - Company service income (auto services, warranty, financial services)" + - column_name: PRC_OVERRIDEN_AMT + data_type: NUMERIC + description: "Price Overridden Amount - Original unit price before override" + + # Report & Pricing Codes + - column_name: RPT_CD + data_type: STRING + description: "Report Code - Sell status: 0=Regular, 1=Store Tab, 2=Company Tab, 7=Rollback, 8=Clearance" + - column_name: RPT_CD_DESC + data_type: STRING + description: "Report Code Description" + + # Transaction Indicators + - column_name: RTN_ITEM_IND + data_type: INT64 + description: "Return Item Indicator - 1 if item was returned" + - column_name: EXCHANGE_ITEM_IND + data_type: INT64 + description: "Exchange Item Indicator - 1 if even exchange for returned item" + - column_name: VOID_ITEM_IND + data_type: INT64 + description: "Void Item Indicator - 1 on void records (negative SALES_AMT)" + - column_name: VOID_CMPLT_TRANS_IND + data_type: INT64 + description: "Void Complete Transaction Indicator - 1 if entire transaction voided after completion" + - column_name: OTHER_INCOME_IND + data_type: INT64 + description: "Other Income Indicator - 1 if item routes to non-sales GL accounts" + - column_name: SELF_CHKOUT_IND + data_type: INT64 + description: "Self Checkout Indicator - 1 if sold on self-checkout register" + - column_name: FUEL_CONV_STORE_IND + data_type: INT64 + description: "Fuel Convenience Store Indicator - 1 if store has company-owned gas station" + - column_name: DSV_IND + data_type: INT64 + description: "Drop Ship Vendor Indicator - 1 if shipped directly from vendor" + - column_name: INSTACART_IND + data_type: INT64 + description: "Instacart Indicator - 1 if purchased by Instacart shopper" + + # Item Identifiers + - column_name: MDS_FAM_ID + data_type: INT64 + description: "Merchandise Family ID - Unique corporate item identifier" + - column_name: UPC_SRC_NBR + data_type: NUMERIC + description: "UPC Number from source" + - column_name: UPC_DESC + data_type: STRING + description: "UPC Description - 12-character receipt description" + - column_name: ITEM_NBR + data_type: INT64 + description: "Item Number - Vendor-specific company-assigned number" + - column_name: ITEM_DESC_1 + data_type: STRING + description: "Item Description 1 - Primary item description" + - column_name: ITEM_DESC_2 + data_type: STRING + description: "Item Description 2 - Secondary item description" + - column_name: ITEM_STATUS_CD + data_type: STRING + description: "Item Status Code - A=Active, I=Inactive, D=Delete" + - column_name: SIGNING_DESC + data_type: STRING + description: "Signing Description - Shelf sign description" + - column_name: UOM_CD + data_type: STRING + description: "Unit of Measure Code - EA, LB, YD, etc." + + # Item Hierarchy (Walmart) + - column_name: DEPT_NBR + data_type: INT64 + description: "Department Number - Merchandising department" + - column_name: DEPT_DESC + data_type: STRING + description: "Department Description" + - column_name: ACCTG_DEPT_NBR + data_type: INT64 + description: "Accounting Department Number - SAP posting department" + - column_name: ACCTG_DEPT_DESC + data_type: STRING + description: "Accounting Department Description" + - column_name: DEPT_CATG_GRP_NBR + data_type: INT64 + description: "Department Category Group Number" + - column_name: DEPT_CATG_GRP_DESC + data_type: STRING + description: "Department Category Group Description" + - column_name: DEPT_CATG_NBR + data_type: INT64 + description: "Department Category Number" + - column_name: DEPT_CATG_DESC + data_type: STRING + description: "Department Category Description" + - column_name: DEPT_SUBCATG_NBR + data_type: INT64 + description: "Department Subcategory Number" + - column_name: DEPT_SUBCATG_DESC + data_type: STRING + description: "Department Subcategory Description" + - column_name: FINELINE_NBR + data_type: INT64 + description: "Fineline Number - Lowest hierarchy level for Walmart" + - column_name: FINELINE_DESC + data_type: STRING + description: "Fineline Description" + - column_name: SUBCLASS_NBR + data_type: INT64 + description: "Subclass Number - Sam's hierarchy (not used for Walmart)" + - column_name: SUBCLASS_DESC + data_type: STRING + description: "Subclass Description" + + # Financial Hierarchy + - column_name: FIN_GROUP_NBR + data_type: INT64 + description: "Finance Group Number" + - column_name: FIN_GROUP_DESC + data_type: STRING + description: "Finance Group Description" + - column_name: FIN_SEG_NBR + data_type: INT64 + description: "Finance Segment Number (SBU)" + - column_name: FIN_SEG_DESC + data_type: STRING + description: "Finance Segment Description" + - column_name: FIN_SUBGROUP_NBR + data_type: INT64 + description: "Finance Subgroup Number" + - column_name: FIN_SUBGROUP_DESC + data_type: STRING + description: "Finance Subgroup Description" + - column_name: FIN_PORTF_NBR + data_type: INT64 + description: "Finance Portfolio Number" + - column_name: FIN_PORTF_DESC + data_type: STRING + description: "Finance Portfolio Description" + - column_name: GL_ACCT_ALGN_NBR + data_type: INT64 + description: "GL Account Aligned Number - SAP GL account alignment" + - column_name: GL_ACCT_ALGN_NM + data_type: STRING + description: "GL Account Aligned Name" + + # Brand & Vendor + - column_name: BRAND_ID + data_type: INT64 + description: "Brand ID" + - column_name: BRAND_NM + data_type: STRING + description: "Brand Name" + - column_name: BRAND_OWNR_ID + data_type: INT64 + description: "Brand Owner ID" + - column_name: BRAND_OWNR_NM + data_type: STRING + description: "Brand Owner Name" + - column_name: VENDOR_NBR + data_type: INT64 + description: "Vendor Number" + - column_name: VENDOR_NM + data_type: STRING + description: "Vendor Name" + + # Store Attributes + - column_name: STORE_NM + data_type: STRING + description: "Store Name" + - column_name: BANNER_CD + data_type: STRING + description: "Banner Code - A1=Supercenter, etc." + - column_name: BANNER_DESC + data_type: STRING + description: "Banner Description" + - column_name: STORE_TYPE_CD + data_type: STRING + description: "Store Type Code - R=Regular, U=Supercenter/NHM, S=Sam's" + - column_name: STORE_TYPE_DESC + data_type: STRING + description: "Store Type Description" + - column_name: STORE_SIZE_SQFT + data_type: INT64 + description: "Store Size in Square Feet" + - column_name: GRAND_OPEN_DT + data_type: DATE + description: "Grand Open Date" + - column_name: EXPND_OPEN_DT + data_type: DATE + description: "Expansion/Relocation Open Date" + - column_name: OPEN_STATUS_CD + data_type: STRING + description: "Open Status Code - 0=Scheduled, 1=New<13mo, 2=Open>13mo, 7=Closed" + - column_name: OPEN_STATUS_DESC + data_type: STRING + description: "Open Status Description" + + # Comparable Store + - column_name: COMP_CD + data_type: STRING + description: "Comparable Code - C=Comparable, L/T=New, S/E=Expansion, M/R=Relocation, N/X=Closed" + - column_name: COMP_DESC + data_type: STRING + description: "Comparable Description" + + # Geographic Hierarchy + - column_name: STATE_PROV_CD + data_type: STRING + description: "State/Province Code" + - column_name: POSTAL_CD + data_type: STRING + description: "Postal Code (5 digits, zero-padded)" + - column_name: LAT_DGR + data_type: NUMERIC + description: "Latitude Degrees" + - column_name: LONG_DGR + data_type: NUMERIC + description: "Longitude Degrees" + - column_name: TZ_CD + data_type: STRING + description: "Timezone Code - EST, CST, MST, PST, etc." + - column_name: MARKET_NBR + data_type: INT64 + description: "Market Number" + - column_name: MARKET_NM + data_type: STRING + description: "Market Name" + - column_name: REGION_NBR + data_type: INT64 + description: "Region Number" + - column_name: REGION_NM + data_type: STRING + description: "Region Name" + - column_name: SUBDIV_NBR + data_type: STRING + description: "Subdivision Number" + - column_name: SUBDIV_NM + data_type: STRING + description: "Subdivision Name" + - column_name: BUO_AREA_NBR + data_type: INT64 + description: "Business Unit Area Number" + - column_name: BUO_AREA_NM + data_type: STRING + description: "Business Unit Area Name" + - column_name: MDSE_MAJ_ZONE_NBR + data_type: INT64 + description: "Merchandise Major Zone Number" + - column_name: MDSE_SUB_ZONE_NBR + data_type: INT64 + description: "Merchandise Sub Zone Number" + + # Price Investment + - column_name: PRICE_INVST_IND + data_type: INT64 + description: "Price Investment Indicator" + - column_name: PRICE_INVST_COMP_IND + data_type: INT64 + description: "Price Investment Comparable Indicator" + - column_name: PRICE_INVST_START_DT + data_type: DATE + description: "Price Investment Start Date" + - column_name: PRICE_INVST_WAVE_DESC + data_type: STRING + description: "Price Investment Wave Description" + + # Fiscal Calendar + - column_name: POSTG_DT + data_type: DATE + description: "Posting Date - Financial business date" + - column_name: VISIT_LOCAL_DT + data_type: DATE + description: "Visit Local Date - Store local date from POS" + - column_name: VISIT_LOCAL_TM + data_type: STRING + description: "Visit Local Time - Store local time (hh:mm:ss)" + - column_name: VISIT_UTC_DT + data_type: DATE + description: "Visit UTC Date" + - column_name: VISIT_UTC_TM + data_type: STRING + description: "Visit UTC Time" + - column_name: FISCAL_FULL_YR_NBR + data_type: INT64 + description: "Fiscal Full Year Number" + - column_name: FISCAL_MTH_NBR + data_type: INT64 + description: "Fiscal Month Number (1=Feb, 12=Jan)" + - column_name: FISCAL_MTH_ABBR + data_type: STRING + description: "Fiscal Month Abbreviation" + - column_name: FISCAL_QTR_NBR + data_type: INT64 + description: "Fiscal Quarter Number" + - column_name: WM_FULL_YR_NBR + data_type: INT64 + description: "Walmart Full Year Number" + - column_name: WM_MTH_NBR + data_type: INT64 + description: "Walmart Month Number" + - column_name: WM_MTH_ABBR + data_type: STRING + description: "Walmart Month Abbreviation" + - column_name: WM_QTR_NBR + data_type: INT64 + description: "Walmart Quarter Number" + - column_name: WM_WK_NBR + data_type: INT64 + description: "Walmart Week Number" + - column_name: WM_YR_WK_NBR + data_type: INT64 + description: "Walmart Year Week Number (YYYYWW)" + - column_name: WM_WK_DAY_NBR + data_type: INT64 + description: "Walmart Weekday Number (1=Sat, 7=Fri)" + - column_name: CAL_WK_DAY_ABBR + data_type: STRING + description: "Calendar Week Day Abbreviation" + + # Item Seasonal + - column_name: SEASN_CD + data_type: INT64 + description: "Season Code - 0=Basic, 1=Spring, 2=Summer, 3=BTS/Fall, 4=Winter" + - column_name: SEASN_DESC + data_type: STRING + description: "Season Description" + - column_name: SEASN_YR + data_type: INT64 + description: "Season Year" + + # Item Linkage + - column_name: ALL_LINKS_ITEM_NBR + data_type: INT64 + description: "All Links Item Number - Prime item at top of linkage chains" + - column_name: ALL_LINKS_MDSE_FAM_ID + data_type: INT64 + description: "All Links Merchandise Family ID" + - column_name: REPL_GROUP_NBR + data_type: NUMERIC + description: "Replenishment Group Number" + + # Distribution Center + - column_name: PRMRY_DC_NBR + data_type: INT64 + description: "Primary Distribution Center Number" + - column_name: UPSTRM_DC_NBR + data_type: INT64 + description: "Upstream Distribution Center Number" + - column_name: WHSE_ALGN_TYPE_CD + data_type: STRING + description: "Warehouse Align Type Code" + - column_name: POS_DEPT_NBR + data_type: INT64 + description: "POS Department Number" + - column_name: BUYG_REGION_CD + data_type: INT64 + description: "Buying Region Code - 1=Alaska, 2=Hawaii, 3=Puerto Rico" + - column_name: BUYG_REGION_DESC + data_type: STRING + description: "Buying Region Description" + + # Register Details + - column_name: REG_NBR + data_type: INT64 + description: "Register Number" + - column_name: TRANS_NBR + data_type: INT64 + description: "Transaction Number" + - column_name: OPERATOR_NBR + data_type: INT64 + description: "Operator Number" + + # Return Details + - column_name: RTN_RSN_CD + data_type: INT64 + description: "Return Reason Code - 0=Unknown, 1=Doesn't Work, 2=Damaged, 3=Incorrect, 4=Poor Quality, 5=Changed Mind" + - column_name: RTN_RSN_DESC + data_type: STRING + description: "Return Reason Description" + - column_name: RCPT_SEQ_NBR + data_type: INT64 + description: "Receipt Sequence Number for returns" + - column_name: DEFECTIVE_ITEM_IND + data_type: INT64 + description: "Defective Item Indicator" + + # Tax Indicators + - column_name: TAX_XMPT_IND + data_type: INT64 + description: "Tax Exempt Indicator" + - column_name: TAX_XMPT_TYPE_DESC + data_type: STRING + description: "Tax Exempt Type Description" + - column_name: TAX_1_IND + data_type: INT64 + description: "Tax 1 Indicator" + - column_name: TAX_2_IND + data_type: INT64 + description: "Tax 2 Indicator" + - column_name: FOODSTAMP_ELIGIBLE_IND + data_type: INT64 + description: "Foodstamp/SNAP Eligible Indicator" + + # WIC (Women, Infants, Children) + - column_name: WIC_CATG_NBR + data_type: INT64 + description: "WIC Category Number" + - column_name: WIC_SUBCATG_NBR + data_type: INT64 + description: "WIC Subcategory Number" + - column_name: WIC_REDEEMED_IND + data_type: INT64 + description: "WIC Redeemed Indicator" + - column_name: WIC_PARTIAL_REDEEM_IND + data_type: INT64 + description: "WIC Partial Redemption Indicator" + + # Price Override Indicators + - column_name: PRICE_OVERRIDE_IND + data_type: INT64 + description: "Price Override Indicator" + - column_name: CSC_CUST_SATISFY_IND + data_type: INT64 + description: "CSC Customer Satisfaction Indicator" + - column_name: MUMD1_PRICE_ERROR_IND + data_type: INT64 + description: "Price Error Indicator" + - column_name: MUMD2_COMP_AD_IND + data_type: INT64 + description: "Comp Ad Indicator" + - column_name: MUMD3_BATTERY_IND + data_type: INT64 + description: "Battery Exchange Indicator" + - column_name: MUMD4_MISC_IND + data_type: INT64 + description: "Miscellaneous Override Indicator" + - column_name: DSIM_IND + data_type: INT64 + description: "Dynamic Store Initiated Markdown Indicator" + + # Item Scan Indicators + - column_name: NOT_ON_FILE_IND + data_type: INT64 + description: "Not On File Indicator" + - column_name: KEYED_ITEM_IND + data_type: INT64 + description: "Keyed Item Indicator" + - column_name: QTY_ENTERED_IND + data_type: INT64 + description: "Quantity Entered Indicator" + - column_name: MEASURE_ITEM_IND + data_type: INT64 + description: "Measured Item Indicator" + - column_name: EMBEDDED_PRICE_IND + data_type: INT64 + description: "Embedded Price Indicator" + - column_name: VAR_WT_IND + data_type: INT64 + description: "Variable Weight Indicator" + + # eComm Order Details + - column_name: ORDER_LINE_NBR + data_type: INT64 + description: "Walmart.com Order Line Number" + - column_name: PG_CUST_ID + data_type: STRING + description: "Pangaea Customer ID for online sales" + + # Source & Operational + - column_name: SRC_TYPE_NM + data_type: STRING + description: "Source Type Name - STORE, STORE-ONLINE-INIT, ECOMMERCE" + - column_name: OP_CMPNY_CD + data_type: STRING + description: "Operational Company Code - WMT-US, SAMS-US, etc." + - column_name: FIN_RPT_CD + data_type: STRING + description: "Financial Report Code" + - column_name: FIN_RPT_DESC + data_type: STRING + description: "Financial Report Description" + - column_name: PICKER_TYPE_NM + data_type: STRING + description: "Picker Type Name for last-mile delivery" + - column_name: CUST_ACCT_TYPE_NM + data_type: STRING + description: "Customer Account Type Name (B2B, etc.)" + - column_name: SVC_TYPE_CD + data_type: STRING + description: "Service Type Code" + + # Surrogate Keys + - column_name: CAL_KEY + data_type: INT64 + description: "Calendar Dimension Key" + - column_name: ITEM_KEY + data_type: INT64 + description: "Item Dimension Key" + - column_name: ITEM_CURR_KEY + data_type: INT64 + description: "Item Current Dimension Key" + - column_name: STORE_KEY + data_type: INT64 + description: "Store Dimension Key" + - column_name: STORE_CURR_KEY + data_type: INT64 + description: "Store Current Dimension Key" + - column_name: FIN_DEPT_KEY + data_type: INT64 + description: "Financial Department Key" + - column_name: FIN_CURR_DEPT_KEY + data_type: INT64 + description: "Financial Current Department Key" + - column_name: DC_ALGN_KEY + data_type: INT64 + description: "DC Alignment Key" + - column_name: COMP_KEY + data_type: INT64 + description: "Comp Status Dimension Key" + - column_name: STORE_DAY_KEY + data_type: INT64 + description: "Store Day Key" + - column_name: CHNL_KEY + data_type: INT64 + description: "Channel Key" + - column_name: CHNL_BASE_ID + data_type: INT64 + description: "Channel Base ID" + - column_name: REG_TYPE_KEY + data_type: INT64 + description: "Register Type Key - 1=Regular, 2=Fuel, 3=SCO, 4=Fuel+SCO" + + # Timestamps + - column_name: LOAD_TS + data_type: TIMESTAMP + description: "Load Timestamp" + - column_name: UPD_TS + data_type: TIMESTAMP + description: "Updated Timestamp" + + few_shot_examples: + - question: Show me all line items for a specific register transaction + answer: | + This query retrieves all item-level details for a single register transaction. Note that some items like gift cards may have multiple records due to separate GL account postings (e.g., Starbucks gift card splits between 3rd party settlement and commission income). Associate discounts are shown at the item level. + sql_query: | + SELECT + STORE_NBR, + VISIT_LOCAL_DT, + VISIT_LOCAL_TM, + VISIT_NBR, + REG_NBR, + TRANS_NBR, + SEQ_LINE_NBR, + ACCTG_DEPT_NBR, + UPC_DESC, + UPC_SRC_NBR, + ITEM_DESC_1, + OTHER_INCOME_IND, + GL_ACCT_ALGN_NBR, + GL_ACCT_ALGN_NM, + UOM_QTY, + ASSOC_DISC_AMT, + SALES_AMT + FROM `wmt-edw-prod.US_FIN_SALES_DL_RPT_VM.WMT_STORE_SALES_DTL_D` + WHERE VISIT_LOCAL_DT = '2025-01-31' -- Data partitioned on this column + AND STORE_NBR = 5260 + AND REG_NBR = 3 + AND TRANS_NBR = 3963 + ORDER BY STORE_NBR, VISIT_LOCAL_DT, REG_NBR, TRANS_NBR, SEQ_LINE_NBR + + - question: Compare store sales vs online-initiated store-fulfilled (OISF) sales + answer: | + This query compares traditional in-store sales (BIS - Brick & Mortar In-Store) versus online-initiated store-fulfilled (OISF) sales like pickup and delivery. + The SRC_TYPE_NM column distinguishes between STORE (pure in-store) and STORE-ONLINE-INIT (online orders fulfilled by stores). + sql_query: | + SELECT + CAST(FORMAT_TIMESTAMP('%Y%m', VISIT_LOCAL_DT) AS INT64) AS CAL_YR_MO, + SRC_TYPE_NM, + ROUND(SUM(SALES_AMT), 0) AS SALES_AMT + FROM `wmt-edw-prod.US_FIN_SALES_DL_RPT_VM.WMT_STORE_SALES_DTL_D` + WHERE VISIT_LOCAL_DT BETWEEN '2025-01-01' AND '2025-01-31' -- Data partitioned on this column + AND OTHER_INCOME_IND = 0 -- Returns just top-line merchandise sales + GROUP BY ALL + ORDER BY CAL_YR_MO, SRC_TYPE_NM + + - question: Show me hourly sales distribution for a store + answer: | + This query analyzes sales by hour of day for a specific store. Useful for staffing optimization, understanding peak shopping times, and identifying overnight activity patterns. Hours are in 24-hour format based on store local time. + sql_query: | + SELECT + X.STORE_NBR, + X.CAL_YR_MO, + X.VISIT_LOCAL_HR_NBR, + ROUND(SUM(X.SALES_AMT), 0) AS SALES_AMT + FROM ( + SELECT + STORE_NBR, + CAST(FORMAT_TIMESTAMP('%Y%m', VISIT_LOCAL_DT) AS INT64) AS CAL_YR_MO, + SUBSTRING(VISIT_LOCAL_TM, 1, 2) AS VISIT_LOCAL_HR_NBR, + SUM(SALES_AMT) AS SALES_AMT + FROM `wmt-edw-prod.US_FIN_SALES_DL_RPT_VM.WMT_STORE_SALES_DTL_D` + WHERE VISIT_LOCAL_DT BETWEEN '2025-01-01' AND '2025-01-31' -- Data partitioned on this column + AND STORE_NBR = 5260 + AND OTHER_INCOME_IND = 0 -- Returns just top-line merchandise sales + GROUP BY ALL + ) AS X + GROUP BY ALL + ORDER BY X.STORE_NBR, X.CAL_YR_MO, X.VISIT_LOCAL_HR_NBR + + - question: What items are frequently bought together with Sleepwear and Dairy? + answer: | + This basket analysis query identifies transactions where customers purchased items from both Sleepwear (Dept 29) and Dairy (Dept 90) departments. This pattern analysis helps understand cross-category shopping behavior and can inform store layout and promotional strategies. + sql_query: | + SELECT + A.STORE_NBR, + A.VISIT_LOCAL_DT, + ROUND(SUM(A.SLEEPWEAR_SALES_AMT), 0) AS SLEEPWEAR_SALES_AMT, + ROUND(SUM(B.DAIRY_SALES_AMT), 0) AS DAIRY_SALES_AMT + FROM ( + SELECT + STORE_NBR, + VISIT_LOCAL_DT, + VISIT_NBR, + SUM(SALES_AMT) AS SLEEPWEAR_SALES_AMT + FROM `wmt-edw-prod.US_FIN_SALES_DL_RPT_VM.WMT_STORE_SALES_DTL_D` + WHERE OP_CMPNY_CD = 'WMT-US' + AND VISIT_LOCAL_DT BETWEEN '2025-01-01' AND '2025-01-31' -- Data partitioned on this column + AND OTHER_INCOME_IND = 0 + AND STORE_NBR = 5260 + AND ACCTG_DEPT_NBR = 29 -- Sleepwear + GROUP BY ALL + HAVING SLEEPWEAR_SALES_AMT != 0 + ) AS A + INNER JOIN ( + SELECT + STORE_NBR, + VISIT_LOCAL_DT, + VISIT_NBR, + SUM(SALES_AMT) AS DAIRY_SALES_AMT + FROM `wmt-edw-prod.US_FIN_SALES_DL_RPT_VM.WMT_STORE_SALES_DTL_D` + WHERE OP_CMPNY_CD = 'WMT-US' + AND VISIT_LOCAL_DT BETWEEN '2025-01-01' AND '2025-01-31' -- Data partitioned on this column + AND OTHER_INCOME_IND = 0 + AND STORE_NBR = 5260 + AND ACCTG_DEPT_NBR = 90 -- Dairy + GROUP BY ALL + HAVING DAIRY_SALES_AMT != 0 + ) AS B + ON A.VISIT_NBR = B.VISIT_NBR + GROUP BY ALL + ORDER BY A.STORE_NBR, A.VISIT_LOCAL_DT + + - question: Show sales by GL account for SAP reconciliation + answer: | + This query summarizes sales by GL account alignment for SAP reconciliation. The GL_ACCT_ALGN_NBR indicates which SAP GL account the sales amount aligns to. Note that VISIT_LOCAL_DT reflects the actual transaction date, which may differ from SAP posting date due to timing differences. + sql_query: | + SELECT + VISIT_LOCAL_DT, + GL_ACCT_ALGN_NBR, + GL_ACCT_ALGN_NM, + ROUND(SUM(SALES_AMT), 0) AS SALES_AMT + FROM `wmt-edw-prod.US_FIN_SALES_DL_RPT_VM.WMT_STORE_SALES_DTL_D` + WHERE VISIT_LOCAL_DT = '2025-06-23' -- Data partitioned on this column + AND OP_CMPNY_CD = 'WMT-US' + GROUP BY VISIT_LOCAL_DT, GL_ACCT_ALGN_NBR, GL_ACCT_ALGN_NM + ORDER BY GL_ACCT_ALGN_NBR + + - question: What are total sales by department for last week? + answer: This query shows total sales amount and units by department for the last 7 days. + sql_query: | + SELECT + ACCTG_DEPT_NBR, + ACCTG_DEPT_DESC, + COUNT(DISTINCT VISIT_NBR) AS transaction_count, + ROUND(SUM(SALES_AMT), 0) AS total_sales, + SUM(UOM_QTY) AS units_sold + FROM `wmt-edw-prod.US_FIN_SALES_DL_RPT_VM.WMT_STORE_SALES_DTL_D` + WHERE VISIT_LOCAL_DT >= DATE_SUB(CURRENT_DATE(), INTERVAL 7 DAY) -- Data partitioned on this column + AND VOID_ITEM_IND = 0 + AND OTHER_INCOME_IND = 0 + GROUP BY ACCTG_DEPT_NBR, ACCTG_DEPT_DESC + ORDER BY total_sales DESC + LIMIT 20 + + - question: Show me pickup vs delivery sales breakdown + answer: This query compares sales by fulfillment type for online-initiated orders. + sql_query: | + SELECT + FULFMT_TYPE_ID, + FULFMT_TYPE_DESC, + CHNL_CONS_HIER_LVL_3_DESC AS channel_level_3, + COUNT(DISTINCT ORDER_NBR) AS order_count, + ROUND(SUM(SALES_AMT), 0) AS total_sales, + SUM(OPS_UNIT_QTY) AS units_sold, + ROUND(SUM(SALES_AMT) / NULLIF(COUNT(DISTINCT ORDER_NBR), 0), 2) AS avg_order_value + FROM `wmt-edw-prod.US_FIN_SALES_DL_RPT_VM.WMT_STORE_SALES_DTL_D` + WHERE SRC_TYPE_NM = 'STORE-ONLINE-INIT' + AND VISIT_LOCAL_DT >= DATE_TRUNC(CURRENT_DATE(), MONTH) -- Data partitioned on this column + AND VOID_ITEM_IND = 0 + GROUP BY FULFMT_TYPE_ID, FULFMT_TYPE_DESC, CHNL_CONS_HIER_LVL_3_DESC + ORDER BY total_sales DESC + + - question: Show comparable store sales by region + answer: This query shows sales for comparable stores (COMP_CD = 'C') grouped by region. + sql_query: | + SELECT + REGION_NBR, + REGION_NM, + COUNT(DISTINCT STORE_NBR) AS comp_store_count, + ROUND(SUM(SALES_AMT), 0) AS total_sales, + SUM(OPS_UNIT_QTY) AS total_units, + ROUND(SUM(SALES_AMT) / NULLIF(COUNT(DISTINCT STORE_NBR), 0), 2) AS avg_store_sales + FROM `wmt-edw-prod.US_FIN_SALES_DL_RPT_VM.WMT_STORE_SALES_DTL_D` + WHERE COMP_CD = 'C' -- Comparable stores only + AND VISIT_LOCAL_DT >= DATE_SUB(CURRENT_DATE(), INTERVAL 7 DAY) -- Data partitioned on this column + AND VOID_ITEM_IND = 0 + AND OTHER_INCOME_IND = 0 + GROUP BY REGION_NBR, REGION_NM + ORDER BY total_sales DESC + + - question: What is the return rate by department? + answer: This query calculates return rate (returns as percentage of gross sales) by department. + sql_query: | + SELECT + ACCTG_DEPT_NBR, + ACCTG_DEPT_DESC, + ROUND(SUM(CASE WHEN RTN_ITEM_IND = 0 THEN SALES_AMT ELSE 0 END), 0) AS gross_sales, + ROUND(SUM(CASE WHEN RTN_ITEM_IND = 1 THEN ABS(SALES_AMT) ELSE 0 END), 0) AS return_amount, + ROUND( + SAFE_DIVIDE( + SUM(CASE WHEN RTN_ITEM_IND = 1 THEN ABS(SALES_AMT) ELSE 0 END), + SUM(CASE WHEN RTN_ITEM_IND = 0 THEN SALES_AMT ELSE 0 END) + ) * 100, 2 + ) AS return_rate_pct + FROM `wmt-edw-prod.US_FIN_SALES_DL_RPT_VM.WMT_STORE_SALES_DTL_D` + WHERE VISIT_LOCAL_DT >= DATE_SUB(CURRENT_DATE(), INTERVAL 30 DAY) -- Data partitioned on this column + AND VOID_ITEM_IND = 0 + AND OTHER_INCOME_IND = 0 + GROUP BY ACCTG_DEPT_NBR, ACCTG_DEPT_DESC + HAVING gross_sales > 0 + ORDER BY return_rate_pct DESC + + - question: Show me self-checkout vs regular checkout sales + answer: This query compares sales between self-checkout (SCO) and regular staffed checkout lanes. + sql_query: | + SELECT + CASE + WHEN SELF_CHKOUT_IND = 1 THEN 'Self Checkout' + ELSE 'Regular Checkout' + END AS checkout_type, + COUNT(DISTINCT VISIT_NBR) AS transaction_count, + ROUND(SUM(SALES_AMT), 0) AS total_sales, + ROUND(SUM(SALES_AMT) / NULLIF(COUNT(DISTINCT VISIT_NBR), 0), 2) AS avg_basket_size + FROM `wmt-edw-prod.US_FIN_SALES_DL_RPT_VM.WMT_STORE_SALES_DTL_D` + WHERE SRC_TYPE_NM = 'STORE' -- In-store sales only + AND VISIT_LOCAL_DT >= DATE_SUB(CURRENT_DATE(), INTERVAL 7 DAY) -- Data partitioned on this column + AND VOID_ITEM_IND = 0 + AND OTHER_INCOME_IND = 0 + GROUP BY checkout_type + ORDER BY total_sales DESC + + - question: What are sales by Walmart fiscal week? + answer: This query shows weekly sales trend using Walmart fiscal calendar (fiscal year starts in February). + sql_query: | + SELECT + WM_YR_WK_NBR, + FISCAL_FULL_YR_NBR, + WM_WK_NBR, + MIN(VISIT_LOCAL_DT) AS week_start, + MAX(VISIT_LOCAL_DT) AS week_end, + ROUND(SUM(SALES_AMT), 0) AS total_sales, + SUM(OPS_UNIT_QTY) AS total_units, + COUNT(DISTINCT STORE_NBR) AS stores_with_sales + FROM `wmt-edw-prod.US_FIN_SALES_DL_RPT_VM.WMT_STORE_SALES_DTL_D` + WHERE FISCAL_FULL_YR_NBR = 2026 + AND VOID_ITEM_IND = 0 + AND OTHER_INCOME_IND = 0 + GROUP BY WM_YR_WK_NBR, FISCAL_FULL_YR_NBR, WM_WK_NBR + ORDER BY WM_YR_WK_NBR + + - question: Show clearance vs regular price sales + answer: This query compares sales by report code (pricing status) - Regular, Rollback, Clearance, etc. + sql_query: | + SELECT + RPT_CD, + RPT_CD_DESC, + ROUND(SUM(SALES_AMT), 0) AS total_sales, + SUM(OPS_UNIT_QTY) AS units_sold, + COUNT(DISTINCT MDS_FAM_ID) AS unique_items, + ROUND(SUM(SALES_AMT) / NULLIF(SUM(OPS_UNIT_QTY), 0), 2) AS avg_unit_price + FROM `wmt-edw-prod.US_FIN_SALES_DL_RPT_VM.WMT_STORE_SALES_DTL_D` + WHERE VISIT_LOCAL_DT >= DATE_SUB(CURRENT_DATE(), INTERVAL 30 DAY) -- Data partitioned on this column + AND VOID_ITEM_IND = 0 + AND RTN_ITEM_IND = 0 + AND OTHER_INCOME_IND = 0 + GROUP BY RPT_CD, RPT_CD_DESC + ORDER BY total_sales DESC diff --git a/src/data_agent/cli/app.py b/src/data_agent/cli/app.py index f2eef82..907ca95 100644 --- a/src/data_agent/cli/app.py +++ b/src/data_agent/cli/app.py @@ -507,5 +507,82 @@ def a2a( ) +@app.command() +def mcp( + config: Annotated[ + str | None, + typer.Option( + "--config", + "-c", + help="Configuration name to load. Loads all configs if not specified.", + ), + ] = None, + transport: Annotated[ + str, + typer.Option( + "--transport", + "-t", + help="Transport mechanism: stdio (for Claude Desktop) or sse (for HTTP).", + ), + ] = "stdio", + port: Annotated[ + int, + typer.Option( + "--port", + "-p", + help="Port for SSE transport.", + ), + ] = 8002, + log_level: Annotated[ + str, + typer.Option( + "--log-level", + help="Logging level.", + ), + ] = "warning", +) -> None: + """Start the MCP (Model Context Protocol) server. + + This exposes the NL2SQL Data Agent via MCP, enabling integration + with Claude Desktop, VS Code, Cursor, and other MCP clients. + + Available tools: + - query: Execute natural language queries + - list_datasources: List available datasources + - get_schema: Get database schema for a datasource + + Examples: + # Start with stdio transport (for Claude Desktop) + data-agent mcp + + # Start with SSE transport for HTTP clients + data-agent mcp --transport sse --port 8002 + + # Start with a specific config + data-agent mcp --config contoso + """ + from data_agent.mcp import create_mcp_server + + if config: + validate_config(config) + + console.print("[cyan]Starting MCP server...[/cyan]") + console.print(f" Transport: [green]{transport}[/green]") + console.print(f" Config: [green]{config or 'all'}[/green]") + if transport == "sse": + console.print(f" Port: [green]{port}[/green]") + console.print() + + mcp_server = create_mcp_server(config_name=config) + + if transport == "stdio": + mcp_server.run(transport="stdio") + else: + # Set port for SSE transport + mcp_server.settings.host = "127.0.0.1" + mcp_server.settings.port = port + mcp_server.run(transport="sse") + + if __name__ == "__main__": app() diff --git a/src/data_agent/config.py b/src/data_agent/config.py index 173cff0..9ff785d 100644 --- a/src/data_agent/config.py +++ b/src/data_agent/config.py @@ -11,7 +11,7 @@ from pydantic import Field from pydantic_settings import BaseSettings, SettingsConfigDict -CONFIG_DIR = Path(__file__).resolve().parent / "config" +CONFIG_DIR = Path(__file__).resolve().parent / "agents" DatasourceType = Literal[ "databricks", "cosmos", "postgres", "azure_sql", "synapse", "bigquery" @@ -362,27 +362,9 @@ class DataAgentConfig: few_shot_examples: list[FewShotExample] = field(default_factory=list) -@dataclass -class IntentDetectionConfig: - """Configuration for intent detection agent.""" - - llm_config: LLMConfig = field(default_factory=LLMConfig) - system_prompt: str = "" - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> "IntentDetectionConfig": - return cls( - llm_config=LLMConfig.from_dict(data.get("llm", {})), - system_prompt=data.get("system_prompt", ""), - ) - - @dataclass class AgentConfig: """Complete agent configuration.""" - intent_detection: IntentDetectionConfig = field( - default_factory=IntentDetectionConfig - ) data_agents: list[DataAgentConfig] = field(default_factory=list) max_retries: int = 3 diff --git a/src/data_agent/config_loader.py b/src/data_agent/config_loader.py index c109912..da469be 100644 --- a/src/data_agent/config_loader.py +++ b/src/data_agent/config_loader.py @@ -19,7 +19,6 @@ DataAgentConfig, Datasource, FewShotExample, - IntentDetectionConfig, LLMConfig, TableSchema, ValidationConfig, @@ -126,7 +125,6 @@ def load_all(cls, validate: bool = True) -> AgentConfig: """Load and merge all configuration files from the config directory. Combines data_agents from all configs into a single AgentConfig. - Uses intent_detection settings from the first config found. Args: validate: Whether to validate against JSON schema (default True). @@ -150,9 +148,6 @@ def load_all(cls, validate: bool = True) -> AgentConfig: def _parse_config(cls, data: dict[str, Any]) -> AgentConfig: """Parse raw config dict into AgentConfig.""" return AgentConfig( - intent_detection=IntentDetectionConfig.from_dict( - data.get("intent_detection_agent", {}) - ), data_agents=[cls._parse_data_agent(a) for a in data.get("data_agents", [])], max_retries=data.get("max_retries", 3), ) diff --git a/src/data_agent/core/logging.py b/src/data_agent/core/logging.py index 71690e4..b7ba49a 100644 --- a/src/data_agent/core/logging.py +++ b/src/data_agent/core/logging.py @@ -1,13 +1,18 @@ -""" -Logging configuration for the Terminal Agent. -""" +"""Logging configuration for the Data Agent.""" import logging import logging.config def setup_logging(default_level: int = logging.INFO) -> None: - """Configure structured logging for the entire package.""" + """Configure structured logging for the entire package. + + Args: + default_level: Default logging level (e.g., logging.INFO). + + Returns: + None + """ level_name = logging.getLevelName(default_level) logging_config = { "version": 1, @@ -41,6 +46,7 @@ def setup_logging(default_level: int = logging.INFO) -> None: "httpx", "azure.identity", "chainlit", + "mcp.server.lowlevel.server", ] for logger_name in noisy_loggers: logging.getLogger(logger_name).setLevel(logging.WARNING) diff --git a/src/data_agent/llm/provider.py b/src/data_agent/llm/provider.py index 803c289..509c563 100644 --- a/src/data_agent/llm/provider.py +++ b/src/data_agent/llm/provider.py @@ -35,10 +35,17 @@ def create_llm(self, **kwargs: Any) -> BaseChatModel: Returns: Configured AzureChatOpenAI instance. """ - return AzureChatOpenAI( - azure_endpoint=kwargs.get("azure_endpoint"), - api_key=kwargs.get("api_key"), - azure_deployment=kwargs.get("deployment_name"), - api_version=kwargs.get("api_version", "2024-08-01-preview"), - temperature=kwargs.get("temperature", 0), - ) + # Build kwargs dict, only including non-None values to allow env var fallback + llm_kwargs: dict[str, Any] = { + "api_version": kwargs.get("api_version", "2024-08-01-preview"), + "temperature": kwargs.get("temperature", 0), + } + + if kwargs.get("azure_endpoint"): + llm_kwargs["azure_endpoint"] = kwargs["azure_endpoint"] + if kwargs.get("api_key"): + llm_kwargs["api_key"] = kwargs["api_key"] + if kwargs.get("deployment_name"): + llm_kwargs["azure_deployment"] = kwargs["deployment_name"] + + return AzureChatOpenAI(**llm_kwargs) diff --git a/src/data_agent/mcp/__init__.py b/src/data_agent/mcp/__init__.py new file mode 100644 index 0000000..93230d7 --- /dev/null +++ b/src/data_agent/mcp/__init__.py @@ -0,0 +1,16 @@ +"""MCP (Model Context Protocol) server for the NL2SQL Data Agent. + +The MCP server exposes: +- Tools: query, list_datasources, list_tables, get_schema, validate_sql +- Resources: datasources://list, schema://{datasource}, tables://{datasource} +""" + +from data_agent.mcp.context import MCPServerContext, set_context +from data_agent.mcp.server import create_mcp_server, main + +__all__ = [ + "create_mcp_server", + "main", + "MCPServerContext", + "set_context", +] diff --git a/src/data_agent/mcp/context.py b/src/data_agent/mcp/context.py new file mode 100644 index 0000000..87588f7 --- /dev/null +++ b/src/data_agent/mcp/context.py @@ -0,0 +1,60 @@ +"""MCP server context management with thread-safe state.""" + +import logging +from contextvars import ContextVar + +from data_agent.agent import DataAgentFlow +from data_agent.config import AgentConfig + +logger = logging.getLogger(__name__) + + +class MCPServerContext: + """Context holding initialized server components. + + Manages the lifecycle of the Data Agent and its connections, + providing thread-safe access to shared resources. + """ + + def __init__(self, config: AgentConfig): + """Initialize the MCP server context. + + Args: + config: Agent configuration with datasource definitions. + """ + self.config = config + self.agent = DataAgentFlow(config=config) + self._connected = False + + async def ensure_connected(self) -> None: + """Connect to all datasources if not already connected.""" + if not self._connected: + logger.info("Connecting to datasources...") + await self.agent.connect() + self._connected = True + logger.info("Connected to datasources") + + async def disconnect(self) -> None: + """Cleanup connections.""" + self._connected = False + logger.info("Disconnected from datasources") + + @property + def is_connected(self) -> bool: + """Check if datasources are connected.""" + return self._connected + + +# Thread-safe context variable for the server context +_context_var: ContextVar[MCPServerContext | None] = ContextVar( + "mcp_server_context", default=None +) + + +def set_context(ctx: MCPServerContext) -> None: + """Set the MCP server context. + + Args: + ctx: The MCPServerContext instance to set. + """ + _context_var.set(ctx) diff --git a/src/data_agent/mcp/resources.py b/src/data_agent/mcp/resources.py new file mode 100644 index 0000000..7b966de --- /dev/null +++ b/src/data_agent/mcp/resources.py @@ -0,0 +1,89 @@ +"""MCP resource definitions for the Data Agent.""" + +import logging + +from mcp.server.fastmcp import FastMCP + +from data_agent.mcp.context import MCPServerContext + +logger = logging.getLogger(__name__) + + +def register_resources(mcp: FastMCP, ctx: MCPServerContext) -> None: + """Register all MCP resources for the Data Agent. + + Args: + mcp: FastMCP server instance. + ctx: Server context with agent and configuration. + """ + + @mcp.resource("datasources://list") + async def datasources_list() -> str: + """List of available datasources as a resource. + + Returns: + Formatted list of datasources with descriptions. + """ + lines = ["Available Datasources:", ""] + for agent_cfg in ctx.config.data_agents: + desc = agent_cfg.description or "No description" + lines.append(f"- {agent_cfg.name}: {desc}") + + return "\n".join(lines) + + @mcp.resource("schema://{datasource}") + async def schema_resource(datasource: str) -> str: + """Database schema as a readable resource. + + Args: + datasource: Name of the datasource. + + Returns: + Schema information for the datasource. + """ + if datasource not in ctx.agent.datasources: + return f"Datasource '{datasource}' not found." + + ds = ctx.agent.datasources[datasource] + + from langchain_community.utilities.sql_database import SQLDatabase + + from data_agent.adapters import CosmosAdapter + + if isinstance(ds, SQLDatabase): + schema_info = ds.get_table_info() + return schema_info or f"No schema information for '{datasource}'." + elif isinstance(ds, CosmosAdapter): + return ( + f"Cosmos DB container: {ds.container_name}\n" + f"Partition key: {ds.partition_key_path}" + ) + + return f"Schema not available for '{datasource}'." + + @mcp.resource("tables://{datasource}") + async def tables_resource(datasource: str) -> str: + """List of tables in a datasource as a resource. + + Args: + datasource: Name of the datasource. + + Returns: + List of table names. + """ + if datasource not in ctx.agent.datasources: + return f"Datasource '{datasource}' not found." + + ds = ctx.agent.datasources[datasource] + + from langchain_community.utilities.sql_database import SQLDatabase + + from data_agent.adapters import CosmosAdapter + + if isinstance(ds, SQLDatabase): + tables = ds.get_usable_table_names() + return "\n".join(sorted(tables)) if tables else "No tables found." + elif isinstance(ds, CosmosAdapter): + return ds.container_name + + return f"Table listing not available for '{datasource}'." diff --git a/src/data_agent/mcp/server.py b/src/data_agent/mcp/server.py new file mode 100644 index 0000000..b64ccaa --- /dev/null +++ b/src/data_agent/mcp/server.py @@ -0,0 +1,150 @@ +"""MCP server for the NL2SQL Data Agent. + +This module provides the main MCP server implementation, integrating +tools, resources, and prompts for natural language to SQL queries. +""" + +import argparse +import logging + +from dotenv import load_dotenv + +load_dotenv() + +from mcp.server.fastmcp import FastMCP + +from data_agent.config import CONFIG_DIR +from data_agent.config_loader import ConfigLoader +from data_agent.core.logging import setup_logging +from data_agent.mcp.context import MCPServerContext, set_context +from data_agent.mcp.resources import register_resources +from data_agent.mcp.tools import register_tools + +setup_logging() +logger = logging.getLogger(__name__) + + +def create_mcp_server( + config_path: str | None = None, + config_name: str | None = None, +) -> FastMCP: + """Create and configure the MCP server. + + Args: + config_path: Path to agent configuration file. + config_name: Name of config to load from config directory. + If neither provided, loads all configs. + + Returns: + Configured FastMCP server instance. + """ + # Load configuration + if config_path: + config = ConfigLoader.load(config_path) + elif config_name: + config = ConfigLoader.load_by_name(config_name) + else: + config = ConfigLoader.load_all() + + # Create and set context (thread-safe) + ctx = MCPServerContext(config) + set_context(ctx) + + # Create the MCP server + mcp = FastMCP( + "data-agent", + instructions="""Data Agent is a natural language to SQL platform. +You can query databases using natural language, list available datasources, +and inspect database schemas. Use the 'query' tool to ask questions about data. + +IMPORTANT: Always include the SQL query that was executed in your response to the user. +Format results clearly with the data, SQL query used, and any relevant insights. + +Available tools: +- query: Execute natural language queries against databases +- list_datasources: See all available data sources +- list_tables: Quick list of tables in a datasource +- get_schema: Get detailed schema information +- validate_sql: Validate SQL syntax without executing + +Available resources: +- datasources://list: List of configured datasources +- schema://{datasource}: Database schema for a datasource +- tables://{datasource}: Tables in a datasource""", + ) + + # Register all components + register_tools(mcp, ctx) + register_resources(mcp, ctx) + + logger.info( + f"MCP server created with {len(config.data_agents)} datasource(s) configured" + ) + + return mcp + + +def get_config_choices() -> list[str]: + """Get available configuration file names. + + Returns: + List of config names (without .yaml extension). + """ + return [f.stem for f in CONFIG_DIR.glob("*.yaml")] + + +def main() -> None: + """Main entry point for the MCP server.""" + parser = argparse.ArgumentParser( + description="Data Agent MCP Server", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "-c", + "--config", + choices=get_config_choices() or None, + default=None, + help="Configuration name to load (loads all if not specified)", + ) + parser.add_argument( + "--config-path", + type=str, + default=None, + help="Path to configuration file", + ) + parser.add_argument( + "--transport", + choices=["stdio", "sse"], + default="sse", + help="Transport mechanism (default: sse for VS Code/Cursor)", + ) + parser.add_argument( + "--port", + type=int, + default=8002, + help="Port for SSE transport (default: 8002)", + ) + parser.add_argument( + "--log-level", + choices=["debug", "info", "warning", "error"], + default="warning", + help="Logging level (default: warning)", + ) + + args = parser.parse_args() + + mcp = create_mcp_server( + config_path=args.config_path, + config_name=args.config, + ) + + if args.transport == "stdio": + mcp.run(transport="stdio") + else: + mcp.settings.host = "127.0.0.1" + mcp.settings.port = args.port + mcp.run(transport="sse") + + +if __name__ == "__main__": + main() diff --git a/src/data_agent/mcp/tools.py b/src/data_agent/mcp/tools.py new file mode 100644 index 0000000..8ef875c --- /dev/null +++ b/src/data_agent/mcp/tools.py @@ -0,0 +1,267 @@ +"""MCP tool definitions for the Data Agent.""" + +import logging +from typing import Any + +from mcp.server.fastmcp import FastMCP + +from data_agent.mcp.context import MCPServerContext +from data_agent.models.state import OutputState + +logger = logging.getLogger(__name__) + + +def register_tools(mcp: FastMCP, ctx: MCPServerContext) -> None: + """Register all MCP tools for the Data Agent. + + Args: + mcp: FastMCP server instance. + ctx: Server context with agent and configuration. + """ + + @mcp.tool() + async def query(question: str, datasource: str | None = None) -> str: + """Execute a natural language query against the configured datasources. + + Args: + question: Natural language question to answer (e.g., "What are the top 10 customers by revenue?") + datasource: Optional specific datasource name to target. If not provided, the agent will auto-detect. + + Returns: + Query results as formatted text, including the SQL query used and the data returned. + """ + logger.debug(f"MCP query tool called with question: {question}") + logger.info(f"Executing query {question} on datasource: {datasource}") + + try: + result = await ctx.agent.run(question) + logger.info("Query execution completed") + return _format_query_result(result) + + except Exception as e: + logger.exception("Error executing query") + return f"Error executing query: {e}" + + @mcp.tool() + async def list_datasources() -> str: + """List all configured datasources available for querying. + + Returns: + Formatted list of available datasources with their descriptions. + """ + datasources = [] + for agent_cfg in ctx.config.data_agents: + ds_type = "unknown" + if agent_cfg.datasource: + ds_type = ( + type(agent_cfg.datasource) + .__name__.replace("Datasource", "") + .lower() + ) + + ds_info = f"- **{agent_cfg.name}** ({ds_type})" + if agent_cfg.description: + ds_info += f": {agent_cfg.description}" + + if agent_cfg.table_schemas: + tables = [schema.name for schema in agent_cfg.table_schemas] + ds_info += f"\n Tables: {', '.join(tables)}" + + datasources.append(ds_info) + + if not datasources: + return "No datasources configured." + + return "**Available Datasources:**\n\n" + "\n\n".join(datasources) + + @mcp.tool() + async def list_tables(datasource: str) -> str: + """List all tables available in a specific datasource. + + Args: + datasource: Name of the datasource to list tables for. + + Returns: + List of table names in the datasource. + """ + if datasource not in ctx.agent.datasources: + available = ", ".join(ctx.agent.datasources.keys()) + return f"Datasource '{datasource}' not found. Available: {available}" + + ds = ctx.agent.datasources[datasource] + + try: + from langchain_community.utilities.sql_database import SQLDatabase + + from data_agent.adapters import CosmosAdapter + + if isinstance(ds, SQLDatabase): + tables = ds.get_usable_table_names() + if tables: + return f"**Tables in {datasource}:**\n\n" + "\n".join( + f"- {t}" for t in sorted(tables) + ) + return f"No tables found in '{datasource}'." + elif isinstance(ds, CosmosAdapter): + return f"**Container in {datasource}:**\n\n- {ds.container_name}" + else: + return f"Table listing not supported for datasource '{datasource}'." + + except Exception as e: + logger.exception("Error listing tables") + return f"Error listing tables: {e}" + + @mcp.tool() + async def get_schema(datasource: str) -> str: + """Get the database schema for a specific datasource. + + Args: + datasource: Name of the datasource to get schema for (use list_datasources to see available options) + + Returns: + Database schema information including tables, columns, and their types. + """ + if datasource not in ctx.agent.datasources: + available = ", ".join(ctx.agent.datasources.keys()) + return f"Datasource '{datasource}' not found. Available: {available}" + + ds = ctx.agent.datasources[datasource] + + try: + from langchain_community.utilities.sql_database import SQLDatabase + + from data_agent.adapters import CosmosAdapter + + if isinstance(ds, SQLDatabase): + schema_info = ds.get_table_info() + if schema_info: + return f"**Schema for {datasource}:**\n\n{schema_info}" + return f"No schema information available for '{datasource}'." + elif isinstance(ds, CosmosAdapter): + return ( + f"**Schema for {datasource}:**\n\n" + f"Cosmos DB container: {ds.container_name}\n" + f"Partition key: {ds.partition_key_path}\n\n" + "Note: Cosmos DB is a NoSQL database. Use queries like 'SELECT * FROM c' to explore data." + ) + else: + return f"Schema inspection not supported for datasource '{datasource}'." + + except Exception as e: + logger.exception("Error getting schema") + return f"Error retrieving schema: {e}" + + @mcp.tool() + async def validate_sql(sql: str, datasource: str) -> str: + """Validate SQL syntax without executing the query. + + Args: + sql: The SQL query to validate. + datasource: Name of the datasource to validate against (for dialect detection). + + Returns: + Validation result indicating if the SQL is valid, with any errors or warnings. + """ + if datasource not in ctx.agent.datasources: + available = ", ".join(ctx.agent.datasources.keys()) + return f"Datasource '{datasource}' not found. Available: {available}" + + try: + from data_agent.validators.sql_validator import ( + SQLValidator, + ValidationStatus, + ) + + # Get dialect from datasource config + dialect = "postgres" # default + for agent_cfg in ctx.config.data_agents: + if agent_cfg.name == datasource and agent_cfg.datasource: + ds_type = type(agent_cfg.datasource).__name__.lower() + if "cosmos" in ds_type: + dialect = "cosmosdb" + elif "synapse" in ds_type or "azuresql" in ds_type: + dialect = "tsql" + elif "bigquery" in ds_type: + dialect = "bigquery" + elif "databricks" in ds_type: + dialect = "databricks" + break + + validator = SQLValidator(dialect=dialect) + result = validator.validate(sql) + + response_parts = [f"**SQL Validation Result:**\n"] + + if result.status == ValidationStatus.VALID: + response_parts.append("✅ **Status:** Valid\n") + if result.query != sql: + response_parts.append( + f"**Transformed Query:**\n```sql\n{result.query}\n```\n" + ) + elif result.status == ValidationStatus.INVALID: + response_parts.append("❌ **Status:** Invalid\n") + else: + response_parts.append("⚠️ **Status:** Unsafe\n") + + if result.errors: + response_parts.append( + f"**Errors:**\n" + "\n".join(f"- {e}" for e in result.errors) + ) + + if result.warnings: + response_parts.append( + f"\n**Warnings:**\n" + "\n".join(f"- {w}" for w in result.warnings) + ) + + return "\n".join(response_parts) + + except Exception as e: + logger.exception("Error validating SQL") + return f"Error validating SQL: {e}" + + +def _format_query_result(result: OutputState | dict[str, Any]) -> str: + """Format query result into a readable response. + + Args: + result: Query result dictionary or OutputState from the agent. + + Returns: + Formatted string response. + """ + if not isinstance(result, dict): + result = dict(result) + + response_parts = [] + + if result.get("final_response"): + response_parts.append(str(result.get("final_response"))) + + if result.get("generated_sql"): + response_parts.append( + f"\n**SQL Query:**\n```sql\n{result.get('generated_sql')}\n```" + ) + + if result.get("result") and not result.get("final_response"): + response_parts.append(f"\n**Results:**\n{result.get('result')}") + + if result.get("visualization_image"): + img_data = result.get("visualization_image") + response_parts.append( + f"\n**Visualization:**\n![Chart](data:image/png;base64,{img_data})" + ) + + if result.get("visualization_code"): + response_parts.append( + f"\n**Visualization Code:**\n```python\n{result.get('visualization_code')}\n```" + ) + + if result.get("visualization_error"): + response_parts.append( + f"\n**Visualization Error:** {result.get('visualization_error')}" + ) + + if result.get("error") and result.get("error") != "out_of_scope": + response_parts.append(f"\n**Error:** {result.get('error')}") + + return "\n".join(response_parts) if response_parts else "No results returned." diff --git a/src/data_agent/nodes/data_nodes.py b/src/data_agent/nodes/data_nodes.py index ff9a1ee..6ffd27b 100644 --- a/src/data_agent/nodes/data_nodes.py +++ b/src/data_agent/nodes/data_nodes.py @@ -19,7 +19,7 @@ SQLValidationOutput, ) from data_agent.utils.message_utils import get_recent_history -from data_agent.utils.sql_utils import build_date_context, clean_sql_query +from data_agent.utils.sql_utils import clean_sql_query from data_agent.validators.sql_validator import SQLValidator, ValidationStatus if TYPE_CHECKING: @@ -28,7 +28,7 @@ from data_agent.adapters.azure.cosmos import CosmosAdapter from data_agent.models.state import AgentState -from data_agent.prompts import COSMOS_PROMPT_ADDENDUM, DEFAULT_SQL_PROMPT +from data_agent.prompts.builder import build_prompt logger = logging.getLogger(__name__) @@ -128,31 +128,29 @@ def _get_schema_context(self) -> str: return "" def _build_prompt(self) -> str: - """Build system prompt, adding Cosmos constraints if needed. + """Build system prompt using the centralized prompt builder. Returns: - Formatted system prompt with schema context and date. + Formatted system prompt with all components. """ schema_context = self._get_schema_context() few_shot = SchemaFormatter.format_few_shot_examples(self._config) - base_prompt = self._config.system_prompt or DEFAULT_SQL_PROMPT - formatted = base_prompt.format( + # Get partition key for Cosmos if applicable + partition_key = None + if self._is_cosmos and self._config.datasource: + partition_key = getattr( + self._config.datasource, "partition_key_path", "/id" + ) + + return build_prompt( + datasource_type=self._dialect, + user_prompt=self._config.system_prompt, schema_context=schema_context, few_shot_examples=few_shot, + partition_key=partition_key, ) - # Add Cosmos-specific constraints - if self._is_cosmos: - partition_key = "/id" - if self._config.datasource: - partition_key = getattr( - self._config.datasource, "partition_key_path", "/id" - ) - formatted += COSMOS_PROMPT_ADDENDUM.format(partition_key=partition_key) - - return build_date_context() + formatted - async def generate_sql(self, state: "AgentState") -> dict[str, Any]: """Generate query from natural language question. diff --git a/src/data_agent/nodes/response.py b/src/data_agent/nodes/response.py index 348a61f..a996c8f 100644 --- a/src/data_agent/nodes/response.py +++ b/src/data_agent/nodes/response.py @@ -17,7 +17,7 @@ if TYPE_CHECKING: from data_agent.models.state import AgentState -from data_agent.prompts import DEFAULT_RESPONSE_PROMPT +from data_agent.prompts.builder import build_response_prompt logger = logging.getLogger(__name__) @@ -58,7 +58,7 @@ def generate_response(self, state: "AgentState") -> dict[str, Any]: Returns: State update with final response and messages. """ - prompt = self._config.response_prompt or DEFAULT_RESPONSE_PROMPT + prompt = build_response_prompt(self._config.response_prompt) question = state["question"] sql = state.get("generated_sql", "") diff --git a/src/data_agent/prompts/__init__.py b/src/data_agent/prompts/__init__.py index ce50e04..b8b41f7 100644 --- a/src/data_agent/prompts/__init__.py +++ b/src/data_agent/prompts/__init__.py @@ -2,14 +2,22 @@ from data_agent.prompts.defaults import ( COSMOS_PROMPT_ADDENDUM, + DEFAULT_GENERAL_CHAT_PROMPT, + DEFAULT_INTENT_DETECTION_PROMPT, + DEFAULT_QUERY_REWRITE_PROMPT, DEFAULT_RESPONSE_PROMPT, DEFAULT_SQL_PROMPT, VISUALIZATION_SYSTEM_PROMPT, ) +from data_agent.prompts.dialects import get_dialect_guidelines __all__ = [ "COSMOS_PROMPT_ADDENDUM", + "DEFAULT_GENERAL_CHAT_PROMPT", + "DEFAULT_INTENT_DETECTION_PROMPT", + "DEFAULT_QUERY_REWRITE_PROMPT", "DEFAULT_RESPONSE_PROMPT", "DEFAULT_SQL_PROMPT", "VISUALIZATION_SYSTEM_PROMPT", + "get_dialect_guidelines", ] diff --git a/src/data_agent/prompts/builder.py b/src/data_agent/prompts/builder.py new file mode 100644 index 0000000..5e05bb7 --- /dev/null +++ b/src/data_agent/prompts/builder.py @@ -0,0 +1,84 @@ +"""Prompt builder for assembling system prompts from components. + +This module provides the `build_prompt` function that assembles +the final system prompt by appending all components in order. +""" + +import logging + +from data_agent.prompts.defaults import ( + COSMOS_PROMPT_ADDENDUM, + DEFAULT_RESPONSE_PROMPT, + DEFAULT_SQL_PROMPT, +) +from data_agent.prompts.dialects import get_dialect_guidelines +from data_agent.utils.sql_utils import build_date_context + +logger = logging.getLogger(__name__) + + +def build_prompt( + datasource_type: str, + user_prompt: str | None = None, + schema_context: str = "", + few_shot_examples: str | None = None, + partition_key: str | None = None, +) -> str: + """Build the complete prompt by appending all components. + + Args: + datasource_type: Database type (e.g., 'bigquery', 'azure_sql', 'postgres', 'cosmosdb'). + user_prompt: Team's custom prompt from YAML (optional). + schema_context: Schema information (auto-discovered or from YAML). + few_shot_examples: Formatted examples string (optional). + partition_key: Cosmos DB partition key path (only for Cosmos datasources). + + Returns: + Complete prompt with all components assembled. + """ + sections: list[str] = [] + + sections.append(build_date_context().strip()) + + base_prompt = user_prompt.strip() if user_prompt else DEFAULT_SQL_PROMPT.strip() + + formatted_prompt = base_prompt.format( + schema_context=schema_context or "", + few_shot_examples=few_shot_examples or "", + ) + sections.append(formatted_prompt) + + dialect_guidelines = get_dialect_guidelines(datasource_type) + if dialect_guidelines: + sections.append(dialect_guidelines.strip()) + + if datasource_type.lower() in ("cosmos", "cosmosdb"): + cosmos_addendum = COSMOS_PROMPT_ADDENDUM.format( + partition_key=partition_key or "/id" + ) + sections.append(cosmos_addendum.strip()) + + prompt = "\n\n".join(sections) + + logger.debug( + "Built system prompt for %s (%d chars):\n%s", + datasource_type, + len(prompt), + prompt, + ) + + return prompt + + +def build_response_prompt(user_prompt: str | None = None) -> str: + """Build the response generation prompt. + + Args: + user_prompt: Team's custom response prompt from YAML (optional). + + Returns: + Response prompt (custom or default). + """ + if user_prompt: + return user_prompt.strip() + return DEFAULT_RESPONSE_PROMPT.strip() diff --git a/src/data_agent/prompts/defaults.py b/src/data_agent/prompts/defaults.py index 9aaa6e1..cf10879 100644 --- a/src/data_agent/prompts/defaults.py +++ b/src/data_agent/prompts/defaults.py @@ -1,5 +1,60 @@ """Default system prompts for data agent nodes.""" +DEFAULT_INTENT_DETECTION_PROMPT = """You are an intent detection assistant responsible for routing user questions to the appropriate data agent. + +## Available Data Agents + +{agent_descriptions} + +## Instructions + +1. Analyze the user's question to understand what data they are asking about. +2. Match the question to the most relevant data agent based on the domain and data types. +3. If the question is ambiguous, choose the agent most likely to have the relevant data. +4. If the user is greeting you (e.g., "hello", "hi", "hey"), asking about your capabilities (e.g., "what can you do?", "help"), or engaging in general conversation that doesn't require data queries, respond with "general_chat". +5. If no agent is a clear match AND it's not general chat, respond with "unknown". + +## Response Format + +Respond with ONLY the agent name (e.g., "financial_transactions") or "general_chat". Do not include any explanation.""" + +DEFAULT_GENERAL_CHAT_PROMPT = """You are a friendly and helpful data assistant. Respond conversationally to the user's greeting or question about your capabilities. + +## Your Capabilities +You help users query and analyze data from the following domains: + +{agent_descriptions} + +## Instructions +- If the user greets you, respond with a friendly greeting and briefly mention what you can help with. +- If the user asks what you can do, explain your capabilities and list the available data domains. +- Keep responses concise, friendly, and helpful. +- Guide users toward asking data-related questions. +""" + +DEFAULT_QUERY_REWRITE_PROMPT = """You are a query rewriter. Your job is to rewrite user questions to be more specific and clear for a database query system. + +## Target Agent +{agent_description} + +## Conversation Context +{conversation_context} + +## Instructions +1. Keep the original intent of the question +2. If this is a follow-up question (e.g., "what's the average?", "show me the same for X", "filter those by Y"), use the conversation history to expand the question with the relevant context +3. For follow-up questions, make the implicit references explicit (e.g., "What's the average?" → "What is the average transaction amount?" if previous query was about transactions) +4. Make the question more specific if needed +5. If the question is already clear and specific, return it unchanged +6. Do NOT add information that wasn't implied by the original question or conversation + +## Original Question +{question} + +## Rewritten Question +Respond with ONLY the rewritten question, nothing else. +""" + DEFAULT_SQL_PROMPT = """You are a SQL expert. Generate a syntactically correct SQL query. Limit results to 10 unless specified. Only select relevant columns. @@ -11,7 +66,8 @@ {schema_context} -{few_shot_examples}""" +{few_shot_examples} +""" COSMOS_PROMPT_ADDENDUM = """ Key Cosmos DB constraints: @@ -24,11 +80,25 @@ 7. Max 4MB response per page; use continuation tokens for large results. """ -DEFAULT_RESPONSE_PROMPT = """You are a helpful data analyst. Given the user's question, -the SQL query that was executed, and the results, provide a clear and concise natural -language response that answers the user's question. +DEFAULT_RESPONSE_PROMPT = """You are a helpful retail analyst for Walmart US sales data. +Given the user's question, the SQL query that was executed, and the results, +provide a clear and concise natural language response. -Be conversational but precise. Include relevant numbers and insights from the data.""" +Be conversational but precise. Include relevant numbers, percentages, and insights. +Format currency values with $ and commas. Format large numbers for readability. +When discussing sales performance, provide context about comparable stores, channels, and time periods. +If the results are empty, explain what that means in context. + +## Data Presentation + +When the query returns tabular data (multiple rows/columns), ALWAYS include a formatted markdown table showing the results. +- Use proper markdown table syntax with headers +- Align numeric columns to the right +- Format currency with $ and commas (e.g., $1,234.56) +- Format dates in readable format (e.g., Jun 21, 2025) +- Limit tables to 20 rows max; if more rows exist, show first 20 and note "... and X more rows" +- After the table, provide a brief summary or insight about the data. +""" VISUALIZATION_SYSTEM_PROMPT = """You are a data visualization expert. Generate Python code using matplotlib to create a chart. diff --git a/src/data_agent/prompts/dialects.py b/src/data_agent/prompts/dialects.py new file mode 100644 index 0000000..4b12ef2 --- /dev/null +++ b/src/data_agent/prompts/dialects.py @@ -0,0 +1,162 @@ +"""SQL dialect-specific guidelines for query generation. + +This module provides dialect-specific SQL guidelines that are automatically +appended to system prompts based on the datasource type. +""" + +BIGQUERY_GUIDELINES = """## BigQuery SQL Guidelines + +1. **Use BigQuery SQL syntax:** + - DATE_TRUNC, DATE_ADD, DATE_SUB for date operations + - CURRENT_DATE(), CURRENT_TIMESTAMP() for current time + - EXTRACT(YEAR FROM date), EXTRACT(MONTH FROM date) for date parts + - STRING, INT64, FLOAT64, NUMERIC, BOOL data types + - Use backticks for table names: `project.dataset.table` + +2. **Aggregation functions:** + - SUM(), AVG(), COUNT(), MIN(), MAX() + - COUNTIF(), SUMIF() for conditional aggregations + - APPROX_COUNT_DISTINCT() for large cardinality counts + +3. **String functions:** + - CONCAT(), SUBSTR(), UPPER(), LOWER(), TRIM() + - REGEXP_CONTAINS(), REGEXP_EXTRACT() for regex + - FORMAT() for string formatting + +4. **Best practices:** + - Always qualify column names with table aliases + - Use LIMIT to restrict results unless user specifies otherwise + - Use fully qualified table names: `project.dataset.table` + - Partition filters improve performance (e.g., WHERE partition_date >= ...) +""" + +POSTGRES_GUIDELINES = """## PostgreSQL SQL Guidelines + +1. **Use PostgreSQL syntax:** + - DATE_TRUNC(), DATE_PART() for date operations + - NOW(), CURRENT_DATE, CURRENT_TIMESTAMP for current time + - EXTRACT(YEAR FROM date) for date parts + - TEXT, INTEGER, BIGINT, NUMERIC, BOOLEAN data types + - Use double quotes for identifiers with special characters + +2. **Aggregation functions:** + - SUM(), AVG(), COUNT(), MIN(), MAX() + - COUNT(*) FILTER (WHERE condition) for conditional counts + - ARRAY_AGG(), STRING_AGG() for aggregation + +3. **String functions:** + - CONCAT(), SUBSTRING(), UPPER(), LOWER(), TRIM() + - ~ operator or SIMILAR TO for regex matching + - FORMAT() for string formatting + +4. **Best practices:** + - Always qualify column names with table aliases + - Use LIMIT to restrict results unless user specifies otherwise + - Use schema-qualified table names: schema.table +""" + +AZURE_SQL_GUIDELINES = """## Azure SQL / SQL Server Guidelines + +1. **Use T-SQL syntax:** + - DATEPART(), DATEDIFF(), DATEADD() for date operations + - GETDATE(), GETUTCDATE() for current time + - VARCHAR, NVARCHAR, INT, BIGINT, DECIMAL, BIT data types + - Use square brackets for identifiers: [schema].[table] + +2. **Aggregation functions:** + - SUM(), AVG(), COUNT(), MIN(), MAX() + - COUNT(*) with CASE for conditional counts + - STRING_AGG() for string aggregation (SQL Server 2017+) + +3. **String functions:** + - CONCAT(), SUBSTRING(), UPPER(), LOWER(), LTRIM(), RTRIM() + - LIKE with wildcards for pattern matching + - FORMAT() for string formatting + +4. **Best practices:** + - Always qualify column names with table aliases + - Use TOP N instead of LIMIT + - Use schema-qualified table names: [schema].[table] +""" + +SYNAPSE_GUIDELINES = """## Azure Synapse Analytics Guidelines + +1. **Use Synapse SQL syntax:** + - Similar to T-SQL with distributed query optimizations + - DATEPART(), DATEDIFF(), DATEADD() for date operations + - GETDATE(), GETUTCDATE() for current time + - VARCHAR, NVARCHAR, INT, BIGINT, DECIMAL data types + +2. **Aggregation functions:** + - SUM(), AVG(), COUNT(), MIN(), MAX() + - APPROX_COUNT_DISTINCT() for large tables + +3. **Best practices:** + - Use TOP N instead of LIMIT + - Filter on distribution columns when possible + - Use schema-qualified table names: [schema].[table] + - Avoid SELECT * on large tables +""" + +DATABRICKS_GUIDELINES = """## Databricks SQL Guidelines + +1. **Use Databricks SQL syntax:** + - DATE_TRUNC(), DATE_ADD(), DATE_SUB() for date operations + - CURRENT_DATE(), CURRENT_TIMESTAMP() for current time + - STRING, INT, BIGINT, DOUBLE, DECIMAL, BOOLEAN data types + +2. **Aggregation functions:** + - SUM(), AVG(), COUNT(), MIN(), MAX() + - APPROX_COUNT_DISTINCT() for large cardinality + - COLLECT_LIST(), COLLECT_SET() for arrays + +3. **Best practices:** + - Always qualify column names with table aliases + - Use LIMIT to restrict results + - Use catalog.schema.table naming convention + - Delta Lake tables support time travel: SELECT * FROM table@v1 +""" + +COSMOS_GUIDELINES = """## Azure Cosmos DB SQL Guidelines + +1. **Use Cosmos DB SQL syntax:** + - SELECT, FROM, WHERE, ORDER BY, TOP + - No JOINs between containers (only within documents) + - Array functions: ARRAY_CONTAINS(), ARRAY_LENGTH() + +2. **Query limitations:** + - Always include partition key in WHERE clause for efficiency + - Cross-partition queries are expensive + - No GROUP BY or aggregations without partition key filter + +3. **Best practices:** + - Filter by partition key first + - Use TOP instead of LIMIT + - Prefer point reads over queries when possible +""" + +# Map datasource types to their guidelines +DIALECT_GUIDELINES_MAP: dict[str, str] = { + "bigquery": BIGQUERY_GUIDELINES, + "postgres": POSTGRES_GUIDELINES, + "postgresql": POSTGRES_GUIDELINES, + "azure_sql": AZURE_SQL_GUIDELINES, + "mssql": AZURE_SQL_GUIDELINES, + "sqlserver": AZURE_SQL_GUIDELINES, + "synapse": SYNAPSE_GUIDELINES, + "databricks": DATABRICKS_GUIDELINES, + "cosmos": COSMOS_GUIDELINES, + "cosmosdb": COSMOS_GUIDELINES, +} + + +def get_dialect_guidelines(datasource_type: str) -> str: + """Get SQL guidelines for a specific datasource type. + + Args: + datasource_type: Database type (e.g., 'bigquery', 'azure_sql', 'postgres'). + + Returns: + Dialect-specific SQL guidelines, or empty string if not found. + """ + return DIALECT_GUIDELINES_MAP.get(datasource_type.lower(), "") diff --git a/uv.lock b/uv.lock index 6dda5dd..b02c21a 100644 --- a/uv.lock +++ b/uv.lock @@ -819,7 +819,7 @@ wheels = [ [[package]] name = "data-agent" -version = "0.3.0" +version = "0.3.1" source = { editable = "." } dependencies = [ { name = "a2a-sdk", extra = ["http-server"] }, @@ -841,6 +841,7 @@ dependencies = [ { name = "langgraph-api" }, { name = "langgraph-cli", extra = ["inmem"] }, { name = "matplotlib" }, + { name = "mcp" }, { name = "pandas" }, { name = "psycopg", extra = ["binary"] }, { name = "psycopg2" }, @@ -905,6 +906,7 @@ requires-dist = [ { name = "langgraph-api", specifier = ">=0.5.42" }, { name = "langgraph-cli", extras = ["inmem"], specifier = ">=0.4.10" }, { name = "matplotlib", specifier = ">=3.10.8" }, + { name = "mcp", specifier = ">=1.25.0" }, { name = "pandas", specifier = ">=2.0.0" }, { name = "poethepoet", marker = "extra == 'dev'" }, { name = "pre-commit", marker = "extra == 'dev'" }, @@ -2463,7 +2465,7 @@ wheels = [ [[package]] name = "mcp" -version = "1.24.0" +version = "1.25.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -2481,9 +2483,9 @@ dependencies = [ { name = "typing-inspection" }, { name = "uvicorn", marker = "sys_platform != 'emscripten'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d6/2c/db9ae5ab1fcdd9cd2bcc7ca3b7361b712e30590b64d5151a31563af8f82d/mcp-1.24.0.tar.gz", hash = "sha256:aeaad134664ce56f2721d1abf300666a1e8348563f4d3baff361c3b652448efc", size = 604375, upload-time = "2025-12-12T14:19:38.205Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d5/2d/649d80a0ecf6a1f82632ca44bec21c0461a9d9fc8934d38cb5b319f2db5e/mcp-1.25.0.tar.gz", hash = "sha256:56310361ebf0364e2d438e5b45f7668cbb124e158bb358333cd06e49e83a6802", size = 605387, upload-time = "2025-12-19T10:19:56.985Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/61/0d/5cf14e177c8ae655a2fd9324a6ef657ca4cafd3fc2201c87716055e29641/mcp-1.24.0-py3-none-any.whl", hash = "sha256:db130e103cc50ddc3dffc928382f33ba3eaef0b711f7a87c05e7ded65b1ca062", size = 232896, upload-time = "2025-12-12T14:19:36.14Z" }, + { url = "https://files.pythonhosted.org/packages/e2/fc/6dc7659c2ae5ddf280477011f4213a74f806862856b796ef08f028e664bf/mcp-1.25.0-py3-none-any.whl", hash = "sha256:b37c38144a666add0862614cc79ec276e97d72aa8ca26d622818d4e278b9721a", size = 233076, upload-time = "2025-12-19T10:19:55.416Z" }, ] [[package]]