Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(go): Added action-level and generate-level middleware. #1949

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion go/ai/action_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,11 @@ func (pm *programmableModel) Generate(ctx context.Context, r *registry.Registry,

func defineProgrammableModel(r *registry.Registry) *programmableModel {
pm := &programmableModel{r: r}
DefineModel(r, "default", "programmableModel", nil, func(ctx context.Context, req *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error) {
supports := &ModelInfoSupports{
Tools: true,
Multiturn: true,
}
DefineModel(r, "", "programmableModel", &ModelInfo{Supports: supports}, nil, func(ctx context.Context, req *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error) {
return pm.Generate(ctx, r, req, &ToolConfig{MaxTurns: 5}, cb)
})
return pm
Expand Down
3 changes: 1 addition & 2 deletions go/ai/embedder.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0


package ai

import (
Expand Down Expand Up @@ -52,7 +51,7 @@ func DefineEmbedder(
provider, name string,
embed func(context.Context, *EmbedRequest) (*EmbedResponse, error),
) Embedder {
return (*embedderActionDef)(core.DefineAction(r, provider, name, atype.Embedder, nil, embed))
return (*embedderActionDef)(core.DefineAction(r, provider, name, atype.Embedder, nil, nil, embed))
}

// IsDefinedEmbedder reports whether an embedder is defined.
Expand Down
6 changes: 3 additions & 3 deletions go/ai/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ type ModelInfoSupports struct {

// A ModelRequest is a request to generate completions from a model.
type ModelRequest struct {
Config any `json:"config,omitempty"`
Context []any `json:"context,omitempty"`
Messages []*Message `json:"messages,omitempty"`
Config any `json:"config,omitempty"`
Context []*Document `json:"context,omitempty"`
Messages []*Message `json:"messages,omitempty"`
// Output describes the desired response format.
Output *ModelRequestOutput `json:"output,omitempty"`
ToolChoice ToolChoice `json:"toolChoice,omitempty"`
Expand Down
61 changes: 46 additions & 15 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,15 @@ type Model interface {
// Name returns the registry name of the model.
Name() string
// Generate applies the [Model] to provided request, handling tool requests and handles streaming.
Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, toolCfg *ToolConfig, cb ModelStreamingCallback) (*ModelResponse, error)
Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, mw []ModelMiddleware, toolCfg *ToolConfig, cb ModelStreamingCallback) (*ModelResponse, error)
}

// ModelFunc is a function that generates a model response.
type ModelFunc = core.Func[*ModelRequest, *ModelResponse, *ModelResponseChunk]

// ModelMiddleware is middleware for model generate requests.
type ModelMiddleware = core.Middleware[*ModelRequest, *ModelResponse, *ModelResponseChunk]

type modelActionDef core.Action[*ModelRequest, *ModelResponse, *ModelResponseChunk]

type modelAction = core.Action[*ModelRequest, *ModelResponse, *ModelResponseChunk]
Expand All @@ -44,7 +50,7 @@ type ToolConfig struct {

// DefineGenerateAction defines a utility generate action.
func DefineGenerateAction(ctx context.Context, r *registry.Registry) *generateAction {
return (*generateAction)(core.DefineStreamingAction(r, "", "generate", atype.Util, map[string]any{},
return (*generateAction)(core.DefineStreamingAction(r, "", "generate", atype.Util, map[string]any{}, nil,
func(ctx context.Context, req *GenerateActionOptions, cb ModelStreamingCallback) (output *ModelResponse, err error) {
logger.FromContext(ctx).Debug("GenerateAction",
"input", fmt.Sprintf("%#v", req))
Expand All @@ -53,9 +59,10 @@ func DefineGenerateAction(ctx context.Context, r *registry.Registry) *generateAc
"output", fmt.Sprintf("%#v", output),
"err", err)
}()

return tracing.RunInNewSpan(ctx, r.TracingState(), "generate", "util", false, req,
func(ctx context.Context, input *GenerateActionOptions) (*ModelResponse, error) {
model := LookupModel(r, "default", req.Model)
model := LookupModel(r, "", req.Model)
if model == nil {
return nil, fmt.Errorf("model %q not found", req.Model)
}
Expand Down Expand Up @@ -95,17 +102,17 @@ func DefineGenerateAction(ctx context.Context, r *registry.Registry) *generateAc
ReturnToolRequests: req.ReturnToolRequests,
}

return model.Generate(ctx, r, modelReq, toolCfg, cb)
return model.Generate(ctx, r, modelReq, nil, toolCfg, cb)
})
}))
}

// DefineModel registers the given generate function as an action, and returns a
// [Model] that runs it.
// DefineModel registers the given generate function as an action, and returns a [Model] that runs it.
func DefineModel(
r *registry.Registry,
provider, name string,
metadata *ModelInfo,
mw []ModelMiddleware,
generate func(context.Context, *ModelRequest, ModelStreamingCallback) (*ModelResponse, error),
) Model {
metadataMap := map[string]any{}
Expand All @@ -129,9 +136,9 @@ func DefineModel(
metadataMap["supports"] = supports
metadataMap["versions"] = metadata.Versions

return (*modelActionDef)(core.DefineStreamingAction(r, provider, name, atype.Model, map[string]any{
"model": metadataMap,
}, generate))
mw = append([]ModelMiddleware{ValidateSupport(name, metadata.Supports)}, mw...)

return (*modelActionDef)(core.DefineStreamingAction(r, provider, name, atype.Model, map[string]any{"model": metadataMap}, mw, generate))
}

// IsDefinedModel reports whether a model is defined.
Expand All @@ -158,6 +165,7 @@ type generateParams struct {
SystemPrompt *Message
MaxTurns int
ReturnToolRequests bool
Middleware []ModelMiddleware
}

// GenerateOption configures params of the Generate call.
Expand Down Expand Up @@ -224,10 +232,13 @@ func WithConfig(config any) GenerateOption {
}
}

// WithContext adds provided context to ModelRequest.
func WithContext(c ...any) GenerateOption {
// WithContext adds provided documents to ModelRequest.
func WithContext(docs ...*Document) GenerateOption {
return func(req *generateParams) error {
req.Request.Context = append(req.Request.Context, c...)
if req.Request.Context != nil {
return errors.New("generate.WithContext: cannot set context more than once")
}
req.Request.Context = docs
return nil
}
}
Expand Down Expand Up @@ -320,6 +331,17 @@ func WithToolChoice(toolChoice ToolChoice) GenerateOption {
}
}

// WithMiddleware adds middleware to the generate request.
func WithMiddleware(middleware ...ModelMiddleware) GenerateOption {
return func(req *generateParams) error {
if req.Middleware != nil {
return errors.New("generate.WithMiddleware: cannot set Middleware more than once")
}
req.Middleware = middleware
return nil
}
}

// Generate run generate request for this model. Returns ModelResponse struct.
func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption) (*ModelResponse, error) {
req := &generateParams{
Expand Down Expand Up @@ -368,7 +390,7 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption)
ReturnToolRequests: req.ReturnToolRequests,
}

return req.Model.Generate(ctx, r, req.Request, toolCfg, req.Stream)
return req.Model.Generate(ctx, r, req.Request, req.Middleware, toolCfg, req.Stream)
}

// validateModelVersion checks in the registry the action of the
Expand Down Expand Up @@ -435,7 +457,7 @@ func GenerateData(ctx context.Context, r *registry.Registry, value any, opts ...
}

// Generate applies the [Action] to provided request, handling tool requests and handles streaming.
func (m *modelActionDef) Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, toolCfg *ToolConfig, cb ModelStreamingCallback) (*ModelResponse, error) {
func (m *modelActionDef) Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, mw []ModelMiddleware, toolCfg *ToolConfig, cb ModelStreamingCallback) (*ModelResponse, error) {
if m == nil {
return nil, errors.New("Generate called on a nil Model; check that all models are defined")
}
Expand Down Expand Up @@ -463,9 +485,18 @@ func (m *modelActionDef) Generate(ctx context.Context, r *registry.Registry, req
return nil, err
}

handler := (*modelAction)(m).Run
for i := len(mw) - 1; i >= 0; i-- {
currentHandler := handler
currentMiddleware := mw[i]
handler = func(ctx context.Context, in *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error) {
return currentMiddleware(ctx, in, cb, currentHandler)
}
}

currentTurn := 0
for {
resp, err := (*modelAction)(m).Run(ctx, req, cb)
resp, err := handler(ctx, req, cb)
if err != nil {
return nil, err
}
Expand Down
65 changes: 58 additions & 7 deletions go/ai/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ var (
Versions: []string{"echo-001", "echo-002"},
}

echoModel = DefineModel(r, "test", modelName, &metadata, func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
echoModel = DefineModel(r, "test", modelName, &metadata, nil, func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
if msc != nil {
msc(ctx, &ModelResponseChunk{
Content: []*Part{NewTextPart("stream!")},
Expand Down Expand Up @@ -264,7 +264,7 @@ func TestGenerate(t *testing.T) {
},
},
Config: GenerationCommonConfig{Temperature: 1},
Context: []any{[]any{string("Banana")}},
Context: []*Document{&Document{Content: []*Part{NewTextPart("Banana")}}},
Output: &ModelRequestOutput{
Format: "json",
Schema: map[string]any{
Expand Down Expand Up @@ -310,7 +310,7 @@ func TestGenerate(t *testing.T) {
Temperature: 1,
}),
WithHistory(NewUserTextMessage("banana"), NewModelTextMessage("yes, banana")),
WithContext([]any{"Banana"}),
WithContext(&Document{Content: []*Part{NewTextPart("Banana")}}),
WithOutputSchema(&GameCharacter{}),
WithTools(gablorkenTool),
WithStreaming(func(ctx context.Context, grc *ModelResponseChunk) error {
Expand Down Expand Up @@ -346,7 +346,13 @@ func TestGenerate(t *testing.T) {
},
)

interruptModel := DefineModel(r, "test", "interrupt", nil,
info := &ModelInfo{
Supports: &ModelInfoSupports{
Multiturn: true,
Tools: true,
},
}
interruptModel := DefineModel(r, "test", "interrupt", info, nil,
func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
return &ModelResponse{
Request: gr,
Expand Down Expand Up @@ -399,7 +405,13 @@ func TestGenerate(t *testing.T) {

t.Run("handles multiple parallel tool calls", func(t *testing.T) {
roundCount := 0
parallelModel := DefineModel(r, "test", "parallel", nil,
info := &ModelInfo{
Supports: &ModelInfoSupports{
Multiturn: true,
Tools: true,
},
}
parallelModel := DefineModel(r, "test", "parallel", info, nil,
func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
roundCount++
if roundCount == 1 {
Expand Down Expand Up @@ -458,7 +470,13 @@ func TestGenerate(t *testing.T) {

t.Run("handles multiple rounds of tool calls", func(t *testing.T) {
roundCount := 0
multiRoundModel := DefineModel(r, "test", "multiround", nil,
info := &ModelInfo{
Supports: &ModelInfoSupports{
Multiturn: true,
Tools: true,
},
}
multiRoundModel := DefineModel(r, "test", "multiround", info, nil,
func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
roundCount++
if roundCount == 1 {
Expand Down Expand Up @@ -520,7 +538,13 @@ func TestGenerate(t *testing.T) {
})

t.Run("exceeds maximum turns", func(t *testing.T) {
infiniteModel := DefineModel(r, "test", "infinite", nil,
info := &ModelInfo{
Supports: &ModelInfoSupports{
Multiturn: true,
Tools: true,
},
}
infiniteModel := DefineModel(r, "test", "infinite", info, nil,
func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
return &ModelResponse{
Request: gr,
Expand Down Expand Up @@ -550,6 +574,33 @@ func TestGenerate(t *testing.T) {
t.Errorf("unexpected error message: %v", err)
}
})

t.Run("applies middleware", func(t *testing.T) {
middlewareCalled := false
testMiddleware := func(ctx context.Context, req *ModelRequest, cb ModelStreamingCallback, next ModelFunc) (*ModelResponse, error) {
middlewareCalled = true
req.Messages = append(req.Messages, NewUserTextMessage("middleware was here"))
return next(ctx, req, cb)
}

res, err := Generate(context.Background(), r,
WithModel(echoModel),
WithTextPrompt("test middleware"),
WithMiddleware(testMiddleware),
)
if err != nil {
t.Fatal(err)
}

if !middlewareCalled {
t.Error("middleware was not called")
}

expectedText := "test middlewaremiddleware was here"
if res.Text() != expectedText {
t.Errorf("got text %q, want %q", res.Text(), expectedText)
}
})
}

func TestModelVersion(t *testing.T) {
Expand Down
53 changes: 53 additions & 0 deletions go/ai/middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright 2025 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package ai

import (
"context"
"fmt"

"github.com/firebase/genkit/go/core"
)

// ValidateSupport creates middleware that validates whether a model supports the requested features.
func ValidateSupport(name string, supports *ModelInfoSupports) ModelMiddleware {
return func(ctx context.Context, input *ModelRequest, cb ModelStreamingCallback, next core.Func[*ModelRequest, *ModelResponse, *ModelResponseChunk]) (*ModelResponse, error) {
if supports == nil {
supports = &ModelInfoSupports{}
}

if !supports.Media {
for _, msg := range input.Messages {
for _, part := range msg.Content {
if part.IsMedia() {
return nil, fmt.Errorf("model %q does not support media, but media was provided. Request: %+v", name, input)
}
}
}
}

if !supports.Tools && len(input.Tools) > 0 {
return nil, fmt.Errorf("model %q does not support tool use, but tools were provided. Request: %+v", name, input)
}

if !supports.Multiturn && len(input.Messages) > 1 {
return nil, fmt.Errorf("model %q does not support multiple messages, but %d were provided. Request: %+v", name, len(input.Messages), input)
}

// TODO: Add validation for features that won't have simulated support via middleware.

return next(ctx, input, cb)
}
}
Loading
Loading