Skip to content

Commit 700c26a

Browse files
committed
fix: CI and added new attestation endpoint
1 parent 6febba3 commit 700c26a

File tree

16 files changed

+107
-67
lines changed

16 files changed

+107
-67
lines changed

docker-compose.dev.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,8 @@ services:
107107
depends_on:
108108
nilauth-postgres:
109109
condition: service_healthy
110+
volumes:
111+
- ./scripts/credit-init.sql:/app/migrations/20251015000006_seed_test_data.sql
110112
healthcheck:
111113
test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://127.0.0.1:3000/health"]
112114
interval: 30s

nilai-api/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ dependencies = [
3535
"trafilatura>=1.7.0",
3636
"secretvaults",
3737
"e2b-code-interpreter>=1.0.3",
38-
"nilauth-credit-middleware>=0.1.0",
38+
"nilauth-credit-middleware>=0.1.1",
3939
]
4040

4141

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,31 @@
11
from fastapi import HTTPException
22
import httpx
3-
from nilai_common import Nonce, AttestationReport, SETTINGS
3+
from nilai_common import AttestationReport
44
from nilai_common.logger import setup_logger
55

66
logger = setup_logger(__name__)
77

8+
ATTESTATION_URL = "http://nilcc-attester/v2/report"
89

9-
async def get_attestation_report(
10-
nonce: Nonce | None,
11-
) -> AttestationReport:
12-
"""Get the attestation report for the given nonce"""
13-
14-
try:
15-
attestation_url = f"http://{SETTINGS.attestation_host}:{SETTINGS.attestation_port}/attestation/report"
16-
async with httpx.AsyncClient() as client:
17-
response: httpx.Response = await client.get(attestation_url, params=nonce)
18-
report = AttestationReport(**response.json())
19-
return report
20-
except Exception as e:
21-
raise HTTPException(status_code=500, detail=str(e))
2210

11+
async def get_attestation_report() -> AttestationReport:
12+
"""Get the attestation report"""
2313

24-
async def verify_attestation_report(attestation_report: AttestationReport) -> bool:
25-
"""Verify the attestation report"""
2614
try:
27-
attestation_url = f"http://{SETTINGS.attestation_host}:{SETTINGS.attestation_port}/attestation/verify"
2815
async with httpx.AsyncClient() as client:
29-
response: httpx.Response = await client.get(
30-
attestation_url, params=attestation_report.model_dump()
16+
response: httpx.Response = await client.get(ATTESTATION_URL)
17+
response_json = response.json()
18+
return AttestationReport(
19+
gpu_attestation=response_json["report"],
20+
cpu_attestation=response_json["gpu_token"],
21+
verifying_key="", # Added later by the API
3122
)
32-
return response.json()
23+
except httpx.HTTPStatusError as e:
24+
raise HTTPException(
25+
status_code=e.response.status_code,
26+
detail=str("Error getting attestation report" + str(e)),
27+
)
3328
except Exception as e:
34-
raise HTTPException(status_code=500, detail=str(e))
29+
raise HTTPException(
30+
status_code=500, detail=str("Error getting attestation report" + str(e))
31+
)

nilai-api/src/nilai_api/credit.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class LLMResponse(BaseModel):
9292
)
9393

9494

95-
def user_id_extractor() -> Callable[[Request], Awaitable[str]]:
95+
def credential_extractor() -> Callable[[Request], Awaitable[str]]:
9696
if CONFIG.auth.auth_strategy == "nuc":
9797
return from_nuc_bearer_root_token()
9898
else:
@@ -145,7 +145,8 @@ async def calculator(request: Request, response_data: dict) -> float:
145145

146146

147147
LLMMeter = create_metering_dependency(
148-
user_id_extractor=user_id_extractor(),
148+
credential_extractor=credential_extractor(),
149149
estimated_cost=2.0,
150150
cost_calculator=llm_cost_calculator(MyCostDictionary),
151+
public_identifiers=CONFIG.auth.auth_strategy == "nuc",
151152
)

nilai-api/src/nilai_api/routers/private.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
ModelMetadata,
3838
MessageAdapter,
3939
SignedChatCompletion,
40-
Nonce,
4140
Source,
4241
Usage,
4342
)
@@ -95,7 +94,6 @@ async def get_usage(auth_info: AuthenticationInfo = Depends(get_auth_info)) -> U
9594

9695
@router.get("/v1/attestation/report", tags=["Attestation"])
9796
async def get_attestation(
98-
nonce: Optional[Nonce] = None,
9997
auth_info: AuthenticationInfo = Depends(get_auth_info),
10098
) -> AttestationReport:
10199
"""
@@ -114,7 +112,7 @@ async def get_attestation(
114112
Provides cryptographic proof of the service's integrity and environment.
115113
"""
116114

117-
attestation_report = await get_attestation_report(nonce)
115+
attestation_report = await get_attestation_report()
118116
attestation_report.verifying_key = state.b64_public_key
119117
return attestation_report
120118

nilai-api/src/nilai_api/routers/public.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
from nilai_api.state import state
44

55
# Internal libraries
6-
from nilai_common import HealthCheckResponse, AttestationReport
7-
from nilai_api.attestation import verify_attestation_report
6+
from nilai_common import HealthCheckResponse
87

98
router = APIRouter()
109

@@ -42,14 +41,3 @@ async def health_check() -> HealthCheckResponse:
4241
```
4342
"""
4443
return HealthCheckResponse(status="ok", uptime=state.uptime)
45-
46-
47-
@router.post("/attestation/verify", tags=["Attestation"])
48-
async def post_attestation(attestation_report: AttestationReport) -> bool:
49-
"""
50-
Verify a cryptographic attestation report.
51-
52-
- **attestation_report**: Attestation report to verify
53-
- **Returns**: True if the attestation report is valid, False otherwise
54-
"""
55-
return await verify_attestation_report(attestation_report)

nilai-models/src/nilai_models/lmstudio_announcer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,10 @@ async def main():
177177
os.getenv("LMSTUDIO_SUPPORTED_FEATURES", "chat_completion")
178178
) or ["chat_completion"]
179179

180-
tool_default = to_bool(os.getenv("LMSTUDIO_TOOL_SUPPORT_DEFAULT", "false"))
180+
tool_default = to_bool(os.getenv("LMSTUDIO_TOOL_SUPPORT_DEFAULT", "true"))
181181
tool_models = set(_parse_csv(os.getenv("LMSTUDIO_TOOL_SUPPORT_MODELS", "")))
182182

183-
multimodal_default = to_bool(os.getenv("LMSTUDIO_MULTIMODAL_DEFAULT", "false"))
183+
multimodal_default = to_bool(os.getenv("LMSTUDIO_MULTIMODAL_DEFAULT", "true"))
184184
multimodal_models = set(_parse_csv(os.getenv("LMSTUDIO_MULTIMODAL_MODELS", "")))
185185

186186
version = os.getenv("LMSTUDIO_MODEL_VERSION", "local")

packages/nilai-common/src/nilai_common/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
HealthCheckResponse,
1111
ModelEndpoint,
1212
ModelMetadata,
13-
Nonce,
1413
AMDAttestationToken,
1514
NVAttestationToken,
1615
SearchResult,
@@ -22,10 +21,10 @@
2221
Topic,
2322
Message,
2423
MessageAdapter,
24+
Usage,
2525
)
2626
from nilai_common.config import SETTINGS, MODEL_SETTINGS
2727
from nilai_common.discovery import ModelServiceDiscovery
28-
from openai.types.completion_usage import CompletionUsage as Usage
2928

3029
__all__ = [
3130
"Message",
@@ -43,7 +42,6 @@
4342
"HealthCheckResponse",
4443
"ModelEndpoint",
4544
"ModelServiceDiscovery",
46-
"Nonce",
4745
"AMDAttestationToken",
4846
"NVAttestationToken",
4947
"SETTINGS",
@@ -55,4 +53,5 @@
5553
"WebSearchEnhancedMessages",
5654
"WebSearchContext",
5755
"ResultContent",
56+
"Usage",
5857
]

packages/nilai-common/src/nilai_common/api_model.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@
3131
from openai.types.chat.chat_completion_content_part_image_param import (
3232
ChatCompletionContentPartImageParam,
3333
)
34+
35+
from openai.types.completion_usage import CompletionUsage as Usage
36+
3437
from openai.types.chat.chat_completion import Choice as OpenaAIChoice
3538
from pydantic import BaseModel, Field
3639

@@ -41,6 +44,37 @@
4144
TextContent: TypeAlias = ChatCompletionContentPartTextParam
4245
Message: TypeAlias = ChatCompletionMessageParam # SDK union of message shapes
4346

47+
# Explicitly re-export OpenAI types that are part of our public API
48+
__all__ = [
49+
"ChatCompletion",
50+
"ChatCompletionMessage",
51+
"ChatCompletionMessageToolCall",
52+
"ChatToolFunction",
53+
"Function",
54+
"ImageContent",
55+
"TextContent",
56+
"Message",
57+
"ResultContent",
58+
"Choice",
59+
"Source",
60+
"SearchResult",
61+
"Topic",
62+
"TopicResponse",
63+
"TopicQuery",
64+
"MessageAdapter",
65+
"WebSearchEnhancedMessages",
66+
"WebSearchContext",
67+
"ChatRequest",
68+
"SignedChatCompletion",
69+
"ModelMetadata",
70+
"ModelEndpoint",
71+
"HealthCheckResponse",
72+
"AttestationReport",
73+
"AMDAttestationToken",
74+
"NVAttestationToken",
75+
"Usage",
76+
]
77+
4478

4579
# ---------- Domain-specific objects for web search ----------
4680
class ResultContent(BaseModel):
@@ -364,14 +398,6 @@ class HealthCheckResponse(BaseModel):
364398

365399

366400
# ---------- Attestation ----------
367-
Nonce = Annotated[
368-
str,
369-
Field(
370-
max_length=64,
371-
min_length=64,
372-
description="The nonce to be used for the attestation",
373-
),
374-
]
375401

376402
AMDAttestationToken = Annotated[
377403
str, Field(description="The attestation token from AMD's attestation service")
@@ -383,7 +409,6 @@ class HealthCheckResponse(BaseModel):
383409

384410

385411
class AttestationReport(BaseModel):
386-
nonce: Nonce
387412
verifying_key: Annotated[str, Field(description="PEM encoded public key")]
388413
cpu_attestation: AMDAttestationToken
389414
gpu_attestation: NVAttestationToken

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ dev = [
2424
"uvicorn>=0.32.1",
2525
"pytest-asyncio>=1.2.0",
2626
"testcontainers>=4.13.0",
27-
"pyright>=1.1.405",
27+
"pyright>=1.1.406",
2828
"pre-commit>=4.1.0",
2929
"httpx>=0.28.1",
3030
]

0 commit comments

Comments
 (0)