Skip to content

Commit 8adcd73

Browse files
authored
azure: Chat implementation (#2615)
* Azure Chat implementation * Make azure work
1 parent bb455e1 commit 8adcd73

File tree

4 files changed

+121
-42
lines changed

4 files changed

+121
-42
lines changed

core/src/providers/azure_openai.rs

Lines changed: 115 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ use super::tiktoken::tiktoken::{decode_async, encode_async, tokenize_async};
22
use crate::providers::embedder::{Embedder, EmbedderVector};
33
use crate::providers::llm::Tokens;
44
use crate::providers::llm::{ChatMessage, LLMChatGeneration, LLMGeneration, LLM};
5-
use crate::providers::openai::{completion, embed, streamed_completion};
5+
use crate::providers::openai::{
6+
chat_completion, completion, embed, streamed_chat_completion, streamed_completion,
7+
};
68
use crate::providers::provider::{Provider, ProviderID};
79
use crate::providers::tiktoken::tiktoken::{
810
cl100k_base_singleton, p50k_base_singleton, r50k_base_singleton, CoreBPE,
@@ -160,7 +162,7 @@ impl AzureOpenAILLM {
160162
assert!(self.endpoint.is_some());
161163

162164
Ok(format!(
163-
"{}openai/deployments/{}/completions?api-version=2022-12-01",
165+
"{}openai/deployments/{}/completions?api-version=2023-08-01-preview",
164166
self.endpoint.as_ref().unwrap(),
165167
self.deployment_id
166168
)
@@ -170,7 +172,7 @@ impl AzureOpenAILLM {
170172
#[allow(dead_code)]
171173
fn chat_uri(&self) -> Result<Uri> {
172174
Ok(format!(
173-
"{}openai/deployments/{}/chat/completions?api-version=2023-03-15-preview",
175+
"{}openai/deployments/{}/chat/completions?api-version=2023-08-01-preview",
174176
self.endpoint.as_ref().unwrap(),
175177
self.deployment_id
176178
)
@@ -430,7 +432,7 @@ impl LLM for AzureOpenAILLM {
430432

431433
Ok(LLMGeneration {
432434
created: utils::now(),
433-
provider: ProviderID::OpenAI.to_string(),
435+
provider: ProviderID::AzureOpenAI.to_string(),
434436
model: self.model_id.clone().unwrap(),
435437
completions: c
436438
.choices
@@ -462,22 +464,113 @@ impl LLM for AzureOpenAILLM {
462464

463465
async fn chat(
464466
&self,
465-
_messages: &Vec<ChatMessage>,
466-
_functions: &Vec<ChatFunction>,
467-
_function_call: Option<String>,
468-
_temperature: f32,
469-
_top_p: Option<f32>,
470-
_n: usize,
471-
_stop: &Vec<String>,
472-
_max_tokens: Option<i32>,
473-
_presence_penalty: Option<f32>,
474-
_frequency_penalty: Option<f32>,
475-
_extras: Option<Value>,
476-
_event_sender: Option<UnboundedSender<Value>>,
467+
messages: &Vec<ChatMessage>,
468+
functions: &Vec<ChatFunction>,
469+
function_call: Option<String>,
470+
temperature: f32,
471+
top_p: Option<f32>,
472+
n: usize,
473+
stop: &Vec<String>,
474+
mut max_tokens: Option<i32>,
475+
presence_penalty: Option<f32>,
476+
frequency_penalty: Option<f32>,
477+
extras: Option<Value>,
478+
event_sender: Option<UnboundedSender<Value>>,
477479
) -> Result<LLMChatGeneration> {
478-
Err(anyhow!(
479-
"Chat capabilties are not implemented for provider `azure_openai`"
480-
))
480+
if let Some(m) = max_tokens {
481+
if m == -1 {
482+
max_tokens = None;
483+
}
484+
}
485+
486+
let c = match event_sender {
487+
Some(_) => {
488+
streamed_chat_completion(
489+
self.chat_uri()?,
490+
self.api_key.clone().unwrap(),
491+
None,
492+
None,
493+
messages,
494+
functions,
495+
function_call,
496+
temperature,
497+
match top_p {
498+
Some(t) => t,
499+
None => 1.0,
500+
},
501+
n,
502+
stop,
503+
max_tokens,
504+
match presence_penalty {
505+
Some(p) => p,
506+
None => 0.0,
507+
},
508+
match frequency_penalty {
509+
Some(f) => f,
510+
None => 0.0,
511+
},
512+
match &extras {
513+
Some(e) => match e.get("openai_user") {
514+
Some(Value::String(u)) => Some(u.to_string()),
515+
_ => None,
516+
},
517+
None => None,
518+
},
519+
event_sender,
520+
)
521+
.await?
522+
}
523+
None => {
524+
chat_completion(
525+
self.chat_uri()?,
526+
self.api_key.clone().unwrap(),
527+
None,
528+
None,
529+
messages,
530+
functions,
531+
function_call,
532+
temperature,
533+
match top_p {
534+
Some(t) => t,
535+
None => 1.0,
536+
},
537+
n,
538+
stop,
539+
max_tokens,
540+
match presence_penalty {
541+
Some(p) => p,
542+
None => 0.0,
543+
},
544+
match frequency_penalty {
545+
Some(f) => f,
546+
None => 0.0,
547+
},
548+
match &extras {
549+
Some(e) => match e.get("openai_user") {
550+
Some(Value::String(u)) => Some(u.to_string()),
551+
_ => None,
552+
},
553+
None => None,
554+
},
555+
)
556+
.await?
557+
}
558+
};
559+
560+
// println!("COMPLETION: {:?}", c);
561+
562+
assert!(c.choices.len() > 0);
563+
564+
Ok(LLMChatGeneration {
565+
created: utils::now(),
566+
provider: ProviderID::AzureOpenAI.to_string(),
567+
model: self.model_id.clone().unwrap(),
568+
completions: c
569+
.choices
570+
.iter()
571+
.map(|c| c.message.clone())
572+
.collect::<Vec<_>>(),
573+
})
481574
}
482575
}
483576

@@ -502,7 +595,7 @@ impl AzureOpenAIEmbedder {
502595
assert!(self.endpoint.is_some());
503596

504597
Ok(format!(
505-
"{}openai/deployments/{}/embeddings?api-version=2022-12-01",
598+
"{}openai/deployments/{}/embeddings?api-version=2023-08-01-preview",
506599
self.endpoint.as_ref().unwrap(),
507600
self.deployment_id
508601
)
@@ -597,13 +690,11 @@ impl Embedder for AzureOpenAIEmbedder {
597690
}
598691

599692
async fn encode(&self, text: &str) -> Result<Vec<usize>> {
600-
let tokens = { self.tokenizer().lock().encode_with_special_tokens(text) };
601-
Ok(tokens)
693+
encode_async(self.tokenizer(), text).await
602694
}
603695

604696
async fn decode(&self, tokens: Vec<usize>) -> Result<String> {
605-
let str = { self.tokenizer().lock().decode(tokens)? };
606-
Ok(str)
697+
decode_async(self.tokenizer(), tokens).await
607698
}
608699

609700
async fn embed(&self, text: Vec<&str>, extras: Option<Value>) -> Result<Vec<EmbedderVector>> {

front/lib/api/credentials.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ const {
44
DUST_MANAGED_OPENAI_API_KEY = "",
55
DUST_MANAGED_ANTHROPIC_API_KEY = "",
66
DUST_MANAGED_TEXTSYNTH_API_KEY = "",
7+
DUST_MANAGED_AZURE_OPENAI_API_KEY = "",
8+
DUST_MANAGED_AZURE_OPENAI_ENDPOINT = "",
79
} = process.env;
810

911
export const credentialsFromProviders = (
@@ -56,5 +58,7 @@ export const dustManagedCredentials = (): CredentialsType => {
5658
OPENAI_API_KEY: DUST_MANAGED_OPENAI_API_KEY,
5759
ANTHROPIC_API_KEY: DUST_MANAGED_ANTHROPIC_API_KEY,
5860
TEXTSYNTH_API_KEY: DUST_MANAGED_TEXTSYNTH_API_KEY,
61+
AZURE_OPENAI_API_KEY: DUST_MANAGED_AZURE_OPENAI_API_KEY,
62+
AZURE_OPENAI_ENDPOINT: DUST_MANAGED_AZURE_OPENAI_ENDPOINT,
5963
};
6064
};

front/lib/providers.ts

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ export const modelProviders: ModelProvider[] = [
4242
name: "Azure OpenAI",
4343
built: true,
4444
enabled: false,
45-
chat: false,
45+
chat: true,
4646
embed: true,
4747
},
4848
{
@@ -61,22 +61,6 @@ export const modelProviders: ModelProvider[] = [
6161
chat: true,
6262
embed: false,
6363
},
64-
{
65-
providerId: "hugging_face",
66-
name: "Hugging Face",
67-
built: false,
68-
enabled: false,
69-
chat: false,
70-
embed: false,
71-
},
72-
{
73-
providerId: "replicate",
74-
name: "Replicate",
75-
built: false,
76-
enabled: false,
77-
chat: false,
78-
embed: false,
79-
},
8064
];
8165

8266
type ServiceProvider = {

front/pages/api/w/[wId]/data_sources/index.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ async function handler(
6060
api_error: {
6161
type: "data_source_auth_error",
6262
message:
63-
"Only the users that are `admins` for the current workspace can create a managed data source.",
63+
"Only the users that are `admins` for the current workspace can create a data source.",
6464
},
6565
});
6666
}

0 commit comments

Comments
 (0)