From 2102a9c7e5025d6e89b30068c356ca6044bce9f8 Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 12 Feb 2025 07:46:24 -0800 Subject: [PATCH 1/3] Added action-level and generate-level middleware. --- go/ai/action_test.go | 6 +- go/ai/embedder.go | 3 +- go/ai/gen.go | 6 +- go/ai/generate.go | 61 ++++++--- go/ai/generator_test.go | 65 ++++++++-- go/ai/middleware.go | 37 ++++++ go/ai/middleware_test.go | 134 ++++++++++++++++++++ go/ai/prompt.go | 3 +- go/ai/retriever.go | 5 +- go/ai/tools.go | 2 +- go/core/action.go | 74 ++++++++--- go/core/action_test.go | 164 ++++++++++++++++++++++++- go/core/schemas.config | 2 +- go/genkit/flow.go | 3 +- go/genkit/genkit.go | 10 +- go/genkit/servers_test.go | 10 +- go/go.mod | 2 +- go/plugins/dotprompt/dotprompt_test.go | 3 +- go/plugins/dotprompt/genkit.go | 25 +++- go/plugins/dotprompt/genkit_test.go | 9 +- go/plugins/googleai/googleai.go | 2 +- go/plugins/ollama/ollama.go | 2 +- go/plugins/ollama/ollama_live_test.go | 3 +- go/plugins/vertexai/vertexai.go | 2 +- go/tests/test_app/main.go | 3 +- 25 files changed, 543 insertions(+), 93 deletions(-) create mode 100644 go/ai/middleware.go create mode 100644 go/ai/middleware_test.go diff --git a/go/ai/action_test.go b/go/ai/action_test.go index 297eb2eec..f66e2232c 100644 --- a/go/ai/action_test.go +++ b/go/ai/action_test.go @@ -56,7 +56,11 @@ func (pm *programmableModel) Generate(ctx context.Context, r *registry.Registry, func defineProgrammableModel(r *registry.Registry) *programmableModel { pm := &programmableModel{r: r} - DefineModel(r, "default", "programmableModel", nil, func(ctx context.Context, req *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error) { + supports := &ModelInfoSupports{ + Tools: true, + Multiturn: true, + } + DefineModel(r, "", "programmableModel", &ModelInfo{Supports: supports}, nil, func(ctx context.Context, req *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error) { return pm.Generate(ctx, r, req, &ToolConfig{MaxTurns: 5}, cb) }) return pm diff --git a/go/ai/embedder.go b/go/ai/embedder.go index 322a4548d..860fdb6c0 100644 --- a/go/ai/embedder.go +++ b/go/ai/embedder.go @@ -1,7 +1,6 @@ // Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 - package ai import ( @@ -52,7 +51,7 @@ func DefineEmbedder( provider, name string, embed func(context.Context, *EmbedRequest) (*EmbedResponse, error), ) Embedder { - return (*embedderActionDef)(core.DefineAction(r, provider, name, atype.Embedder, nil, embed)) + return (*embedderActionDef)(core.DefineAction(r, provider, name, atype.Embedder, nil, nil, embed)) } // IsDefinedEmbedder reports whether an embedder is defined. diff --git a/go/ai/gen.go b/go/ai/gen.go index 265051cd3..b357312ee 100644 --- a/go/ai/gen.go +++ b/go/ai/gen.go @@ -124,9 +124,9 @@ type ModelInfoSupports struct { // A ModelRequest is a request to generate completions from a model. type ModelRequest struct { - Config any `json:"config,omitempty"` - Context []any `json:"context,omitempty"` - Messages []*Message `json:"messages,omitempty"` + Config any `json:"config,omitempty"` + Context []*Document `json:"context,omitempty"` + Messages []*Message `json:"messages,omitempty"` // Output describes the desired response format. Output *ModelRequestOutput `json:"output,omitempty"` ToolChoice ToolChoice `json:"toolChoice,omitempty"` diff --git a/go/ai/generate.go b/go/ai/generate.go index cf4fe202a..41ed4bcae 100644 --- a/go/ai/generate.go +++ b/go/ai/generate.go @@ -24,9 +24,15 @@ type Model interface { // Name returns the registry name of the model. Name() string // Generate applies the [Model] to provided request, handling tool requests and handles streaming. - Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, toolCfg *ToolConfig, cb ModelStreamingCallback) (*ModelResponse, error) + Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, mw []ModelMiddleware, toolCfg *ToolConfig, cb ModelStreamingCallback) (*ModelResponse, error) } +// ModelFunc is a function that generates a model response. +type ModelFunc = core.Func[*ModelRequest, *ModelResponse, *ModelResponseChunk] + +// ModelMiddleware is middleware for model generate requests. +type ModelMiddleware = core.Middleware[*ModelRequest, *ModelResponse, *ModelResponseChunk] + type modelActionDef core.Action[*ModelRequest, *ModelResponse, *ModelResponseChunk] type modelAction = core.Action[*ModelRequest, *ModelResponse, *ModelResponseChunk] @@ -44,7 +50,7 @@ type ToolConfig struct { // DefineGenerateAction defines a utility generate action. func DefineGenerateAction(ctx context.Context, r *registry.Registry) *generateAction { - return (*generateAction)(core.DefineStreamingAction(r, "", "generate", atype.Util, map[string]any{}, + return (*generateAction)(core.DefineStreamingAction(r, "", "generate", atype.Util, map[string]any{}, nil, func(ctx context.Context, req *GenerateActionOptions, cb ModelStreamingCallback) (output *ModelResponse, err error) { logger.FromContext(ctx).Debug("GenerateAction", "input", fmt.Sprintf("%#v", req)) @@ -53,9 +59,10 @@ func DefineGenerateAction(ctx context.Context, r *registry.Registry) *generateAc "output", fmt.Sprintf("%#v", output), "err", err) }() + return tracing.RunInNewSpan(ctx, r.TracingState(), "generate", "util", false, req, func(ctx context.Context, input *GenerateActionOptions) (*ModelResponse, error) { - model := LookupModel(r, "default", req.Model) + model := LookupModel(r, "", req.Model) if model == nil { return nil, fmt.Errorf("model %q not found", req.Model) } @@ -95,17 +102,17 @@ func DefineGenerateAction(ctx context.Context, r *registry.Registry) *generateAc ReturnToolRequests: req.ReturnToolRequests, } - return model.Generate(ctx, r, modelReq, toolCfg, cb) + return model.Generate(ctx, r, modelReq, nil, toolCfg, cb) }) })) } -// DefineModel registers the given generate function as an action, and returns a -// [Model] that runs it. +// DefineModel registers the given generate function as an action, and returns a [Model] that runs it. func DefineModel( r *registry.Registry, provider, name string, metadata *ModelInfo, + mw []ModelMiddleware, generate func(context.Context, *ModelRequest, ModelStreamingCallback) (*ModelResponse, error), ) Model { metadataMap := map[string]any{} @@ -129,9 +136,9 @@ func DefineModel( metadataMap["supports"] = supports metadataMap["versions"] = metadata.Versions - return (*modelActionDef)(core.DefineStreamingAction(r, provider, name, atype.Model, map[string]any{ - "model": metadataMap, - }, generate)) + mw = append([]ModelMiddleware{ValidateSupport(name, metadata.Supports)}, mw...) + + return (*modelActionDef)(core.DefineStreamingAction(r, provider, name, atype.Model, map[string]any{"model": metadataMap}, mw, generate)) } // IsDefinedModel reports whether a model is defined. @@ -158,6 +165,7 @@ type generateParams struct { SystemPrompt *Message MaxTurns int ReturnToolRequests bool + Middleware []ModelMiddleware } // GenerateOption configures params of the Generate call. @@ -224,10 +232,13 @@ func WithConfig(config any) GenerateOption { } } -// WithContext adds provided context to ModelRequest. -func WithContext(c ...any) GenerateOption { +// WithContext adds provided documents to ModelRequest. +func WithContext(docs ...*Document) GenerateOption { return func(req *generateParams) error { - req.Request.Context = append(req.Request.Context, c...) + if req.Request.Context != nil { + return errors.New("generate.WithContext: cannot set context more than once") + } + req.Request.Context = docs return nil } } @@ -320,6 +331,17 @@ func WithToolChoice(toolChoice ToolChoice) GenerateOption { } } +// WithMiddleware adds middleware to the generate request. +func WithMiddleware(middleware ...ModelMiddleware) GenerateOption { + return func(req *generateParams) error { + if req.Middleware != nil { + return errors.New("generate.WithMiddleware: cannot set Middleware more than once") + } + req.Middleware = middleware + return nil + } +} + // Generate run generate request for this model. Returns ModelResponse struct. func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption) (*ModelResponse, error) { req := &generateParams{ @@ -368,7 +390,7 @@ func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption) ReturnToolRequests: req.ReturnToolRequests, } - return req.Model.Generate(ctx, r, req.Request, toolCfg, req.Stream) + return req.Model.Generate(ctx, r, req.Request, req.Middleware, toolCfg, req.Stream) } // validateModelVersion checks in the registry the action of the @@ -435,7 +457,7 @@ func GenerateData(ctx context.Context, r *registry.Registry, value any, opts ... } // Generate applies the [Action] to provided request, handling tool requests and handles streaming. -func (m *modelActionDef) Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, toolCfg *ToolConfig, cb ModelStreamingCallback) (*ModelResponse, error) { +func (m *modelActionDef) Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, mw []ModelMiddleware, toolCfg *ToolConfig, cb ModelStreamingCallback) (*ModelResponse, error) { if m == nil { return nil, errors.New("Generate called on a nil Model; check that all models are defined") } @@ -463,9 +485,18 @@ func (m *modelActionDef) Generate(ctx context.Context, r *registry.Registry, req return nil, err } + handler := (*modelAction)(m).Run + for i := len(mw) - 1; i >= 0; i-- { + currentHandler := handler + currentMiddleware := mw[i] + handler = func(ctx context.Context, in *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error) { + return currentMiddleware(ctx, in, cb, currentHandler) + } + } + currentTurn := 0 for { - resp, err := (*modelAction)(m).Run(ctx, req, cb) + resp, err := handler(ctx, req, cb) if err != nil { return nil, err } diff --git a/go/ai/generator_test.go b/go/ai/generator_test.go index 36337c423..a94e19c38 100644 --- a/go/ai/generator_test.go +++ b/go/ai/generator_test.go @@ -37,7 +37,7 @@ var ( Versions: []string{"echo-001", "echo-002"}, } - echoModel = DefineModel(r, "test", modelName, &metadata, func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) { + echoModel = DefineModel(r, "test", modelName, &metadata, nil, func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) { if msc != nil { msc(ctx, &ModelResponseChunk{ Content: []*Part{NewTextPart("stream!")}, @@ -264,7 +264,7 @@ func TestGenerate(t *testing.T) { }, }, Config: GenerationCommonConfig{Temperature: 1}, - Context: []any{[]any{string("Banana")}}, + Context: []*Document{&Document{Content: []*Part{NewTextPart("Banana")}}}, Output: &ModelRequestOutput{ Format: "json", Schema: map[string]any{ @@ -310,7 +310,7 @@ func TestGenerate(t *testing.T) { Temperature: 1, }), WithHistory(NewUserTextMessage("banana"), NewModelTextMessage("yes, banana")), - WithContext([]any{"Banana"}), + WithContext(&Document{Content: []*Part{NewTextPart("Banana")}}), WithOutputSchema(&GameCharacter{}), WithTools(gablorkenTool), WithStreaming(func(ctx context.Context, grc *ModelResponseChunk) error { @@ -346,7 +346,13 @@ func TestGenerate(t *testing.T) { }, ) - interruptModel := DefineModel(r, "test", "interrupt", nil, + info := &ModelInfo{ + Supports: &ModelInfoSupports{ + Multiturn: true, + Tools: true, + }, + } + interruptModel := DefineModel(r, "test", "interrupt", info, nil, func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) { return &ModelResponse{ Request: gr, @@ -399,7 +405,13 @@ func TestGenerate(t *testing.T) { t.Run("handles multiple parallel tool calls", func(t *testing.T) { roundCount := 0 - parallelModel := DefineModel(r, "test", "parallel", nil, + info := &ModelInfo{ + Supports: &ModelInfoSupports{ + Multiturn: true, + Tools: true, + }, + } + parallelModel := DefineModel(r, "test", "parallel", info, nil, func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) { roundCount++ if roundCount == 1 { @@ -458,7 +470,13 @@ func TestGenerate(t *testing.T) { t.Run("handles multiple rounds of tool calls", func(t *testing.T) { roundCount := 0 - multiRoundModel := DefineModel(r, "test", "multiround", nil, + info := &ModelInfo{ + Supports: &ModelInfoSupports{ + Multiturn: true, + Tools: true, + }, + } + multiRoundModel := DefineModel(r, "test", "multiround", info, nil, func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) { roundCount++ if roundCount == 1 { @@ -520,7 +538,13 @@ func TestGenerate(t *testing.T) { }) t.Run("exceeds maximum turns", func(t *testing.T) { - infiniteModel := DefineModel(r, "test", "infinite", nil, + info := &ModelInfo{ + Supports: &ModelInfoSupports{ + Multiturn: true, + Tools: true, + }, + } + infiniteModel := DefineModel(r, "test", "infinite", info, nil, func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) { return &ModelResponse{ Request: gr, @@ -550,6 +574,33 @@ func TestGenerate(t *testing.T) { t.Errorf("unexpected error message: %v", err) } }) + + t.Run("applies middleware", func(t *testing.T) { + middlewareCalled := false + testMiddleware := func(ctx context.Context, req *ModelRequest, cb ModelStreamingCallback, next ModelFunc) (*ModelResponse, error) { + middlewareCalled = true + req.Messages = append(req.Messages, NewUserTextMessage("middleware was here")) + return next(ctx, req, cb) + } + + res, err := Generate(context.Background(), r, + WithModel(echoModel), + WithTextPrompt("test middleware"), + WithMiddleware(testMiddleware), + ) + if err != nil { + t.Fatal(err) + } + + if !middlewareCalled { + t.Error("middleware was not called") + } + + expectedText := "test middlewaremiddleware was here" + if res.Text() != expectedText { + t.Errorf("got text %q, want %q", res.Text(), expectedText) + } + }) } func TestModelVersion(t *testing.T) { diff --git a/go/ai/middleware.go b/go/ai/middleware.go new file mode 100644 index 000000000..4940bb1a4 --- /dev/null +++ b/go/ai/middleware.go @@ -0,0 +1,37 @@ +package ai + +import ( + "context" + "fmt" + + "github.com/firebase/genkit/go/core" +) + +// ValidateSupport creates middleware that validates whether a model supports the requested features. +func ValidateSupport(name string, supports *ModelInfoSupports) ModelMiddleware { + return func(ctx context.Context, input *ModelRequest, cb ModelStreamingCallback, next core.Func[*ModelRequest, *ModelResponse, *ModelResponseChunk]) (*ModelResponse, error) { + if supports == nil { + supports = &ModelInfoSupports{} + } + + if !supports.Media { + for _, msg := range input.Messages { + for _, part := range msg.Content { + if part.IsMedia() { + return nil, fmt.Errorf("model %q does not support media, but media was provided. Request: %+v", name, input) + } + } + } + } + + if !supports.Tools && len(input.Tools) > 0 { + return nil, fmt.Errorf("model %q does not support tool use, but tools were provided. Request: %+v", name, input) + } + + if !supports.Multiturn && len(input.Messages) > 1 { + return nil, fmt.Errorf("model %q does not support multiple messages, but %d were provided. Request: %+v", name, len(input.Messages), input) + } + + return next(ctx, input, cb) + } +} diff --git a/go/ai/middleware_test.go b/go/ai/middleware_test.go new file mode 100644 index 000000000..d3c1fd2a5 --- /dev/null +++ b/go/ai/middleware_test.go @@ -0,0 +1,134 @@ +package ai + +import ( + "context" + "testing" +) + +func TestValidateSupport(t *testing.T) { + tests := []struct { + name string + supports *ModelInfoSupports + input *ModelRequest + wantErr bool + }{ + { + name: "valid request with no special features", + supports: &ModelInfoSupports{ + Media: false, + Tools: false, + Multiturn: false, + }, + input: &ModelRequest{ + Messages: []*Message{ + {Content: []*Part{NewTextPart("hello")}}, + }, + }, + wantErr: false, + }, + { + name: "media not supported but requested", + supports: &ModelInfoSupports{ + Media: false, + }, + input: &ModelRequest{ + Messages: []*Message{ + {Content: []*Part{NewMediaPart("image/png", "data:image/png;base64,...")}}, + }, + }, + wantErr: true, + }, + { + name: "tools not supported but requested", + supports: &ModelInfoSupports{ + Tools: false, + }, + input: &ModelRequest{ + Tools: []*ToolDefinition{ + { + Name: "test-tool", + Description: "A test tool", + }, + }, + }, + wantErr: true, + }, + { + name: "multiturn not supported but requested", + supports: &ModelInfoSupports{ + Multiturn: false, + }, + input: &ModelRequest{ + Messages: []*Message{ + {Content: []*Part{NewTextPart("message 1")}}, + {Content: []*Part{NewTextPart("message 2")}}, + }, + }, + wantErr: true, + }, + { + name: "all features supported and used", + supports: &ModelInfoSupports{ + Media: true, + Tools: true, + Multiturn: true, + }, + input: &ModelRequest{ + Messages: []*Message{ + {Content: []*Part{NewMediaPart("image/png", "data:image/png;base64,...")}}, + {Content: []*Part{NewTextPart("follow-up message")}}, + }, + Tools: []*ToolDefinition{ + { + Name: "test-tool", + Description: "A test tool", + }, + }, + }, + wantErr: false, + }, + { + name: "nil supports defaults to no features", + supports: nil, + input: &ModelRequest{ + Messages: []*Message{ + {Content: []*Part{NewMediaPart("image/png", "data:image/png;base64,...")}}, + }, + }, + wantErr: true, + }, + { + name: "mixed content types in message", + supports: &ModelInfoSupports{ + Media: false, + }, + input: &ModelRequest{ + Messages: []*Message{ + {Content: []*Part{ + NewTextPart("text content"), + NewMediaPart("image/png", "data:image/png;base64,..."), + }}, + }, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + middleware := ValidateSupport("test-model", tt.supports) + + _, err := middleware(context.Background(), tt.input, nil, + func(ctx context.Context, req *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error) { + return &ModelResponse{}, nil + }) + + if (err != nil) != tt.wantErr { + t.Errorf("ValidateSupport() error = %v, wantErr %v", err, tt.wantErr) + if err != nil { + t.Logf("Error message: %v", err) + } + } + }) + } +} diff --git a/go/ai/prompt.go b/go/ai/prompt.go index b36cdc8ff..41c97a900 100644 --- a/go/ai/prompt.go +++ b/go/ai/prompt.go @@ -1,7 +1,6 @@ // Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 - package ai import ( @@ -30,7 +29,7 @@ func DefinePrompt(r *registry.Registry, provider, name string, metadata map[stri mm = make(map[string]any) } mm["type"] = "prompt" - return (*Prompt)(core.DefineActionWithInputSchema(r, provider, name, atype.Prompt, mm, inputSchema, render)) + return (*Prompt)(core.DefineActionWithInputSchema(r, provider, name, atype.Prompt, mm, inputSchema, nil, render)) } // IsDefinedPrompt reports whether a [Prompt] is defined. diff --git a/go/ai/retriever.go b/go/ai/retriever.go index 141cecdb1..985d7d4d9 100644 --- a/go/ai/retriever.go +++ b/go/ai/retriever.go @@ -1,7 +1,6 @@ // Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 - package ai import ( @@ -62,7 +61,7 @@ func DefineIndexer(r *registry.Registry, provider, name string, index func(conte f := func(ctx context.Context, req *IndexerRequest) (struct{}, error) { return struct{}{}, index(ctx, req) } - return (*indexerActionDef)(core.DefineAction(r, provider, name, atype.Indexer, nil, f)) + return (*indexerActionDef)(core.DefineAction(r, provider, name, atype.Indexer, nil, nil, f)) } // IsDefinedIndexer reports whether an [Indexer] is defined. @@ -79,7 +78,7 @@ func LookupIndexer(r *registry.Registry, provider, name string) Indexer { // DefineRetriever registers the given retrieve function as an action, and returns a // [Retriever] that runs it. func DefineRetriever(r *registry.Registry, provider, name string, ret func(context.Context, *RetrieverRequest) (*RetrieverResponse, error)) *retrieverActionDef { - return (*retrieverActionDef)(core.DefineAction(r, provider, name, atype.Retriever, nil, ret)) + return (*retrieverActionDef)(core.DefineAction(r, provider, name, atype.Retriever, nil, nil, ret)) } // IsDefinedRetriever reports whether a [Retriever] is defined. diff --git a/go/ai/tools.go b/go/ai/tools.go index 912a2af71..72f34d3ab 100644 --- a/go/ai/tools.go +++ b/go/ai/tools.go @@ -80,7 +80,7 @@ func DefineTool[In, Out any](r *registry.Registry, name, description string, return fn(toolCtx, input) } - toolAction := core.DefineAction(r, provider, name, atype.Tool, metadata, wrappedFn) + toolAction := core.DefineAction(r, provider, name, atype.Tool, metadata, nil, wrappedFn) return &ToolDef[In, Out]{ action: toolAction, } diff --git a/go/core/action.go b/go/core/action.go index 89d360a82..5c9136308 100644 --- a/go/core/action.go +++ b/go/core/action.go @@ -27,9 +27,11 @@ import ( // stream the results by invoking the callback periodically, ultimately returning // with a final return value. Otherwise, it should ignore the StreamingCallback and // just return a result. -type Func[In, Out, Stream any] func(context.Context, In, func(context.Context, Stream) error) (Out, error) +type Func[In, Out, Stream any] = func(context.Context, In, func(context.Context, Stream) error) (Out, error) -// TODO: use a generic type alias for the above when they become available? +// Middleware is a function that wraps an action execution, similar to HTTP middleware. +// It can modify the input, output, and context, or perform side effects. +type Middleware[In, Out, Stream any] = func(ctx context.Context, input In, cb func(context.Context, Stream) error, next Func[In, Out, Stream]) (Out, error) // An Action is a named, observable operation. // It consists of a function that takes an input of type I and returns an output @@ -42,6 +44,7 @@ type Action[In, Out, Stream any] struct { name string atype atype.ActionType fn Func[In, Out, Stream] + middleware []Middleware[In, Out, Stream] tstate *tracing.State inputSchema *jsonschema.Schema outputSchema *jsonschema.Schema @@ -60,9 +63,10 @@ func DefineAction[In, Out any]( provider, name string, atype atype.ActionType, metadata map[string]any, + mw []Middleware[In, Out, struct{}], fn func(context.Context, In) (Out, error), ) *Action[In, Out, struct{}] { - return defineAction(r, provider, name, atype, metadata, nil, + return defineAction(r, provider, name, atype, metadata, nil, mw, func(ctx context.Context, in In, _ noStream) (Out, error) { return fn(ctx, in) }) @@ -74,9 +78,10 @@ func DefineStreamingAction[In, Out, Stream any]( provider, name string, atype atype.ActionType, metadata map[string]any, + mw []Middleware[In, Out, Stream], fn Func[In, Out, Stream], ) *Action[In, Out, Stream] { - return defineAction(r, provider, name, atype, metadata, nil, fn) + return defineAction(r, provider, name, atype, metadata, nil, mw, fn) } // DefineCustomAction defines a streaming action with type Custom. @@ -84,9 +89,10 @@ func DefineCustomAction[In, Out, Stream any]( r *registry.Registry, provider, name string, metadata map[string]any, + mw []Middleware[In, Out, Stream], fn Func[In, Out, Stream], ) *Action[In, Out, Stream] { - return DefineStreamingAction(r, provider, name, atype.Custom, metadata, fn) + return DefineStreamingAction(r, provider, name, atype.Custom, metadata, mw, fn) } // DefineActionWithInputSchema creates a new Action and registers it. @@ -99,9 +105,10 @@ func DefineActionWithInputSchema[Out any]( atype atype.ActionType, metadata map[string]any, inputSchema *jsonschema.Schema, + mw []Middleware[any, Out, struct{}], fn func(context.Context, any) (Out, error), ) *Action[any, Out, struct{}] { - return defineAction(r, provider, name, atype, metadata, inputSchema, + return defineAction(r, provider, name, atype, metadata, inputSchema, mw, func(ctx context.Context, in any, _ noStream) (Out, error) { return fn(ctx, in) }) @@ -114,13 +121,14 @@ func defineAction[In, Out, Stream any]( atype atype.ActionType, metadata map[string]any, inputSchema *jsonschema.Schema, + mw []Middleware[In, Out, Stream], fn Func[In, Out, Stream], ) *Action[In, Out, Stream] { fullName := name if provider != "" { fullName = provider + "/" + name } - a := newAction(fullName, atype, metadata, inputSchema, fn) + a := newAction(fullName, atype, metadata, inputSchema, mw, fn) r.RegisterAction(atype, a) return a } @@ -132,6 +140,7 @@ func newAction[In, Out, Stream any]( atype atype.ActionType, metadata map[string]any, inputSchema *jsonschema.Schema, + mw []Middleware[In, Out, Stream], fn Func[In, Out, Stream], ) *Action[In, Out, Stream] { var i In @@ -155,6 +164,7 @@ func newAction[In, Out, Stream any]( inputSchema: inputSchema, outputSchema: outputSchema, metadata: metadata, + middleware: mw, } } @@ -164,6 +174,12 @@ func (a *Action[In, Out, Stream]) Name() string { return a.name } // setTracingState sets the action's tracing.State. func (a *Action[In, Out, Stream]) SetTracingState(tstate *tracing.State) { a.tstate = tstate } +// Use adds middleware to the action with type inference +func (a *Action[In, Out, Stream]) Use(middleware ...Middleware[In, Out, Stream]) *Action[In, Out, Stream] { + a.middleware = append(a.middleware, middleware...) + return a +} + // Run executes the Action's function in a new trace span. func (a *Action[In, Out, Stream]) Run(ctx context.Context, input In, cb func(context.Context, Stream) error) (output Out, err error) { logger.FromContext(ctx).Debug("Action.Run", @@ -175,28 +191,45 @@ func (a *Action[In, Out, Stream]) Run(ctx context.Context, input In, cb func(con "output", fmt.Sprintf("%#v", output), "err", err) }() + return tracing.RunInNewSpan(ctx, a.tstate, a.name, "action", false, input, func(ctx context.Context, input In) (Out, error) { start := time.Now() - var err error - if err = base.ValidateValue(input, a.inputSchema); err != nil { - err = fmt.Errorf("invalid input: %w", err) + + if err := base.ValidateValue(input, a.inputSchema); err != nil { + return base.Zero[Out](), fmt.Errorf("invalid input: %w", err) } - var output Out - if err == nil { - output, err = a.fn(ctx, input, cb) - if err == nil { - if err = base.ValidateValue(output, a.outputSchema); err != nil { - err = fmt.Errorf("invalid output: %w", err) - } + + handler := func(ctx context.Context, in In, cb func(context.Context, Stream) error) (Out, error) { + out, err := a.fn(ctx, in, cb) + if err != nil { + return base.Zero[Out](), err + } + + if err := base.ValidateValue(out, a.outputSchema); err != nil { + return base.Zero[Out](), fmt.Errorf("invalid output: %w", err) } + + return out, nil } + + for i := len(a.middleware) - 1; i >= 0; i-- { + currentHandler := handler + currentMiddleware := a.middleware[i] + handler = func(ctx context.Context, in In, cb func(context.Context, Stream) error) (Out, error) { + return currentMiddleware(ctx, in, cb, currentHandler) + } + } + + output, err := handler(ctx, input, cb) + latency := time.Since(start) if err != nil { metrics.WriteActionFailure(ctx, a.name, latency, err) return base.Zero[Out](), err } metrics.WriteActionSuccess(ctx, a.name, latency) + return output, nil }) } @@ -255,7 +288,12 @@ func (a *Action[I, O, S]) Desc() action.Desc { // or nil if there is none. // It panics if the action is of the wrong type. func LookupActionFor[In, Out, Stream any](r *registry.Registry, typ atype.ActionType, provider, name string) *Action[In, Out, Stream] { - key := fmt.Sprintf("/%s/%s/%s", typ, provider, name) + var key string + if provider != "" { + key = fmt.Sprintf("/%s/%s/%s", typ, provider, name) + } else { + key = fmt.Sprintf("/%s/%s", typ, name) + } a := r.LookupAction(key) if a == nil { return nil diff --git a/go/core/action_test.go b/go/core/action_test.go index 694ff7609..421c9dc56 100644 --- a/go/core/action_test.go +++ b/go/core/action_test.go @@ -1,7 +1,6 @@ // Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 - package core import ( @@ -25,7 +24,7 @@ func TestActionRun(t *testing.T) { if err != nil { t.Fatal(err) } - a := defineAction(r, "test", "inc", atype.Custom, nil, nil, inc) + a := defineAction(r, "test", "inc", atype.Custom, nil, nil, nil, inc) got, err := a.Run(context.Background(), 3, nil) if err != nil { t.Fatal(err) @@ -40,7 +39,7 @@ func TestActionRunJSON(t *testing.T) { if err != nil { t.Fatal(err) } - a := defineAction(r, "test", "inc", atype.Custom, nil, nil, inc) + a := defineAction(r, "test", "inc", atype.Custom, nil, nil, nil, inc) input := []byte("3") want := []byte("4") got, err := a.RunJSON(context.Background(), input, nil) @@ -70,7 +69,7 @@ func TestActionStreaming(t *testing.T) { if err != nil { t.Fatal(err) } - a := defineAction(r, "test", "count", atype.Custom, nil, nil, count) + a := defineAction(r, "test", "count", atype.Custom, nil, nil, nil, count) const n = 3 // Non-streaming. @@ -109,7 +108,7 @@ func TestActionTracing(t *testing.T) { tc := tracing.NewTestOnlyTelemetryClient() r.TracingState().WriteTelemetryImmediate(tc) const actionName = "TestTracing-inc" - a := defineAction(r, provider, actionName, atype.Custom, nil, nil, inc) + a := defineAction(r, provider, actionName, atype.Custom, nil, nil, nil, inc) if _, err := a.Run(context.Background(), 3, nil); err != nil { t.Fatal(err) } @@ -127,3 +126,158 @@ func TestActionTracing(t *testing.T) { } t.Fatalf("did not find trace named %q", actionName) } + +func countingMiddleware[In, Out, Stream any](counts *int) Middleware[In, Out, Stream] { + return func(ctx context.Context, input In, cb func(context.Context, Stream) error, next Func[In, Out, Stream]) (Out, error) { + *counts++ + return next(ctx, input, cb) + } +} + +func addOneMiddleware[Stream any](ctx context.Context, input int, cb func(context.Context, Stream) error, next Func[int, int, Stream]) (int, error) { + return next(ctx, input+1, cb) +} + +func multiplyOutputMiddleware[Stream any](ctx context.Context, input int, cb func(context.Context, Stream) error, next Func[int, int, Stream]) (int, error) { + out, err := next(ctx, input, cb) + if err != nil { + return 0, err + } + return out * 2, nil +} + +func doubleStreamMiddleware(ctx context.Context, input int, cb func(context.Context, int) error, next Func[int, int, int]) (int, error) { + wrappedCb := func(ctx context.Context, val int) error { + return cb(ctx, val*2) + } + return next(ctx, input, wrappedCb) +} + +func TestMiddleware(t *testing.T) { + ctx := context.Background() + + t.Run("Single middleware execution count", func(t *testing.T) { + r, err := registry.New() + if err != nil { + t.Fatal(err) + } + + counts := 0 + middleware := []Middleware[int, int, struct{}]{countingMiddleware[int, int, struct{}](&counts)} + a := defineAction(r, "test", "inc", atype.Custom, nil, nil, middleware, inc) + + _, err = a.Run(ctx, 3, nil) + if err != nil { + t.Fatal(err) + } + + if counts != 1 { + t.Errorf("middleware execution count: got %d, want 1", counts) + } + }) + + t.Run("Multiple middleware order", func(t *testing.T) { + r, err := registry.New() + if err != nil { + t.Fatal(err) + } + + middleware := []Middleware[int, int, struct{}]{ + addOneMiddleware[struct{}], + multiplyOutputMiddleware[struct{}], + } + a := defineAction(r, "test", "inc", atype.Custom, nil, nil, middleware, inc) + + got, err := a.Run(ctx, 3, nil) + if err != nil { + t.Fatal(err) + } + + want := 10 + if got != want { + t.Errorf("got %d, want %d", got, want) + } + }) + + t.Run("Streaming middleware", func(t *testing.T) { + r, err := registry.New() + if err != nil { + t.Fatal(err) + } + + middleware := []Middleware[int, int, int]{doubleStreamMiddleware} + a := defineAction(r, "test", "count", atype.Custom, nil, nil, middleware, count) + + var gotStreamed []int + got, err := a.Run(ctx, 3, func(_ context.Context, i int) error { + gotStreamed = append(gotStreamed, i) + return nil + }) + if err != nil { + t.Fatal(err) + } + + wantStreamed := []int{0, 2, 4} + if !slices.Equal(gotStreamed, wantStreamed) { + t.Errorf("got streamed values %v, want %v", gotStreamed, wantStreamed) + } + if got != 3 { + t.Errorf("got final value %d, want 3", got) + } + }) + + t.Run("Error handling in middleware", func(t *testing.T) { + r, err := registry.New() + if err != nil { + t.Fatal(err) + } + + expectedErr := fmt.Errorf("middleware error") + errorMiddleware := func(ctx context.Context, input int, cb func(context.Context, struct{}) error, next Func[int, int, struct{}]) (int, error) { + return 0, expectedErr + } + + middleware := []Middleware[int, int, struct{}]{errorMiddleware} + a := defineAction(r, "test", "inc", atype.Custom, nil, nil, middleware, inc) + + _, err = a.Run(ctx, 3, nil) + if err != expectedErr { + t.Errorf("got error %v, want %v", err, expectedErr) + } + }) + + t.Run("Context modification in middleware", func(t *testing.T) { + r, err := registry.New() + if err != nil { + t.Fatal(err) + } + + key := "test_key" + value := "test_value" + var gotValue string + + contextMiddleware := func(ctx context.Context, input int, cb func(context.Context, struct{}) error, next Func[int, int, struct{}]) (int, error) { + newCtx := context.WithValue(ctx, key, value) + return next(newCtx, input, cb) + } + + checkContextMiddleware := func(ctx context.Context, input int, cb func(context.Context, struct{}) error, next Func[int, int, struct{}]) (int, error) { + if v := ctx.Value(key); v != nil { + gotValue = v.(string) + } + return next(ctx, input, cb) + } + + middleware := []Middleware[int, int, struct{}]{contextMiddleware, checkContextMiddleware} + a := defineAction(r, "test", "inc", atype.Custom, nil, nil, middleware, inc) + + _, err = a.Run(ctx, 3, nil) + if err != nil { + t.Fatal(err) + } + + if gotValue != value { + t.Errorf("got context value %q, want %q", gotValue, value) + } + }) +} diff --git a/go/core/schemas.config b/go/core/schemas.config index 2c57159e1..905a2d3e7 100644 --- a/go/core/schemas.config +++ b/go/core/schemas.config @@ -205,7 +205,7 @@ GenerateActionOptionsOutput.jsonSchema type map[string]any # ModelRequest ModelRequest pkg ai ModelRequest.config type any -ModelRequest.context type []any +ModelRequest.context type []*Document ModelRequest.messages type []*Message ModelRequest.output type *ModelRequestOutput ModelRequest.tools type []*ToolDefinition diff --git a/go/genkit/flow.go b/go/genkit/flow.go index b0a5d39fc..c502bd4e4 100644 --- a/go/genkit/flow.go +++ b/go/genkit/flow.go @@ -1,7 +1,6 @@ // Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 - package genkit import ( @@ -225,7 +224,7 @@ func defineFlow[In, Out, Stream any](r *registry.Registry, name string, fn core. } return &result, err } - core.DefineStreamingAction(r, "", f.name, atype.Flow, metadata, afunc) + core.DefineStreamingAction(r, "", f.name, atype.Flow, metadata, nil, afunc) f.tstate = r.TracingState() r.RegisterFlow(f) return f diff --git a/go/genkit/genkit.go b/go/genkit/genkit.go index 73e7c9241..2633ea145 100644 --- a/go/genkit/genkit.go +++ b/go/genkit/genkit.go @@ -153,15 +153,15 @@ func (g *Genkit) Start(ctx context.Context, opts *StartOptions) error { return shutdownServers(servers) } -// DefineModel registers the given generate function as an action, and returns a -// [Model] that runs it. +// DefineModel registers the given generate function as an action, and returns a [Model] that runs it. func DefineModel( g *Genkit, provider, name string, metadata *ai.ModelInfo, + mw []ai.ModelMiddleware, generate func(context.Context, *ai.ModelRequest, ai.ModelStreamingCallback) (*ai.ModelResponse, error), ) ai.Model { - return ai.DefineModel(g.reg, provider, name, metadata, generate) + return ai.DefineModel(g.reg, provider, name, metadata, mw, generate) } // IsDefinedModel reports whether a model is defined. @@ -239,8 +239,8 @@ func GenerateData(ctx context.Context, g *Genkit, value any, opts ...ai.Generate } // GenerateWithRequest runs the model with the given request and streaming callback. -func GenerateWithRequest(ctx context.Context, g *Genkit, m ai.Model, req *ai.ModelRequest, toolCfg *ai.ToolConfig, cb ai.ModelStreamingCallback) (*ai.ModelResponse, error) { - return m.Generate(ctx, g.reg, req, toolCfg, cb) +func GenerateWithRequest(ctx context.Context, g *Genkit, m ai.Model, req *ai.ModelRequest, mw []ai.ModelMiddleware, toolCfg *ai.ToolConfig, cb ai.ModelStreamingCallback) (*ai.ModelResponse, error) { + return m.Generate(ctx, g.reg, req, mw, toolCfg, cb) } // DefineIndexer registers the given index function as an action, and returns an diff --git a/go/genkit/servers_test.go b/go/genkit/servers_test.go index 6ee102013..cc1a5f547 100644 --- a/go/genkit/servers_test.go +++ b/go/genkit/servers_test.go @@ -1,7 +1,6 @@ // Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 - package genkit import ( @@ -40,12 +39,9 @@ func TestDevServer(t *testing.T) { tc := tracing.NewTestOnlyTelemetryClient() r.TracingState().WriteTelemetryImmediate(tc) - core.DefineAction(r, "devServer", "inc", atype.Custom, map[string]any{ - "foo": "bar", - }, inc) - core.DefineAction(r, "devServer", "dec", atype.Custom, map[string]any{ - "bar": "baz", - }, dec) + core.DefineAction(r, "devServer", "inc", atype.Custom, map[string]any{"foo": "bar"}, nil, inc) + core.DefineAction(r, "devServer", "dec", atype.Custom, map[string]any{"bar": "baz"}, nil, dec) + srv := httptest.NewServer(newDevServeMux(&devServer{reg: r})) defer srv.Close() diff --git a/go/go.mod b/go/go.mod index 9e1fb4f0e..7ff4c28bc 100644 --- a/go/go.mod +++ b/go/go.mod @@ -1,6 +1,6 @@ module github.com/firebase/genkit/go -go 1.22.0 +go 1.24.0 retract ( v0.1.4 // Retraction only. diff --git a/go/plugins/dotprompt/dotprompt_test.go b/go/plugins/dotprompt/dotprompt_test.go index 44a45d8a5..b4493871a 100644 --- a/go/plugins/dotprompt/dotprompt_test.go +++ b/go/plugins/dotprompt/dotprompt_test.go @@ -1,7 +1,6 @@ // Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 - package dotprompt import ( @@ -33,7 +32,7 @@ var g, _ = genkit.New(&genkit.Options{ PromptDir: "testdata", }) -var testModel = genkit.DefineModel(g, "defineoptions", "test", nil, testGenerate) +var testModel = genkit.DefineModel(g, "defineoptions", "test", nil, nil, testGenerate) func TestPrompts(t *testing.T) { g, err := genkit.New(&genkit.Options{ diff --git a/go/plugins/dotprompt/genkit.go b/go/plugins/dotprompt/genkit.go index eec20b4ba..4669838bf 100644 --- a/go/plugins/dotprompt/genkit.go +++ b/go/plugins/dotprompt/genkit.go @@ -24,8 +24,8 @@ type PromptRequest struct { Input any `json:"input,omitempty"` // Model configuration. If nil will be taken from the prompt config. Config *ai.GenerationCommonConfig `json:"config,omitempty"` - // Context to pass to model, if any. - Context []any `json:"context,omitempty"` + // Documents to pass to the model as context, if any. + Context []*ai.Document `json:"context,omitempty"` // The model to use. This overrides any model specified by the prompt. Model ai.Model `json:"model,omitempty"` // The name of the model to use. This overrides any model specified by the prompt. @@ -40,6 +40,8 @@ type PromptRequest struct { IsReturnToolRequestsSet bool `json:"-"` // Whether tool calls are required, disabled, or optional for the prompt. ToolChoice ai.ToolChoice `json:"toolChoice,omitempty"` + // Middleware to apply to the prompt. + Middleware []ai.ModelMiddleware `json:"-"` } // GenerateOption configures params for Generate function @@ -259,7 +261,7 @@ func (p *Prompt) Generate(ctx context.Context, g *genkit.Genkit, opts ...Generat ReturnToolRequests: returnToolRequests, } - resp, err := genkit.GenerateWithRequest(ctx, g, model, mr, toolCfg, pr.Stream) + resp, err := genkit.GenerateWithRequest(ctx, g, model, mr, pr.Middleware, toolCfg, pr.Stream) if err != nil { return nil, err } @@ -319,13 +321,13 @@ func WithConfig(config *ai.GenerationCommonConfig) GenerateOption { } } -// WithContext add context to pass to model, if any. -func WithContext(context []any) GenerateOption { +// WithContext adds documents to pass to model as context, if any. +func WithContext(docs ...*ai.Document) GenerateOption { return func(p *PromptRequest) error { if p.Context != nil { return errors.New("dotprompt.WithContext: cannot set Context more than once") } - p.Context = context + p.Context = docs return nil } } @@ -399,3 +401,14 @@ func WithToolChoice(toolChoice ai.ToolChoice) GenerateOption { return nil } } + +// WithMiddleware adds middleware to the prompt request. +func WithMiddleware(middleware ...ai.ModelMiddleware) GenerateOption { + return func(p *PromptRequest) error { + if p.Middleware != nil { + return errors.New("dotprompt.WithMiddleware: cannot set Middleware more than once") + } + p.Middleware = middleware + return nil + } +} diff --git a/go/plugins/dotprompt/genkit_test.go b/go/plugins/dotprompt/genkit_test.go index f6aa2488b..1fd2165a4 100644 --- a/go/plugins/dotprompt/genkit_test.go +++ b/go/plugins/dotprompt/genkit_test.go @@ -1,7 +1,6 @@ // Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 - package dotprompt import ( @@ -45,7 +44,7 @@ func TestExecute(t *testing.T) { if err != nil { log.Fatal(err) } - testModel := genkit.DefineModel(g, "test", "test", nil, testGenerate) + testModel := genkit.DefineModel(g, "test", "test", nil, nil, testGenerate) t.Run("Model", func(t *testing.T) { p, err := New("TestExecute", "TestExecute", Config{Model: testModel}) if err != nil { @@ -100,7 +99,7 @@ func TestOptionsPatternGenerate(t *testing.T) { if err != nil { log.Fatal(err) } - testModel := genkit.DefineModel(g, "options", "test", nil, testGenerate) + testModel := genkit.DefineModel(g, "options", "test", nil, nil, testGenerate) t.Run("Streaming", func(t *testing.T) { p, err := Define(g, "TestExecute", "TestExecute", WithInputType(InputOutput{})) @@ -120,7 +119,7 @@ func TestOptionsPatternGenerate(t *testing.T) { return nil }), WithModel(testModel), - WithContext([]any{"context"}), + WithContext(&ai.Document{Content: []*ai.Part{ai.NewTextPart("context")}}), ) if err != nil { t.Fatal(err) @@ -178,7 +177,7 @@ func TestGenerateOptions(t *testing.T) { }, { name: "WithContext", - with: WithContext([]any{"context"}), + with: WithContext(&ai.Document{Content: []*ai.Part{ai.NewTextPart("context")}}), }, { name: "WithModelName", diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index aed57b86f..139ceb07c 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -160,7 +160,7 @@ func defineModel(g *genkit.Genkit, name string, info ai.ModelInfo) ai.Model { Supports: info.Supports, Versions: info.Versions, } - return genkit.DefineModel(g, provider, name, meta, func( + return genkit.DefineModel(g, provider, name, meta, nil, func( ctx context.Context, input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error, diff --git a/go/plugins/ollama/ollama.go b/go/plugins/ollama/ollama.go index ed0f9b093..e7cfd1fff 100644 --- a/go/plugins/ollama/ollama.go +++ b/go/plugins/ollama/ollama.go @@ -66,7 +66,7 @@ func DefineModel(g *genkit.Genkit, model ModelDefinition, info *ai.ModelInfo) ai Versions: []string{}, } gen := &generator{model: model, serverAddress: state.serverAddress} - return genkit.DefineModel(g, provider, model.Name, meta, gen.generate) + return genkit.DefineModel(g, provider, model.Name, meta, nil, gen.generate) } // IsDefinedModel reports whether a model is defined. diff --git a/go/plugins/ollama/ollama_live_test.go b/go/plugins/ollama/ollama_live_test.go index a7e11fd50..d87933a26 100644 --- a/go/plugins/ollama/ollama_live_test.go +++ b/go/plugins/ollama/ollama_live_test.go @@ -1,7 +1,6 @@ // Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 - package ollama_test import ( @@ -57,7 +56,7 @@ func TestLive(t *testing.T) { ai.NewModelRequest( &ai.GenerationCommonConfig{Temperature: 1}, ai.NewUserTextMessage("I'm hungry, what should I eat?")), - nil, nil) + nil, nil, nil) if err != nil { t.Fatalf("failed to generate response: %s", err) } diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index 15580566a..490fe1d06 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -182,7 +182,7 @@ func defineModel(g *genkit.Genkit, name string, info ai.ModelInfo) ai.Model { Supports: info.Supports, Versions: info.Versions, } - return genkit.DefineModel(g, provider, name, meta, func( + return genkit.DefineModel(g, provider, name, meta, nil, func( ctx context.Context, input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error, diff --git a/go/tests/test_app/main.go b/go/tests/test_app/main.go index 55d13a7f9..379fe0397 100644 --- a/go/tests/test_app/main.go +++ b/go/tests/test_app/main.go @@ -1,7 +1,6 @@ // Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 - // This program doesn't do anything interesting. // It is used by go/tests/api_test.go. package main @@ -30,7 +29,7 @@ func main() { if err != nil { log.Fatal(err) } - model := genkit.DefineModel(g, "", "customReflector", nil, echo) + model := genkit.DefineModel(g, "", "customReflector", nil, nil, echo) genkit.DefineFlow(g, "testFlow", func(ctx context.Context, in string) (string, error) { res, err := genkit.Generate(ctx, g, ai.WithModel(model), ai.WithTextPrompt(in)) if err != nil { From a42eb37cda11be61d5738788e4623eccb6b6672a Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 12 Feb 2025 07:50:45 -0800 Subject: [PATCH 2/3] Fixed docs. --- go/ai/middleware.go | 2 ++ go/internal/doc-snippets/modelplugin/modelplugin.go | 5 ++++- go/internal/doc-snippets/prompts.go | 3 +-- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/go/ai/middleware.go b/go/ai/middleware.go index 4940bb1a4..d991c1c65 100644 --- a/go/ai/middleware.go +++ b/go/ai/middleware.go @@ -32,6 +32,8 @@ func ValidateSupport(name string, supports *ModelInfoSupports) ModelMiddleware { return nil, fmt.Errorf("model %q does not support multiple messages, but %d were provided. Request: %+v", name, len(input.Messages), input) } + // TODO: Add validation for features that won't have simulated support via middleware. + return next(ctx, input, cb) } } diff --git a/go/internal/doc-snippets/modelplugin/modelplugin.go b/go/internal/doc-snippets/modelplugin/modelplugin.go index 669ef9f8b..851400b73 100644 --- a/go/internal/doc-snippets/modelplugin/modelplugin.go +++ b/go/internal/doc-snippets/modelplugin/modelplugin.go @@ -40,7 +40,10 @@ func Init() error { Media: false, // Can the model accept media input? Tools: false, // Does the model support function calling (tools)? }, - Versions: []string{}, + Versions: []string{"my-model-001", "..."}, + }, + []ai.ModelMiddleware{ + // Add any middleware you want to apply to the model here. }, func(ctx context.Context, genRequest *ai.ModelRequest, diff --git a/go/internal/doc-snippets/prompts.go b/go/internal/doc-snippets/prompts.go index d39636599..fe32b8191 100644 --- a/go/internal/doc-snippets/prompts.go +++ b/go/internal/doc-snippets/prompts.go @@ -1,7 +1,6 @@ // Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 - package snippets import ( @@ -95,7 +94,7 @@ func pr03() error { if err != nil { return err } - response, err := genkit.GenerateWithRequest(context.Background(), g, model, request, nil, nil) + response, err := genkit.GenerateWithRequest(context.Background(), g, model, request, nil, nil, nil) // [END pr03_2] _ = response From 71bcd70693efc1cc0b79c4e3833617413802d61f Mon Sep 17 00:00:00 2001 From: Alex Pascal Date: Wed, 12 Feb 2025 08:23:32 -0800 Subject: [PATCH 3/3] Formatting. --- go/ai/middleware.go | 14 ++++++++++++++ go/ai/middleware_test.go | 14 ++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/go/ai/middleware.go b/go/ai/middleware.go index d991c1c65..e7666dedd 100644 --- a/go/ai/middleware.go +++ b/go/ai/middleware.go @@ -1,3 +1,17 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package ai import ( diff --git a/go/ai/middleware_test.go b/go/ai/middleware_test.go index d3c1fd2a5..9b4504b7e 100644 --- a/go/ai/middleware_test.go +++ b/go/ai/middleware_test.go @@ -1,3 +1,17 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package ai import (