@@ -2,7 +2,9 @@ use super::tiktoken::tiktoken::{decode_async, encode_async, tokenize_async};
2
2
use crate :: providers:: embedder:: { Embedder , EmbedderVector } ;
3
3
use crate :: providers:: llm:: Tokens ;
4
4
use crate :: providers:: llm:: { ChatMessage , LLMChatGeneration , LLMGeneration , LLM } ;
5
- use crate :: providers:: openai:: { completion, embed, streamed_completion} ;
5
+ use crate :: providers:: openai:: {
6
+ chat_completion, completion, embed, streamed_chat_completion, streamed_completion,
7
+ } ;
6
8
use crate :: providers:: provider:: { Provider , ProviderID } ;
7
9
use crate :: providers:: tiktoken:: tiktoken:: {
8
10
cl100k_base_singleton, p50k_base_singleton, r50k_base_singleton, CoreBPE ,
@@ -160,7 +162,7 @@ impl AzureOpenAILLM {
160
162
assert ! ( self . endpoint. is_some( ) ) ;
161
163
162
164
Ok ( format ! (
163
- "{}openai/deployments/{}/completions?api-version=2022-12 -01" ,
165
+ "{}openai/deployments/{}/completions?api-version=2023-08 -01-preview " ,
164
166
self . endpoint. as_ref( ) . unwrap( ) ,
165
167
self . deployment_id
166
168
)
@@ -170,7 +172,7 @@ impl AzureOpenAILLM {
170
172
#[ allow( dead_code) ]
171
173
fn chat_uri ( & self ) -> Result < Uri > {
172
174
Ok ( format ! (
173
- "{}openai/deployments/{}/chat/completions?api-version=2023-03-15 -preview" ,
175
+ "{}openai/deployments/{}/chat/completions?api-version=2023-08-01 -preview" ,
174
176
self . endpoint. as_ref( ) . unwrap( ) ,
175
177
self . deployment_id
176
178
)
@@ -430,7 +432,7 @@ impl LLM for AzureOpenAILLM {
430
432
431
433
Ok ( LLMGeneration {
432
434
created : utils:: now ( ) ,
433
- provider : ProviderID :: OpenAI . to_string ( ) ,
435
+ provider : ProviderID :: AzureOpenAI . to_string ( ) ,
434
436
model : self . model_id . clone ( ) . unwrap ( ) ,
435
437
completions : c
436
438
. choices
@@ -462,22 +464,113 @@ impl LLM for AzureOpenAILLM {
462
464
463
465
async fn chat (
464
466
& self ,
465
- _messages : & Vec < ChatMessage > ,
466
- _functions : & Vec < ChatFunction > ,
467
- _function_call : Option < String > ,
468
- _temperature : f32 ,
469
- _top_p : Option < f32 > ,
470
- _n : usize ,
471
- _stop : & Vec < String > ,
472
- _max_tokens : Option < i32 > ,
473
- _presence_penalty : Option < f32 > ,
474
- _frequency_penalty : Option < f32 > ,
475
- _extras : Option < Value > ,
476
- _event_sender : Option < UnboundedSender < Value > > ,
467
+ messages : & Vec < ChatMessage > ,
468
+ functions : & Vec < ChatFunction > ,
469
+ function_call : Option < String > ,
470
+ temperature : f32 ,
471
+ top_p : Option < f32 > ,
472
+ n : usize ,
473
+ stop : & Vec < String > ,
474
+ mut max_tokens : Option < i32 > ,
475
+ presence_penalty : Option < f32 > ,
476
+ frequency_penalty : Option < f32 > ,
477
+ extras : Option < Value > ,
478
+ event_sender : Option < UnboundedSender < Value > > ,
477
479
) -> Result < LLMChatGeneration > {
478
- Err ( anyhow ! (
479
- "Chat capabilties are not implemented for provider `azure_openai`"
480
- ) )
480
+ if let Some ( m) = max_tokens {
481
+ if m == -1 {
482
+ max_tokens = None ;
483
+ }
484
+ }
485
+
486
+ let c = match event_sender {
487
+ Some ( _) => {
488
+ streamed_chat_completion (
489
+ self . chat_uri ( ) ?,
490
+ self . api_key . clone ( ) . unwrap ( ) ,
491
+ None ,
492
+ None ,
493
+ messages,
494
+ functions,
495
+ function_call,
496
+ temperature,
497
+ match top_p {
498
+ Some ( t) => t,
499
+ None => 1.0 ,
500
+ } ,
501
+ n,
502
+ stop,
503
+ max_tokens,
504
+ match presence_penalty {
505
+ Some ( p) => p,
506
+ None => 0.0 ,
507
+ } ,
508
+ match frequency_penalty {
509
+ Some ( f) => f,
510
+ None => 0.0 ,
511
+ } ,
512
+ match & extras {
513
+ Some ( e) => match e. get ( "openai_user" ) {
514
+ Some ( Value :: String ( u) ) => Some ( u. to_string ( ) ) ,
515
+ _ => None ,
516
+ } ,
517
+ None => None ,
518
+ } ,
519
+ event_sender,
520
+ )
521
+ . await ?
522
+ }
523
+ None => {
524
+ chat_completion (
525
+ self . chat_uri ( ) ?,
526
+ self . api_key . clone ( ) . unwrap ( ) ,
527
+ None ,
528
+ None ,
529
+ messages,
530
+ functions,
531
+ function_call,
532
+ temperature,
533
+ match top_p {
534
+ Some ( t) => t,
535
+ None => 1.0 ,
536
+ } ,
537
+ n,
538
+ stop,
539
+ max_tokens,
540
+ match presence_penalty {
541
+ Some ( p) => p,
542
+ None => 0.0 ,
543
+ } ,
544
+ match frequency_penalty {
545
+ Some ( f) => f,
546
+ None => 0.0 ,
547
+ } ,
548
+ match & extras {
549
+ Some ( e) => match e. get ( "openai_user" ) {
550
+ Some ( Value :: String ( u) ) => Some ( u. to_string ( ) ) ,
551
+ _ => None ,
552
+ } ,
553
+ None => None ,
554
+ } ,
555
+ )
556
+ . await ?
557
+ }
558
+ } ;
559
+
560
+ // println!("COMPLETION: {:?}", c);
561
+
562
+ assert ! ( c. choices. len( ) > 0 ) ;
563
+
564
+ Ok ( LLMChatGeneration {
565
+ created : utils:: now ( ) ,
566
+ provider : ProviderID :: AzureOpenAI . to_string ( ) ,
567
+ model : self . model_id . clone ( ) . unwrap ( ) ,
568
+ completions : c
569
+ . choices
570
+ . iter ( )
571
+ . map ( |c| c. message . clone ( ) )
572
+ . collect :: < Vec < _ > > ( ) ,
573
+ } )
481
574
}
482
575
}
483
576
@@ -502,7 +595,7 @@ impl AzureOpenAIEmbedder {
502
595
assert ! ( self . endpoint. is_some( ) ) ;
503
596
504
597
Ok ( format ! (
505
- "{}openai/deployments/{}/embeddings?api-version=2022-12 -01" ,
598
+ "{}openai/deployments/{}/embeddings?api-version=2023-08 -01-preview " ,
506
599
self . endpoint. as_ref( ) . unwrap( ) ,
507
600
self . deployment_id
508
601
)
@@ -597,13 +690,11 @@ impl Embedder for AzureOpenAIEmbedder {
597
690
}
598
691
599
692
async fn encode ( & self , text : & str ) -> Result < Vec < usize > > {
600
- let tokens = { self . tokenizer ( ) . lock ( ) . encode_with_special_tokens ( text) } ;
601
- Ok ( tokens)
693
+ encode_async ( self . tokenizer ( ) , text) . await
602
694
}
603
695
604
696
async fn decode ( & self , tokens : Vec < usize > ) -> Result < String > {
605
- let str = { self . tokenizer ( ) . lock ( ) . decode ( tokens) ? } ;
606
- Ok ( str)
697
+ decode_async ( self . tokenizer ( ) , tokens) . await
607
698
}
608
699
609
700
async fn embed ( & self , text : Vec < & str > , extras : Option < Value > ) -> Result < Vec < EmbedderVector > > {
0 commit comments