diff --git a/NAMESPACE b/NAMESPACE index 4b6c332..10a4537 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,5 +1,7 @@ # Generated by roxygen2: do not edit by hand +S3method(as_message,character) +S3method(as_message,list) S3method(print,chat_tibble) export(chat) export(models) diff --git a/R/chat.R b/R/chat.R index 7cbde0e..9aef6dd 100644 --- a/R/chat.R +++ b/R/chat.R @@ -1,9 +1,8 @@ #' Chat with the Mistral api #' -#' @param text some text +#' @param ... either a character string or a list with the message to send #' @param model which model to use. See [models()] for more information about which models are available #' @param dry_run if TRUE the request is not performed -#' @param ... ignored #' @inheritParams httr2::req_perform #' #' @return A tibble with columns `role` and `content` with class `chat_tibble` or a request @@ -13,8 +12,8 @@ #' chat("Top 5 R packages", dry_run = TRUE) #' #' @export -chat <- function(text = "What are the top 5 R packages ?", model = "mistral-tiny", dry_run = FALSE, ..., error_call = current_env()) { - req <- req_chat(text, model, error_call = error_call, dry_run = dry_run) +chat <- function(..., model = "mistral-tiny", dry_run = FALSE, error_call = current_env()) { + req <- req_chat(..., model, error_call = error_call, dry_run = dry_run) if (is_true(dry_run)) { return(req) } @@ -22,22 +21,20 @@ chat <- function(text = "What are the top 5 R packages ?", model = "mistral-tiny resp_chat(resp, error_call = error_call) } -req_chat <- function(text = "What are the top 5 R packages ?", model = "mistral-tiny", stream = FALSE, dry_run = FALSE, error_call = caller_env()) { +req_chat <- function(..., model = "mistral-tiny", stream = FALSE, dry_run = FALSE, error_call = caller_env()) { if (!is_true(dry_run)) { check_model(model, error_call = error_call) } + + messages <- as_messages(...) + request(mistral_base_url) |> req_url_path_append("v1", "chat", "completions") |> authenticate() |> req_body_json( list( model = model, - messages = list( - list( - role = "user", - content = text - ) - ), + messages = messages, stream = is_true(stream) ) ) @@ -47,7 +44,8 @@ resp_chat <- function(response, error_call = current_env()) { data <- resp_body_json(response) tib <- map_dfr(data$choices, \(choice) { - as_tibble(choice$message) + tibble(role = choice$message$role, + content = choice$message$content) }) class(tib) <- c("chat_tibble", class(tib)) diff --git a/R/messages.R b/R/messages.R new file mode 100644 index 0000000..81012be --- /dev/null +++ b/R/messages.R @@ -0,0 +1,28 @@ +as_message <- function(x) { + UseMethod("as_message") +} + +#' @export +as_message.character <- function(x) { + list(role = "user", content = x) +} + +#' @export +as_message.list <- function(x) { + x +} + +as_messages <- function(...) { + x <- list(...) + messages <- list() + + for (i in seq_along(x)) { + if (is.character(x[[i]])) { + messages <- append(messages, list(as_message(x[[i]]))) + } else if (is.list(x[[i]])) { + messages <- append(messages, x[[i]]) + } + } + + messages +} diff --git a/R/stream.R b/R/stream.R index aaf793b..1ece1e5 100644 --- a/R/stream.R +++ b/R/stream.R @@ -3,25 +3,33 @@ #' @inheritParams chat #' #' @export -stream <- function(text, model = "mistral-tiny", dry_run = FALSE, ..., error_call = current_env()) { +stream <- function(..., model = "mistral-tiny", dry_run = FALSE, error_call = current_env()) { check_model(model, error_call = error_call) - req <- req_chat(text, model, stream = TRUE, error_call = error_call, dry_run = dry_run) + req <- req_chat(..., model, stream = TRUE, error_call = error_call, dry_run = dry_run) if (is_true(dry_run)) { return(req) } + env <- caller_env() + + env$response <- character() + resp <- req_perform_stream(req, - callback = stream_callback, - round = "line", - buffer_kb = 0.01 + callback = \(x) stream_callback(x, env), + round = "line", + buffer_kb = 0.01 ) - invisible(resp) + tib <- tibble(role = "assistant", + content = env$response) + class(tib) <- c("chat_tibble", class(tib)) + + invisible(tib) } #' @importFrom jsonlite fromJSON -stream_callback <- function(x) { +stream_callback <- function(x, env) { txt <- rawToChar(x) lines <- str_split(txt, "\n")[[1]] @@ -34,6 +42,8 @@ stream_callback <- function(x) { chunk$choices$delta$content }) + env$response <- paste0(env$response, tokens) + cat(tokens) TRUE diff --git a/man/chat.Rd b/man/chat.Rd index 67f8483..5de9718 100644 --- a/man/chat.Rd +++ b/man/chat.Rd @@ -4,23 +4,15 @@ \alias{chat} \title{Chat with the Mistral api} \usage{ -chat( - text = "What are the top 5 R packages ?", - model = "mistral-tiny", - dry_run = FALSE, - ..., - error_call = current_env() -) +chat(..., model = "mistral-tiny", dry_run = FALSE, error_call = current_env()) } \arguments{ -\item{text}{some text} +\item{...}{either a character string or a list with the message to send} \item{model}{which model to use. See \code{\link[=models]{models()}} for more information about which models are available} \item{dry_run}{if TRUE the request is not performed} -\item{...}{ignored} - \item{error_call}{The execution environment of a currently running function, e.g. \code{caller_env()}. The function will be mentioned in error messages as the source of the error. See the diff --git a/man/stream.Rd b/man/stream.Rd index 7e7ae64..133aa7e 100644 --- a/man/stream.Rd +++ b/man/stream.Rd @@ -5,22 +5,19 @@ \title{stream} \usage{ stream( - text, + ..., model = "mistral-tiny", dry_run = FALSE, - ..., error_call = current_env() ) } \arguments{ -\item{text}{some text} +\item{...}{either a character string or a list with the message to send} \item{model}{which model to use. See \code{\link[=models]{models()}} for more information about which models are available} \item{dry_run}{if TRUE the request is not performed} -\item{...}{ignored} - \item{error_call}{The execution environment of a currently running function, e.g. \code{caller_env()}. The function will be mentioned in error messages as the source of the error. See the