Skip to content

Commit 494701a

Browse files
authored
Merge pull request #65 from Supahands/feat/advanced-options
feat: added advanced options
2 parents 2140cc9 + 314cbb6 commit 494701a

File tree

2 files changed

+52
-14
lines changed

2 files changed

+52
-14
lines changed

ai_router.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,13 @@ def from_response(cls, response_obj):
141141
response_time=response_obj.usage.response_time
142142
)
143143

144+
class ModelConfig(BaseModel):
145+
system_prompt: str
146+
temperature: float
147+
top_p: float
148+
max_tokens: int
149+
json_format: bool
150+
144151

145152
class ModelResponse(BaseModel):
146153
id: str
@@ -155,6 +162,7 @@ class ModelResponse(BaseModel):
155162
class MessageRequest(BaseModel):
156163
model: str = Field(..., description="Name of the model to use.")
157164
message: str = Field(..., description="Message text to send to the model.")
165+
config: ModelConfig
158166
openai_api_key: Optional[str] = Field(
159167
None, description="API key if required by the openai provider."
160168
)
@@ -195,7 +203,7 @@ def redact_words(model_name, text):
195203
return text
196204

197205
async def handle_completion(
198-
model_name: str, message: str, api_base: Optional[str] = None
206+
model_name: str, message: str, config: ModelConfig, api_base: Optional[str] = None,
199207
):
200208
try:
201209
start_time = time.time()
@@ -204,21 +212,48 @@ async def handle_completion(
204212
logging.info(f"Using API base: {api_base}")
205213
response_obj = completion(
206214
model=model_name,
207-
messages=[{"content": message, "role": "user"}],
208-
api_base=api_base,
215+
messages=[
216+
{
217+
"content": config.system_prompt + ". Please generate the response in JSON" if config.json_format else "",
218+
"role": "system",
219+
},
220+
{
221+
"content": message,
222+
"role": "user",
223+
}
224+
],
225+
api_base=api_base + "/v1",
209226
timeout=180.00,
210227
metadata = {
211228
"generation_name": model_name, # set langfuse generation name
212-
}
229+
},
230+
temperature=config.temperature,
231+
max_tokens=config.max_tokens,
232+
top_p=config.top_p,
233+
response_format= {"type": "json_object"} if config.json_format else None,
234+
api_key="None",
213235
)
214236
else:
215237
response_obj = completion(
216238
model=model_name,
217-
messages=[{"content": message, "role": "user"}],
239+
messages=[
240+
{
241+
"content": config.system_prompt + ". Please generate the response in JSON" if config.json_format else "",
242+
"role": "system",
243+
},
244+
{
245+
"content": message,
246+
"role": "user",
247+
}
248+
],
218249
timeout=180.00,
219250
metadata = {
220251
"generation_name": model_name, # set langfuse generation name
221-
}
252+
},
253+
temperature=config.temperature,
254+
max_tokens=config.max_tokens,
255+
top_p=config.top_p,
256+
response_format= {"type": "json_object"} if config.json_format == True else None
222257
)
223258

224259
end_time = time.time()
@@ -275,7 +310,10 @@ async def messaging(request: MessageRequest):
275310
message = request.message
276311
openai_api_key = request.openai_api_key
277312
anthropic_api_key = request.anthropic_api_key
313+
config = request.config
314+
278315
logging.info(f"Requested model name: {model_name}")
316+
logging.info(f"Config: {config}")
279317
logging.info(f"Message: {message}")
280318
# Fetch models from Supabase
281319
models = await fetch_models_from_supabase()
@@ -285,36 +323,36 @@ async def messaging(request: MessageRequest):
285323
if openai_model and openai_api_key:
286324
logging.info(f"Using OpenAI provider with model_id: {openai_model['model_id']}")
287325
with temporary_env_var("OPENAI_API_KEY", openai_api_key):
288-
return await handle_completion(openai_model['model_id'], message)
326+
return await handle_completion(openai_model['model_id'], message, config=config)
289327

290328
# Anthropic provider check
291329
anthropic_model = next((m for m in models if m["model_name"] == model_name and m["provider"] == "anthropic"), None)
292330
if anthropic_model and anthropic_api_key:
293331
logging.info(f"Using Anthropic provider with model_id: {anthropic_model['model_id']}")
294332
with temporary_env_var("ANTHROPIC_API_KEY", anthropic_api_key):
295-
return await handle_completion(anthropic_model['model_id'], message)
333+
return await handle_completion(anthropic_model['model_id'], message, config=config)
296334

297335
# GitHub provider check
298336
github_model = next((m for m in models if m["model_name"] == model_name and m["provider"] == "github"), None)
299337
if github_model:
300338
logging.info(f"Using GitHub provider with model_id: {github_model['model_id']}")
301339
model_id = f"{github_model['model_id']}"
302-
return await handle_completion(model_id, message)
340+
return await handle_completion(model_id, message, config=config)
303341

304342
# Hugging Face provider check
305343
huggingface_model = next((m for m in models if m["model_name"] == model_name and m["provider"] == "huggingface"), None)
306344
if huggingface_model:
307345
logging.info(f"Using Hugging Face provider with model_id: {huggingface_model['model_id']}")
308346
model_id = f"{huggingface_model['model_id']}"
309-
return await handle_completion(model_id, message)
347+
return await handle_completion(model_id, message, config=config)
310348

311349
# Ollama provider check
312350
ollama_model = next((m for m in models if m["model_name"] == model_name and m["provider"] == "ollama"), None)
313351
if ollama_model:
314352
logging.info(f"Using Ollama provider with model_id: {ollama_model['model_id']}")
315-
model_id = f"ollama/{ollama_model['model_id']}"
316-
api_url = os.environ["OLLAMA_API_URL"]
317-
return await handle_completion(model_id, message, api_base=api_url)
353+
model_id = f"openai/{ollama_model['model_id']}"
354+
api_url = "https://supa-dev--llm-comparison-api-ollama-api-dev.modal.run"
355+
return await handle_completion(model_id, message, config=config, api_base=api_url)
318356

319357
# Error handling
320358
model_info = next((m for m in models if m["model_name"] == model_name), None)

ollama_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ async def _response():
265265
headers=dict(response.headers)
266266
)
267267

268-
if request.url.path in ("/api/generate", "/api/chat"):
268+
if request.url.path in ("/v1/chat/completions"):
269269
return await _streaming_response()
270270
return await _response()
271271

0 commit comments

Comments
 (0)