Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(go): add firestore retriever #749

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 114 additions & 17 deletions go/genkit/flow.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"strconv"
"sync"
Expand Down Expand Up @@ -98,28 +99,74 @@ type Flow[In, Out, Stream any] struct {
tstate *tracing.State // set from the action when the flow is defined
inputSchema *jsonschema.Schema // Schema of the input to the flow
outputSchema *jsonschema.Schema // Schema of the output out of the flow
auth FlowAuth
// TODO: scheduler
// TODO: experimentalDurable
// TODO: authPolicy
// TODO: middleware
}

// runOptions configures a single flow run.
type runOptions struct {
authContext map[string]any // Auth context to pass to auth policy checker when calling a flow directly.
}

// flowOptions configures a flow.
type flowOptions struct {
auth FlowAuth // Auth provider and policy checker for the flow.
}

type noStream = func(context.Context, struct{}) error

// FlowAuth configures a auth context provider and a auth policy check for a flow.
type FlowAuth interface {
// ProvideAuthContext provides auth context from an auth header.
ProvideAuthContext(ctx context.Context, authHeader string) (map[string]any, error)

// CheckAuthPolicy checks auth context against policy.
CheckAuthPolicy(auth map[string]any, input any) error
}

// streamingCallback is the type of streaming callbacks.
type streamingCallback[Stream any] func(context.Context, Stream) error

// flowOption modifies the flow with the provided option.
type flowOption func(opts *flowOptions)

// flowRunOption modifies a flow run with the provided option.
type flowRunOption func(opts *runOptions)

// WithFlowAuth sets an auth provider and policy checker for the flow.
func WithFlowAuth(auth FlowAuth) flowOption {
return func(f *flowOptions) {
if f.auth != nil {
log.Panic("auth already set in flow")
}
f.auth = auth
}
}

// WithLocalAuth configures an option to run or stream a flow with a local auth value.
func WithLocalAuth(authContext map[string]any) flowRunOption {
return func(opts *runOptions) {
if opts.authContext != nil {
log.Panic("authContext already set in runOptions")
}
opts.authContext = authContext
}
}

// DefineFlow creates a Flow that runs fn, and registers it as an action.
//
// fn takes an input of type In and returns an output of type Out.
func DefineFlow[In, Out any](
name string,
fn func(ctx context.Context, input In) (Out, error),
opts ...flowOption,
) *Flow[In, Out, struct{}] {
return defineFlow(registry.Global, name, core.Func[In, Out, struct{}](
func(ctx context.Context, input In, cb func(ctx context.Context, _ struct{}) error) (Out, error) {
return fn(ctx, input)
}))
}), opts...)
}

// DefineStreamingFlow creates a streaming Flow that runs fn, and registers it as an action.
Expand All @@ -134,11 +181,12 @@ func DefineFlow[In, Out any](
func DefineStreamingFlow[In, Out, Stream any](
name string,
fn func(ctx context.Context, input In, callback func(context.Context, Stream) error) (Out, error),
opts ...flowOption,
) *Flow[In, Out, Stream] {
return defineFlow(registry.Global, name, core.Func[In, Out, Stream](fn))
return defineFlow(registry.Global, name, core.Func[In, Out, Stream](fn), opts...)
}

func defineFlow[In, Out, Stream any](r *registry.Registry, name string, fn core.Func[In, Out, Stream]) *Flow[In, Out, Stream] {
func defineFlow[In, Out, Stream any](r *registry.Registry, name string, fn core.Func[In, Out, Stream], opts ...flowOption) *Flow[In, Out, Stream] {
var i In
var o Out
f := &Flow[In, Out, Stream]{
Expand All @@ -148,12 +196,24 @@ func defineFlow[In, Out, Stream any](r *registry.Registry, name string, fn core.
outputSchema: base.InferJSONSchema(o),
// TODO: set stateStore?
}
flowOpts := &flowOptions{}
for _, opt := range opts {
opt(flowOpts)
}
f.auth = flowOpts.auth
metadata := map[string]any{
"inputSchema": f.inputSchema,
"outputSchema": f.outputSchema,
"requiresAuth": f.auth != nil,
}
afunc := func(ctx context.Context, inst *flowInstruction[In], cb func(context.Context, Stream) error) (*flowState[In, Out], error) {
tracing.SetCustomMetadataAttr(ctx, "flow:wrapperAction", "true")
// Only non-durable flows have an auth policy so can safely assume Start.Input.
if inst.Start != nil {
if err := f.checkAuthPolicy(inst.Auth, any(inst.Start.Input)); err != nil {
return nil, err
}
}
return f.runInstruction(ctx, inst, streamingCallback[Stream](cb))
}
core.DefineActionInRegistry(r, "", f.name, atype.Flow, metadata, nil, afunc)
Expand All @@ -167,18 +227,19 @@ func defineFlow[In, Out, Stream any](r *registry.Registry, name string, fn core.
// A flowInstruction is an instruction to follow with a flow.
// It is the input for the flow's action.
// Exactly one field will be non-nil.
type flowInstruction[I any] struct {
Start *startInstruction[I] `json:"start,omitempty"`
type flowInstruction[In any] struct {
Start *startInstruction[In] `json:"start,omitempty"`
Resume *resumeInstruction `json:"resume,omitempty"`
Schedule *scheduleInstruction[I] `json:"schedule,omitempty"`
Schedule *scheduleInstruction[In] `json:"schedule,omitempty"`
RunScheduled *runScheduledInstruction `json:"runScheduled,omitempty"`
State *stateInstruction `json:"state,omitempty"`
Retry *retryInstruction `json:"retry,omitempty"`
Auth map[string]any `json:"auth,omitempty"`
}

// A startInstruction starts a flow.
type startInstruction[I any] struct {
Input I `json:"input,omitempty"`
type startInstruction[In any] struct {
Input In `json:"input,omitempty"`
Labels map[string]string `json:"labels,omitempty"`
}

Expand All @@ -189,9 +250,9 @@ type resumeInstruction struct {
}

// A scheduleInstruction schedules a flow to start at a later time.
type scheduleInstruction[I any] struct {
type scheduleInstruction[In any] struct {
DelaySecs float64 `json:"delay,omitempty"`
Input I `json:"input,omitempty"`
Input In `json:"input,omitempty"`
}

// A runScheduledInstruction starts a scheduled flow.
Expand Down Expand Up @@ -324,7 +385,7 @@ func (f *Flow[In, Out, Stream]) runInstruction(ctx context.Context, inst *flowIn
// Name returns the name that the flow was defined with.
func (f *Flow[In, Out, Stream]) Name() string { return f.name }

func (f *Flow[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error) {
func (f *Flow[In, Out, Stream]) runJSON(ctx context.Context, authHeader string, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error) {
// Validate input before unmarshaling it because invalid or unknown fields will be discarded in the process.
if err := base.ValidateJSON(input, f.inputSchema); err != nil {
return nil, &base.HTTPError{Code: http.StatusBadRequest, Err: err}
Expand All @@ -333,6 +394,13 @@ func (f *Flow[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessa
if err := json.Unmarshal(input, &in); err != nil {
return nil, &base.HTTPError{Code: http.StatusBadRequest, Err: err}
}
authContext, err := f.provideAuthContext(ctx, authHeader)
if err != nil {
return nil, &base.HTTPError{Code: http.StatusUnauthorized, Err: err}
}
if err := f.checkAuthPolicy(authContext, in); err != nil {
return nil, &base.HTTPError{Code: http.StatusForbidden, Err: err}
}
// If there is a callback, wrap it to turn an S into a json.RawMessage.
var callback streamingCallback[Stream]
if cb != nil {
Expand Down Expand Up @@ -361,6 +429,28 @@ func (f *Flow[In, Out, Stream]) runJSON(ctx context.Context, input json.RawMessa
return json.Marshal(res.Response)
}

// provideAuthContext provides auth context for the given auth header if flow auth is configured.
func (f *Flow[In, Out, Stream]) provideAuthContext(ctx context.Context, authHeader string) (map[string]any, error) {
if f.auth != nil {
authContext, err := f.auth.ProvideAuthContext(ctx, authHeader)
if err != nil {
return nil, fmt.Errorf("unauthorized: %w", err)
}
return authContext, nil
}
return nil, nil
}

// checkAuthPolicy checks auth context against the policy if flow auth is configured.
func (f *Flow[In, Out, Stream]) checkAuthPolicy(authContext map[string]any, input any) error {
if f.auth != nil {
if err := f.auth.CheckAuthPolicy(authContext, input); err != nil {
return fmt.Errorf("permission denied for resource: %w", err)
}
}
return nil
}

// start starts executing the flow with the given input.
func (f *Flow[In, Out, Stream]) start(ctx context.Context, input In, cb streamingCallback[Stream]) (_ *flowState[In, Out], err error) {
flowID, err := generateFlowID()
Expand Down Expand Up @@ -569,11 +659,18 @@ func Run[Out any](ctx context.Context, name string, f func() (Out, error)) (Out,

// Run runs the flow in the context of another flow. The flow must run to completion when started
// (that is, it must not have interrupts).
func (f *Flow[In, Out, Stream]) Run(ctx context.Context, input In) (Out, error) {
return f.run(ctx, input, nil)
func (f *Flow[In, Out, Stream]) Run(ctx context.Context, input In, opts ...flowRunOption) (Out, error) {
return f.run(ctx, input, nil, opts...)
}

func (f *Flow[In, Out, Stream]) run(ctx context.Context, input In, cb func(context.Context, Stream) error) (Out, error) {
func (f *Flow[In, Out, Stream]) run(ctx context.Context, input In, cb func(context.Context, Stream) error, opts ...flowRunOption) (Out, error) {
runOpts := &runOptions{}
for _, opt := range opts {
opt(runOpts)
}
if err := f.checkAuthPolicy(runOpts.authContext, input); err != nil {
return base.Zero[Out](), err
}
state, err := f.start(ctx, input, cb)
if err != nil {
return base.Zero[Out](), err
Expand Down Expand Up @@ -602,7 +699,7 @@ type StreamFlowValue[Out, Stream any] struct {
// again.
//
// Otherwise the Stream field of the passed [StreamFlowValue] holds a streamed result.
func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In) func(func(*StreamFlowValue[Out, Stream], error) bool) {
func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In, opts ...flowRunOption) func(func(*StreamFlowValue[Out, Stream], error) bool) {
return func(yield func(*StreamFlowValue[Out, Stream], error) bool) {
cb := func(ctx context.Context, s Stream) error {
if ctx.Err() != nil {
Expand All @@ -613,7 +710,7 @@ func (f *Flow[In, Out, Stream]) Stream(ctx context.Context, input In) func(func(
}
return nil
}
output, err := f.run(ctx, input, cb)
output, err := f.run(ctx, input, cb, opts...)
if err != nil {
yield(nil, err)
} else {
Expand Down
78 changes: 32 additions & 46 deletions go/genkit/servers.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,13 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"log/slog"
"net/http"
"os"
"os/signal"
"strconv"
"sync"
"sync/atomic"
"syscall"
"time"

"github.com/firebase/genkit/go/core/logger"
Expand Down Expand Up @@ -76,7 +73,7 @@ type flow interface {

// runJSON uses encoding/json to unmarshal the input,
// calls Flow.start, then returns the marshaled result.
runJSON(ctx context.Context, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error)
runJSON(ctx context.Context, authHeader string, input json.RawMessage, cb streamingCallback[json.RawMessage]) (json.RawMessage, error)
}

// startServer starts an HTTP server listening on the address.
Expand Down Expand Up @@ -163,19 +160,17 @@ func (s *devServer) handleRunAction(w http.ResponseWriter, r *http.Request) erro
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
return &base.HTTPError{Code: http.StatusBadRequest, Err: err}
}
stream := false
if s := r.FormValue("stream"); s != "" {
var err error
stream, err = strconv.ParseBool(s)
if err != nil {
return err
}
stream, err := parseBoolQueryParam(r, "stream")
if err != nil {
return err
}
logger.FromContext(ctx).Debug("running action",
"key", body.Key,
"stream", stream)
var callback streamingCallback[json.RawMessage]
if stream {
w.Header().Set("Content-Type", "text/plain")
w.Header().Set("Transfer-Encoding", "chunked")
// Stream results are newline-separated JSON.
callback = func(ctx context.Context, msg json.RawMessage) error {
_, err := fmt.Fprintf(w, "%s\n", msg)
Expand Down Expand Up @@ -328,29 +323,42 @@ func newFlowServeMux(r *registry.Registry, flows []string) *http.ServeMux {

func nonDurableFlowHandler(f flow) func(http.ResponseWriter, *http.Request) error {
return func(w http.ResponseWriter, r *http.Request) error {
var body struct {
Data json.RawMessage `json:"data"`
}
defer r.Body.Close()
input, err := io.ReadAll(r.Body)
if err != nil {
return err
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
return &base.HTTPError{Code: http.StatusBadRequest, Err: err}
}
stream, err := parseBoolQueryParam(r, "stream")
if err != nil {
return err
}
var callback streamingCallback[json.RawMessage]
if stream {
// TODO: implement streaming.
return &base.HTTPError{Code: http.StatusNotImplemented, Err: errors.New("streaming")}
} else {
// TODO: telemetry
out, err := f.runJSON(r.Context(), json.RawMessage(input), nil)
if err != nil {
return err
w.Header().Set("Content-Type", "text/plain")
w.Header().Set("Transfer-Encoding", "chunked")
// Stream results are newline-separated JSON.
callback = func(ctx context.Context, msg json.RawMessage) error {
_, err := fmt.Fprintf(w, "%s\n", msg)
if err != nil {
return err
}
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
return nil
}
// Responses for non-streaming, non-durable flows are passed back
// with the flow result stored in a field called "result."
_, err = fmt.Fprintf(w, `{"result": %s}\n`, out)
}
// TODO: telemetry
out, err := f.runJSON(r.Context(), r.Header.Get("Authorization"), body.Data, callback)
if err != nil {
return err
}
// Responses for non-streaming, non-durable flows are passed back
// with the flow result stored in a field called "result."
_, err = fmt.Fprintf(w, `{"result": %s}\n`, out)
return err
}
}

Expand All @@ -365,28 +373,6 @@ func serverAddress(arg, envVar, defaultValue string) string {
return defaultValue
}

func listenAndServe(addr string, mux *http.ServeMux) error {
server := &http.Server{
Addr: addr,
Handler: mux,
}
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGTERM)
go func() {
<-sigCh
slog.Info("received SIGTERM, shutting down server")
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
slog.Error("server shutdown failed", "err", err)
} else {
slog.Info("server shutdown successfully")
}
}()
slog.Info("listening", "addr", addr)
return server.ListenAndServe()
}

// requestID is a unique ID for each request.
var requestID atomic.Int64

Expand Down
Loading
Loading