Skip to content

Commit

Permalink
feat: tools
Browse files Browse the repository at this point in the history
Signed-off-by: thxCode <thxcode0824@gmail.com>
  • Loading branch information
thxCode committed Jul 17, 2024
1 parent b5a353c commit 6829b5e
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 31 deletions.
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,28 @@ logging:
see https://platform.openai.com/docs/api-reference/embeddings/create.
+ This endpoint is only available if the `--embeddings` flag is enabled.
## Tools
It was so hard to find a Chat UI that was directly compatible with OpenAI, I mean, no installation required (I can live
with `docker run`), no tokens (or optional), no Ollama required (don't you think Ollama’s API is hard to use?), just a
simple RESTful API.

So I was inspired by
the [llama.cpp/chat.sh](https://github.com/ggerganov/llama.cpp/blob/e6f291d15844398f8326940fe5ad7f2e02b5aa56/examples/server/chat.sh)
and adjust it to interact with llama-box.

All you need is a Bash shell and curl.

- **completion.sh**: A simple script to interact with the `/completion` endpoint.

```shell
$ # one-shot completion
$ N_PREDICT=4096 TOP_K=1 ./llama-box/tools/completion.sh "// Quick-sort implementation in C (4 spaces indentation + detailed comments) and sample usage:\n\n#include"
$ # interactive completion
$ N_PREDICT=4096 ./llama-box/tools/completion.sh
```

## License

MIT
57 changes: 26 additions & 31 deletions llama-box/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
#include <chrono>
#include <condition_variable>
#include <csignal>
#include <cstddef>
#include <memory>
#include <set>
#include <thread>
Expand Down Expand Up @@ -1097,6 +1096,7 @@ struct server_context {
if (slot.ctx_sampling != nullptr) {
llama_sampling_free(slot.ctx_sampling);
}

slot.ctx_sampling = llama_sampling_init(slot.sparams);
if (slot.ctx_sampling == nullptr) {
// for now, the only error that may happen here is invalid
Expand All @@ -1113,8 +1113,8 @@ struct server_context {
}
int tps = task.tps;
#ifndef NDEBUG
tps = json_value(data, "tps",
task.tps); // allow overriding tps for debugging
// allow overriding tps for debugging
tps = json_value(data, "tps", task.tps);
if (tps > n_tps) {
tps = n_tps;
}
Expand Down Expand Up @@ -1188,7 +1188,7 @@ struct server_context {
slot.find_stopping_strings(str_test, token_str.size(), STOP_TYPE_FULL);
if (stop_pos != std::string::npos) {
is_stop_full = true;
slot.generated_text.erase(slot.generated_text.begin() + pos + stop_pos,
slot.generated_text.erase(slot.generated_text.begin() + long(pos) + long(stop_pos),
slot.generated_text.end());
pos = std::min(slot.n_sent_text, slot.generated_text.size());
} else {
Expand Down Expand Up @@ -1350,8 +1350,8 @@ struct server_context {
std::vector<completion_token_output> probs_output;
if (probs_pos < probs_stop_pos) {
probs_output = std::vector<completion_token_output>(
slot.generated_token_probs.begin() + probs_pos,
slot.generated_token_probs.begin() + probs_stop_pos);
slot.generated_token_probs.begin() + long(probs_pos),
slot.generated_token_probs.begin() + long(probs_stop_pos));
}
slot.n_sent_token_probs = probs_stop_pos;

Expand Down Expand Up @@ -1835,11 +1835,10 @@ struct server_context {
// TODO: simplify and improve
for (server_slot &slot : slots) {
if (slot.ga_n == 1) {
if (slot.is_processing() &&
(int)system_tokens.size() + slot.n_past >= slot.n_ctx - 1) {
if (slot.is_processing() && n_system_tokens + slot.n_past >= slot.n_ctx - 1) {
// Shift context
const int n_keep = slot.params.n_keep + add_bos_token;
const int n_left = (int)system_tokens.size() + slot.n_past - n_keep;
const int n_left = n_system_tokens + slot.n_past - n_keep;
const int n_discard =
slot.params.n_discard ? slot.params.n_discard : (n_left / 2);

Expand Down Expand Up @@ -2077,13 +2076,13 @@ struct server_context {
}

// keep only the common part
int p0 = (int)system_tokens.size() + slot.n_past;
int p0 = n_system_tokens + slot.n_past;
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
// could not partially delete (likely using a
// non-Transformer model)
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);

p0 = (int)system_tokens.size();
p0 = n_system_tokens;
if (p0 != 0) {
// copy over the system prompt when there is one
llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1);
Expand All @@ -2102,9 +2101,6 @@ struct server_context {
// remove the non-common part from the cache
slot.cache_tokens.resize(slot.n_past);

LOG_INFO("kv cache rm [p0, end)",
{{"id_slot", slot.id}, {"id_task", slot.id_task}, {"p0", p0}});

int32_t slot_npast = slot.n_past_se > 0 ? slot.n_past_se : slot.n_past;

int32_t ga_i = slot.ga_i;
Expand Down Expand Up @@ -2228,7 +2224,6 @@ struct server_context {
// clang-format on

const int ret = llama_decode(ctx, batch_view);

if (ret != 0) {
if (n_batch == 1 || ret < 0) {
// if you get here, it means the KV cache is full - try
Expand Down Expand Up @@ -2285,10 +2280,9 @@ struct server_context {
}

completion_token_output result;
const llama_token id =
result.tok =
llama_sampling_sample(slot.ctx_sampling, ctx, nullptr, slot.i_batch - i);

llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
llama_sampling_accept(slot.ctx_sampling, ctx, result.tok, true);

slot.n_decoded += 1;
if (slot.n_decoded == 1) {
Expand All @@ -2298,13 +2292,12 @@ struct server_context {
metrics.on_prompt_eval(slot);
}

llama_token_data_array cur_p = {slot.ctx_sampling->cur.data(),
slot.ctx_sampling->cur.size(), false};
result.tok = id;

const size_t n_probs = std::min(cur_p.size, (size_t)slot.sparams.n_probs);
const size_t n_probs =
std::min(slot.ctx_sampling->cur.size(), (size_t)slot.sparams.n_probs);
if (n_probs > 0) {
const size_t n_valid = slot.ctx_sampling->n_valid;
llama_token_data_array cur_p = {slot.ctx_sampling->cur.data(),
slot.ctx_sampling->cur.size(), false};

// Make sure at least n_probs top tokens are at the front of
// the vector:
Expand Down Expand Up @@ -2441,9 +2434,9 @@ struct server_context {
if (tokens.empty()) {
std::string suffix;
if (batch.n_tokens == 0) {
suffix = "\nASSISTANT:\n";
suffix = "\n### Assistant:\n";
} else {
suffix = "\nAnswer the questions.\nASSISTANT:\n";
suffix = "\nAnswer the questions.\n### Assistant:\n";
}
tokens = llama_tokenize(ctx, suffix, false, true);
}
Expand Down Expand Up @@ -2513,7 +2506,7 @@ int main(int argc, char **argv) {
if (!params.system_prompt.empty()) {
ctx_server.system_prompt = params.system_prompt;
} else if (!params.mmproj.empty()) {
ctx_server.system_prompt = "SYSTEM:\nAnswer the questions.\nUSER:";
ctx_server.system_prompt = "### System: You are a helpful assistant.\n### Human: ";
}

if (params.model_alias == "unknown") {
Expand Down Expand Up @@ -3135,10 +3128,12 @@ int main(int argc, char **argv) {
continue;
}

const std::string done = "data: [DONE] \n\n";
if (!sink.write(done.c_str(), done.size())) {
sink.done();
return false;
if (oaicompat) {
const std::string done = "data: [DONE] \n\n";
if (!sink.write(done.c_str(), done.size())) {
sink.done();
return false;
}
}

sink.done_with_trailer(
Expand All @@ -3164,7 +3159,7 @@ int main(int argc, char **argv) {
{
{{"id", params.model_alias},
{"object", "model"},
{"created", std::time(0)},
{"created", std::time(nullptr)},
{"owned_by", "llama-box"},
{"meta", ctx_server.model_meta()}},
}}};
Expand Down
6 changes: 6 additions & 0 deletions llama-box/param.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ static void llama_box_params_print_usage(int, char **argv, const llama_box_param
opts.push_back({ "*", " --keep N", "number of tokens to keep from the initial prompt (default: %d, -1 = all)", params.n_keep });
opts.push_back({ "*", " --chunks N", "max number of chunks to process (default: %d, -1 = all)", params.n_chunks });
opts.push_back({ "*", "-fa, --flash-attn", "enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled" });
opts.push_back({ "*", "-e, --escape", R"(process escapes sequences (\n, \r, \t, \', \", \\) (default: %s))", params.escape ? "true" : "false" });
opts.push_back({ "*", " --no-escape", "do not process escape sequences" });
opts.push_back({ "*", " --samplers SAMPLERS", "samplers that will be used for generation in the order, separated by \';\'\n"
"(default: %s)", sampler_type_names.c_str() });
Expand Down Expand Up @@ -404,6 +405,11 @@ static bool llama_box_params_parse(int argc, char **argv, llama_box_params &bpar
continue;
}

if (!strcmp(flag, "-e") || !strcmp(flag, "--escape")) {
bparams.gparams.escape = true;
return true;
}

if (!strcmp(flag, "--no-escape")) {
bparams.gparams.escape = false;
continue;
Expand Down
157 changes: 157 additions & 0 deletions llama-box/tools/completion.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#!/bin/bash

#
# MIT license
# Copyright (c) 2024 llama-box authors
# SPDX-License-Identifier: MIT
#

#
# MIT license
# Copyright (c) 2023-2024 The ggml authors
# SPDX-License-Identifier: MIT
#

LOG_FILE=${LOG_FILE:-/dev/null}

API_URL="${API_URL:-http://127.0.0.1:8080}"

CHAT=(
"Hello, Assistant."
"Hello. How may I help you today?"
"Please tell me the largest city in Europe."
"Sure. The largest city in Europe is Moscow, the capital of Russia."
)

INSTRUCTION="### System: You are a helpful assistant."

trim() {
shopt -s extglob
set -- "${1##+([[:space:]])}"
printf "%s" "${1%%+([[:space:]])}"
}

trim_trailing() {
shopt -s extglob
printf "%s" "${1%%+([[:space:]])}"
}

format_prompt() {
echo -n "${INSTRUCTION}"
printf "\n### Human: %s\n### Assistant: %s" "${CHAT[@]}" "$1"
}

tokenize() {
curl \
--silent \
--request POST \
--url "${API_URL}/tokenize" \
--header "Content-Type: application/json" \
--data-raw "$(jq -ns --arg content "$1" '{content:$content}')" \
| jq '.tokens[]'
}

N_PREDICT="${N_PREDICT:-"256"}"
TEMPERATURE="${TEMPERATURE:-"0.2"}"
TOP_P="${TOP_P:-"0.9"}"
TOP_K="${TOP_K:-"40"}"
STOP="${STOP:-"[\"\\n### Human:\"]"}"
SEED="${SEED:-"-1"}"
N_KEEP=$(tokenize "${INSTRUCTION}" | wc -l)

completion() {
PROMPT="$(trim_trailing "$(format_prompt "$1")")"
DATA="$(echo -n "$PROMPT" | jq -Rs \
--argjson n_predict "${N_PREDICT}" \
--argjson temperature "${TEMPERATURE}" \
--argjson top_p "${TOP_P}" \
--argjson top_k "${TOP_K}" \
--argjson stop "${STOP}" \
--argjson seed "${SEED}" \
--argjson n_keep "${N_KEEP}" \
'{
prompt: .,
temperature: $temperature,
top_k: $top_k,
top_p: $top_p,
n_predict: $n_predict,
cache_prompt: false,
n_keep: $n_keep,
seed: $seed,
stop: $stop,
stream: true
}')"
echo "Q: ${DATA}" >> "${LOG_FILE}"

ANSWER=''
PRE_CONTENT=''

while IFS= read -r LINE; do
if [[ $LINE = data:* ]]; then
echo "A: ${LINE}" >> "${LOG_FILE}"
CONTENT="$(echo "${LINE:5}" | jq -r '.content')"
if [[ "${PRE_CONTENT: -1}" == "\\" ]] && [[ "${CONTENT}" =~ ^b|n|r|t|\\|\'|\"$ ]]; then
printf "\b "
case "${CONTENT}" in
b) printf "\b\b" ;;
n) printf "\b\n" ;;
r) printf "\b\r" ;;
t) printf "\b\t" ;;
\\) printf "\b\\" ;;
\') printf "\b'" ;;
\") printf "\b\"" ;;
esac
CONTENT=""
fi
PRE_CONTENT="${CONTENT}"
printf "%s" "${CONTENT}"
ANSWER+="${CONTENT}"
TIMINGS="$(echo "${LINE:5}" | jq -r '.timings')"
if [ "${TIMINGS}" != "null" ]; then
printf "\n------------------------"
printf "\n- TTFT : %10.2f ms -" "$(echo "${TIMINGS}" | jq -r '.prompt_ms')"
printf "\n- TBT : %10.2f ms -" "$(echo "${TIMINGS}" | jq -r '.predicted_per_token_ms')"
printf "\n- TPS : %10.2f -" "$(echo "${TIMINGS}" | jq -r '.predicted_per_second')"
printf "\n------------------------"
fi
fi
done < <(curl \
--silent \
--no-buffer \
--request POST \
--url "${API_URL}/completion" \
--header "Content-Type: application/json" \
--data-raw "${DATA}")

printf "\n"

CHAT+=("$1" "$(trim "$ANSWER")")
}

echo "====================================================="
echo "LOG_FILE : ${LOG_FILE}"
echo "API_URL : ${API_URL}"
echo "TEMPERATURE : ${TEMPERATURE}"
echo "TOP_P : ${TOP_P}"
echo "TOP_K : ${TOP_K}"
echo "N_PREDICT : ${N_PREDICT}"
echo "STOP : ${STOP}"
echo "SEED : ${SEED}"
printf "=====================================================\n\n"

if [[ -f "${LOG_FILE}" ]]; then
rm -f "${LOG_FILE}"
fi
if [[ ! -f "${LOG_FILE}" ]]; then
touch "${LOG_FILE}"
fi

if [[ "${#@}" -ge 1 ]]; then
echo "> ${*}"
completion "${*}"
else
while true; do
read -r -e -p "> " QUESTION
completion "${QUESTION}"
done
fi

0 comments on commit 6829b5e

Please sign in to comment.