Skip to content

Commit

Permalink
Merge pull request #16 from tadascience/dry_run
Browse files Browse the repository at this point in the history
+ `chat|models|stream(dry_run =)`
  • Loading branch information
romainfrancois authored Mar 9, 2024
2 parents 6b21752 + 796ee08 commit b10fc84
Show file tree
Hide file tree
Showing 7 changed files with 52 additions and 32 deletions.
5 changes: 2 additions & 3 deletions R/authenticate.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
authenticate <- function(request, error_call = caller_env()){
authenticate <- function(request, dry_run = FALSE, error_call = caller_env()){
key <- Sys.getenv("MISTRAL_API_KEY")
if (identical(key, "")) {
if (!is_true(dry_run) && identical(key, "")) {
cli_abort(call = error_call, c(
"Please set the {.code MISTRAL_API_KEY} environment variable",
i = "Get an API key from {.url https://console.mistral.ai/api-keys/}",
Expand All @@ -9,4 +9,3 @@ authenticate <- function(request, error_call = caller_env()){
}
req_auth_bearer_token(request, key)
}

27 changes: 16 additions & 11 deletions R/chat.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,33 @@
#'
#' @param text some text
#' @param model which model to use. See [models()] for more information about which models are available
#' @param dry_run if TRUE the request is not performed
#' @param ... ignored
#' @inheritParams httr2::req_perform
#'
#' @return Result text from Mistral
#' @return A tibble with columns `role` and `content` with class `chat_tibble` or a request
#' if this is a `dry_run`
#'
#' @examples
#' \dontrun{
#' chat("Top 5 R packages")
#' }
#' chat("Top 5 R packages", dry_run = TRUE)
#'
#' @export
chat <- function(text = "What are the top 5 R packages ?", model = "mistral-tiny", ..., error_call = current_env()) {
req_chat(text, model, error_call = error_call) |>
req_mistral_perform(error_call = error_call) |>
resp_chat(error_call = error_call)
chat <- function(text = "What are the top 5 R packages ?", model = "mistral-tiny", dry_run = FALSE, ..., error_call = current_env()) {
req <- req_chat(text, model, error_call = error_call, dry_run = dry_run)
if (is_true(dry_run)) {
return(req)
}
resp <- req_mistral_perform(req, error_call = error_call)
resp_chat(resp, error_call = error_call)
}

req_chat <- function(text = "What are the top 5 R packages ?", model = "mistral-tiny", stream = FALSE, error_call = caller_env()) {
check_model(model, error_call = error_call)
req_chat <- function(text = "What are the top 5 R packages ?", model = "mistral-tiny", stream = FALSE, dry_run = FALSE, error_call = caller_env()) {
if (!is_true(dry_run)) {
check_model(model, error_call = error_call)
}
request(mistral_base_url) |>
req_url_path_append("v1", "chat", "completions") |>
authenticate(error_call = error_call) |>
authenticate(error_call = error_call, dry_run = dry_run) |>
req_body_json(
list(
model = model,
Expand Down
15 changes: 8 additions & 7 deletions R/models.R
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
#' Retrieve all models available in the Mistral API
#'
#' @inheritParams httr2::req_perform
#' @inheritParams chat
#'
#' @return A character vector with the models available in the Mistral API
#'
#' @examples
#' \dontrun{
#' models()
#' }
#' models(dry_run = TRUE)
#'
#' @export
models <- function(error_call = current_env()) {

models <- function(error_call = caller_env(), dry_run = FALSE) {
req <- request(mistral_base_url) |>
req_url_path_append("v1", "models") |>
authenticate(error_call = call) |>
authenticate(error_call = call, dry_run = dry_run) |>
req_cache(tempdir(),
use_on_error = TRUE,
max_age = 2 * 60 * 60) # 2 hours

if (is_true(dry_run)) {
return(req)
}

req_mistral_perform(req, error_call = error_call) |>
resp_body_json(simplifyVector = TRUE) |>
pluck("data","id")
Expand Down
9 changes: 7 additions & 2 deletions R/stream.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
#' stream
#'
#' @inheritParams chat
#'
#' @export
stream <- function(text, model = "mistral-tiny", ..., error_call = current_env()) {
stream <- function(text, model = "mistral-tiny", dry_run = FALSE, ..., error_call = current_env()) {
check_model(model, error_call = error_call)

req <- req_chat(text, model, stream = TRUE, error_call = error_call)
req <- req_chat(text, model, stream = TRUE, error_call = error_call, dry_run = dry_run)
if (is_true(dry_run)) {
return(req)
}

resp <- req_perform_stream(req,
callback = stream_callback,
round = "line",
Expand Down
10 changes: 6 additions & 4 deletions man/chat.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions man/models.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 9 additions & 1 deletion man/stream.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit b10fc84

Please sign in to comment.