From dd4de02435355e9bda47597ec01cebe0c85a545b Mon Sep 17 00:00:00 2001 From: madclaws Date: Tue, 18 Nov 2025 03:55:56 +0530 Subject: [PATCH 1/2] Added thinking in the cli itself --- Cargo.lock | 7 +++ Cargo.toml | 1 + server/api.py | 108 +++++++++++++++++++++++-------------- server/cache_utils.py | 1 - server/main.py | 41 ++++++++++---- server/mem_agent/engine.py | 4 +- server/mem_agent/utils.py | 2 + server/system_prompt.txt | 1 + src/runner/mlx.rs | 84 ++++++++++++++++++++++++++--- 9 files changed, 188 insertions(+), 61 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 8ca5da0..592c5d7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -752,6 +752,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "owo-colors" +version = "4.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c6901729fa79e91a0913333229e9ca5dc725089d1c363b2f4b4760709dc4a52" + [[package]] name = "percent-encoding" version = "2.3.2" @@ -1128,6 +1134,7 @@ dependencies = [ "anyhow", "clap", "nom", + "owo-colors", "reqwest", "serde", "serde_json", diff --git a/Cargo.toml b/Cargo.toml index 80faa4b..dff6ff4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,3 +11,4 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" anyhow = "1.0" tokio = { version = "1" , features = ["macros", "rt-multi-thread"]} +owo-colors = "4" diff --git a/server/api.py b/server/api.py index 62a1a53..675a9e0 100644 --- a/server/api.py +++ b/server/api.py @@ -22,7 +22,7 @@ from fastapi import FastAPI, HTTPException from .config import SYSTEM_PROMPT - +import logging import json import time import uuid @@ -40,6 +40,8 @@ from server.mem_agent.utils import extract_python_code, extract_reply, extract_thoughts, create_memory_if_not_exists, format_results from server.mem_agent.engine import execute_sandboxed_code # Global model cache and configuration + +logger = logging.getLogger("app") _model_cache: Dict[str, MLXRunner] = {} _current_model_path: Optional[str] = None _default_max_tokens: Optional[int] = None # Use dynamic model-aware limits by default @@ -67,6 +69,7 @@ class ChatMessage(BaseModel): class ChatCompletionRequest(BaseModel): model: str messages: List[ChatMessage] + chat_start: bool max_tokens: Optional[int] = None temperature: Optional[float] = 0.7 top_p: Optional[float] = 0.9 @@ -135,22 +138,21 @@ def get_or_load_model(model_spec: str, verbose: bool = False) -> MLXRunner: """Get model from cache or load it if not cached.""" global _model_cache, _current_model_path - print(model_spec) # Use the existing model path resolution from cache_utils try: model_path, model_name, commit_hash = get_model_path(model_spec) if not model_path.exists(): + logger.info(f"Model {model_spec} not found in cache") raise HTTPException(status_code=404, detail=f"Model {model_spec} not found in cache") except Exception as e: + logger.info(f"Model {model_spec} not found in: {str(e)}") raise HTTPException(status_code=404, detail=f"Model {model_spec} not found: {str(e)}") # Check if it's an MLX model model_path_str = str(model_path) - print(_current_model_path) - print(model_path_str) # Check if we need to load a different model if _current_model_path != model_path_str: # Proactively clean up any previously loaded runner to release memory @@ -168,11 +170,14 @@ def get_or_load_model(model_spec: str, verbose: bool = False) -> MLXRunner: if verbose: print(f"Loading model: {model_name}") + logger.info(f"Loading model: {model_name}") runner = MLXRunner(model_path_str, verbose=verbose) runner.load_model() _model_cache[model_path_str] = runner _current_model_path = model_path_str + else: + logger.info(f"Model {model_name} already in memory") return _model_cache[model_path_str] @@ -196,9 +201,10 @@ async def ping(): async def start_model(request: StartRequest): """Load the model and start the agent""" global _messages, _runner,_memory_path - print(str(request)) + _messages = [ChatMessage(role="system", content=SYSTEM_PROMPT)] _memory_path = request.memory_path + try: _runner = get_or_load_model(request.model) return {"message": "Model loaded"} @@ -226,7 +232,8 @@ async def create_chat_completion(request: ChatCompletionRequest): # Convert messages to dict format for runner # _messages.append(system_message) - _messages.extend(request.messages) + if request.chat_start: + _messages.extend(request.messages) message_dicts = format_chat_messages_for_runner(_messages) # Let the runner format with chat templates prompt = runner._format_conversation(message_dicts, use_chat_template=True) @@ -241,14 +248,17 @@ async def create_chat_completion(request: ChatCompletionRequest): ) # Token counting - # total_prompt = "\n\n".join([msg.content for msg in request.messages]) - # prompt_tokens = count_tokens(total_prompt) - # completion_tokens = count_tokens(generated_text) + total_prompt = "\n\n".join([msg.content for msg in request.messages]) + prompt_tokens = count_tokens(total_prompt) + completion_tokens = count_tokens(generated_text) + + logger.info(f"prompt_token\n{prompt_tokens}") + logger.info(f"completion_tokens\n{completion_tokens}") thoughts = extract_thoughts(generated_text) reply = extract_reply(generated_text) python_code = extract_python_code(generated_text) - print(generated_text) + result = ({}, "") if python_code: create_memory_if_not_exists() @@ -258,36 +268,52 @@ async def create_chat_completion(request: ChatCompletionRequest): import_module="server.mem_agent.tools", ) - print(reply) - print(str(result)) - - remaining_tool_turns = _max_tool_turns - while remaining_tool_turns > 0 and not reply: - _messages.append(ChatMessage(role="user", content=format_results(result[0], result[1]))) - message_dicts = format_chat_messages_for_runner(_messages) - # Let the runner format with chat templates - prompt = runner._format_conversation(message_dicts, use_chat_template=True) - generated_text = runner.generate_batch( - prompt=prompt - ) - print(generated_text) - # Extract the thoughts, reply and python code from the response - thoughts = extract_thoughts(generated_text) - reply = extract_reply(generated_text) - python_code = extract_python_code(generated_text) - - _messages.append(ChatMessage(role="assistant", content=generated_text)) - if python_code: - create_memory_if_not_exists() - result = execute_sandboxed_code( - code=python_code, - allowed_path=_memory_path, - import_module="server.mem_agent.tools", - ) - else: - # Reset result when no Python code is executed - result = ({}, "") - remaining_tool_turns -= 1 + logger.info(f"Model thoughts\n{thoughts}") + logger.info(f"Model reply\n{reply}") + logger.info(f"Model python\n{python_code}") + logger.info(f"executed python result\n{str(result)}") + + # while remaining_tool_turns > 0 and not reply: + # logger.info(f"Turn count\n{remaining_tool_turns}") + _messages.append(ChatMessage(role="user", content=format_results(result[0], result[1]))) + message_dicts = format_chat_messages_for_runner(_messages) + # # Let the runner format with chat templates + # prompt = runner._format_conversation(message_dicts, use_chat_template=True) + # generated_text = runner.generate_batch( + # prompt=prompt + # ) + + # total_prompt = "\n\n".join([msg.content for msg in _messages]) + # prompt_tokens = count_tokens(total_prompt) + # completion_tokens = count_tokens(generated_text) + + # logger.info(f"prompt_token\n{prompt_tokens}") + # logger.info(f"completion_tokens\n{completion_tokens}") + + # # print(generated_text) + # # Extract the thoughts, reply and python code from the response + # thoughts = extract_thoughts(generated_text) + # reply = extract_reply(generated_text) + # python_code = extract_python_code(generated_text) + + # logger.info(f"Model thoughts\n{thoughts}") + # logger.info(f"Model reply\n{reply}") + # logger.info(f"Model python\n{python_code}") + + # _messages.append(ChatMessage(role="assistant", content=generated_text)) + # if python_code: + # create_memory_if_not_exists() + # result = execute_sandboxed_code( + # code=python_code, + # allowed_path=_memory_path, + # import_module="server.mem_agent.tools", + # ) + # logger.info(f"executed python result\n{str(result)}") + # else: + # # Reset result when no Python code is executed + # result = ({}, "") + # logger.info(f"executed python result\n{str(result)}") + # remaining_tool_turns -= 1 return ChatCompletionResponse( id=completion_id, @@ -298,7 +324,7 @@ async def create_chat_completion(request: ChatCompletionRequest): "index": 0, "message": { "role": "assistant", - "content": reply + "content": generated_text }, "finish_reason": "stop" } diff --git a/server/cache_utils.py b/server/cache_utils.py index cf6bc29..78f8ec0 100644 --- a/server/cache_utils.py +++ b/server/cache_utils.py @@ -182,7 +182,6 @@ def resolve_single_model(model_spec): def get_model_path(model_spec): model_name, commit_hash = parse_model_spec(model_spec) base_cache_dir = MODEL_CACHE / hf_to_cache_dir(model_name) - print(base_cache_dir) if not base_cache_dir.exists(): return None, model_name, commit_hash if commit_hash: diff --git a/server/main.py b/server/main.py index b52ef7e..f4e9d7b 100644 --- a/server/main.py +++ b/server/main.py @@ -1,19 +1,40 @@ -# import os import uvicorn from .api import app from .config import PORT +import logging +import sys +from fastapi import Request -def run(): - # Write PID file - # PID_FILE.write_text(str(os.getpid())) +# --- logging setup --- +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", + handlers=[logging.StreamHandler(sys.stdout)], +) +logger = logging.getLogger("app") + + +# --- middleware for request logging --- +@app.middleware("http") +async def log_requests(request: Request, call_next): + try: + body = await request.json() + except Exception: + body = None - # try: + logger.info({ + "method": request.method, + "url": str(request.url), + "client": request.client.host, + "body": body, + }) + + response = await call_next(request) + logger.info(f"<-- {request.method} {request.url.path} {response.status_code}") + return response + +def run(): uvicorn.run(app, host="127.0.0.1", port=PORT) - # finally: - # if PID_FILE.exists(): - # PID_FILE.unlink() - # - # print("hello from main!") if __name__ == "__main__": run() diff --git a/server/mem_agent/engine.py b/server/mem_agent/engine.py index 7db783b..0187629 100644 --- a/server/mem_agent/engine.py +++ b/server/mem_agent/engine.py @@ -300,8 +300,8 @@ def execute_sandboxed_code( if result.returncode != 0: return None, result.stderr.decode().strip() - print("stderr:", result.stderr.decode()) - print("stdout:", result.stdout[:200]) + # print("stderr:", result.stderr.decode()) + # print("stdout:", result.stdout[:200]) try: local_vars, error_msg = pickle.loads(result.stdout) diff --git a/server/mem_agent/utils.py b/server/mem_agent/utils.py index 6477e10..034525e 100644 --- a/server/mem_agent/utils.py +++ b/server/mem_agent/utils.py @@ -188,6 +188,8 @@ def extract_thoughts(response: str) -> str: """ if "" in response and "" in response: return response.split("")[1].split("")[0] + elif "" in response: + return response.split("")[0] else: return "" diff --git a/server/system_prompt.txt b/server/system_prompt.txt index 1b08f4b..7ff8327 100644 --- a/server/system_prompt.txt +++ b/server/system_prompt.txt @@ -40,6 +40,7 @@ You are an LLM agent with a self-managed, Obsidian-like memory system. You inter ``` **CRITICAL: Always close ALL tags! Missing , , or will cause errors!** +**CRITICAL: Think block MUST have opening and closing . This block should not only have **NEVER:** - Skip the `` block diff --git a/src/runner/mlx.rs b/src/runner/mlx.rs index adb2aed..00d4529 100644 --- a/src/runner/mlx.rs +++ b/src/runner/mlx.rs @@ -1,13 +1,20 @@ +use crate::core::modelfile::Modelfile; use anyhow::{Context, Result}; +use owo_colors::OwoColorize; use reqwest::Client; use serde_json::{Value, json}; use std::io::Write; use std::path::PathBuf; use std::process::Stdio; +use std::str::FromStr; use std::{env, fs}; use std::{io, process::Command}; -use crate::core::modelfile::Modelfile; +pub struct ChatResponse { + think: String, + reply: String, + code: String, +} pub async fn run(modelfile: Modelfile) { let model = modelfile.from.as_ref().unwrap(); @@ -141,10 +148,28 @@ async fn run_model_with_server(modelfile: Modelfile) -> reqwest::Result<()> { break; } _ => { - if let Ok(response) = chat(input, modelname).await { - println!(">> {}", response) - } else { - println!(">> failed to respond") + let mut remaining_count = 6; + let mut g_reply: String = "".to_owned(); + loop { + if remaining_count > 0 { + let chat_start = if remaining_count == 6 { true } else { false }; + if let Ok(response) = chat(input, modelname, chat_start).await { + if response.reply.is_empty() { + remaining_count = remaining_count - 1; + println!("{}", response.think.dimmed()) + } else { + g_reply = response.reply.clone(); + println!(">> {}", response.reply.trim()); + break; + } + } else { + println!(">> failed to respond"); + break; + } + } + } + if g_reply.is_empty() { + println!(">> No reply") } } } @@ -178,10 +203,12 @@ async fn load_model(model_name: &str, memory_path: &str) -> Result<(), String> { } } -async fn chat(input: &str, model_name: &str) -> Result { +async fn chat(input: &str, model_name: &str, chat_start: bool) -> Result { let client = Client::new(); + let body = json!({ "model": model_name, + "chat_start": chat_start, "messages": [{"role": "user", "content": input}] }); let res = client @@ -197,12 +224,55 @@ async fn chat(input: &str, model_name: &str) -> Result { let content = v["choices"][0]["message"]["content"] .as_str() .unwrap_or(""); - Ok(content.to_owned()) + + Ok(convert_to_chat_response(content)) } else { Err(String::from("request failed")) } } +fn convert_to_chat_response(content: &str) -> ChatResponse { + // content.split() + ChatResponse { + think: extract_think(content), + reply: extract_reply(content), + code: extract_python(content), + } +} + +fn extract_reply(content: &str) -> String { + if content.contains("") && content.contains("") { + let list_a = content.split("").collect::>(); + let list_b = list_a[1].split("").collect::>(); + list_b[0].to_owned() + } else { + "".to_owned() + } +} + +fn extract_python(content: &str) -> String { + if content.contains("") && content.contains("") { + let list_a = content.split("").collect::>(); + let list_b = list_a[1].split("").collect::>(); + list_b[0].to_owned() + } else { + "".to_owned() + } +} + +fn extract_think(content: &str) -> String { + if content.contains("") && content.contains("") { + let list_a = content.split("").collect::>(); + let list_b = list_a[1].split("").collect::>(); + list_b[0].to_owned() + } else if content.contains("").collect::>(); + list_a[0].to_owned() + } else { + "".to_owned() + } +} + fn get_memory_path() -> Result { let tiles_config_dir = get_config_dir()?; let tiles_data_dir = get_data_dir()?; From 4ccb2841750ba304c9a7f2a21bf84fc82d4c0118 Mon Sep 17 00:00:00 2001 From: madclaws Date: Sun, 23 Nov 2025 16:04:12 +0530 Subject: [PATCH 2/2] feat: sloppy streaming from cli --- Cargo.lock | 28 ++++ Cargo.toml | 3 +- server/api.py | 334 +++++++++++++++++++++++++++++++--------------- src/runner/mlx.rs | 81 ++++++++--- 4 files changed, 320 insertions(+), 126 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 592c5d7..9064ba5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -266,6 +266,17 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.31" @@ -286,6 +297,7 @@ checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ "futures-core", "futures-io", + "futures-macro", "futures-sink", "futures-task", "memchr", @@ -848,12 +860,14 @@ dependencies = [ "sync_wrapper", "tokio", "tokio-native-tls", + "tokio-util", "tower", "tower-http", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", ] @@ -1133,6 +1147,7 @@ version = "0.1.0" dependencies = [ "anyhow", "clap", + "futures-util", "nom", "owo-colors", "reqwest", @@ -1418,6 +1433,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.81" diff --git a/Cargo.toml b/Cargo.toml index dff6ff4..cbcdbad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,9 +6,10 @@ edition = "2024" [dependencies] clap = { version = "4.5.48", features = ["derive"] } nom = "8" -reqwest = { version = "0.12", features = ["json", "blocking"] } +reqwest = { version = "0.12", features = ["json", "blocking", "stream"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" anyhow = "1.0" tokio = { version = "1" , features = ["macros", "rt-multi-thread"]} owo-colors = "4" +futures-util = "0.3" diff --git a/server/api.py b/server/api.py index 675a9e0..9d502a7 100644 --- a/server/api.py +++ b/server/api.py @@ -70,6 +70,7 @@ class ChatCompletionRequest(BaseModel): model: str messages: List[ChatMessage] chat_start: bool + python_code: str max_tokens: Optional[int] = None temperature: Optional[float] = 0.7 top_p: Optional[float] = 0.9 @@ -218,122 +219,237 @@ async def create_chat_completion(request: ChatCompletionRequest): try: runner = get_or_load_model(request.model) - # if request.stream: - # # Streaming response - # return StreamingResponse( - # generate_chat_stream(runner, request.messages, request), - # media_type="text/plain", - # headers={"Cache-Control": "no-cache"} - # ) - # else: - # Non-streaming response - completion_id = f"chatcmpl-{uuid.uuid4()}" - created = int(time.time()) - - # Convert messages to dict format for runner - # _messages.append(system_message) - if request.chat_start: - _messages.extend(request.messages) - message_dicts = format_chat_messages_for_runner(_messages) - # Let the runner format with chat templates - prompt = runner._format_conversation(message_dicts, use_chat_template=True) - - generated_text = runner.generate_batch( + if request.stream: + result = ({}, "") + if request.python_code: + create_memory_if_not_exists() + result = execute_sandboxed_code( + code=request.python_code, + allowed_path=_memory_path, + import_module="server.mem_agent.tools", + ) + + _messages.append(ChatMessage(role="user", content=format_results(result[0], result[1]))) + + # Streaming response + return StreamingResponse( + generate_chat_stream(runner, request.messages, request), + media_type="text/plain", + headers={"Cache-Control": "no-cache"} + ) + else: + # Non-streaming response + completion_id = f"chatcmpl-{uuid.uuid4()}" + created = int(time.time()) + + # Convert messages to dict format for runner + # _messages.append(system_message) + if request.chat_start: + _messages.extend(request.messages) + message_dicts = format_chat_messages_for_runner(_messages) + # Let the runner format with chat templates + prompt = runner._format_conversation(message_dicts, use_chat_template=True) + + generated_text = runner.generate_batch( + prompt=prompt, + max_tokens=runner.get_effective_max_tokens(request.max_tokens or _default_max_tokens, interactive=False), + temperature=request.temperature, + top_p=request.top_p, + repetition_penalty=request.repetition_penalty, + use_chat_template=False # Already applied in _format_conversation + ) + + # Token counting + total_prompt = "\n\n".join([msg.content for msg in request.messages]) + prompt_tokens = count_tokens(total_prompt) + completion_tokens = count_tokens(generated_text) + + logger.info(f"prompt_token\n{prompt_tokens}") + logger.info(f"completion_tokens\n{completion_tokens}") + + thoughts = extract_thoughts(generated_text) + reply = extract_reply(generated_text) + python_code = extract_python_code(generated_text) + + result = ({}, "") + if python_code: + create_memory_if_not_exists() + result = execute_sandboxed_code( + code=python_code, + allowed_path=_memory_path, + import_module="server.mem_agent.tools", + ) + + logger.info(f"Model thoughts\n{thoughts}") + logger.info(f"Model reply\n{reply}") + logger.info(f"Model python\n{python_code}") + logger.info(f"executed python result\n{str(result)}") + + # while remaining_tool_turns > 0 and not reply: + # logger.info(f"Turn count\n{remaining_tool_turns}") + _messages.append(ChatMessage(role="user", content=format_results(result[0], result[1]))) + message_dicts = format_chat_messages_for_runner(_messages) + # # Let the runner format with chat templates + # prompt = runner._format_conversation(message_dicts, use_chat_template=True) + # generated_text = runner.generate_batch( + # prompt=prompt + # ) + + # total_prompt = "\n\n".join([msg.content for msg in _messages]) + # prompt_tokens = count_tokens(total_prompt) + # completion_tokens = count_tokens(generated_text) + + # logger.info(f"prompt_token\n{prompt_tokens}") + # logger.info(f"completion_tokens\n{completion_tokens}") + + # # print(generated_text) + # # Extract the thoughts, reply and python code from the response + # thoughts = extract_thoughts(generated_text) + # reply = extract_reply(generated_text) + # python_code = extract_python_code(generated_text) + + # logger.info(f"Model thoughts\n{thoughts}") + # logger.info(f"Model reply\n{reply}") + # logger.info(f"Model python\n{python_code}") + + # _messages.append(ChatMessage(role="assistant", content=generated_text)) + # if python_code: + # create_memory_if_not_exists() + # result = execute_sandboxed_code( + # code=python_code, + # allowed_path=_memory_path, + # import_module="server.mem_agent.tools", + # ) + # logger.info(f"executed python result\n{str(result)}") + # else: + # # Reset result when no Python code is executed + # result = ({}, "") + # logger.info(f"executed python result\n{str(result)}") + # remaining_tool_turns -= 1 + + return ChatCompletionResponse( + id=completion_id, + created=created, + model=request.model, + choices=[ + { + "index": 0, + "message": { + "role": "assistant", + "content": generated_text + }, + "finish_reason": "stop" + } + ], + # usage={ + # "prompt_tokens": prompt_tokens, + # "completion_tokens": completion_tokens, + # "total_tokens": prompt_tokens + completion_tokens + # } + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + +async def generate_chat_stream( + runner: MLXRunner, + messages: List[ChatMessage], + request: ChatCompletionRequest +) -> AsyncGenerator[str, None]: + """Generate streaming chat completion response.""" + + global _messages + completion_id = f"chatcmpl-{uuid.uuid4()}" + created = int(time.time()) + + if request.chat_start: + _messages.extend(request.messages) + # Convert messages to dict format for runner + message_dicts = format_chat_messages_for_runner(_messages) + + # Let the runner format with chat templates + prompt = runner._format_conversation(message_dicts, use_chat_template=True) + + # Yield initial response + initial_response = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created, + "model": request.model, + "choices": [ + { + "index": 0, + "delta": {"role": "assistant", "content": ""}, + "finish_reason": None + } + ] + } + + yield f"data: {json.dumps(initial_response)}\n\n" + + # Stream tokens + try: + for token in runner.generate_streaming( prompt=prompt, max_tokens=runner.get_effective_max_tokens(request.max_tokens or _default_max_tokens, interactive=False), temperature=request.temperature, top_p=request.top_p, repetition_penalty=request.repetition_penalty, - use_chat_template=False # Already applied in _format_conversation - ) - - # Token counting - total_prompt = "\n\n".join([msg.content for msg in request.messages]) - prompt_tokens = count_tokens(total_prompt) - completion_tokens = count_tokens(generated_text) - - logger.info(f"prompt_token\n{prompt_tokens}") - logger.info(f"completion_tokens\n{completion_tokens}") - - thoughts = extract_thoughts(generated_text) - reply = extract_reply(generated_text) - python_code = extract_python_code(generated_text) - - result = ({}, "") - if python_code: - create_memory_if_not_exists() - result = execute_sandboxed_code( - code=python_code, - allowed_path=_memory_path, - import_module="server.mem_agent.tools", - ) + use_chat_template=False, # Already applied in _format_conversation + use_chat_stop_tokens=False # Server mode shouldn't stop on chat markers + ): + chunk_response = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created, + "model": request.model, + "choices": [ + { + "index": 0, + "delta": {"content": token}, + "finish_reason": None + } + ] + } + + yield f"data: {json.dumps(chunk_response)}\n\n" + + # Check for stop sequences + if request.stop: + stop_sequences = request.stop if isinstance(request.stop, list) else [request.stop] + if any(stop in token for stop in stop_sequences): + break - logger.info(f"Model thoughts\n{thoughts}") - logger.info(f"Model reply\n{reply}") - logger.info(f"Model python\n{python_code}") - logger.info(f"executed python result\n{str(result)}") - - # while remaining_tool_turns > 0 and not reply: - # logger.info(f"Turn count\n{remaining_tool_turns}") - _messages.append(ChatMessage(role="user", content=format_results(result[0], result[1]))) - message_dicts = format_chat_messages_for_runner(_messages) - # # Let the runner format with chat templates - # prompt = runner._format_conversation(message_dicts, use_chat_template=True) - # generated_text = runner.generate_batch( - # prompt=prompt - # ) - - # total_prompt = "\n\n".join([msg.content for msg in _messages]) - # prompt_tokens = count_tokens(total_prompt) - # completion_tokens = count_tokens(generated_text) - - # logger.info(f"prompt_token\n{prompt_tokens}") - # logger.info(f"completion_tokens\n{completion_tokens}") - - # # print(generated_text) - # # Extract the thoughts, reply and python code from the response - # thoughts = extract_thoughts(generated_text) - # reply = extract_reply(generated_text) - # python_code = extract_python_code(generated_text) - - # logger.info(f"Model thoughts\n{thoughts}") - # logger.info(f"Model reply\n{reply}") - # logger.info(f"Model python\n{python_code}") - - # _messages.append(ChatMessage(role="assistant", content=generated_text)) - # if python_code: - # create_memory_if_not_exists() - # result = execute_sandboxed_code( - # code=python_code, - # allowed_path=_memory_path, - # import_module="server.mem_agent.tools", - # ) - # logger.info(f"executed python result\n{str(result)}") - # else: - # # Reset result when no Python code is executed - # result = ({}, "") - # logger.info(f"executed python result\n{str(result)}") - # remaining_tool_turns -= 1 - - return ChatCompletionResponse( - id=completion_id, - created=created, - model=request.model, - choices=[ + except Exception as e: + error_response = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created, + "model": request.model, + "choices": [ { "index": 0, - "message": { - "role": "assistant", - "content": generated_text - }, - "finish_reason": "stop" + "delta": {}, + "finish_reason": "error" } ], - # usage={ - # "prompt_tokens": prompt_tokens, - # "completion_tokens": completion_tokens, - # "total_tokens": prompt_tokens + completion_tokens - # } - ) - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) + "error": str(e) + } + yield f"data: {json.dumps(error_response)}\n\n" + + # Final response + final_response = { + "id": completion_id, + "object": "chat.completion.chunk", + "created": created, + "model": request.model, + "choices": [ + { + "index": 0, + "delta": {}, + "finish_reason": "stop" + } + ] + } + + yield f"data: {json.dumps(final_response)}\n\n" + yield "data: [DONE]\n\n" diff --git a/src/runner/mlx.rs b/src/runner/mlx.rs index 00d4529..fbad45c 100644 --- a/src/runner/mlx.rs +++ b/src/runner/mlx.rs @@ -1,5 +1,6 @@ use crate::core::modelfile::Modelfile; use anyhow::{Context, Result}; +use futures_util::StreamExt; use owo_colors::OwoColorize; use reqwest::Client; use serde_json::{Value, json}; @@ -9,7 +10,6 @@ use std::process::Stdio; use std::str::FromStr; use std::{env, fs}; use std::{io, process::Command}; - pub struct ChatResponse { think: String, reply: String, @@ -136,6 +136,7 @@ async fn run_model_with_server(modelfile: Modelfile) -> reqwest::Result<()> { let modelname = modelfile.from.as_ref().unwrap(); load_model(modelname, &memory_path).await.unwrap(); println!("Running in interactive mode"); + // TODO: Handle "enter" key press or any key press when repl is processing an input loop { print!(">> "); stdout.flush().unwrap(); @@ -150,20 +151,24 @@ async fn run_model_with_server(modelfile: Modelfile) -> reqwest::Result<()> { _ => { let mut remaining_count = 6; let mut g_reply: String = "".to_owned(); + let mut python_code: String = "".to_owned(); loop { if remaining_count > 0 { let chat_start = if remaining_count == 6 { true } else { false }; - if let Ok(response) = chat(input, modelname, chat_start).await { + if let Ok(response) = chat(input, modelname, chat_start, &python_code).await + { if response.reply.is_empty() { + if !response.code.is_empty() { + python_code = response.code; + } remaining_count = remaining_count - 1; - println!("{}", response.think.dimmed()) } else { g_reply = response.reply.clone(); - println!(">> {}", response.reply.trim()); + println!("\n>> {}", response.reply.trim()); break; } } else { - println!(">> failed to respond"); + println!("\n>> failed to respond"); break; } } @@ -203,12 +208,19 @@ async fn load_model(model_name: &str, memory_path: &str) -> Result<(), String> { } } -async fn chat(input: &str, model_name: &str, chat_start: bool) -> Result { +async fn chat( + input: &str, + model_name: &str, + chat_start: bool, + python_code: &str, +) -> Result { let client = Client::new(); let body = json!({ "model": model_name, "chat_start": chat_start, + "stream": true, + "python_code": python_code, "messages": [{"role": "user", "content": input}] }); let res = client @@ -217,18 +229,55 @@ async fn chat(input: &str, model_name: &str, chat_start: bool) -> Result"); - Ok(convert_to_chat_response(content)) - } else { - Err(String::from("request failed")) + let mut stream = res.bytes_stream(); + let mut accumulated = String::new(); + let mut chat_response = ChatResponse { + think: String::new(), + reply: String::new(), + code: String::new(), + }; + // let mut inside_python = false; + // let mut tag_buffer = String::new(); + print!("\n"); + while let Some(chunk) = stream.next().await { + let chunk = chunk.unwrap(); + let s = String::from_utf8_lossy(&chunk); + for line in s.lines() { + if !line.starts_with("data: ") { + continue; + } + + let data = line.trim_start_matches("data: "); + + if data == "[DONE]" { + chat_response = convert_to_chat_response(&accumulated); + return Ok(chat_response); + } + // Parse JSON + let v: Value = serde_json::from_str(data).unwrap(); + if let Some(delta) = v["choices"][0]["delta"]["content"].as_str() { + accumulated.push_str(delta); + print!("{}", delta.dimmed()); + use std::io::Write; + std::io::stdout().flush().ok(); + } + } } + // println!("{:?}", res); + // if res.status() == 200 { + // let text = res.text().await.unwrap(); + // let v: Value = serde_json::from_str(&text).unwrap(); + // let content = v["choices"][0]["message"]["content"] + // .as_str() + // .unwrap_or(""); + + // // Ok(convert_to_chat_response(content)) + // } else { + // // Err(String::from("request failed")) + // } + // unimplemented!() + Err(String::from("request failed")) } fn convert_to_chat_response(content: &str) -> ChatResponse {