Skip to content

Commit

Permalink
breaking(go): Refactored into instance-based Genkit and Registry. (
Browse files Browse the repository at this point in the history
  • Loading branch information
apascal07 authored Jan 9, 2025
1 parent 2845d32 commit 38baf30
Show file tree
Hide file tree
Showing 62 changed files with 1,045 additions and 464 deletions.
17 changes: 11 additions & 6 deletions go/ai/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/internal/atype"
"github.com/firebase/genkit/go/internal/registry"
)

// Embedder represents an embedder that can perform content embedding.
Expand Down Expand Up @@ -56,19 +57,23 @@ type DocumentEmbedding struct {

// DefineEmbedder registers the given embed function as an action, and returns an
// [Embedder] that runs it.
func DefineEmbedder(provider, name string, embed func(context.Context, *EmbedRequest) (*EmbedResponse, error)) Embedder {
return (*embedderActionDef)(core.DefineAction(provider, name, atype.Embedder, nil, embed))
func DefineEmbedder(
r *registry.Registry,
provider, name string,
embed func(context.Context, *EmbedRequest) (*EmbedResponse, error),
) Embedder {
return (*embedderActionDef)(core.DefineAction(r, provider, name, atype.Embedder, nil, embed))
}

// IsDefinedEmbedder reports whether an embedder is defined.
func IsDefinedEmbedder(provider, name string) bool {
return LookupEmbedder(provider, name) != nil
func IsDefinedEmbedder(r *registry.Registry, provider, name string) bool {
return LookupEmbedder(r, provider, name) != nil
}

// LookupEmbedder looks up an [Embedder] registered by [DefineEmbedder].
// It returns nil if the embedder was not defined.
func LookupEmbedder(provider, name string) Embedder {
action := core.LookupActionFor[*EmbedRequest, *EmbedResponse, struct{}](atype.Embedder, provider, name)
func LookupEmbedder(r *registry.Registry, provider, name string) Embedder {
action := core.LookupActionFor[*EmbedRequest, *EmbedResponse, struct{}](r, atype.Embedder, provider, name)
if action == nil {
return nil
}
Expand Down
55 changes: 38 additions & 17 deletions go/ai/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ import (
"github.com/firebase/genkit/go/core/logger"
"github.com/firebase/genkit/go/internal/atype"
"github.com/firebase/genkit/go/internal/base"
"github.com/firebase/genkit/go/internal/registry"
)

// Model represents a model that can perform content generation tasks.
type Model interface {
// Name returns the registry name of the model.
Name() string
// Generate applies the [Model] to provided request, handling tool requests and handles streaming.
Generate(ctx context.Context, req *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error)
Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error)
}

type modelActionDef core.Action[*ModelRequest, *ModelResponse, *ModelResponseChunk]
Expand All @@ -60,7 +61,12 @@ type ModelMetadata struct {

// DefineModel registers the given generate function as an action, and returns a
// [Model] that runs it.
func DefineModel(provider, name string, metadata *ModelMetadata, generate func(context.Context, *ModelRequest, ModelStreamingCallback) (*ModelResponse, error)) Model {
func DefineModel(
r *registry.Registry,
provider, name string,
metadata *ModelMetadata,
generate func(context.Context, *ModelRequest, ModelStreamingCallback) (*ModelResponse, error),
) Model {
metadataMap := map[string]any{}
if metadata == nil {
// Always make sure there's at least minimal metadata.
Expand All @@ -79,20 +85,20 @@ func DefineModel(provider, name string, metadata *ModelMetadata, generate func(c
}
metadataMap["supports"] = supports

return (*modelActionDef)(core.DefineStreamingAction(provider, name, atype.Model, map[string]any{
return (*modelActionDef)(core.DefineStreamingAction(r, provider, name, atype.Model, map[string]any{
"model": metadataMap,
}, generate))
}

// IsDefinedModel reports whether a model is defined.
func IsDefinedModel(provider, name string) bool {
return core.LookupActionFor[*ModelRequest, *ModelResponse, *ModelResponseChunk](atype.Model, provider, name) != nil
func IsDefinedModel(r *registry.Registry, provider, name string) bool {
return core.LookupActionFor[*ModelRequest, *ModelResponse, *ModelResponseChunk](r, atype.Model, provider, name) != nil
}

// LookupModel looks up a [Model] registered by [DefineModel].
// It returns nil if the model was not defined.
func LookupModel(provider, name string) Model {
action := core.LookupActionFor[*ModelRequest, *ModelResponse, *ModelResponseChunk](atype.Model, provider, name)
func LookupModel(r *registry.Registry, provider, name string) Model {
action := core.LookupActionFor[*ModelRequest, *ModelResponse, *ModelResponseChunk](r, atype.Model, provider, name)
if action == nil {
return nil
}
Expand All @@ -102,6 +108,7 @@ func LookupModel(provider, name string) Model {
// generateParams represents various params of the Generate call.
type generateParams struct {
Request *ModelRequest
Model Model
Stream ModelStreamingCallback
History []*Message
SystemPrompt *Message
Expand All @@ -110,6 +117,14 @@ type generateParams struct {
// GenerateOption configures params of the Generate call.
type GenerateOption func(req *generateParams) error

// WithModel sets the model to use for the generate request.
func WithModel(m Model) GenerateOption {
return func(req *generateParams) error {
req.Model = m
return nil
}
}

// WithTextPrompt adds a simple text user prompt to ModelRequest.
func WithTextPrompt(prompt string) GenerateOption {
return func(req *generateParams) error {
Expand Down Expand Up @@ -174,6 +189,9 @@ func WithContext(c ...any) GenerateOption {
// WithTools adds provided tools to ModelRequest.
func WithTools(tools ...Tool) GenerateOption {
return func(req *generateParams) error {
if req.Request.Tools != nil {
return errors.New("cannot set Request.Tools (WithTools) more than once")
}
var toolDefs []*ToolDefinition
for _, t := range tools {
toolDefs = append(toolDefs, t.Definition())
Expand Down Expand Up @@ -221,7 +239,7 @@ func WithStreaming(cb ModelStreamingCallback) GenerateOption {
}

// Generate run generate request for this model. Returns ModelResponse struct.
func Generate(ctx context.Context, m Model, opts ...GenerateOption) (*ModelResponse, error) {
func Generate(ctx context.Context, r *registry.Registry, opts ...GenerateOption) (*ModelResponse, error) {
req := &generateParams{
Request: &ModelRequest{},
}
Expand All @@ -231,6 +249,9 @@ func Generate(ctx context.Context, m Model, opts ...GenerateOption) (*ModelRespo
return nil, err
}
}
if req.Model == nil {
return nil, errors.New("model is required")
}
if req.History != nil {
prev := req.Request.Messages
req.Request.Messages = req.History
Expand All @@ -242,12 +263,12 @@ func Generate(ctx context.Context, m Model, opts ...GenerateOption) (*ModelRespo
req.Request.Messages = append(req.Request.Messages, prev...)
}

return m.Generate(ctx, req.Request, req.Stream)
return req.Model.Generate(ctx, r, req.Request, req.Stream)
}

// GenerateText run generate request for this model. Returns generated text only.
func GenerateText(ctx context.Context, m Model, opts ...GenerateOption) (string, error) {
res, err := Generate(ctx, m, opts...)
func GenerateText(ctx context.Context, r *registry.Registry, opts ...GenerateOption) (string, error) {
res, err := Generate(ctx, r, opts...)
if err != nil {
return "", err
}
Expand All @@ -257,9 +278,9 @@ func GenerateText(ctx context.Context, m Model, opts ...GenerateOption) (string,

// Generate run generate request for this model. Returns ModelResponse struct.
// TODO: Stream GenerateData with partial JSON
func GenerateData(ctx context.Context, m Model, value any, opts ...GenerateOption) (*ModelResponse, error) {
func GenerateData(ctx context.Context, r *registry.Registry, value any, opts ...GenerateOption) (*ModelResponse, error) {
opts = append(opts, WithOutputSchema(value))
resp, err := Generate(ctx, m, opts...)
resp, err := Generate(ctx, r, opts...)
if err != nil {
return nil, err
}
Expand All @@ -271,7 +292,7 @@ func GenerateData(ctx context.Context, m Model, value any, opts ...GenerateOptio
}

// Generate applies the [Action] to provided request, handling tool requests and handles streaming.
func (m *modelActionDef) Generate(ctx context.Context, req *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error) {
func (m *modelActionDef) Generate(ctx context.Context, r *registry.Registry, req *ModelRequest, cb ModelStreamingCallback) (*ModelResponse, error) {
if m == nil {
return nil, errors.New("Generate called on a nil Model; check that all models are defined")
}
Expand All @@ -292,7 +313,7 @@ func (m *modelActionDef) Generate(ctx context.Context, req *ModelRequest, cb Mod
}
resp.Message = msg

newReq, err := handleToolRequest(ctx, req, resp)
newReq, err := handleToolRequest(ctx, r, req, resp)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -362,7 +383,7 @@ func validMessage(m *Message, output *ModelRequestOutput) (*Message, error) {
// handleToolRequest checks if a tool was requested by a model.
// If a tool was requested, this runs the tool and returns an
// updated ModelRequest. If no tool was requested this returns nil.
func handleToolRequest(ctx context.Context, req *ModelRequest, resp *ModelResponse) (*ModelRequest, error) {
func handleToolRequest(ctx context.Context, r *registry.Registry, req *ModelRequest, resp *ModelResponse) (*ModelRequest, error) {
msg := resp.Message
if msg == nil || len(msg.Content) == 0 {
return nil, nil
Expand All @@ -373,7 +394,7 @@ func handleToolRequest(ctx context.Context, req *ModelRequest, resp *ModelRespon
}

toolReq := part.ToolRequest
tool := LookupTool(toolReq.Name)
tool := LookupTool(r, toolReq.Name)
if tool == nil {
return nil, fmt.Errorf("tool %v not found", toolReq.Name)
}
Expand Down
18 changes: 11 additions & 7 deletions go/ai/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"strings"
"testing"

"github.com/firebase/genkit/go/internal/registry"
test_utils "github.com/firebase/genkit/go/tests/utils"
"github.com/google/go-cmp/cmp"
)
Expand All @@ -30,7 +31,9 @@ type GameCharacter struct {
Backstory string
}

var echoModel = DefineModel("test", "echo", nil, func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
var r, _ = registry.New()

var echoModel = DefineModel(r, "test", "echo", nil, func(ctx context.Context, gr *ModelRequest, msc ModelStreamingCallback) (*ModelResponse, error) {
if msc != nil {
msc(ctx, &ModelResponseChunk{
Content: []*Part{NewTextPart("stream!")},
Expand All @@ -49,7 +52,7 @@ var echoModel = DefineModel("test", "echo", nil, func(ctx context.Context, gr *M
})

// with tools
var gablorkenTool = DefineTool("gablorken", "use when need to calculate a gablorken",
var gablorkenTool = DefineTool(r, "gablorken", "use when need to calculate a gablorken",
func(ctx context.Context, input struct {
Value float64
Over float64
Expand Down Expand Up @@ -292,7 +295,8 @@ func TestGenerate(t *testing.T) {

wantStreamText := "stream!"
streamText := ""
res, err := Generate(context.Background(), echoModel,
res, err := Generate(context.Background(), r,
WithModel(echoModel),
WithTextPrompt(charJSONmd),
WithMessages(NewModelTextMessage("banana again")),
WithSystemPrompt("you are"),
Expand Down Expand Up @@ -328,25 +332,25 @@ func TestGenerate(t *testing.T) {

func TestIsDefinedModel(t *testing.T) {
t.Run("should return true", func(t *testing.T) {
if IsDefinedModel("test", "echo") != true {
if IsDefinedModel(r, "test", "echo") != true {
t.Errorf("IsDefinedModel did not return true")
}
})
t.Run("should return false", func(t *testing.T) {
if IsDefinedModel("foo", "bar") != false {
if IsDefinedModel(r, "foo", "bar") != false {
t.Errorf("IsDefinedModel did not return false")
}
})
}

func TestLookupModel(t *testing.T) {
t.Run("should return model", func(t *testing.T) {
if LookupModel("test", "echo") == nil {
if LookupModel(r, "test", "echo") == nil {
t.Errorf("LookupModel did not return model")
}
})
t.Run("should return nil", func(t *testing.T) {
if LookupModel("foo", "bar") != nil {
if LookupModel(r, "foo", "bar") != nil {
t.Errorf("LookupModel did not return nil")
}
})
Expand Down
13 changes: 7 additions & 6 deletions go/ai/prompt.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/internal/atype"
"github.com/firebase/genkit/go/internal/registry"
"github.com/invopop/jsonschema"
)

Expand All @@ -33,24 +34,24 @@ type Prompt core.Action[any, *ModelRequest, struct{}]
// The prompt expects some input described by inputSchema.
// DefinePrompt registers the function as an action,
// and returns a [Prompt] that runs it.
func DefinePrompt(provider, name string, metadata map[string]any, inputSchema *jsonschema.Schema, render func(context.Context, any) (*ModelRequest, error)) *Prompt {
func DefinePrompt(r *registry.Registry, provider, name string, metadata map[string]any, inputSchema *jsonschema.Schema, render func(context.Context, any) (*ModelRequest, error)) *Prompt {
mm := maps.Clone(metadata)
if mm == nil {
mm = make(map[string]any)
}
mm["type"] = "prompt"
return (*Prompt)(core.DefineActionWithInputSchema(provider, name, atype.Prompt, mm, inputSchema, render))
return (*Prompt)(core.DefineActionWithInputSchema(r, provider, name, atype.Prompt, mm, inputSchema, render))
}

// IsDefinedPrompt reports whether a [Prompt] is defined.
func IsDefinedPrompt(provider, name string) bool {
return LookupPrompt(provider, name) != nil
func IsDefinedPrompt(r *registry.Registry, provider, name string) bool {
return LookupPrompt(r, provider, name) != nil
}

// LookupPrompt looks up a [Prompt] registered by [DefinePrompt].
// It returns nil if the prompt was not defined.
func LookupPrompt(provider, name string) *Prompt {
return (*Prompt)(core.LookupActionFor[any, *ModelRequest, struct{}](atype.Prompt, provider, name))
func LookupPrompt(r *registry.Registry, provider, name string) *Prompt {
return (*Prompt)(core.LookupActionFor[any, *ModelRequest, struct{}](r, atype.Prompt, provider, name))
}

// Render renders the [Prompt] with some input data.
Expand Down
25 changes: 13 additions & 12 deletions go/ai/retriever.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

"github.com/firebase/genkit/go/core"
"github.com/firebase/genkit/go/internal/atype"
"github.com/firebase/genkit/go/internal/registry"
)

// Retriever represents a document retriever.
Expand Down Expand Up @@ -67,39 +68,39 @@ type RetrieverResponse struct {

// DefineIndexer registers the given index function as an action, and returns an
// [Indexer] that runs it.
func DefineIndexer(provider, name string, index func(context.Context, *IndexerRequest) error) Indexer {
func DefineIndexer(r *registry.Registry, provider, name string, index func(context.Context, *IndexerRequest) error) Indexer {
f := func(ctx context.Context, req *IndexerRequest) (struct{}, error) {
return struct{}{}, index(ctx, req)
}
return (*indexerActionDef)(core.DefineAction(provider, name, atype.Indexer, nil, f))
return (*indexerActionDef)(core.DefineAction(r, provider, name, atype.Indexer, nil, f))
}

// IsDefinedIndexer reports whether an [Indexer] is defined.
func IsDefinedIndexer(provider, name string) bool {
return (*indexerActionDef)(core.LookupActionFor[*IndexerRequest, struct{}, struct{}](atype.Indexer, provider, name)) != nil
func IsDefinedIndexer(r *registry.Registry, provider, name string) bool {
return (*indexerActionDef)(core.LookupActionFor[*IndexerRequest, struct{}, struct{}](r, atype.Indexer, provider, name)) != nil
}

// LookupIndexer looks up an [Indexer] registered by [DefineIndexer].
// It returns nil if the model was not defined.
func LookupIndexer(provider, name string) Indexer {
return (*indexerActionDef)(core.LookupActionFor[*IndexerRequest, struct{}, struct{}](atype.Indexer, provider, name))
func LookupIndexer(r *registry.Registry, provider, name string) Indexer {
return (*indexerActionDef)(core.LookupActionFor[*IndexerRequest, struct{}, struct{}](r, atype.Indexer, provider, name))
}

// DefineRetriever registers the given retrieve function as an action, and returns a
// [Retriever] that runs it.
func DefineRetriever(provider, name string, ret func(context.Context, *RetrieverRequest) (*RetrieverResponse, error)) *retrieverActionDef {
return (*retrieverActionDef)(core.DefineAction(provider, name, atype.Retriever, nil, ret))
func DefineRetriever(r *registry.Registry, provider, name string, ret func(context.Context, *RetrieverRequest) (*RetrieverResponse, error)) *retrieverActionDef {
return (*retrieverActionDef)(core.DefineAction(r, provider, name, atype.Retriever, nil, ret))
}

// IsDefinedRetriever reports whether a [Retriever] is defined.
func IsDefinedRetriever(provider, name string) bool {
return (*retrieverActionDef)(core.LookupActionFor[*RetrieverRequest, *RetrieverResponse, struct{}](atype.Retriever, provider, name)) != nil
func IsDefinedRetriever(r *registry.Registry, provider, name string) bool {
return (*retrieverActionDef)(core.LookupActionFor[*RetrieverRequest, *RetrieverResponse, struct{}](r, atype.Retriever, provider, name)) != nil
}

// LookupRetriever looks up a [Retriever] registered by [DefineRetriever].
// It returns nil if the model was not defined.
func LookupRetriever(provider, name string) Retriever {
return (*retrieverActionDef)(core.LookupActionFor[*RetrieverRequest, *RetrieverResponse, struct{}](atype.Retriever, provider, name))
func LookupRetriever(r *registry.Registry, provider, name string) Retriever {
return (*retrieverActionDef)(core.LookupActionFor[*RetrieverRequest, *RetrieverResponse, struct{}](r, atype.Retriever, provider, name))
}

// Index runs the given [Indexer].
Expand Down
Loading

0 comments on commit 38baf30

Please sign in to comment.