diff --git a/.dockerignore b/.dockerignore index fd64c09b3..1b85f8c9a 100644 --- a/.dockerignore +++ b/.dockerignore @@ -2,6 +2,7 @@ _skbuild/ .envrc +# LLMs - comment if you'd like to bake the model into the image models/ # Byte-compiled / optimized / DLL files diff --git a/dev.Dockerfile b/dev.Dockerfile new file mode 100644 index 000000000..24f5be727 --- /dev/null +++ b/dev.Dockerfile @@ -0,0 +1,44 @@ +# Define the image argument and provide a default value +ARG IMAGE=python:3.11.8 + +# Use the image as specified +FROM ${IMAGE} + +# Re-declare the ARG after FROM +ARG IMAGE + +# Update and upgrade the existing packages +RUN apt-get update && apt-get upgrade -y && apt-get install -y --no-install-recommends \ + python3 \ + python3-pip \ + ninja-build \ + libopenblas-dev \ + build-essential \ + git + +RUN mkdir /app +WORKDIR /app +COPY . /app + +RUN python3 -m pip install --upgrade pip + +RUN make deps && make build && make clean + +# Set environment variable for the host +ENV GH_TOKEN=$GH_TOKEN +ENV HOST=0.0.0.0 +ENV PORT=8000 +ENV MODEL=/app/models/mistral-7b-openorca.Q5_K_M.gguf + +# # Install depencencies +# RUN python3 -m pip install --upgrade pip pytest cmake scikit-build setuptools fastapi uvicorn sse-starlette pydantic-settings starlette-context psutil prometheus_client + +# # Install llama-cpp-python (build with METAL) +# RUN CMAKE_ARGS="-DLLAMA_METAL=on" FORCE_CMAKE=1 pip install git+https://${GH_TOKEN}@github.com/ZenHubHQ/llama-cpp-python.git --force-reinstall --upgrade --no-cache-dir --verbose + +# Expose a port for the server +EXPOSE 8000 + +# Run the server start script +CMD ["/bin/sh", "/app/docker/simple/run.sh"] +# CMD python3 -m llama_cpp.server --n_gpu_layers -1 diff --git a/dev.docker-compose b/dev.docker-compose new file mode 100644 index 000000000..7b21e468a --- /dev/null +++ b/dev.docker-compose @@ -0,0 +1,15 @@ +version: '3' +services: + dev-llama-cpp-python: + build: + context: . + dockerfile: dev.Dockerfile + ports: + - 8000:8000 + volumes: + - ./llama_cpp:/app/llama_cpp + networks: + - zh-service-network +networks: + zh-service-network: + external: true \ No newline at end of file diff --git a/docker/simple/run.sh b/docker/simple/run.sh index c85e73d2b..d4fd489a0 100644 --- a/docker/simple/run.sh +++ b/docker/simple/run.sh @@ -1,4 +1,5 @@ #!/bin/bash make build -uvicorn --factory llama_cpp.server.app:create_app --host $HOST --port $PORT +# uvicorn --factory llama_cpp.server.app:create_app --host $HOST --port $PORT --reload +python3 -m llama_cpp.server --model $MODEL --n_gpu_layers -1 \ No newline at end of file diff --git a/llama_cpp/_utils.py b/llama_cpp/_utils.py index 7ab94964b..db5b6eb68 100644 --- a/llama_cpp/_utils.py +++ b/llama_cpp/_utils.py @@ -3,7 +3,7 @@ import psutil import subprocess -from typing import Any, Dict, List +from typing import Any, Dict, List, Tuple, Union # Avoid "LookupError: unknown encoding: ascii" when open() called in a destructor outnull_file = open(os.devnull, "w") @@ -112,7 +112,7 @@ def get_gpu_info_by_pid(pid) -> float: pass return 0.0 -def get_gpu_general_info() -> tuple[float, float, float]: +def get_gpu_general_info() -> Tuple[float, float, float]: """ GPU general info (if GPU is available) """ @@ -123,25 +123,3 @@ def get_gpu_general_info() -> tuple[float, float, float]: except (subprocess.CalledProcessError, FileNotFoundError): pass return 0.0, 0.0, 0.0 - -def infer_service_from_prompt(prompt: str | List[str]): - """ - Infer the service for which a completion request is sent based on the prompt. - """ - LABEL_SUGGESTIONS_TASK = "Your task is to select the most relevant labels for a GitHub issue title from a list of labels provided." - ACCEPTANCE_CRITERIA_TASK = "Your task is to write the acceptance criteria for a GitHub issue." - SPRINT_REVIEW_TASK = "You are helping me prepare a sprint review." - - if isinstance(prompt, list): - prompt = " ".join(prompt) - - if LABEL_SUGGESTIONS_TASK in prompt: - return "label-suggestions" - - elif ACCEPTANCE_CRITERIA_TASK in prompt: - return "acceptance-criteria" - - elif SPRINT_REVIEW_TASK in prompt: - return "sprint-review" - - return "not-specified" diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 432f4db3b..e0a24ce85 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -41,7 +41,6 @@ from llama_cpp.llama_metrics import Metrics, MetricsExporter from llama_cpp._utils import ( - infer_service_from_prompt, get_cpu_usage, get_ram_usage, get_gpu_info_by_pid, @@ -70,6 +69,7 @@ class Llama: """High-level Python wrapper for a llama.cpp model.""" __backend_initialized = False + __prometheus_metrics = MetricsExporter() def __init__( self, @@ -464,7 +464,7 @@ def __init__( print(f"Using fallback chat format: {chat_format}", file=sys.stderr) # Prometheus metrics - self.metrics = MetricsExporter() + self.metrics = self.__prometheus_metrics @property def ctx(self) -> llama_cpp.llama_context_p: @@ -960,6 +960,7 @@ def _create_completion( logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, logit_bias: Optional[Dict[str, float]] = None, + ai_service: Optional[str] = None ) -> Union[ Iterator[CreateCompletionResponse], Iterator[CreateCompletionStreamResponse] ]: @@ -974,8 +975,10 @@ def _create_completion( _ttft_start = time.time() _pid = os.getpid() _tpot_metrics = [] + if not ai_service: + raise ValueError("ai_service must be provided") _labels = { - "service": infer_service_from_prompt(prompt), # Infer the service for which the completion is being generated + "service": ai_service if ai_service is not None else "not-specified", "request_type": "chat/completions", } # Get CPU usage before generating completion so it can be used to calculate CPU when called after completing the process @@ -1278,6 +1281,14 @@ def logit_bias_processor( token_end_position = 0 for token in remaining_tokens: + # Record TTFT metric (once) + if idx == 0: + _metrics_dict["time_to_first_token"] = time.time() - _ttft_start + # Record TPOT metric + else: + _tpot_metrics.append(time.time() - _tpot_start) + _tpot_start = time.time() # reset + token_end_position += len(self.detokenize([token], prev_tokens=prompt_tokens + completion_tokens[:returned_tokens])) logprobs_or_none: Optional[CompletionLogprobs] = None @@ -1371,6 +1382,53 @@ def logit_bias_processor( print("Llama._create_completion: cache save", file=sys.stderr) self.cache[prompt_tokens + completion_tokens] = self.save_state() print("Llama._create_completion: cache saved", file=sys.stderr) + + ## PROMETHEUS METRICS IN STREAMING MODE ## + # Record TTFT metric -- Setting to None if no tokens were generated + if not _metrics_dict.get("time_to_first_token"): + _metrics_dict["time_to_first_token"] = None + + # Record TPOT metrics (per generated token) + _metrics_dict["time_per_output_token"] = _tpot_metrics + + # Record metrics from the C++ backend (converted to seconds) + _timings = llama_cpp.llama_get_timings(self._ctx.ctx) + _metrics_dict["load_time"] = round(_timings.t_load_ms / 1e3, 2) + _metrics_dict["sample_time"] = round(_timings.t_sample_ms / 1e3, 2) + _metrics_dict["sample_throughput"] = round(1e3 / _timings.t_sample_ms * _timings.n_sample, 2) if _timings.t_sample_ms > 0 else 0.0 + _metrics_dict["prompt_eval_time"] = round(_timings.t_p_eval_ms / 1e3, 2) + _metrics_dict["prompt_eval_throughput"] = round(1e3 / _timings.t_p_eval_ms * _timings.n_p_eval, 2) if _timings.t_p_eval_ms > 0 else 0.0 + _metrics_dict["completion_eval_time"] = round(_timings.t_eval_ms / 1e3, 2) + _metrics_dict["completion_eval_throughput"] = round(1e3 / _timings.t_eval_ms * _timings.n_eval, 2) if _timings.t_eval_ms > 0 else 0.0 + _metrics_dict["end_to_end_latency"] = round((_timings.t_end_ms - _timings.t_start_ms) / 1e3, 2) + + # Record prefill and generation token metrics + _metrics_dict["prefill_tokens"] = len(prompt_tokens) + _metrics_dict["generation_tokens"] = len(completion_tokens) + + # Record system info + _gpu_utilization, _gpu_memory_used, _gpu_memory_free = get_gpu_general_info() + _metrics_dict["cpu_utilization"] = get_cpu_usage(_pid) # TODO: Returning always 0.0 -> check + _metrics_dict["cpu_ram_pid"] = get_ram_usage(_pid) + _metrics_dict["gpu_utilization"] = _gpu_utilization + _metrics_dict["gpu_ram_usage"] = _gpu_memory_used + _metrics_dict["gpu_ram_free"] = _gpu_memory_free + _metrics_dict["gpu_ram_pid"] = get_gpu_info_by_pid(_pid) + _metrics_dict["state_size"] = llama_cpp.llama_get_state_size(self._ctx.ctx) + _metrics_dict["kv_cache_usage_ratio"] = round(1. * llama_cpp.llama_get_kv_cache_used_cells(self._ctx.ctx) / self.n_ctx(), 2) + _metrics_dict["system_info"] = { + "model": model_name, + "n_params": str(llama_cpp.llama_model_n_params(self.model)), + "n_embd": str(self.n_embd()), + "n_ctx": str(self.n_ctx()), + "n_vocab": str(self.n_vocab()), + "n_threads": str(self.n_threads) + } + + # Log metrics to Prometheus + _all_metrics = Metrics(**_metrics_dict) + self.metrics.log_metrics(_all_metrics, labels=_labels) + return if self.cache: @@ -1446,6 +1504,11 @@ def logit_bias_processor( "top_logprobs": top_logprobs, } + ## PROMETHEUS METRICS IN CHAT COMPLETION MODE ## + # Record TTFT metric -- Setting to None if no tokens were generated + if not _metrics_dict.get("time_to_first_token"): + _metrics_dict["time_to_first_token"] = None + # Record TPOT metrics (per generated token) _metrics_dict["time_per_output_token"] = _tpot_metrics @@ -1484,7 +1547,6 @@ def logit_bias_processor( } # Log metrics to Prometheus - #print(_metrics_dict, file=sys.stderr) _all_metrics = Metrics(**_metrics_dict) self.metrics.log_metrics(_all_metrics, labels=_labels) @@ -1493,6 +1555,7 @@ def logit_bias_processor( "object": "text_completion", "created": created, "model": model_name, + "service": ai_service, "choices": [ { "text": text_str, @@ -1535,6 +1598,7 @@ def create_completion( logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, logit_bias: Optional[Dict[str, float]] = None, + ai_service: Optional[str] = None ) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]: """Generate text from a prompt. @@ -1598,6 +1662,7 @@ def create_completion( logits_processor=logits_processor, grammar=grammar, logit_bias=logit_bias, + ai_service=ai_service ) if stream: chunks: Iterator[CreateCompletionStreamResponse] = completion_or_chunks @@ -1632,6 +1697,7 @@ def __call__( logits_processor: Optional[LogitsProcessorList] = None, grammar: Optional[LlamaGrammar] = None, logit_bias: Optional[Dict[str, float]] = None, + ai_service: Optional[str] = None ) -> Union[CreateCompletionResponse, Iterator[CreateCompletionStreamResponse]]: """Generate text from a prompt. @@ -1695,6 +1761,7 @@ def __call__( logits_processor=logits_processor, grammar=grammar, logit_bias=logit_bias, + ai_service=ai_service ) def create_chat_completion( @@ -1727,6 +1794,7 @@ def create_chat_completion( logit_bias: Optional[Dict[str, float]] = None, logprobs: Optional[bool] = None, top_logprobs: Optional[int] = None, + ai_service: Optional[str] = None ) -> Union[ CreateChatCompletionResponse, Iterator[CreateChatCompletionStreamResponse] ]: @@ -1796,6 +1864,7 @@ def create_chat_completion( logits_processor=logits_processor, grammar=grammar, logit_bias=logit_bias, + ai_service=ai_service ) def create_chat_completion_openai_v1( diff --git a/llama_cpp/llama_chat_format.py b/llama_cpp/llama_chat_format.py index 3ab94e0d3..d5194bb91 100644 --- a/llama_cpp/llama_chat_format.py +++ b/llama_cpp/llama_chat_format.py @@ -87,6 +87,7 @@ def __call__( grammar: Optional[llama.LlamaGrammar] = None, logprobs: Optional[bool] = None, top_logprobs: Optional[int] = None, + ai_service: Optional[str] = None, **kwargs, # type: ignore ) -> Union[ llama_types.CreateChatCompletionResponse, @@ -535,6 +536,7 @@ def chat_completion_handler( logit_bias: Optional[Dict[str, float]] = None, logprobs: Optional[bool] = None, top_logprobs: Optional[int] = None, + ai_service: Optional[str] = None, **kwargs, # type: ignore ) -> Union[ llama_types.CreateChatCompletionResponse, @@ -625,6 +627,7 @@ def chat_completion_handler( stopping_criteria=stopping_criteria, grammar=grammar, logit_bias=logit_bias, + ai_service=ai_service ) if tool is not None: tool_name = tool["function"]["name"] @@ -1715,6 +1718,7 @@ def functionary_v1_v2_chat_handler( model: Optional[str] = None, logits_processor: Optional[llama.LogitsProcessorList] = None, grammar: Optional[llama.LlamaGrammar] = None, + ai_service: Optional[str] = None, **kwargs, # type: ignore ) -> Union[llama_types.ChatCompletion, Iterator[llama_types.ChatCompletionChunk]]: SYSTEM_MESSAGE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary""" @@ -1931,6 +1935,7 @@ def prepare_messages_for_inference( model=model, logits_processor=logits_processor, grammar=grammar, + ai_service=ai_service ) if stream is False: completion_or_completion_chunks["choices"][0]["text"] = completion_or_completion_chunks["choices"][0]["text"].lstrip() diff --git a/llama_cpp/llama_metrics.py b/llama_cpp/llama_metrics.py index 7334c7140..105346fed 100644 --- a/llama_cpp/llama_metrics.py +++ b/llama_cpp/llama_metrics.py @@ -45,7 +45,7 @@ class MetricsExporter: def __init__(self): self.labels = LABELS # One-time metrics - self._histrogram_load_time = Histogram( + self._histogram_load_time = Histogram( name="llama_cpp_python:load_t_seconds", documentation="Histogram of load time in seconds", labelnames=self.labels, @@ -194,9 +194,10 @@ def log_metrics(self, metrics: Metrics, labels: Dict[str, str]): """ Log the metrics using the Prometheus client. """ - self._histrogram_load_time.labels(**labels).observe(metrics.load_time) + self._histogram_load_time.labels(**labels).observe(metrics.load_time) self._histogram_sample_time.labels(**labels).observe(metrics.sample_time) - self._histogram_time_to_first_token.labels(**labels).observe(metrics.time_to_first_token) + if metrics.time_to_first_token: + self._histogram_time_to_first_token.labels(**labels).observe(metrics.time_to_first_token) for _tpot in metrics.time_per_output_token: self._histogram_time_per_output_token.labels(**labels).observe(_tpot) self._histogram_prompt_eval_time.labels(**labels).observe(metrics.prompt_eval_time) diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index cb3a30582..5d9abf22f 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -386,6 +386,7 @@ async def create_chat_completion( {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is the capital of France?"}, ], + "ai_service": "copilot" }, }, "json_mode": { @@ -454,6 +455,11 @@ async def create_chat_completion( "user", } kwargs = body.model_dump(exclude=exclude) + + # Adds the ai_service value from the request body to the kwargs + # to be passed downstream to the llama_cpp.ChatCompletion object + kwargs["ai_service"] = body.ai_service + llama = llama_proxy(body.model) if body.logit_bias is not None: kwargs["logit_bias"] = ( diff --git a/llama_cpp/server/types.py b/llama_cpp/server/types.py index a20b3940f..f3fa5fa73 100644 --- a/llama_cpp/server/types.py +++ b/llama_cpp/server/types.py @@ -259,6 +259,9 @@ class CreateChatCompletionRequest(BaseModel): } } + # AI service added as request body parameter by Client + ai_service: Optional[str] = None + class ModelData(TypedDict): id: str diff --git a/tests/test_llama.py b/tests/test_llama.py index 469ef91ca..aca1745a8 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -153,7 +153,10 @@ def mock_kv_cache_seq_add( def test_llama_patch(mock_llama): n_ctx = 128 + ai_service_completion = "test-label-suggestions" + ai_service_streaming = "test-acceptance-criteria" llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, n_ctx=n_ctx) + n_vocab = llama_cpp.llama_n_vocab(llama._model.model) assert n_vocab == 32000 @@ -163,32 +166,32 @@ def test_llama_patch(mock_llama): ## Test basic completion from bos until eos mock_llama(llama, all_text) - completion = llama.create_completion("", max_tokens=36) + completion = llama.create_completion("", max_tokens=36, ai_service=ai_service_completion) assert completion["choices"][0]["text"] == all_text assert completion["choices"][0]["finish_reason"] == "stop" ## Test basic completion until eos mock_llama(llama, all_text) - completion = llama.create_completion(text, max_tokens=20) + completion = llama.create_completion(text, max_tokens=20, ai_service=ai_service_completion) assert completion["choices"][0]["text"] == output_text assert completion["choices"][0]["finish_reason"] == "stop" ## Test streaming completion until eos mock_llama(llama, all_text) - chunks = list(llama.create_completion(text, max_tokens=20, stream=True)) + chunks = list(llama.create_completion(text, max_tokens=20, stream=True, ai_service=ai_service_streaming)) assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == output_text assert chunks[-1]["choices"][0]["finish_reason"] == "stop" ## Test basic completion until stop sequence mock_llama(llama, all_text) - completion = llama.create_completion(text, max_tokens=20, stop=["lazy"]) + completion = llama.create_completion(text, max_tokens=20, stop=["lazy"], ai_service=ai_service_completion) assert completion["choices"][0]["text"] == " jumps over the " assert completion["choices"][0]["finish_reason"] == "stop" ## Test streaming completion until stop sequence mock_llama(llama, all_text) chunks = list( - llama.create_completion(text, max_tokens=20, stream=True, stop=["lazy"]) + llama.create_completion(text, max_tokens=20, stream=True, stop=["lazy"], ai_service=ai_service_streaming) ) assert ( "".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps over the " @@ -197,13 +200,13 @@ def test_llama_patch(mock_llama): ## Test basic completion until length mock_llama(llama, all_text) - completion = llama.create_completion(text, max_tokens=2) + completion = llama.create_completion(text, max_tokens=2, ai_service=ai_service_completion) assert completion["choices"][0]["text"] == " jumps" assert completion["choices"][0]["finish_reason"] == "length" ## Test streaming completion until length mock_llama(llama, all_text) - chunks = list(llama.create_completion(text, max_tokens=2, stream=True)) + chunks = list(llama.create_completion(text, max_tokens=2, stream=True, ai_service=ai_service_streaming)) assert "".join(chunk["choices"][0]["text"] for chunk in chunks) == " jumps" assert chunks[-1]["choices"][0]["finish_reason"] == "length" @@ -230,15 +233,16 @@ def test_utf8(mock_llama): llama = llama_cpp.Llama(model_path=MODEL, vocab_only=True, logits_all=True) output_text = "😀" + ai_service = "label-suggestions" ## Test basic completion with utf8 multibyte mock_llama(llama, output_text) - completion = llama.create_completion("", max_tokens=4) + completion = llama.create_completion("", max_tokens=4, ai_service=ai_service) assert completion["choices"][0]["text"] == output_text ## Test basic completion with incomplete utf8 multibyte mock_llama(llama, output_text) - completion = llama.create_completion("", max_tokens=1) + completion = llama.create_completion("", max_tokens=1, ai_service=ai_service) assert completion["choices"][0]["text"] == "" @@ -266,6 +270,22 @@ def test_llama_server(): } +def test_metrics_endpoint(): + from fastapi.testclient import TestClient + from llama_cpp.server.app import create_app, Settings + + settings = Settings( + model=MODEL, + vocab_only=True, + ) + app = create_app(settings) + client = TestClient(app) + response = client.get("/metrics") + assert response.status_code == 200 + assert "test-label-suggestions" in response.text + assert "test-acceptance-criteria" in response.text + + @pytest.mark.parametrize( "size_and_axis", [