Skip to content

Commit

Permalink
feat: add azure_open_ai_text content provider
Browse files Browse the repository at this point in the history
  • Loading branch information
anasmuhmd committed Aug 4, 2024
1 parent d3bec52 commit f1bf794
Show file tree
Hide file tree
Showing 11 changed files with 549 additions and 4 deletions.
4 changes: 4 additions & 0 deletions .mockery.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ packages:
config:
interfaces:
Client:
github.com/blackstork-io/fabric/internal/microsoft:
config:
interfaces:
AzureOpenAIClient:
github.com/blackstork-io/fabric/plugin/resolver:
config:
inpackage: true
Expand Down
38 changes: 38 additions & 0 deletions examples/templates/azureopenai/example.fabric
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@

fabric {
plugin_versions = {
"blackstork/microsoft" = ">= 0.4 < 1.0 || 0.4.0-rev0"
}
}

document "example" {
meta {
name = "example_document"
}

title = "Document title"

section {
title = "Section 2"

section {
title = "Subsection 2"

content text {
value = "Text value 4"
}
}
}

content azure_openai_text {
config {
api_key = env.AZURE_OPENAI_KEY
resource_endpoint = env.AZURE_OPENAI_ENDPOINT
deployment_name = env.AZURE_OPENAI_DEPLOYMENT
api_version = "2024-02-01"
}
prompt = "How are you today?"
max_tokens = 10
}
}

3 changes: 3 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ replace github.com/evanphx/go-hclog-slog v0.0.0-20230905211129-6d31b63d6f09 => g

require (
dario.cat/mergo v1.0.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.6.0 // indirect
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 // indirect
github.com/Azure/azure-sdk-for-go/sdk/internal v1.7.0 // indirect
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect
github.com/Masterminds/goutils v1.1.1 // indirect
github.com/Microsoft/go-winio v0.6.1 // indirect
Expand Down
6 changes: 6 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
github.com/Andrew-Morozko/go-hclog-slog v0.0.0-20240624145756-528e39db2968 h1:9lp9JWCEkcLqYRx8IZQmvZboh+5gcxJ4r9SfWQYjvbI=
github.com/Andrew-Morozko/go-hclog-slog v0.0.0-20240624145756-528e39db2968/go.mod h1:30+1dTR5EdDQGmcjkgOj4i6iVD3wnI0XymSOTPMlUeA=
github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.6.0 h1:FQOmDxJj1If0D0khZR00MDa2Eb+k9BBsSaK7cEbLwkk=
github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.6.0/go.mod h1:X0+PSrHOZdTjkiEhgv53HS5gplbzVVl2jd6hQRYSS3c=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 h1:E+OJmp2tPvt1W+amx48v1eqbjDYsgN+RzP4q16yV5eM=
github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1/go.mod h1:a6xsAQUZg+VsS3TJ05SRp524Hs4pZ/AeFSr5ENf0Yjo=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.7.0 h1:rTfKOCZGy5ViVrlA74ZPE99a+SgoEE2K/yg3RyW9dFA=
github.com/Azure/azure-sdk-for-go/sdk/internal v1.7.0/go.mod h1:4OG6tQ9EOP/MT0NMjDlRzWoVFxfu9rN9B2X+tlSVktg=
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8=
github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E=
github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI=
Expand Down
2 changes: 1 addition & 1 deletion internal/microsoft/cmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ var version string

func main() {
pluginapiv1.Serve(
microsoft.Plugin(version, microsoft.DefaultClientLoader),
microsoft.Plugin(version, microsoft.DefaultClientLoader, microsoft.DefaultAzureOpenAIClientLoader),
)
}
153 changes: 153 additions & 0 deletions internal/microsoft/content_ azure_openai_text.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package microsoft

import (
"bytes"
"context"
"fmt"
"strings"
"text/template"

"github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Masterminds/sprig/v3"
"github.com/blackstork-io/fabric/pkg/diagnostics"
"github.com/blackstork-io/fabric/plugin"
"github.com/blackstork-io/fabric/plugin/dataspec"
"github.com/blackstork-io/fabric/plugin/dataspec/constraint"
"github.com/hashicorp/hcl/v2"
"github.com/zclconf/go-cty/cty"
)

func makeAzureOpenAITextContentSchema(loader AzureOpenAIClientLoadFn) *plugin.ContentProvider {
return &plugin.ContentProvider{
Config: dataspec.ObjectSpec{
&dataspec.AttrSpec{
Name: "api_key",
Type: cty.String,
Constraints: constraint.RequiredNonNull,
Secret: true,
},
&dataspec.AttrSpec{
Name: "resource_endpoint",
Type: cty.String,
Constraints: constraint.RequiredNonNull,
},
&dataspec.AttrSpec{
Name: "deployment_name",
Type: cty.String,
Constraints: constraint.RequiredNonNull,
},
&dataspec.AttrSpec{
Name: "api_version",
Type: cty.String,
DefaultVal: cty.StringVal("2024-02-01"),
},
},
Args: dataspec.ObjectSpec{
&dataspec.AttrSpec{
Name: "prompt",
Type: cty.String,
Constraints: constraint.RequiredNonNull,
ExampleVal: cty.StringVal("Summarize the following text: {{.vars.text_to_summarize}}"),
},
&dataspec.AttrSpec{
Name: "max_tokens",
Type: cty.Number,
DefaultVal: cty.NumberIntVal(1000),
},
&dataspec.AttrSpec{
Name: "temperature",
Type: cty.Number,
DefaultVal: cty.NumberFloatVal(0),
},
&dataspec.AttrSpec{
Name: "top_p",
Type: cty.Number,
},
&dataspec.AttrSpec{
Name: "completions_count",
Type: cty.Number,
DefaultVal: cty.NumberIntVal(1),
},
},
ContentFunc: genOpenAIText(loader),
}
}

func genOpenAIText(loader AzureOpenAIClientLoadFn) plugin.ProvideContentFunc {
return func(ctx context.Context, params *plugin.ProvideContentParams) (*plugin.ContentResult, diagnostics.Diag) {
apiKey := params.Config.GetAttr("api_key").AsString()
resourceEndpoint := params.Config.GetAttr("resource_endpoint").AsString()
client, err := loader(apiKey, resourceEndpoint)
if err != nil {
return nil, diagnostics.Diag{{
Severity: hcl.DiagError,
Summary: "Failed to create client",
Detail: err.Error(),
}}
}
result, err := renderText(ctx, client, params.Config, params.Args, params.DataContext)
if err != nil {
return nil, diagnostics.Diag{{
Severity: hcl.DiagError,
Summary: "Failed to generate text",
Detail: err.Error(),
}}
}
return &plugin.ContentResult{
Content: &plugin.ContentElement{
Markdown: result,
},
}, nil
}
}

func renderText(ctx context.Context, cli AzureOpenAIClient, cfg, args cty.Value, dataCtx plugin.MapData) (string, error) {

params := azopenai.CompletionsOptions{}
params.DeploymentName = to.Ptr(cfg.GetAttr("deployment_name").AsString())

maxTokens, _ := args.GetAttr("max_tokens").AsBigFloat().Int64()
params.MaxTokens = to.Ptr(int32(maxTokens))

temperature, _ := args.GetAttr("temperature").AsBigFloat().Float32()
params.Temperature = to.Ptr(temperature)

completionsCount, _ := args.GetAttr("completions_count").AsBigFloat().Int64()
params.N = to.Ptr(int32(completionsCount))

topPAttr := args.GetAttr("top_p")
if !topPAttr.IsNull() {
topP, _ := topPAttr.AsBigFloat().Float32()
params.TopP = to.Ptr(topP)
}

renderedPrompt, err := templateText(args.GetAttr("prompt").AsString(), dataCtx)
if err != nil {
return "", err
}
params.Prompt = []string{renderedPrompt}
// TODO: use api version from config
resp, err := cli.GetCompletions(ctx, params, nil)
if err != nil {
return "", err
}
if len(resp.Choices) == 0 {
return "", nil
}
return *resp.Choices[0].Text, nil
}

func templateText(text string, dataCtx plugin.MapData) (string, error) {
tmpl, err := template.New("text").Funcs(sprig.FuncMap()).Parse(text)
if err != nil {
return "", fmt.Errorf("failed to parse template: %w", err)
}

var buf bytes.Buffer
err = tmpl.Execute(&buf, dataCtx.Any())
if err != nil {
return "", fmt.Errorf("failed to execute template: %w", err)
}
return strings.TrimSpace(buf.String()), nil
}
Loading

0 comments on commit f1bf794

Please sign in to comment.