Skip to content

Commit d08311f

Browse files
committed
Centralized Request Middleware - also adds VAD tests
Signed-off-by: Dave Lee <dave@gray101.com>
1 parent 708cba0 commit d08311f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+481022
-803
lines changed

.bruno/LocalAI Test Requests/vad/vad test audio.bru

Lines changed: 240024 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
meta {
2+
name: vad test too few
3+
type: http
4+
seq: 1
5+
}
6+
7+
post {
8+
url: {{PROTOCOL}}{{HOST}}:{{PORT}}/vad
9+
body: json
10+
auth: none
11+
}
12+
13+
headers {
14+
Content-Type: application/json
15+
}
16+
17+
body:json {
18+
{
19+
"model": "silero-vad",
20+
"audio": []
21+
}
22+
}

aio/cpu/vad.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
backend: silero-vad
2+
name: silero-vad
3+
parameters:
4+
model: silero-vad.onnx
5+
download_files:
6+
- filename: silero-vad.onnx
7+
uri: https://huggingface.co/onnx-community/silero-vad/resolve/main/onnx/model.onnx
8+
sha256: a4a068cd6cf1ea8355b84327595838ca748ec29a25bc91fc82e6c299ccdc5808

aio/entrypoint.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ detect_gpu
129129
detect_gpu_size
130130

131131
PROFILE="${PROFILE:-$GPU_SIZE}" # default to cpu
132-
export MODELS="${MODELS:-/aio/${PROFILE}/embeddings.yaml,/aio/${PROFILE}/rerank.yaml,/aio/${PROFILE}/text-to-speech.yaml,/aio/${PROFILE}/image-gen.yaml,/aio/${PROFILE}/text-to-text.yaml,/aio/${PROFILE}/speech-to-text.yaml,/aio/${PROFILE}/vision.yaml}"
132+
export MODELS="${MODELS:-/aio/${PROFILE}/embeddings.yaml,/aio/${PROFILE}/rerank.yaml,/aio/${PROFILE}/text-to-speech.yaml,/aio/${PROFILE}/image-gen.yaml,/aio/${PROFILE}/text-to-text.yaml,/aio/${PROFILE}/speech-to-text.yaml,/aio/${PROFILE}/vad.yaml,/aio/${PROFILE}/vision.yaml}"
133133

134134
check_vars
135135

aio/gpu-8g/vad.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
backend: silero-vad
2+
name: silero-vad
3+
parameters:
4+
model: silero-vad.onnx
5+
download_files:
6+
- filename: silero-vad.onnx
7+
uri: https://huggingface.co/onnx-community/silero-vad/resolve/main/onnx/model.onnx
8+
sha256: a4a068cd6cf1ea8355b84327595838ca748ec29a25bc91fc82e6c299ccdc5808

aio/intel/vad.yaml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
backend: silero-vad
2+
name: silero-vad
3+
parameters:
4+
model: silero-vad.onnx
5+
download_files:
6+
- filename: silero-vad.onnx
7+
uri: https://huggingface.co/onnx-community/silero-vad/resolve/main/onnx/model.onnx
8+
sha256: a4a068cd6cf1ea8355b84327595838ca748ec29a25bc91fc82e6c299ccdc5808

core/application/startup.go

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -145,13 +145,7 @@ func New(opts ...config.AppOption) (*Application, error) {
145145

146146
if options.LoadToMemory != nil {
147147
for _, m := range options.LoadToMemory {
148-
cfg, err := application.BackendLoader().LoadBackendConfigFileByName(m, options.ModelPath,
149-
config.LoadOptionDebug(options.Debug),
150-
config.LoadOptionThreads(options.Threads),
151-
config.LoadOptionContextSize(options.ContextSize),
152-
config.LoadOptionF16(options.F16),
153-
config.ModelPath(options.ModelPath),
154-
)
148+
cfg, err := application.BackendLoader().LoadBackendConfigFileByNameDefaultOptions(m, options)
155149
if err != nil {
156150
return nil, err
157151
}

core/backend/llm.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ type TokenUsage struct {
3131
Completion int
3232
}
3333

34-
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
34+
func ModelInference(ctx context.Context, s string, messages []schema.Message, images, videos, audios []string, loader *model.ModelLoader, c *config.BackendConfig, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool) (func() (LLMResponse, error), error) {
3535
modelFile := c.Model
3636

3737
// Check if the modelFile exists, if it doesn't try to load it from the gallery
@@ -46,7 +46,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
4646
}
4747
}
4848

49-
opts := ModelOptions(c, o)
49+
opts := ModelOptions(*c, o)
5050
inferenceModel, err := loader.Load(opts...)
5151
if err != nil {
5252
return nil, err
@@ -82,7 +82,7 @@ func ModelInference(ctx context.Context, s string, messages []schema.Message, im
8282

8383
// in GRPC, the backend is supposed to answer to 1 single token if stream is not supported
8484
fn := func() (LLMResponse, error) {
85-
opts := gRPCPredictOpts(c, loader.ModelPath)
85+
opts := gRPCPredictOpts(*c, loader.ModelPath)
8686
opts.Prompt = s
8787
opts.Messages = protoMessages
8888
opts.UseTokenizerTemplate = c.TemplateConfig.UseTokenizerTemplate

core/backend/rerank.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ import (
99
model "github.com/mudler/LocalAI/pkg/model"
1010
)
1111

12-
func Rerank(modelFile string, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {
13-
14-
opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile))
12+
func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, backendConfig config.BackendConfig) (*proto.RerankResult, error) {
13+
opts := ModelOptions(backendConfig, appConfig)
1514
rerankModel, err := loader.Load(opts...)
15+
1616
if err != nil {
1717
return nil, err
1818
}

core/backend/soundgeneration.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
)
1414

1515
func SoundGeneration(
16-
modelFile string,
1716
text string,
1817
duration *float32,
1918
temperature *float32,
@@ -25,8 +24,9 @@ func SoundGeneration(
2524
backendConfig config.BackendConfig,
2625
) (string, *proto.Result, error) {
2726

28-
opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile))
27+
opts := ModelOptions(backendConfig, appConfig)
2928
soundGenModel, err := loader.Load(opts...)
29+
3030
if err != nil {
3131
return "", nil, err
3232
}
@@ -44,7 +44,7 @@ func SoundGeneration(
4444

4545
res, err := soundGenModel.SoundGeneration(context.Background(), &proto.SoundGenerationRequest{
4646
Text: text,
47-
Model: modelFile,
47+
Model: backendConfig.Model,
4848
Dst: filePath,
4949
Sample: doSample,
5050
Duration: duration,

core/backend/tokenize.go

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,17 @@ import (
44
"github.com/mudler/LocalAI/core/config"
55
"github.com/mudler/LocalAI/core/schema"
66
"github.com/mudler/LocalAI/pkg/grpc"
7-
model "github.com/mudler/LocalAI/pkg/model"
7+
"github.com/mudler/LocalAI/pkg/model"
88
)
99

1010
func ModelTokenize(s string, loader *model.ModelLoader, backendConfig config.BackendConfig, appConfig *config.ApplicationConfig) (schema.TokenizeResponse, error) {
1111

12-
modelFile := backendConfig.Model
13-
1412
var inferenceModel grpc.Backend
1513
var err error
1614

17-
opts := ModelOptions(backendConfig, appConfig, model.WithModel(modelFile))
15+
opts := ModelOptions(backendConfig, appConfig)
1816

17+
// TODO: looks weird, seems to be a correct merge?
1918
if backendConfig.Backend == "" {
2019
inferenceModel, err = loader.Load(opts...)
2120
} else {

core/backend/transcript.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func ModelTranscription(audio, language string, translate bool, ml *model.ModelL
4747
tks = append(tks, int(t))
4848
}
4949
tr.Segments = append(tr.Segments,
50-
schema.Segment{
50+
schema.TranscriptionSegment{
5151
Text: s.Text,
5252
Id: int(s.Id),
5353
Start: time.Duration(s.Start),

core/backend/tts.go

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,22 @@ import (
1414
)
1515

1616
func ModelTTS(
17-
backend,
1817
text,
19-
modelFile,
2018
voice,
2119
language string,
2220
loader *model.ModelLoader,
2321
appConfig *config.ApplicationConfig,
2422
backendConfig config.BackendConfig,
2523
) (string, *proto.Result, error) {
26-
bb := backend
27-
if bb == "" {
28-
bb = model.PiperBackend
29-
}
30-
31-
opts := ModelOptions(backendConfig, appConfig, model.WithBackendString(bb), model.WithModel(modelFile))
24+
opts := ModelOptions(backendConfig, appConfig, model.WithDefaultBackendString(model.PiperBackend))
3225
ttsModel, err := loader.Load(opts...)
26+
3327
if err != nil {
3428
return "", nil, err
3529
}
3630

3731
if ttsModel == nil {
38-
return "", nil, fmt.Errorf("could not load piper model")
32+
return "", nil, fmt.Errorf("could not load tts model %q", backendConfig.Model)
3933
}
4034

4135
if err := os.MkdirAll(appConfig.AudioDir, 0750); err != nil {
@@ -45,22 +39,21 @@ func ModelTTS(
4539
fileName := utils.GenerateUniqueFileName(appConfig.AudioDir, "tts", ".wav")
4640
filePath := filepath.Join(appConfig.AudioDir, fileName)
4741

48-
// If the model file is not empty, we pass it joined with the model path
42+
// We join the model name to the model path here. This seems to only be done for TTS and is HIGHLY suspect.
43+
// This should be addressed in a follow up PR soon.
44+
// Copying it over nearly verbatim, as TTS backends are not functional without this.
4945
modelPath := ""
50-
if modelFile != "" {
51-
// If the model file is not empty, we pass it joined with the model path
52-
// Checking first that it exists and is not outside ModelPath
53-
// TODO: we should actually first check if the modelFile is looking like
54-
// a FS path
55-
mp := filepath.Join(loader.ModelPath, modelFile)
56-
if _, err := os.Stat(mp); err == nil {
57-
if err := utils.VerifyPath(mp, appConfig.ModelPath); err != nil {
58-
return "", nil, err
59-
}
60-
modelPath = mp
61-
} else {
62-
modelPath = modelFile
46+
// Checking first that it exists and is not outside ModelPath
47+
// TODO: we should actually first check if the modelFile is looking like
48+
// a FS path
49+
mp := filepath.Join(loader.ModelPath, backendConfig.Model)
50+
if _, err := os.Stat(mp); err == nil {
51+
if err := utils.VerifyPath(mp, appConfig.ModelPath); err != nil {
52+
return "", nil, err
6353
}
54+
modelPath = mp
55+
} else {
56+
modelPath = backendConfig.Model // skip this step if it fails?????
6457
}
6558

6659
res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{

core/backend/vad.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
package backend
2+
3+
import (
4+
"context"
5+
6+
"github.com/mudler/LocalAI/core/config"
7+
"github.com/mudler/LocalAI/core/schema"
8+
"github.com/mudler/LocalAI/pkg/grpc/proto"
9+
"github.com/mudler/LocalAI/pkg/model"
10+
)
11+
12+
func VAD(request *schema.VADRequest,
13+
ctx context.Context,
14+
ml *model.ModelLoader,
15+
appConfig *config.ApplicationConfig,
16+
backendConfig config.BackendConfig) (*schema.VADResponse, error) {
17+
opts := ModelOptions(backendConfig, appConfig)
18+
vadModel, err := ml.Load(opts...)
19+
if err != nil {
20+
return nil, err
21+
}
22+
req := proto.VADRequest{
23+
Audio: request.Audio,
24+
}
25+
resp, err := vadModel.VAD(ctx, &req)
26+
if err != nil {
27+
return nil, err
28+
}
29+
30+
segments := []schema.VADSegment{}
31+
for _, s := range resp.Segments {
32+
segments = append(segments, schema.VADSegment{Start: s.Start, End: s.End})
33+
}
34+
35+
return &schema.VADResponse{
36+
Segments: segments,
37+
}, nil
38+
}

core/cli/soundgeneration.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,14 @@ func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error {
8686
options := config.BackendConfig{}
8787
options.SetDefaults()
8888
options.Backend = t.Backend
89+
options.Model = t.Model
8990

9091
var inputFile *string
9192
if t.InputFile != "" {
9293
inputFile = &t.InputFile
9394
}
9495

95-
filePath, _, err := backend.SoundGeneration(t.Model, text,
96+
filePath, _, err := backend.SoundGeneration(text,
9697
parseToFloat32Ptr(t.Duration), parseToFloat32Ptr(t.Temperature), &t.DoSample,
9798
inputFile, parseToInt32Ptr(t.InputFileSampleDivisor), ml, opts, options)
9899

core/cli/tts.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,10 @@ func (t *TTSCMD) Run(ctx *cliContext.Context) error {
5252

5353
options := config.BackendConfig{}
5454
options.SetDefaults()
55+
options.Backend = t.Backend
56+
options.Model = t.Model
5557

56-
filePath, _, err := backend.ModelTTS(t.Backend, text, t.Model, t.Voice, t.Language, ml, opts, options)
58+
filePath, _, err := backend.ModelTTS(text, t.Voice, t.Language, ml, opts, options)
5759
if err != nil {
5860
return err
5961
}

core/config/backend_config.go

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -441,19 +441,21 @@ func (c *BackendConfig) HasTemplate() bool {
441441
type BackendConfigUsecases int
442442

443443
const (
444-
FLAG_ANY BackendConfigUsecases = 0b000000000
445-
FLAG_CHAT BackendConfigUsecases = 0b000000001
446-
FLAG_COMPLETION BackendConfigUsecases = 0b000000010
447-
FLAG_EDIT BackendConfigUsecases = 0b000000100
448-
FLAG_EMBEDDINGS BackendConfigUsecases = 0b000001000
449-
FLAG_RERANK BackendConfigUsecases = 0b000010000
450-
FLAG_IMAGE BackendConfigUsecases = 0b000100000
451-
FLAG_TRANSCRIPT BackendConfigUsecases = 0b001000000
452-
FLAG_TTS BackendConfigUsecases = 0b010000000
453-
FLAG_SOUND_GENERATION BackendConfigUsecases = 0b100000000
444+
FLAG_ANY BackendConfigUsecases = 0b00000000000
445+
FLAG_CHAT BackendConfigUsecases = 0b00000000001
446+
FLAG_COMPLETION BackendConfigUsecases = 0b00000000010
447+
FLAG_EDIT BackendConfigUsecases = 0b00000000100
448+
FLAG_EMBEDDINGS BackendConfigUsecases = 0b00000001000
449+
FLAG_RERANK BackendConfigUsecases = 0b00000010000
450+
FLAG_IMAGE BackendConfigUsecases = 0b00000100000
451+
FLAG_TRANSCRIPT BackendConfigUsecases = 0b00001000000
452+
FLAG_TTS BackendConfigUsecases = 0b00010000000
453+
FLAG_SOUND_GENERATION BackendConfigUsecases = 0b00100000000
454+
FLAG_TOKENIZE BackendConfigUsecases = 0b01000000000
455+
FLAG_VAD BackendConfigUsecases = 0b10000000000
454456

455457
// Common Subsets
456-
FLAG_LLM BackendConfigUsecases = FLAG_CHAT & FLAG_COMPLETION & FLAG_EDIT
458+
FLAG_LLM BackendConfigUsecases = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT
457459
)
458460

459461
func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases {
@@ -468,6 +470,8 @@ func GetAllBackendConfigUsecases() map[string]BackendConfigUsecases {
468470
"FLAG_TRANSCRIPT": FLAG_TRANSCRIPT,
469471
"FLAG_TTS": FLAG_TTS,
470472
"FLAG_SOUND_GENERATION": FLAG_SOUND_GENERATION,
473+
"FLAG_TOKENIZE": FLAG_TOKENIZE,
474+
"FLAG_VAD": FLAG_VAD,
471475
"FLAG_LLM": FLAG_LLM,
472476
}
473477
}
@@ -553,5 +557,18 @@ func (c *BackendConfig) GuessUsecases(u BackendConfigUsecases) bool {
553557
}
554558
}
555559

560+
if (u & FLAG_TOKENIZE) == FLAG_TOKENIZE {
561+
tokenizeCapableBackends := []string{"llama.cpp", "rwkv"}
562+
if !slices.Contains(tokenizeCapableBackends, c.Backend) {
563+
return false
564+
}
565+
}
566+
567+
if (u & FLAG_VAD) == FLAG_VAD {
568+
if c.Backend != "silero-vad" {
569+
return false
570+
}
571+
}
572+
556573
return true
557574
}

0 commit comments

Comments
 (0)