From f1bf794b294669b56d570391e4e32e7a7219c5cb Mon Sep 17 00:00:00 2001 From: Anas Muhammed Date: Sun, 4 Aug 2024 22:52:30 +0400 Subject: [PATCH 1/3] feat: add azure_open_ai_text content provider --- .mockery.yaml | 4 + examples/templates/azureopenai/example.fabric | 38 +++ go.mod | 3 + go.sum | 6 + internal/microsoft/cmd/main.go | 2 +- .../microsoft/content_ azure_openai_text.go | 153 ++++++++++++ .../content_azure_openai_text_test.go | 227 ++++++++++++++++++ internal/microsoft/plugin.go | 19 +- internal/microsoft/plugin_test.go | 3 +- internal/plugin_validity_test.go | 2 +- .../microsoft/azure_open_ai_client.go | 96 ++++++++ 11 files changed, 549 insertions(+), 4 deletions(-) create mode 100644 examples/templates/azureopenai/example.fabric create mode 100644 internal/microsoft/content_ azure_openai_text.go create mode 100644 internal/microsoft/content_azure_openai_text_test.go create mode 100644 mocks/internalpkg/microsoft/azure_open_ai_client.go diff --git a/.mockery.yaml b/.mockery.yaml index 4079849e..a7ee190e 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -45,6 +45,10 @@ packages: config: interfaces: Client: + github.com/blackstork-io/fabric/internal/microsoft: + config: + interfaces: + AzureOpenAIClient: github.com/blackstork-io/fabric/plugin/resolver: config: inpackage: true diff --git a/examples/templates/azureopenai/example.fabric b/examples/templates/azureopenai/example.fabric new file mode 100644 index 00000000..48d131f6 --- /dev/null +++ b/examples/templates/azureopenai/example.fabric @@ -0,0 +1,38 @@ + +fabric { + plugin_versions = { + "blackstork/microsoft" = ">= 0.4 < 1.0 || 0.4.0-rev0" + } +} + +document "example" { + meta { + name = "example_document" + } + + title = "Document title" + + section { + title = "Section 2" + + section { + title = "Subsection 2" + + content text { + value = "Text value 4" + } + } + } + + content azure_openai_text { + config { + api_key = env.AZURE_OPENAI_KEY + resource_endpoint = env.AZURE_OPENAI_ENDPOINT + deployment_name = env.AZURE_OPENAI_DEPLOYMENT + api_version = "2024-02-01" + } + prompt = "How are you today?" + max_tokens = 10 + } +} + diff --git a/go.mod b/go.mod index dab48e32..4363aafd 100644 --- a/go.mod +++ b/go.mod @@ -59,6 +59,9 @@ replace github.com/evanphx/go-hclog-slog v0.0.0-20230905211129-6d31b63d6f09 => g require ( dario.cat/mergo v1.0.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.6.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.7.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect github.com/Masterminds/goutils v1.1.1 // indirect github.com/Microsoft/go-winio v0.6.1 // indirect diff --git a/go.sum b/go.sum index f48a08cf..6760e11f 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,12 @@ github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9 github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8= github.com/Andrew-Morozko/go-hclog-slog v0.0.0-20240624145756-528e39db2968 h1:9lp9JWCEkcLqYRx8IZQmvZboh+5gcxJ4r9SfWQYjvbI= github.com/Andrew-Morozko/go-hclog-slog v0.0.0-20240624145756-528e39db2968/go.mod h1:30+1dTR5EdDQGmcjkgOj4i6iVD3wnI0XymSOTPMlUeA= +github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.6.0 h1:FQOmDxJj1If0D0khZR00MDa2Eb+k9BBsSaK7cEbLwkk= +github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.6.0/go.mod h1:X0+PSrHOZdTjkiEhgv53HS5gplbzVVl2jd6hQRYSS3c= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 h1:E+OJmp2tPvt1W+amx48v1eqbjDYsgN+RzP4q16yV5eM= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1/go.mod h1:a6xsAQUZg+VsS3TJ05SRp524Hs4pZ/AeFSr5ENf0Yjo= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.7.0 h1:rTfKOCZGy5ViVrlA74ZPE99a+SgoEE2K/yg3RyW9dFA= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.7.0/go.mod h1:4OG6tQ9EOP/MT0NMjDlRzWoVFxfu9rN9B2X+tlSVktg= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI= diff --git a/internal/microsoft/cmd/main.go b/internal/microsoft/cmd/main.go index 10245c02..9fb41629 100644 --- a/internal/microsoft/cmd/main.go +++ b/internal/microsoft/cmd/main.go @@ -9,6 +9,6 @@ var version string func main() { pluginapiv1.Serve( - microsoft.Plugin(version, microsoft.DefaultClientLoader), + microsoft.Plugin(version, microsoft.DefaultClientLoader, microsoft.DefaultAzureOpenAIClientLoader), ) } diff --git a/internal/microsoft/content_ azure_openai_text.go b/internal/microsoft/content_ azure_openai_text.go new file mode 100644 index 00000000..261c8bd4 --- /dev/null +++ b/internal/microsoft/content_ azure_openai_text.go @@ -0,0 +1,153 @@ +package microsoft + +import ( + "bytes" + "context" + "fmt" + "strings" + "text/template" + + "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Masterminds/sprig/v3" + "github.com/blackstork-io/fabric/pkg/diagnostics" + "github.com/blackstork-io/fabric/plugin" + "github.com/blackstork-io/fabric/plugin/dataspec" + "github.com/blackstork-io/fabric/plugin/dataspec/constraint" + "github.com/hashicorp/hcl/v2" + "github.com/zclconf/go-cty/cty" +) + +func makeAzureOpenAITextContentSchema(loader AzureOpenAIClientLoadFn) *plugin.ContentProvider { + return &plugin.ContentProvider{ + Config: dataspec.ObjectSpec{ + &dataspec.AttrSpec{ + Name: "api_key", + Type: cty.String, + Constraints: constraint.RequiredNonNull, + Secret: true, + }, + &dataspec.AttrSpec{ + Name: "resource_endpoint", + Type: cty.String, + Constraints: constraint.RequiredNonNull, + }, + &dataspec.AttrSpec{ + Name: "deployment_name", + Type: cty.String, + Constraints: constraint.RequiredNonNull, + }, + &dataspec.AttrSpec{ + Name: "api_version", + Type: cty.String, + DefaultVal: cty.StringVal("2024-02-01"), + }, + }, + Args: dataspec.ObjectSpec{ + &dataspec.AttrSpec{ + Name: "prompt", + Type: cty.String, + Constraints: constraint.RequiredNonNull, + ExampleVal: cty.StringVal("Summarize the following text: {{.vars.text_to_summarize}}"), + }, + &dataspec.AttrSpec{ + Name: "max_tokens", + Type: cty.Number, + DefaultVal: cty.NumberIntVal(1000), + }, + &dataspec.AttrSpec{ + Name: "temperature", + Type: cty.Number, + DefaultVal: cty.NumberFloatVal(0), + }, + &dataspec.AttrSpec{ + Name: "top_p", + Type: cty.Number, + }, + &dataspec.AttrSpec{ + Name: "completions_count", + Type: cty.Number, + DefaultVal: cty.NumberIntVal(1), + }, + }, + ContentFunc: genOpenAIText(loader), + } +} + +func genOpenAIText(loader AzureOpenAIClientLoadFn) plugin.ProvideContentFunc { + return func(ctx context.Context, params *plugin.ProvideContentParams) (*plugin.ContentResult, diagnostics.Diag) { + apiKey := params.Config.GetAttr("api_key").AsString() + resourceEndpoint := params.Config.GetAttr("resource_endpoint").AsString() + client, err := loader(apiKey, resourceEndpoint) + if err != nil { + return nil, diagnostics.Diag{{ + Severity: hcl.DiagError, + Summary: "Failed to create client", + Detail: err.Error(), + }} + } + result, err := renderText(ctx, client, params.Config, params.Args, params.DataContext) + if err != nil { + return nil, diagnostics.Diag{{ + Severity: hcl.DiagError, + Summary: "Failed to generate text", + Detail: err.Error(), + }} + } + return &plugin.ContentResult{ + Content: &plugin.ContentElement{ + Markdown: result, + }, + }, nil + } +} + +func renderText(ctx context.Context, cli AzureOpenAIClient, cfg, args cty.Value, dataCtx plugin.MapData) (string, error) { + + params := azopenai.CompletionsOptions{} + params.DeploymentName = to.Ptr(cfg.GetAttr("deployment_name").AsString()) + + maxTokens, _ := args.GetAttr("max_tokens").AsBigFloat().Int64() + params.MaxTokens = to.Ptr(int32(maxTokens)) + + temperature, _ := args.GetAttr("temperature").AsBigFloat().Float32() + params.Temperature = to.Ptr(temperature) + + completionsCount, _ := args.GetAttr("completions_count").AsBigFloat().Int64() + params.N = to.Ptr(int32(completionsCount)) + + topPAttr := args.GetAttr("top_p") + if !topPAttr.IsNull() { + topP, _ := topPAttr.AsBigFloat().Float32() + params.TopP = to.Ptr(topP) + } + + renderedPrompt, err := templateText(args.GetAttr("prompt").AsString(), dataCtx) + if err != nil { + return "", err + } + params.Prompt = []string{renderedPrompt} + // TODO: use api version from config + resp, err := cli.GetCompletions(ctx, params, nil) + if err != nil { + return "", err + } + if len(resp.Choices) == 0 { + return "", nil + } + return *resp.Choices[0].Text, nil +} + +func templateText(text string, dataCtx plugin.MapData) (string, error) { + tmpl, err := template.New("text").Funcs(sprig.FuncMap()).Parse(text) + if err != nil { + return "", fmt.Errorf("failed to parse template: %w", err) + } + + var buf bytes.Buffer + err = tmpl.Execute(&buf, dataCtx.Any()) + if err != nil { + return "", fmt.Errorf("failed to execute template: %w", err) + } + return strings.TrimSpace(buf.String()), nil +} diff --git a/internal/microsoft/content_azure_openai_text_test.go b/internal/microsoft/content_azure_openai_text_test.go new file mode 100644 index 00000000..495667ff --- /dev/null +++ b/internal/microsoft/content_azure_openai_text_test.go @@ -0,0 +1,227 @@ +package microsoft_test + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/hashicorp/hcl/v2" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "github.com/zclconf/go-cty/cty" + + "github.com/blackstork-io/fabric/internal/microsoft" + client_mocks "github.com/blackstork-io/fabric/mocks/internalpkg/microsoft" + "github.com/blackstork-io/fabric/pkg/diagnostics" + "github.com/blackstork-io/fabric/pkg/diagnostics/diagtest" + "github.com/blackstork-io/fabric/plugin" + "github.com/blackstork-io/fabric/plugin/plugintest" + "github.com/blackstork-io/fabric/print/mdprint" +) + +type AzureOpenAITextContentTestSuite struct { + suite.Suite + plugin *plugin.Schema + schema *plugin.ContentProvider + cli *client_mocks.AzureOpenAIClient +} + +func TestAzureOpenAITextContentSuite(t *testing.T) { + suite.Run(t, &AzureOpenAITextContentTestSuite{}) +} + +func (s *AzureOpenAITextContentTestSuite) SetupSuite() { + s.plugin = microsoft.Plugin("1.0.0", nil, (func(apiKey string, endPoint string) (cli microsoft.AzureOpenAIClient, err error) { + return s.cli, nil + })) + s.schema = s.plugin.ContentProviders["azure_openai_text"] +} + +func (s *AzureOpenAITextContentTestSuite) SetupTest() { + s.cli = &client_mocks.AzureOpenAIClient{} +} + +func (s *AzureOpenAITextContentTestSuite) TearDownTest() { + s.cli.AssertExpectations(s.T()) +} + +func (s *AzureOpenAITextContentTestSuite) TestSchema() { + s.Require().NotNil(s.plugin) + s.Require().NotNil(s.schema) + s.NotNil(s.schema.Args) + s.NotNil(s.schema.ContentFunc) + s.NotNil(s.schema.Config) +} + +func (s *AzureOpenAITextContentTestSuite) TestBasic() { + s.cli.On("GetCompletions", mock.Anything, azopenai.CompletionsOptions{ + DeploymentName: to.Ptr("test"), + MaxTokens: to.Ptr(int32(1000)), + Temperature: to.Ptr(float32(0)), + Prompt: []string{"Tell me a story"}, + TopP: nil, + N: to.Ptr(int32(1)), + }, mock.Anything).Return(azopenai.GetCompletionsResponse{ + Completions: azopenai.Completions{ + Choices: []azopenai.Choice{ + {Text: to.Ptr("Once upon a time.")}, + }, + }, + }, nil) + ctx := context.Background() + dataCtx := plugin.MapData{} + result, diags := s.schema.ContentFunc(ctx, &plugin.ProvideContentParams{ + Args: cty.ObjectVal(map[string]cty.Value{ + "prompt": cty.StringVal("Tell me a story"), + "max_tokens": cty.NumberIntVal(1000), + "temperature": cty.NumberFloatVal(0), + "top_p": cty.NilVal, + "completions_count": cty.NumberIntVal(1), + }), + Config: cty.ObjectVal(map[string]cty.Value{ + "api_key": cty.StringVal("testtoken"), + "resource_endpoint": cty.StringVal("http://test"), + "deployment_name": cty.StringVal("test"), + "api_version": cty.StringVal("2024-02-01"), + }), + DataContext: dataCtx, + }) + fmt.Println(diags) + s.Nil(diags) + s.Equal("Once upon a time.", mdprint.PrintString(result.Content)) +} + +func (s *AzureOpenAITextContentTestSuite) TestAdvanced() { + s.cli.On("GetCompletions", mock.Anything, azopenai.CompletionsOptions{ + DeploymentName: to.Ptr("test"), + MaxTokens: to.Ptr(int32(1000)), + Temperature: to.Ptr(float32(0)), + Prompt: []string{"Tell me a story about BAR. {\"foo\":\"bar\"}"}, + TopP: nil, + N: to.Ptr(int32(1)), + }, mock.Anything).Return(azopenai.GetCompletionsResponse{ + Completions: azopenai.Completions{ + Choices: []azopenai.Choice{ + {Text: to.Ptr("Once upon a time.")}, + }, + }, + }, nil) + + ctx := context.Background() + dataCtx := plugin.MapData{ + "local": plugin.MapData{ + "foo": plugin.StringData("bar"), + }, + } + + result, diags := s.schema.ContentFunc(ctx, &plugin.ProvideContentParams{ + Args: cty.ObjectVal(map[string]cty.Value{ + "prompt": cty.StringVal("Tell me a story about {{.local.foo | upper}}. {{ .local | toRawJson }}"), + "max_tokens": cty.NumberIntVal(1000), + "temperature": cty.NumberFloatVal(0), + "top_p": cty.NilVal, + "completions_count": cty.NumberIntVal(1), + }), + Config: cty.ObjectVal(map[string]cty.Value{ + "api_key": cty.StringVal("testtoken"), + "resource_endpoint": cty.StringVal("http://test"), + "deployment_name": cty.StringVal("test"), + "api_version": cty.StringVal("2024-02-01"), + }), + DataContext: dataCtx, + }) + s.Empty(diags) + s.Equal("Once upon a time.", mdprint.PrintString(result.Content)) +} + +func (s *AzureOpenAITextContentTestSuite) TestMissingPrompt() { + plugintest.DecodeAndAssert(s.T(), s.schema.Args, "", plugin.MapData{}, diagtest.Asserts{ + { + diagtest.IsError, + diagtest.SummaryEquals("Missing required argument"), + diagtest.DetailContains("prompt"), + }, + { + diagtest.IsError, + diagtest.SummaryEquals("Argument value must be non-null"), + diagtest.DetailContains("prompt"), + }, + }) +} + +func (s *AzureOpenAITextContentTestSuite) TestMissingAPIKey() { + plugintest.DecodeAndAssert(s.T(), s.schema.Args, ` + prompt = "Tell me a story" + `, plugin.MapData{}, diagtest.Asserts{}) + plugintest.DecodeAndAssert(s.T(), s.schema.Config, ` + `, plugin.MapData{}, diagtest.Asserts{ + { + diagtest.IsError, + diagtest.SummaryEquals("Missing required argument"), + diagtest.DetailContains("api_key"), + }, + { + diagtest.IsError, + diagtest.SummaryEquals("Argument value must be non-null"), + diagtest.DetailContains("api_key"), + }, + { + diagtest.IsError, + diagtest.SummaryEquals("Missing required argument"), + diagtest.DetailContains("resource_endpoint"), + }, + { + diagtest.IsError, + diagtest.SummaryEquals("Argument value must be non-null"), + diagtest.DetailContains("resource_endpoint"), + }, + { + diagtest.IsError, + diagtest.SummaryEquals("Missing required argument"), + diagtest.DetailContains("deployment_name"), + }, + { + diagtest.IsError, + diagtest.SummaryEquals("Argument value must be non-null"), + diagtest.DetailContains("deployment_name"), + }, + }) +} + +func (s *AzureOpenAITextContentTestSuite) TestFailingClient() { + s.cli.On("GetCompletions", mock.Anything, azopenai.CompletionsOptions{ + DeploymentName: to.Ptr("test"), + MaxTokens: to.Ptr(int32(1000)), + Temperature: to.Ptr(float32(0)), + Prompt: []string{"Tell me a story"}, + TopP: nil, + N: to.Ptr(int32(1)), + }, mock.Anything).Return(azopenai.GetCompletionsResponse{}, errors.New("failed to generate text from model")) + ctx := context.Background() + dataCtx := plugin.MapData{} + result, diags := s.schema.ContentFunc(ctx, &plugin.ProvideContentParams{ + Args: cty.ObjectVal(map[string]cty.Value{ + "prompt": cty.StringVal("Tell me a story"), + "max_tokens": cty.NumberIntVal(1000), + "temperature": cty.NumberFloatVal(0), + "top_p": cty.NilVal, + "completions_count": cty.NumberIntVal(1), + }), + Config: cty.ObjectVal(map[string]cty.Value{ + "api_key": cty.StringVal("testtoken"), + "resource_endpoint": cty.StringVal("http://test"), + "deployment_name": cty.StringVal("test"), + "api_version": cty.StringVal("2024-02-01"), + }), + DataContext: dataCtx, + }) + s.Nil(result) + s.Equal(diagnostics.Diag{{ + Severity: hcl.DiagError, + Summary: "Failed to generate text", + Detail: "failed to generate text from model", + }}, diags) +} diff --git a/internal/microsoft/plugin.go b/internal/microsoft/plugin.go index 053405fa..5e010a53 100644 --- a/internal/microsoft/plugin.go +++ b/internal/microsoft/plugin.go @@ -6,6 +6,8 @@ import ( "github.com/zclconf/go-cty/cty" + "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/blackstork-io/fabric/internal/microsoft/client" "github.com/blackstork-io/fabric/plugin" ) @@ -14,7 +16,19 @@ type ClientLoadFn func() client.Client var DefaultClientLoader ClientLoadFn = client.New -func Plugin(version string, loader ClientLoadFn) *plugin.Schema { +type AzureOpenAIClientLoadFn func(azureOpenAIKey string, azureOpenAIEndpoint string) (client AzureOpenAIClient, err error) + +type AzureOpenAIClient interface { + GetCompletions(ctx context.Context, body azopenai.CompletionsOptions, options *azopenai.GetCompletionsOptions) (azopenai.GetCompletionsResponse, error) +} + +var DefaultAzureOpenAIClientLoader AzureOpenAIClientLoadFn = func(azureOpenAIKey string, azureOpenAIEndpoint string) (client AzureOpenAIClient, err error) { + keyCredential := azcore.NewKeyCredential(azureOpenAIKey) + client, err = azopenai.NewClientWithKeyCredential(azureOpenAIEndpoint, keyCredential, nil) + return +} + +func Plugin(version string, loader ClientLoadFn, openAiClientLoader AzureOpenAIClientLoadFn) *plugin.Schema { if loader == nil { loader = DefaultClientLoader } @@ -25,6 +39,9 @@ func Plugin(version string, loader ClientLoadFn) *plugin.Schema { DataSources: plugin.DataSources{ "microsoft_sentinel_incidents": makeMicrosoftSentinelIncidentsDataSource(loader), }, + ContentProviders: plugin.ContentProviders{ + "azure_openai_text": makeAzureOpenAITextContentSchema(openAiClientLoader), + }, } } diff --git a/internal/microsoft/plugin_test.go b/internal/microsoft/plugin_test.go index 3e44160a..c9aaa301 100644 --- a/internal/microsoft/plugin_test.go +++ b/internal/microsoft/plugin_test.go @@ -7,8 +7,9 @@ import ( ) func TestPlugin_Schema(t *testing.T) { - schema := Plugin("1.2.3", nil) + schema := Plugin("1.2.3", nil, nil) assert.Equal(t, "blackstork/microsoft", schema.Name) assert.Equal(t, "1.2.3", schema.Version) assert.NotNil(t, schema.DataSources["microsoft_sentinel_incidents"]) + assert.NotNil(t, schema.ContentProviders["azure_openai_text"]) } diff --git a/internal/plugin_validity_test.go b/internal/plugin_validity_test.go index 707da142..1680ff64 100644 --- a/internal/plugin_validity_test.go +++ b/internal/plugin_validity_test.go @@ -46,7 +46,7 @@ func TestAllPluginSchemaValidity(t *testing.T) { splunk.Plugin(ver, nil), nistnvd.Plugin(ver, nil), snyk.Plugin(ver, nil), - microsoft.Plugin(ver, nil), + microsoft.Plugin(ver, nil, nil), } for _, p := range plugins { p := p diff --git a/mocks/internalpkg/microsoft/azure_open_ai_client.go b/mocks/internalpkg/microsoft/azure_open_ai_client.go new file mode 100644 index 00000000..b3938f82 --- /dev/null +++ b/mocks/internalpkg/microsoft/azure_open_ai_client.go @@ -0,0 +1,96 @@ +// Code generated by mockery v2.42.1. DO NOT EDIT. + +package microsoft_mocks + +import ( + context "context" + + azopenai "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" + + mock "github.com/stretchr/testify/mock" +) + +// AzureOpenAIClient is an autogenerated mock type for the AzureOpenAIClient type +type AzureOpenAIClient struct { + mock.Mock +} + +type AzureOpenAIClient_Expecter struct { + mock *mock.Mock +} + +func (_m *AzureOpenAIClient) EXPECT() *AzureOpenAIClient_Expecter { + return &AzureOpenAIClient_Expecter{mock: &_m.Mock} +} + +// GetCompletions provides a mock function with given fields: ctx, body, options +func (_m *AzureOpenAIClient) GetCompletions(ctx context.Context, body azopenai.CompletionsOptions, options *azopenai.GetCompletionsOptions) (azopenai.GetCompletionsResponse, error) { + ret := _m.Called(ctx, body, options) + + if len(ret) == 0 { + panic("no return value specified for GetCompletions") + } + + var r0 azopenai.GetCompletionsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, azopenai.CompletionsOptions, *azopenai.GetCompletionsOptions) (azopenai.GetCompletionsResponse, error)); ok { + return rf(ctx, body, options) + } + if rf, ok := ret.Get(0).(func(context.Context, azopenai.CompletionsOptions, *azopenai.GetCompletionsOptions) azopenai.GetCompletionsResponse); ok { + r0 = rf(ctx, body, options) + } else { + r0 = ret.Get(0).(azopenai.GetCompletionsResponse) + } + + if rf, ok := ret.Get(1).(func(context.Context, azopenai.CompletionsOptions, *azopenai.GetCompletionsOptions) error); ok { + r1 = rf(ctx, body, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// AzureOpenAIClient_GetCompletions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCompletions' +type AzureOpenAIClient_GetCompletions_Call struct { + *mock.Call +} + +// GetCompletions is a helper method to define mock.On call +// - ctx context.Context +// - body azopenai.CompletionsOptions +// - options *azopenai.GetCompletionsOptions +func (_e *AzureOpenAIClient_Expecter) GetCompletions(ctx interface{}, body interface{}, options interface{}) *AzureOpenAIClient_GetCompletions_Call { + return &AzureOpenAIClient_GetCompletions_Call{Call: _e.mock.On("GetCompletions", ctx, body, options)} +} + +func (_c *AzureOpenAIClient_GetCompletions_Call) Run(run func(ctx context.Context, body azopenai.CompletionsOptions, options *azopenai.GetCompletionsOptions)) *AzureOpenAIClient_GetCompletions_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(azopenai.CompletionsOptions), args[2].(*azopenai.GetCompletionsOptions)) + }) + return _c +} + +func (_c *AzureOpenAIClient_GetCompletions_Call) Return(_a0 azopenai.GetCompletionsResponse, _a1 error) *AzureOpenAIClient_GetCompletions_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *AzureOpenAIClient_GetCompletions_Call) RunAndReturn(run func(context.Context, azopenai.CompletionsOptions, *azopenai.GetCompletionsOptions) (azopenai.GetCompletionsResponse, error)) *AzureOpenAIClient_GetCompletions_Call { + _c.Call.Return(run) + return _c +} + +// NewAzureOpenAIClient creates a new instance of AzureOpenAIClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewAzureOpenAIClient(t interface { + mock.TestingT + Cleanup(func()) +}) *AzureOpenAIClient { + mock := &AzureOpenAIClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} From 6bc9ea59cd4147184adda5aaccf988dac18d9523 Mon Sep 17 00:00:00 2001 From: Anas Muhammed Date: Sun, 4 Aug 2024 23:03:31 +0400 Subject: [PATCH 2/3] fix: linting --- ...nt_ azure_openai_text.go => content_azure_openai_text.go} | 5 +++-- internal/microsoft/plugin.go | 4 ++-- tools/docgen/main.go | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) rename internal/microsoft/{content_ azure_openai_text.go => content_azure_openai_text.go} (99%) diff --git a/internal/microsoft/content_ azure_openai_text.go b/internal/microsoft/content_azure_openai_text.go similarity index 99% rename from internal/microsoft/content_ azure_openai_text.go rename to internal/microsoft/content_azure_openai_text.go index 261c8bd4..d4ea4e03 100644 --- a/internal/microsoft/content_ azure_openai_text.go +++ b/internal/microsoft/content_azure_openai_text.go @@ -10,12 +10,13 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" "github.com/Masterminds/sprig/v3" + "github.com/hashicorp/hcl/v2" + "github.com/zclconf/go-cty/cty" + "github.com/blackstork-io/fabric/pkg/diagnostics" "github.com/blackstork-io/fabric/plugin" "github.com/blackstork-io/fabric/plugin/dataspec" "github.com/blackstork-io/fabric/plugin/dataspec/constraint" - "github.com/hashicorp/hcl/v2" - "github.com/zclconf/go-cty/cty" ) func makeAzureOpenAITextContentSchema(loader AzureOpenAIClientLoadFn) *plugin.ContentProvider { diff --git a/internal/microsoft/plugin.go b/internal/microsoft/plugin.go index 5e010a53..1c302e0e 100644 --- a/internal/microsoft/plugin.go +++ b/internal/microsoft/plugin.go @@ -4,10 +4,10 @@ import ( "context" "fmt" - "github.com/zclconf/go-cty/cty" - "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/zclconf/go-cty/cty" + "github.com/blackstork-io/fabric/internal/microsoft/client" "github.com/blackstork-io/fabric/plugin" ) diff --git a/tools/docgen/main.go b/tools/docgen/main.go index 4d4ce66d..adc34cdf 100644 --- a/tools/docgen/main.go +++ b/tools/docgen/main.go @@ -310,7 +310,7 @@ func main() { stixview.Plugin(version), nistnvd.Plugin(version, nil), snyk.Plugin(version, nil), - microsoft.Plugin(version, nil), + microsoft.Plugin(version, nil, nil), } // generate markdown for each plugin for _, p := range plugins { From 5b9be67e80a9ac98638177a279e879658e5e5937 Mon Sep 17 00:00:00 2001 From: Anas Muhammed Date: Mon, 5 Aug 2024 20:17:06 +0400 Subject: [PATCH 3/3] fix: rename client --- .mockery.yaml | 2 +- .../microsoft/content_azure_openai_text.go | 6 ++-- .../content_azure_openai_text_test.go | 6 ++-- internal/microsoft/plugin.go | 8 ++--- ...en_ai_client.go => azure_openai_client.go} | 34 +++++++++---------- 5 files changed, 28 insertions(+), 28 deletions(-) rename mocks/internalpkg/microsoft/{azure_open_ai_client.go => azure_openai_client.go} (66%) diff --git a/.mockery.yaml b/.mockery.yaml index a7ee190e..21ca887f 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -48,7 +48,7 @@ packages: github.com/blackstork-io/fabric/internal/microsoft: config: interfaces: - AzureOpenAIClient: + AzureOpenaiClient: github.com/blackstork-io/fabric/plugin/resolver: config: inpackage: true diff --git a/internal/microsoft/content_azure_openai_text.go b/internal/microsoft/content_azure_openai_text.go index d4ea4e03..313e251d 100644 --- a/internal/microsoft/content_azure_openai_text.go +++ b/internal/microsoft/content_azure_openai_text.go @@ -19,7 +19,7 @@ import ( "github.com/blackstork-io/fabric/plugin/dataspec/constraint" ) -func makeAzureOpenAITextContentSchema(loader AzureOpenAIClientLoadFn) *plugin.ContentProvider { +func makeAzureOpenAITextContentSchema(loader AzureOpenaiClientLoadFn) *plugin.ContentProvider { return &plugin.ContentProvider{ Config: dataspec.ObjectSpec{ &dataspec.AttrSpec{ @@ -75,7 +75,7 @@ func makeAzureOpenAITextContentSchema(loader AzureOpenAIClientLoadFn) *plugin.Co } } -func genOpenAIText(loader AzureOpenAIClientLoadFn) plugin.ProvideContentFunc { +func genOpenAIText(loader AzureOpenaiClientLoadFn) plugin.ProvideContentFunc { return func(ctx context.Context, params *plugin.ProvideContentParams) (*plugin.ContentResult, diagnostics.Diag) { apiKey := params.Config.GetAttr("api_key").AsString() resourceEndpoint := params.Config.GetAttr("resource_endpoint").AsString() @@ -103,7 +103,7 @@ func genOpenAIText(loader AzureOpenAIClientLoadFn) plugin.ProvideContentFunc { } } -func renderText(ctx context.Context, cli AzureOpenAIClient, cfg, args cty.Value, dataCtx plugin.MapData) (string, error) { +func renderText(ctx context.Context, cli AzureOpenaiClient, cfg, args cty.Value, dataCtx plugin.MapData) (string, error) { params := azopenai.CompletionsOptions{} params.DeploymentName = to.Ptr(cfg.GetAttr("deployment_name").AsString()) diff --git a/internal/microsoft/content_azure_openai_text_test.go b/internal/microsoft/content_azure_openai_text_test.go index 495667ff..3f9be01d 100644 --- a/internal/microsoft/content_azure_openai_text_test.go +++ b/internal/microsoft/content_azure_openai_text_test.go @@ -26,7 +26,7 @@ type AzureOpenAITextContentTestSuite struct { suite.Suite plugin *plugin.Schema schema *plugin.ContentProvider - cli *client_mocks.AzureOpenAIClient + cli *client_mocks.AzureOpenaiClient } func TestAzureOpenAITextContentSuite(t *testing.T) { @@ -34,14 +34,14 @@ func TestAzureOpenAITextContentSuite(t *testing.T) { } func (s *AzureOpenAITextContentTestSuite) SetupSuite() { - s.plugin = microsoft.Plugin("1.0.0", nil, (func(apiKey string, endPoint string) (cli microsoft.AzureOpenAIClient, err error) { + s.plugin = microsoft.Plugin("1.0.0", nil, (func(apiKey string, endPoint string) (cli microsoft.AzureOpenaiClient, err error) { return s.cli, nil })) s.schema = s.plugin.ContentProviders["azure_openai_text"] } func (s *AzureOpenAITextContentTestSuite) SetupTest() { - s.cli = &client_mocks.AzureOpenAIClient{} + s.cli = &client_mocks.AzureOpenaiClient{} } func (s *AzureOpenAITextContentTestSuite) TearDownTest() { diff --git a/internal/microsoft/plugin.go b/internal/microsoft/plugin.go index 1c302e0e..05765e8f 100644 --- a/internal/microsoft/plugin.go +++ b/internal/microsoft/plugin.go @@ -16,19 +16,19 @@ type ClientLoadFn func() client.Client var DefaultClientLoader ClientLoadFn = client.New -type AzureOpenAIClientLoadFn func(azureOpenAIKey string, azureOpenAIEndpoint string) (client AzureOpenAIClient, err error) +type AzureOpenaiClientLoadFn func(azureOpenAIKey string, azureOpenAIEndpoint string) (client AzureOpenaiClient, err error) -type AzureOpenAIClient interface { +type AzureOpenaiClient interface { GetCompletions(ctx context.Context, body azopenai.CompletionsOptions, options *azopenai.GetCompletionsOptions) (azopenai.GetCompletionsResponse, error) } -var DefaultAzureOpenAIClientLoader AzureOpenAIClientLoadFn = func(azureOpenAIKey string, azureOpenAIEndpoint string) (client AzureOpenAIClient, err error) { +var DefaultAzureOpenAIClientLoader AzureOpenaiClientLoadFn = func(azureOpenAIKey string, azureOpenAIEndpoint string) (client AzureOpenaiClient, err error) { keyCredential := azcore.NewKeyCredential(azureOpenAIKey) client, err = azopenai.NewClientWithKeyCredential(azureOpenAIEndpoint, keyCredential, nil) return } -func Plugin(version string, loader ClientLoadFn, openAiClientLoader AzureOpenAIClientLoadFn) *plugin.Schema { +func Plugin(version string, loader ClientLoadFn, openAiClientLoader AzureOpenaiClientLoadFn) *plugin.Schema { if loader == nil { loader = DefaultClientLoader } diff --git a/mocks/internalpkg/microsoft/azure_open_ai_client.go b/mocks/internalpkg/microsoft/azure_openai_client.go similarity index 66% rename from mocks/internalpkg/microsoft/azure_open_ai_client.go rename to mocks/internalpkg/microsoft/azure_openai_client.go index b3938f82..14711b31 100644 --- a/mocks/internalpkg/microsoft/azure_open_ai_client.go +++ b/mocks/internalpkg/microsoft/azure_openai_client.go @@ -10,21 +10,21 @@ import ( mock "github.com/stretchr/testify/mock" ) -// AzureOpenAIClient is an autogenerated mock type for the AzureOpenAIClient type -type AzureOpenAIClient struct { +// AzureOpenaiClient is an autogenerated mock type for the AzureOpenaiClient type +type AzureOpenaiClient struct { mock.Mock } -type AzureOpenAIClient_Expecter struct { +type AzureOpenaiClient_Expecter struct { mock *mock.Mock } -func (_m *AzureOpenAIClient) EXPECT() *AzureOpenAIClient_Expecter { - return &AzureOpenAIClient_Expecter{mock: &_m.Mock} +func (_m *AzureOpenaiClient) EXPECT() *AzureOpenaiClient_Expecter { + return &AzureOpenaiClient_Expecter{mock: &_m.Mock} } // GetCompletions provides a mock function with given fields: ctx, body, options -func (_m *AzureOpenAIClient) GetCompletions(ctx context.Context, body azopenai.CompletionsOptions, options *azopenai.GetCompletionsOptions) (azopenai.GetCompletionsResponse, error) { +func (_m *AzureOpenaiClient) GetCompletions(ctx context.Context, body azopenai.CompletionsOptions, options *azopenai.GetCompletionsOptions) (azopenai.GetCompletionsResponse, error) { ret := _m.Called(ctx, body, options) if len(ret) == 0 { @@ -51,8 +51,8 @@ func (_m *AzureOpenAIClient) GetCompletions(ctx context.Context, body azopenai.C return r0, r1 } -// AzureOpenAIClient_GetCompletions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCompletions' -type AzureOpenAIClient_GetCompletions_Call struct { +// AzureOpenaiClient_GetCompletions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCompletions' +type AzureOpenaiClient_GetCompletions_Call struct { *mock.Call } @@ -60,34 +60,34 @@ type AzureOpenAIClient_GetCompletions_Call struct { // - ctx context.Context // - body azopenai.CompletionsOptions // - options *azopenai.GetCompletionsOptions -func (_e *AzureOpenAIClient_Expecter) GetCompletions(ctx interface{}, body interface{}, options interface{}) *AzureOpenAIClient_GetCompletions_Call { - return &AzureOpenAIClient_GetCompletions_Call{Call: _e.mock.On("GetCompletions", ctx, body, options)} +func (_e *AzureOpenaiClient_Expecter) GetCompletions(ctx interface{}, body interface{}, options interface{}) *AzureOpenaiClient_GetCompletions_Call { + return &AzureOpenaiClient_GetCompletions_Call{Call: _e.mock.On("GetCompletions", ctx, body, options)} } -func (_c *AzureOpenAIClient_GetCompletions_Call) Run(run func(ctx context.Context, body azopenai.CompletionsOptions, options *azopenai.GetCompletionsOptions)) *AzureOpenAIClient_GetCompletions_Call { +func (_c *AzureOpenaiClient_GetCompletions_Call) Run(run func(ctx context.Context, body azopenai.CompletionsOptions, options *azopenai.GetCompletionsOptions)) *AzureOpenaiClient_GetCompletions_Call { _c.Call.Run(func(args mock.Arguments) { run(args[0].(context.Context), args[1].(azopenai.CompletionsOptions), args[2].(*azopenai.GetCompletionsOptions)) }) return _c } -func (_c *AzureOpenAIClient_GetCompletions_Call) Return(_a0 azopenai.GetCompletionsResponse, _a1 error) *AzureOpenAIClient_GetCompletions_Call { +func (_c *AzureOpenaiClient_GetCompletions_Call) Return(_a0 azopenai.GetCompletionsResponse, _a1 error) *AzureOpenaiClient_GetCompletions_Call { _c.Call.Return(_a0, _a1) return _c } -func (_c *AzureOpenAIClient_GetCompletions_Call) RunAndReturn(run func(context.Context, azopenai.CompletionsOptions, *azopenai.GetCompletionsOptions) (azopenai.GetCompletionsResponse, error)) *AzureOpenAIClient_GetCompletions_Call { +func (_c *AzureOpenaiClient_GetCompletions_Call) RunAndReturn(run func(context.Context, azopenai.CompletionsOptions, *azopenai.GetCompletionsOptions) (azopenai.GetCompletionsResponse, error)) *AzureOpenaiClient_GetCompletions_Call { _c.Call.Return(run) return _c } -// NewAzureOpenAIClient creates a new instance of AzureOpenAIClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// NewAzureOpenaiClient creates a new instance of AzureOpenaiClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. -func NewAzureOpenAIClient(t interface { +func NewAzureOpenaiClient(t interface { mock.TestingT Cleanup(func()) -}) *AzureOpenAIClient { - mock := &AzureOpenAIClient{} +}) *AzureOpenaiClient { + mock := &AzureOpenaiClient{} mock.Mock.Test(t) t.Cleanup(func() { mock.AssertExpectations(t) })