Skip to content

Commit d639b2f

Browse files
authored
Merge branch 'main' into feat/nilauth-credit
2 parents 4bd80a4 + df2a37b commit d639b2f

File tree

13 files changed

+745
-9
lines changed

13 files changed

+745
-9
lines changed

docker/compose/docker-compose.llama-8b-gpu.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,22 @@ services:
2020
condition: service_healthy
2121
command: >
2222
--model meta-llama/Llama-3.1-8B-Instruct
23-
--gpu-memory-utilization 0.20
23+
--gpu-memory-utilization 0.95
2424
--max-model-len 10000
2525
--max-num-batched-tokens 10000
2626
--tensor-parallel-size 1
27-
--enable-auto-tool-choice
2827
--tool-call-parser llama3_json
2928
--uvicorn-log-level warning
29+
--enable-auto-tool-choice
30+
--chat-template /opt/vllm/templates/llama3.1_tool_json.jinja
3031
environment:
3132
- SVC_HOST=llama_8b_gpu
3233
- SVC_PORT=8000
3334
- ETCD_HOST=etcd
3435
- ETCD_PORT=2379
3536
- TOOL_SUPPORT=true
3637
volumes:
37-
- hugging_face_models:/root/.cache/huggingface # cache models
38+
- hugging_face_models:/root/.cache/huggingface
3839
healthcheck:
3940
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
4041
interval: 30s

docker/compose/docker-compose.nilai-prod-2.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,16 @@ services:
3232
--enable-auto-tool-choice
3333
--tool-call-parser llama3_json
3434
--uvicorn-log-level warning
35+
--enable-auto-tool-choice
36+
--chat-template /opt/vllm/templates/llama3.1_tool_json.jinja
3537
environment:
3638
- SVC_HOST=llama_8b_gpu
3739
- SVC_PORT=8000
3840
- ETCD_HOST=etcd
3941
- ETCD_PORT=2379
4042
- TOOL_SUPPORT=true
4143
volumes:
42-
- hugging_face_models:/root/.cache/huggingface # cache models
44+
- hugging_face_models:/root/.cache/huggingface
4345
healthcheck:
4446
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
4547
interval: 30s

docker/vllm.Dockerfile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ FROM vllm/vllm-openai:v0.10.1
1010
# ENV EXEC_PATH=nilai_models.models.${MODEL_NAME}:app
1111

1212
COPY --link . /daemon/
13+
COPY --link vllm_templates /opt/vllm/templates
1314

1415
WORKDIR /daemon/nilai-models/
1516

nilai-api/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ dependencies = [
3737
"trafilatura>=1.7.0",
3838
"secretvaults",
3939
"nilauth-credit-middleware>=0.1.0",
40+
"e2b-code-interpreter>=1.0.3",
4041
]
4142

4243

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import logging
5+
6+
from e2b_code_interpreter import Sandbox
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
def _run_in_sandbox_sync(code: str) -> str:
12+
"""Execute Python code in an e2b sandbox and return the textual output or stdout if available."""
13+
try:
14+
with Sandbox.create() as sandbox:
15+
exec_ = sandbox.run_code(code)
16+
if exec_.text:
17+
return exec_.text
18+
if getattr(exec_, "logs", None) and getattr(exec_.logs, "stdout", None):
19+
return "\n".join(exec_.logs.stdout)
20+
return ""
21+
except Exception as e:
22+
logger.error("Error executing code in sandbox: %s", e)
23+
raise
24+
25+
26+
async def execute_python(code: str) -> str:
27+
"""Execute Python code in an e2b Code Interpreter sandbox and return the textual output.
28+
29+
This function is async-safe and runs the blocking execution in a thread.
30+
"""
31+
logger.info("Executing Python code asynchronously")
32+
try:
33+
result = await asyncio.to_thread(_run_in_sandbox_sync, code)
34+
logger.info("Python code execution completed successfully")
35+
return result
36+
except Exception as e:
37+
logger.error(f"Error in async Python code execution: {e}")
38+
raise
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
from __future__ import annotations
2+
3+
import json
4+
import uuid
5+
from typing import List, Optional, Tuple, cast
6+
from nilai_common import (
7+
Message,
8+
MessageAdapter,
9+
ChatRequest,
10+
ChatCompletion,
11+
ChatCompletionMessage,
12+
ChatCompletionMessageToolCall,
13+
ChatToolFunction,
14+
)
15+
16+
from . import code_execution
17+
from openai import AsyncOpenAI
18+
19+
import logging
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
async def route_and_execute_tool_call(
25+
tool_call: ChatCompletionMessageToolCall,
26+
) -> Message:
27+
"""Route a single tool call to its implementation and return a tool message.
28+
29+
The returned message is a dict compatible with OpenAI's ChatCompletionMessageParam
30+
with role="tool".
31+
"""
32+
func_name = tool_call.function.name
33+
arguments = tool_call.function.arguments or "{}"
34+
35+
if func_name == "execute_python":
36+
# arguments is a JSON string
37+
try:
38+
args = json.loads(arguments)
39+
except Exception:
40+
args = {}
41+
code = args.get("code", "")
42+
result = await code_execution.execute_python(code)
43+
logger.info(f"[tool] execute_python result: {result}")
44+
return MessageAdapter.new_tool_message(
45+
name="execute_python",
46+
content=result,
47+
tool_call_id=tool_call.id,
48+
)
49+
50+
# Unknown tool: return an error message to the model
51+
return MessageAdapter.new_tool_message(
52+
name=func_name,
53+
content=f"Tool '{func_name}' not implemented",
54+
tool_call_id=tool_call.id,
55+
)
56+
57+
58+
async def process_tool_calls(
59+
tool_calls: List[ChatCompletionMessageToolCall],
60+
) -> List[Message]:
61+
"""Process a list of tool calls and return their corresponding tool messages.
62+
63+
Routes each tool call to its implementation and collects the results as
64+
tool messages that can be appended to the conversation history.
65+
"""
66+
msgs: List[Message] = []
67+
for tc in tool_calls:
68+
msg = await route_and_execute_tool_call(tc)
69+
msgs.append(msg)
70+
return msgs
71+
72+
73+
def extract_tool_calls_from_response_message(
74+
response_message: ChatCompletionMessage,
75+
) -> List[ChatCompletionMessageToolCall]:
76+
"""Return tool calls from a ChatCompletionMessage, parsing content if needed.
77+
78+
Many models may emit function-calling either via the structured `tool_calls`
79+
field or encode it as JSON in the assistant `content`. This helper returns a
80+
normalized list of `ChatCompletionMessageToolCall` objects, using a
81+
best-effort parse of the content when `tool_calls` is empty.
82+
"""
83+
if response_message.tool_calls:
84+
return cast(List[ChatCompletionMessageToolCall], response_message.tool_calls)
85+
86+
try:
87+
adapter = MessageAdapter(
88+
raw=cast(
89+
Message,
90+
response_message.model_dump(exclude_unset=True),
91+
)
92+
)
93+
content: Optional[str] = adapter.extract_text()
94+
except Exception:
95+
content = response_message.content
96+
97+
if not content:
98+
return []
99+
100+
try:
101+
data = json.loads(content)
102+
except Exception:
103+
return []
104+
105+
if not isinstance(data, dict):
106+
return []
107+
108+
# Support multiple possible schemas
109+
fn = data.get("function")
110+
if isinstance(fn, dict) and "name" in fn:
111+
name = fn.get("name")
112+
args = fn.get("parameters", {})
113+
else:
114+
# Fallbacks for other schemas
115+
name = data.get("name") or data.get("tool") or data.get("function_name")
116+
raw_args = data.get("arguments")
117+
try:
118+
args = (
119+
(json.loads(raw_args) if isinstance(raw_args, str) else raw_args)
120+
or data.get("parameters", {})
121+
or {}
122+
)
123+
except Exception:
124+
args = data.get("parameters", {}) or {}
125+
126+
if not isinstance(name, str) or not name:
127+
return []
128+
129+
try:
130+
tool_call = ChatCompletionMessageToolCall(
131+
id=f"call_{uuid.uuid4()}",
132+
type="function",
133+
function=ChatToolFunction(name=name, arguments=json.dumps(args)),
134+
)
135+
except Exception:
136+
return []
137+
138+
return [tool_call]
139+
140+
141+
async def handle_tool_workflow(
142+
client: AsyncOpenAI,
143+
req: ChatRequest,
144+
current_messages: List[Message],
145+
first_response: ChatCompletion,
146+
) -> Tuple[ChatCompletion, int, int]:
147+
"""Execute tool workflow if requested and return final completion and usage.
148+
149+
- Extracts tool calls from the first response (structured or JSON in content)
150+
- Executes tools and appends tool messages
151+
- Runs a follow-up completion providing tool outputs
152+
- Returns the final ChatCompletion and aggregated usage (prompt, completion)
153+
"""
154+
logger.info("[tools] evaluating tool workflow for response")
155+
156+
prompt_tokens = first_response.usage.prompt_tokens if first_response.usage else 0
157+
completion_tokens = (
158+
first_response.usage.completion_tokens if first_response.usage else 0
159+
)
160+
161+
response_message = first_response.choices[0].message
162+
tool_calls = extract_tool_calls_from_response_message(response_message)
163+
logger.info(f"[tools] extracted tool_calls: {tool_calls}")
164+
165+
if not tool_calls:
166+
return first_response, 0, 0
167+
168+
assistant_tool_call_msg = MessageAdapter.new_assistant_tool_call_message(tool_calls)
169+
current_messages = [*current_messages, assistant_tool_call_msg]
170+
171+
tool_messages = await process_tool_calls(tool_calls)
172+
current_messages.extend(tool_messages)
173+
174+
request_kwargs = {
175+
"model": req.model,
176+
"messages": current_messages, # type: ignore[arg-type]
177+
"top_p": req.top_p,
178+
"temperature": req.temperature,
179+
"max_tokens": req.max_tokens,
180+
"tool_choice": "none",
181+
}
182+
183+
logger.info("[tools] performing follow-up completion with tool outputs")
184+
second: ChatCompletion = await client.chat.completions.create(**request_kwargs) # type: ignore
185+
if second.usage:
186+
prompt_tokens += second.usage.prompt_tokens
187+
completion_tokens += second.usage.completion_tokens
188+
189+
return second, prompt_tokens, completion_tokens

nilai-api/src/nilai_api/routers/private.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from nilai_api.credit import LLMMeter, LLMUsage
1010
from nilai_api.handlers.nilrag import handle_nilrag
1111
from nilai_api.handlers.web_search import handle_web_search
12+
from nilai_api.handlers.tools.tool_router import handle_tool_workflow
1213

1314
from fastapi import APIRouter, Body, Depends, HTTPException, status, Request
1415
from fastapi.responses import StreamingResponse
@@ -158,7 +159,7 @@ async def chat_completion_web_search_rate_limit(request: Request) -> bool:
158159
chat_request = ChatRequest(**body)
159160
except ValueError:
160161
raise HTTPException(status_code=400, detail="Invalid request body")
161-
return getattr(chat_request, "web_search", False)
162+
return bool(chat_request.web_search)
162163

163164

164165
@router.post("/v1/chat/completions", tags=["Chat"], response_model=None)
@@ -402,6 +403,8 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]:
402403
}
403404
if req.tools:
404405
request_kwargs["tools"] = req.tools # type: ignore
406+
request_kwargs["tool_choice"] = req.tool_choice
407+
405408
logger.info(f"[chat] call start request_id={request_id}")
406409
logger.info(f"[chat] call message: {current_messages}")
407410
t_call = time.monotonic()
@@ -410,11 +413,20 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]:
410413
f"[chat] call done request_id={request_id} duration_ms={(time.monotonic() - t_call) * 1000:.0f}"
411414
)
412415
logger.info(f"[chat] call response: {response}")
416+
417+
# Handle tool workflow fully inside tools.router
418+
(
419+
final_completion,
420+
agg_prompt_tokens,
421+
agg_completion_tokens,
422+
) = await handle_tool_workflow(client, req, current_messages, response)
423+
logger.info(f"[chat] call final_completion: {final_completion}")
413424
model_response = SignedChatCompletion(
414-
**response.model_dump(),
425+
**final_completion.model_dump(),
415426
signature="",
416427
sources=sources,
417428
)
429+
418430
logger.info(
419431
f"[chat] model_response request_id={request_id} duration_ms={(time.monotonic() - t_call) * 1000:.0f}"
420432
)
@@ -424,7 +436,21 @@ async def chat_completion_stream_generator() -> AsyncGenerator[str, None]:
424436
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
425437
detail="Model response does not contain usage statistics",
426438
)
427-
# Update token usage
439+
440+
if agg_prompt_tokens or agg_completion_tokens:
441+
total_prompt_tokens = response.usage.prompt_tokens
442+
total_completion_tokens = response.usage.completion_tokens
443+
444+
total_prompt_tokens += agg_prompt_tokens
445+
total_completion_tokens += agg_completion_tokens
446+
447+
model_response.usage.prompt_tokens = total_prompt_tokens
448+
model_response.usage.completion_tokens = total_completion_tokens
449+
model_response.usage.total_tokens = (
450+
total_prompt_tokens + total_completion_tokens
451+
)
452+
453+
# Update token usage in DB
428454
await UserManager.update_token_usage(
429455
auth_info.user.userid,
430456
prompt_tokens=model_response.usage.prompt_tokens,

packages/nilai-common/src/nilai_common/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
ChatRequest,
44
SignedChatCompletion,
55
Choice,
6+
ChatCompletion,
7+
ChatCompletionMessage,
8+
ChatCompletionMessageToolCall,
9+
ChatToolFunction,
610
HealthCheckResponse,
711
ModelEndpoint,
812
ModelMetadata,
@@ -29,6 +33,10 @@
2933
"ChatRequest",
3034
"SignedChatCompletion",
3135
"Choice",
36+
"ChatCompletion",
37+
"ChatCompletionMessage",
38+
"ChatCompletionMessageToolCall",
39+
"ChatToolFunction",
3240
"ModelMetadata",
3341
"Usage",
3442
"AttestationReport",

0 commit comments

Comments
 (0)