|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import json |
| 4 | +import uuid |
| 5 | +from typing import List, Optional, Tuple, cast |
| 6 | +from nilai_common import ( |
| 7 | + Message, |
| 8 | + MessageAdapter, |
| 9 | + ChatRequest, |
| 10 | + ChatCompletion, |
| 11 | + ChatCompletionMessage, |
| 12 | + ChatCompletionMessageToolCall, |
| 13 | + ChatToolFunction, |
| 14 | +) |
| 15 | + |
| 16 | +from . import code_execution |
| 17 | +from openai import AsyncOpenAI |
| 18 | + |
| 19 | +import logging |
| 20 | + |
| 21 | +logger = logging.getLogger(__name__) |
| 22 | + |
| 23 | + |
| 24 | +async def route_and_execute_tool_call( |
| 25 | + tool_call: ChatCompletionMessageToolCall, |
| 26 | +) -> Message: |
| 27 | + """Route a single tool call to its implementation and return a tool message. |
| 28 | +
|
| 29 | + The returned message is a dict compatible with OpenAI's ChatCompletionMessageParam |
| 30 | + with role="tool". |
| 31 | + """ |
| 32 | + func_name = tool_call.function.name |
| 33 | + arguments = tool_call.function.arguments or "{}" |
| 34 | + |
| 35 | + if func_name == "execute_python": |
| 36 | + # arguments is a JSON string |
| 37 | + try: |
| 38 | + args = json.loads(arguments) |
| 39 | + except Exception: |
| 40 | + args = {} |
| 41 | + code = args.get("code", "") |
| 42 | + result = await code_execution.execute_python(code) |
| 43 | + logger.info(f"[tool] execute_python result: {result}") |
| 44 | + return MessageAdapter.new_tool_message( |
| 45 | + name="execute_python", |
| 46 | + content=result, |
| 47 | + tool_call_id=tool_call.id, |
| 48 | + ) |
| 49 | + |
| 50 | + # Unknown tool: return an error message to the model |
| 51 | + return MessageAdapter.new_tool_message( |
| 52 | + name=func_name, |
| 53 | + content=f"Tool '{func_name}' not implemented", |
| 54 | + tool_call_id=tool_call.id, |
| 55 | + ) |
| 56 | + |
| 57 | + |
| 58 | +async def process_tool_calls( |
| 59 | + tool_calls: List[ChatCompletionMessageToolCall], |
| 60 | +) -> List[Message]: |
| 61 | + """Process a list of tool calls and return their corresponding tool messages. |
| 62 | +
|
| 63 | + Routes each tool call to its implementation and collects the results as |
| 64 | + tool messages that can be appended to the conversation history. |
| 65 | + """ |
| 66 | + msgs: List[Message] = [] |
| 67 | + for tc in tool_calls: |
| 68 | + msg = await route_and_execute_tool_call(tc) |
| 69 | + msgs.append(msg) |
| 70 | + return msgs |
| 71 | + |
| 72 | + |
| 73 | +def extract_tool_calls_from_response_message( |
| 74 | + response_message: ChatCompletionMessage, |
| 75 | +) -> List[ChatCompletionMessageToolCall]: |
| 76 | + """Return tool calls from a ChatCompletionMessage, parsing content if needed. |
| 77 | +
|
| 78 | + Many models may emit function-calling either via the structured `tool_calls` |
| 79 | + field or encode it as JSON in the assistant `content`. This helper returns a |
| 80 | + normalized list of `ChatCompletionMessageToolCall` objects, using a |
| 81 | + best-effort parse of the content when `tool_calls` is empty. |
| 82 | + """ |
| 83 | + if response_message.tool_calls: |
| 84 | + return cast(List[ChatCompletionMessageToolCall], response_message.tool_calls) |
| 85 | + |
| 86 | + try: |
| 87 | + adapter = MessageAdapter( |
| 88 | + raw=cast( |
| 89 | + Message, |
| 90 | + response_message.model_dump(exclude_unset=True), |
| 91 | + ) |
| 92 | + ) |
| 93 | + content: Optional[str] = adapter.extract_text() |
| 94 | + except Exception: |
| 95 | + content = response_message.content |
| 96 | + |
| 97 | + if not content: |
| 98 | + return [] |
| 99 | + |
| 100 | + try: |
| 101 | + data = json.loads(content) |
| 102 | + except Exception: |
| 103 | + return [] |
| 104 | + |
| 105 | + if not isinstance(data, dict): |
| 106 | + return [] |
| 107 | + |
| 108 | + # Support multiple possible schemas |
| 109 | + fn = data.get("function") |
| 110 | + if isinstance(fn, dict) and "name" in fn: |
| 111 | + name = fn.get("name") |
| 112 | + args = fn.get("parameters", {}) |
| 113 | + else: |
| 114 | + # Fallbacks for other schemas |
| 115 | + name = data.get("name") or data.get("tool") or data.get("function_name") |
| 116 | + raw_args = data.get("arguments") |
| 117 | + try: |
| 118 | + args = ( |
| 119 | + (json.loads(raw_args) if isinstance(raw_args, str) else raw_args) |
| 120 | + or data.get("parameters", {}) |
| 121 | + or {} |
| 122 | + ) |
| 123 | + except Exception: |
| 124 | + args = data.get("parameters", {}) or {} |
| 125 | + |
| 126 | + if not isinstance(name, str) or not name: |
| 127 | + return [] |
| 128 | + |
| 129 | + try: |
| 130 | + tool_call = ChatCompletionMessageToolCall( |
| 131 | + id=f"call_{uuid.uuid4()}", |
| 132 | + type="function", |
| 133 | + function=ChatToolFunction(name=name, arguments=json.dumps(args)), |
| 134 | + ) |
| 135 | + except Exception: |
| 136 | + return [] |
| 137 | + |
| 138 | + return [tool_call] |
| 139 | + |
| 140 | + |
| 141 | +async def handle_tool_workflow( |
| 142 | + client: AsyncOpenAI, |
| 143 | + req: ChatRequest, |
| 144 | + current_messages: List[Message], |
| 145 | + first_response: ChatCompletion, |
| 146 | +) -> Tuple[ChatCompletion, int, int]: |
| 147 | + """Execute tool workflow if requested and return final completion and usage. |
| 148 | +
|
| 149 | + - Extracts tool calls from the first response (structured or JSON in content) |
| 150 | + - Executes tools and appends tool messages |
| 151 | + - Runs a follow-up completion providing tool outputs |
| 152 | + - Returns the final ChatCompletion and aggregated usage (prompt, completion) |
| 153 | + """ |
| 154 | + logger.info("[tools] evaluating tool workflow for response") |
| 155 | + |
| 156 | + prompt_tokens = first_response.usage.prompt_tokens if first_response.usage else 0 |
| 157 | + completion_tokens = ( |
| 158 | + first_response.usage.completion_tokens if first_response.usage else 0 |
| 159 | + ) |
| 160 | + |
| 161 | + response_message = first_response.choices[0].message |
| 162 | + tool_calls = extract_tool_calls_from_response_message(response_message) |
| 163 | + logger.info(f"[tools] extracted tool_calls: {tool_calls}") |
| 164 | + |
| 165 | + if not tool_calls: |
| 166 | + return first_response, 0, 0 |
| 167 | + |
| 168 | + assistant_tool_call_msg = MessageAdapter.new_assistant_tool_call_message(tool_calls) |
| 169 | + current_messages = [*current_messages, assistant_tool_call_msg] |
| 170 | + |
| 171 | + tool_messages = await process_tool_calls(tool_calls) |
| 172 | + current_messages.extend(tool_messages) |
| 173 | + |
| 174 | + request_kwargs = { |
| 175 | + "model": req.model, |
| 176 | + "messages": current_messages, # type: ignore[arg-type] |
| 177 | + "top_p": req.top_p, |
| 178 | + "temperature": req.temperature, |
| 179 | + "max_tokens": req.max_tokens, |
| 180 | + "tool_choice": "none", |
| 181 | + } |
| 182 | + |
| 183 | + logger.info("[tools] performing follow-up completion with tool outputs") |
| 184 | + second: ChatCompletion = await client.chat.completions.create(**request_kwargs) # type: ignore |
| 185 | + if second.usage: |
| 186 | + prompt_tokens += second.usage.prompt_tokens |
| 187 | + completion_tokens += second.usage.completion_tokens |
| 188 | + |
| 189 | + return second, prompt_tokens, completion_tokens |
0 commit comments