diff --git a/.mockery.yaml b/.mockery.yaml index 4079849e..a7ee190e 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -45,6 +45,10 @@ packages: config: interfaces: Client: + github.com/blackstork-io/fabric/internal/microsoft: + config: + interfaces: + AzureOpenAIClient: github.com/blackstork-io/fabric/plugin/resolver: config: inpackage: true diff --git a/examples/templates/azureopenai/example.fabric b/examples/templates/azureopenai/example.fabric new file mode 100644 index 00000000..48d131f6 --- /dev/null +++ b/examples/templates/azureopenai/example.fabric @@ -0,0 +1,38 @@ + +fabric { + plugin_versions = { + "blackstork/microsoft" = ">= 0.4 < 1.0 || 0.4.0-rev0" + } +} + +document "example" { + meta { + name = "example_document" + } + + title = "Document title" + + section { + title = "Section 2" + + section { + title = "Subsection 2" + + content text { + value = "Text value 4" + } + } + } + + content azure_openai_text { + config { + api_key = env.AZURE_OPENAI_KEY + resource_endpoint = env.AZURE_OPENAI_ENDPOINT + deployment_name = env.AZURE_OPENAI_DEPLOYMENT + api_version = "2024-02-01" + } + prompt = "How are you today?" + max_tokens = 10 + } +} + diff --git a/go.mod b/go.mod index dab48e32..4363aafd 100644 --- a/go.mod +++ b/go.mod @@ -59,6 +59,9 @@ replace github.com/evanphx/go-hclog-slog v0.0.0-20230905211129-6d31b63d6f09 => g require ( dario.cat/mergo v1.0.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.6.0 // indirect + github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 // indirect + github.com/Azure/azure-sdk-for-go/sdk/internal v1.7.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect github.com/Masterminds/goutils v1.1.1 // indirect github.com/Microsoft/go-winio v0.6.1 // indirect diff --git a/go.sum b/go.sum index f48a08cf..6760e11f 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,12 @@ github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9 github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8= github.com/Andrew-Morozko/go-hclog-slog v0.0.0-20240624145756-528e39db2968 h1:9lp9JWCEkcLqYRx8IZQmvZboh+5gcxJ4r9SfWQYjvbI= github.com/Andrew-Morozko/go-hclog-slog v0.0.0-20240624145756-528e39db2968/go.mod h1:30+1dTR5EdDQGmcjkgOj4i6iVD3wnI0XymSOTPMlUeA= +github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.6.0 h1:FQOmDxJj1If0D0khZR00MDa2Eb+k9BBsSaK7cEbLwkk= +github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.6.0/go.mod h1:X0+PSrHOZdTjkiEhgv53HS5gplbzVVl2jd6hQRYSS3c= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 h1:E+OJmp2tPvt1W+amx48v1eqbjDYsgN+RzP4q16yV5eM= +github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1/go.mod h1:a6xsAQUZg+VsS3TJ05SRp524Hs4pZ/AeFSr5ENf0Yjo= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.7.0 h1:rTfKOCZGy5ViVrlA74ZPE99a+SgoEE2K/yg3RyW9dFA= +github.com/Azure/azure-sdk-for-go/sdk/internal v1.7.0/go.mod h1:4OG6tQ9EOP/MT0NMjDlRzWoVFxfu9rN9B2X+tlSVktg= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI= diff --git a/internal/microsoft/cmd/main.go b/internal/microsoft/cmd/main.go index 10245c02..9fb41629 100644 --- a/internal/microsoft/cmd/main.go +++ b/internal/microsoft/cmd/main.go @@ -9,6 +9,6 @@ var version string func main() { pluginapiv1.Serve( - microsoft.Plugin(version, microsoft.DefaultClientLoader), + microsoft.Plugin(version, microsoft.DefaultClientLoader, microsoft.DefaultAzureOpenAIClientLoader), ) } diff --git a/internal/microsoft/content_ azure_openai_text.go b/internal/microsoft/content_ azure_openai_text.go new file mode 100644 index 00000000..261c8bd4 --- /dev/null +++ b/internal/microsoft/content_ azure_openai_text.go @@ -0,0 +1,153 @@ +package microsoft + +import ( + "bytes" + "context" + "fmt" + "strings" + "text/template" + + "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/Masterminds/sprig/v3" + "github.com/blackstork-io/fabric/pkg/diagnostics" + "github.com/blackstork-io/fabric/plugin" + "github.com/blackstork-io/fabric/plugin/dataspec" + "github.com/blackstork-io/fabric/plugin/dataspec/constraint" + "github.com/hashicorp/hcl/v2" + "github.com/zclconf/go-cty/cty" +) + +func makeAzureOpenAITextContentSchema(loader AzureOpenAIClientLoadFn) *plugin.ContentProvider { + return &plugin.ContentProvider{ + Config: dataspec.ObjectSpec{ + &dataspec.AttrSpec{ + Name: "api_key", + Type: cty.String, + Constraints: constraint.RequiredNonNull, + Secret: true, + }, + &dataspec.AttrSpec{ + Name: "resource_endpoint", + Type: cty.String, + Constraints: constraint.RequiredNonNull, + }, + &dataspec.AttrSpec{ + Name: "deployment_name", + Type: cty.String, + Constraints: constraint.RequiredNonNull, + }, + &dataspec.AttrSpec{ + Name: "api_version", + Type: cty.String, + DefaultVal: cty.StringVal("2024-02-01"), + }, + }, + Args: dataspec.ObjectSpec{ + &dataspec.AttrSpec{ + Name: "prompt", + Type: cty.String, + Constraints: constraint.RequiredNonNull, + ExampleVal: cty.StringVal("Summarize the following text: {{.vars.text_to_summarize}}"), + }, + &dataspec.AttrSpec{ + Name: "max_tokens", + Type: cty.Number, + DefaultVal: cty.NumberIntVal(1000), + }, + &dataspec.AttrSpec{ + Name: "temperature", + Type: cty.Number, + DefaultVal: cty.NumberFloatVal(0), + }, + &dataspec.AttrSpec{ + Name: "top_p", + Type: cty.Number, + }, + &dataspec.AttrSpec{ + Name: "completions_count", + Type: cty.Number, + DefaultVal: cty.NumberIntVal(1), + }, + }, + ContentFunc: genOpenAIText(loader), + } +} + +func genOpenAIText(loader AzureOpenAIClientLoadFn) plugin.ProvideContentFunc { + return func(ctx context.Context, params *plugin.ProvideContentParams) (*plugin.ContentResult, diagnostics.Diag) { + apiKey := params.Config.GetAttr("api_key").AsString() + resourceEndpoint := params.Config.GetAttr("resource_endpoint").AsString() + client, err := loader(apiKey, resourceEndpoint) + if err != nil { + return nil, diagnostics.Diag{{ + Severity: hcl.DiagError, + Summary: "Failed to create client", + Detail: err.Error(), + }} + } + result, err := renderText(ctx, client, params.Config, params.Args, params.DataContext) + if err != nil { + return nil, diagnostics.Diag{{ + Severity: hcl.DiagError, + Summary: "Failed to generate text", + Detail: err.Error(), + }} + } + return &plugin.ContentResult{ + Content: &plugin.ContentElement{ + Markdown: result, + }, + }, nil + } +} + +func renderText(ctx context.Context, cli AzureOpenAIClient, cfg, args cty.Value, dataCtx plugin.MapData) (string, error) { + + params := azopenai.CompletionsOptions{} + params.DeploymentName = to.Ptr(cfg.GetAttr("deployment_name").AsString()) + + maxTokens, _ := args.GetAttr("max_tokens").AsBigFloat().Int64() + params.MaxTokens = to.Ptr(int32(maxTokens)) + + temperature, _ := args.GetAttr("temperature").AsBigFloat().Float32() + params.Temperature = to.Ptr(temperature) + + completionsCount, _ := args.GetAttr("completions_count").AsBigFloat().Int64() + params.N = to.Ptr(int32(completionsCount)) + + topPAttr := args.GetAttr("top_p") + if !topPAttr.IsNull() { + topP, _ := topPAttr.AsBigFloat().Float32() + params.TopP = to.Ptr(topP) + } + + renderedPrompt, err := templateText(args.GetAttr("prompt").AsString(), dataCtx) + if err != nil { + return "", err + } + params.Prompt = []string{renderedPrompt} + // TODO: use api version from config + resp, err := cli.GetCompletions(ctx, params, nil) + if err != nil { + return "", err + } + if len(resp.Choices) == 0 { + return "", nil + } + return *resp.Choices[0].Text, nil +} + +func templateText(text string, dataCtx plugin.MapData) (string, error) { + tmpl, err := template.New("text").Funcs(sprig.FuncMap()).Parse(text) + if err != nil { + return "", fmt.Errorf("failed to parse template: %w", err) + } + + var buf bytes.Buffer + err = tmpl.Execute(&buf, dataCtx.Any()) + if err != nil { + return "", fmt.Errorf("failed to execute template: %w", err) + } + return strings.TrimSpace(buf.String()), nil +} diff --git a/internal/microsoft/content_azure_openai_text_test.go b/internal/microsoft/content_azure_openai_text_test.go new file mode 100644 index 00000000..495667ff --- /dev/null +++ b/internal/microsoft/content_azure_openai_text_test.go @@ -0,0 +1,227 @@ +package microsoft_test + +import ( + "context" + "errors" + "fmt" + "testing" + + "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" + "github.com/hashicorp/hcl/v2" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "github.com/zclconf/go-cty/cty" + + "github.com/blackstork-io/fabric/internal/microsoft" + client_mocks "github.com/blackstork-io/fabric/mocks/internalpkg/microsoft" + "github.com/blackstork-io/fabric/pkg/diagnostics" + "github.com/blackstork-io/fabric/pkg/diagnostics/diagtest" + "github.com/blackstork-io/fabric/plugin" + "github.com/blackstork-io/fabric/plugin/plugintest" + "github.com/blackstork-io/fabric/print/mdprint" +) + +type AzureOpenAITextContentTestSuite struct { + suite.Suite + plugin *plugin.Schema + schema *plugin.ContentProvider + cli *client_mocks.AzureOpenAIClient +} + +func TestAzureOpenAITextContentSuite(t *testing.T) { + suite.Run(t, &AzureOpenAITextContentTestSuite{}) +} + +func (s *AzureOpenAITextContentTestSuite) SetupSuite() { + s.plugin = microsoft.Plugin("1.0.0", nil, (func(apiKey string, endPoint string) (cli microsoft.AzureOpenAIClient, err error) { + return s.cli, nil + })) + s.schema = s.plugin.ContentProviders["azure_openai_text"] +} + +func (s *AzureOpenAITextContentTestSuite) SetupTest() { + s.cli = &client_mocks.AzureOpenAIClient{} +} + +func (s *AzureOpenAITextContentTestSuite) TearDownTest() { + s.cli.AssertExpectations(s.T()) +} + +func (s *AzureOpenAITextContentTestSuite) TestSchema() { + s.Require().NotNil(s.plugin) + s.Require().NotNil(s.schema) + s.NotNil(s.schema.Args) + s.NotNil(s.schema.ContentFunc) + s.NotNil(s.schema.Config) +} + +func (s *AzureOpenAITextContentTestSuite) TestBasic() { + s.cli.On("GetCompletions", mock.Anything, azopenai.CompletionsOptions{ + DeploymentName: to.Ptr("test"), + MaxTokens: to.Ptr(int32(1000)), + Temperature: to.Ptr(float32(0)), + Prompt: []string{"Tell me a story"}, + TopP: nil, + N: to.Ptr(int32(1)), + }, mock.Anything).Return(azopenai.GetCompletionsResponse{ + Completions: azopenai.Completions{ + Choices: []azopenai.Choice{ + {Text: to.Ptr("Once upon a time.")}, + }, + }, + }, nil) + ctx := context.Background() + dataCtx := plugin.MapData{} + result, diags := s.schema.ContentFunc(ctx, &plugin.ProvideContentParams{ + Args: cty.ObjectVal(map[string]cty.Value{ + "prompt": cty.StringVal("Tell me a story"), + "max_tokens": cty.NumberIntVal(1000), + "temperature": cty.NumberFloatVal(0), + "top_p": cty.NilVal, + "completions_count": cty.NumberIntVal(1), + }), + Config: cty.ObjectVal(map[string]cty.Value{ + "api_key": cty.StringVal("testtoken"), + "resource_endpoint": cty.StringVal("http://test"), + "deployment_name": cty.StringVal("test"), + "api_version": cty.StringVal("2024-02-01"), + }), + DataContext: dataCtx, + }) + fmt.Println(diags) + s.Nil(diags) + s.Equal("Once upon a time.", mdprint.PrintString(result.Content)) +} + +func (s *AzureOpenAITextContentTestSuite) TestAdvanced() { + s.cli.On("GetCompletions", mock.Anything, azopenai.CompletionsOptions{ + DeploymentName: to.Ptr("test"), + MaxTokens: to.Ptr(int32(1000)), + Temperature: to.Ptr(float32(0)), + Prompt: []string{"Tell me a story about BAR. {\"foo\":\"bar\"}"}, + TopP: nil, + N: to.Ptr(int32(1)), + }, mock.Anything).Return(azopenai.GetCompletionsResponse{ + Completions: azopenai.Completions{ + Choices: []azopenai.Choice{ + {Text: to.Ptr("Once upon a time.")}, + }, + }, + }, nil) + + ctx := context.Background() + dataCtx := plugin.MapData{ + "local": plugin.MapData{ + "foo": plugin.StringData("bar"), + }, + } + + result, diags := s.schema.ContentFunc(ctx, &plugin.ProvideContentParams{ + Args: cty.ObjectVal(map[string]cty.Value{ + "prompt": cty.StringVal("Tell me a story about {{.local.foo | upper}}. {{ .local | toRawJson }}"), + "max_tokens": cty.NumberIntVal(1000), + "temperature": cty.NumberFloatVal(0), + "top_p": cty.NilVal, + "completions_count": cty.NumberIntVal(1), + }), + Config: cty.ObjectVal(map[string]cty.Value{ + "api_key": cty.StringVal("testtoken"), + "resource_endpoint": cty.StringVal("http://test"), + "deployment_name": cty.StringVal("test"), + "api_version": cty.StringVal("2024-02-01"), + }), + DataContext: dataCtx, + }) + s.Empty(diags) + s.Equal("Once upon a time.", mdprint.PrintString(result.Content)) +} + +func (s *AzureOpenAITextContentTestSuite) TestMissingPrompt() { + plugintest.DecodeAndAssert(s.T(), s.schema.Args, "", plugin.MapData{}, diagtest.Asserts{ + { + diagtest.IsError, + diagtest.SummaryEquals("Missing required argument"), + diagtest.DetailContains("prompt"), + }, + { + diagtest.IsError, + diagtest.SummaryEquals("Argument value must be non-null"), + diagtest.DetailContains("prompt"), + }, + }) +} + +func (s *AzureOpenAITextContentTestSuite) TestMissingAPIKey() { + plugintest.DecodeAndAssert(s.T(), s.schema.Args, ` + prompt = "Tell me a story" + `, plugin.MapData{}, diagtest.Asserts{}) + plugintest.DecodeAndAssert(s.T(), s.schema.Config, ` + `, plugin.MapData{}, diagtest.Asserts{ + { + diagtest.IsError, + diagtest.SummaryEquals("Missing required argument"), + diagtest.DetailContains("api_key"), + }, + { + diagtest.IsError, + diagtest.SummaryEquals("Argument value must be non-null"), + diagtest.DetailContains("api_key"), + }, + { + diagtest.IsError, + diagtest.SummaryEquals("Missing required argument"), + diagtest.DetailContains("resource_endpoint"), + }, + { + diagtest.IsError, + diagtest.SummaryEquals("Argument value must be non-null"), + diagtest.DetailContains("resource_endpoint"), + }, + { + diagtest.IsError, + diagtest.SummaryEquals("Missing required argument"), + diagtest.DetailContains("deployment_name"), + }, + { + diagtest.IsError, + diagtest.SummaryEquals("Argument value must be non-null"), + diagtest.DetailContains("deployment_name"), + }, + }) +} + +func (s *AzureOpenAITextContentTestSuite) TestFailingClient() { + s.cli.On("GetCompletions", mock.Anything, azopenai.CompletionsOptions{ + DeploymentName: to.Ptr("test"), + MaxTokens: to.Ptr(int32(1000)), + Temperature: to.Ptr(float32(0)), + Prompt: []string{"Tell me a story"}, + TopP: nil, + N: to.Ptr(int32(1)), + }, mock.Anything).Return(azopenai.GetCompletionsResponse{}, errors.New("failed to generate text from model")) + ctx := context.Background() + dataCtx := plugin.MapData{} + result, diags := s.schema.ContentFunc(ctx, &plugin.ProvideContentParams{ + Args: cty.ObjectVal(map[string]cty.Value{ + "prompt": cty.StringVal("Tell me a story"), + "max_tokens": cty.NumberIntVal(1000), + "temperature": cty.NumberFloatVal(0), + "top_p": cty.NilVal, + "completions_count": cty.NumberIntVal(1), + }), + Config: cty.ObjectVal(map[string]cty.Value{ + "api_key": cty.StringVal("testtoken"), + "resource_endpoint": cty.StringVal("http://test"), + "deployment_name": cty.StringVal("test"), + "api_version": cty.StringVal("2024-02-01"), + }), + DataContext: dataCtx, + }) + s.Nil(result) + s.Equal(diagnostics.Diag{{ + Severity: hcl.DiagError, + Summary: "Failed to generate text", + Detail: "failed to generate text from model", + }}, diags) +} diff --git a/internal/microsoft/plugin.go b/internal/microsoft/plugin.go index 053405fa..5e010a53 100644 --- a/internal/microsoft/plugin.go +++ b/internal/microsoft/plugin.go @@ -6,6 +6,8 @@ import ( "github.com/zclconf/go-cty/cty" + "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" + "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/blackstork-io/fabric/internal/microsoft/client" "github.com/blackstork-io/fabric/plugin" ) @@ -14,7 +16,19 @@ type ClientLoadFn func() client.Client var DefaultClientLoader ClientLoadFn = client.New -func Plugin(version string, loader ClientLoadFn) *plugin.Schema { +type AzureOpenAIClientLoadFn func(azureOpenAIKey string, azureOpenAIEndpoint string) (client AzureOpenAIClient, err error) + +type AzureOpenAIClient interface { + GetCompletions(ctx context.Context, body azopenai.CompletionsOptions, options *azopenai.GetCompletionsOptions) (azopenai.GetCompletionsResponse, error) +} + +var DefaultAzureOpenAIClientLoader AzureOpenAIClientLoadFn = func(azureOpenAIKey string, azureOpenAIEndpoint string) (client AzureOpenAIClient, err error) { + keyCredential := azcore.NewKeyCredential(azureOpenAIKey) + client, err = azopenai.NewClientWithKeyCredential(azureOpenAIEndpoint, keyCredential, nil) + return +} + +func Plugin(version string, loader ClientLoadFn, openAiClientLoader AzureOpenAIClientLoadFn) *plugin.Schema { if loader == nil { loader = DefaultClientLoader } @@ -25,6 +39,9 @@ func Plugin(version string, loader ClientLoadFn) *plugin.Schema { DataSources: plugin.DataSources{ "microsoft_sentinel_incidents": makeMicrosoftSentinelIncidentsDataSource(loader), }, + ContentProviders: plugin.ContentProviders{ + "azure_openai_text": makeAzureOpenAITextContentSchema(openAiClientLoader), + }, } } diff --git a/internal/microsoft/plugin_test.go b/internal/microsoft/plugin_test.go index 3e44160a..c9aaa301 100644 --- a/internal/microsoft/plugin_test.go +++ b/internal/microsoft/plugin_test.go @@ -7,8 +7,9 @@ import ( ) func TestPlugin_Schema(t *testing.T) { - schema := Plugin("1.2.3", nil) + schema := Plugin("1.2.3", nil, nil) assert.Equal(t, "blackstork/microsoft", schema.Name) assert.Equal(t, "1.2.3", schema.Version) assert.NotNil(t, schema.DataSources["microsoft_sentinel_incidents"]) + assert.NotNil(t, schema.ContentProviders["azure_openai_text"]) } diff --git a/internal/plugin_validity_test.go b/internal/plugin_validity_test.go index 707da142..1680ff64 100644 --- a/internal/plugin_validity_test.go +++ b/internal/plugin_validity_test.go @@ -46,7 +46,7 @@ func TestAllPluginSchemaValidity(t *testing.T) { splunk.Plugin(ver, nil), nistnvd.Plugin(ver, nil), snyk.Plugin(ver, nil), - microsoft.Plugin(ver, nil), + microsoft.Plugin(ver, nil, nil), } for _, p := range plugins { p := p diff --git a/mocks/internalpkg/microsoft/azure_open_ai_client.go b/mocks/internalpkg/microsoft/azure_open_ai_client.go new file mode 100644 index 00000000..b3938f82 --- /dev/null +++ b/mocks/internalpkg/microsoft/azure_open_ai_client.go @@ -0,0 +1,96 @@ +// Code generated by mockery v2.42.1. DO NOT EDIT. + +package microsoft_mocks + +import ( + context "context" + + azopenai "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" + + mock "github.com/stretchr/testify/mock" +) + +// AzureOpenAIClient is an autogenerated mock type for the AzureOpenAIClient type +type AzureOpenAIClient struct { + mock.Mock +} + +type AzureOpenAIClient_Expecter struct { + mock *mock.Mock +} + +func (_m *AzureOpenAIClient) EXPECT() *AzureOpenAIClient_Expecter { + return &AzureOpenAIClient_Expecter{mock: &_m.Mock} +} + +// GetCompletions provides a mock function with given fields: ctx, body, options +func (_m *AzureOpenAIClient) GetCompletions(ctx context.Context, body azopenai.CompletionsOptions, options *azopenai.GetCompletionsOptions) (azopenai.GetCompletionsResponse, error) { + ret := _m.Called(ctx, body, options) + + if len(ret) == 0 { + panic("no return value specified for GetCompletions") + } + + var r0 azopenai.GetCompletionsResponse + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, azopenai.CompletionsOptions, *azopenai.GetCompletionsOptions) (azopenai.GetCompletionsResponse, error)); ok { + return rf(ctx, body, options) + } + if rf, ok := ret.Get(0).(func(context.Context, azopenai.CompletionsOptions, *azopenai.GetCompletionsOptions) azopenai.GetCompletionsResponse); ok { + r0 = rf(ctx, body, options) + } else { + r0 = ret.Get(0).(azopenai.GetCompletionsResponse) + } + + if rf, ok := ret.Get(1).(func(context.Context, azopenai.CompletionsOptions, *azopenai.GetCompletionsOptions) error); ok { + r1 = rf(ctx, body, options) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// AzureOpenAIClient_GetCompletions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'GetCompletions' +type AzureOpenAIClient_GetCompletions_Call struct { + *mock.Call +} + +// GetCompletions is a helper method to define mock.On call +// - ctx context.Context +// - body azopenai.CompletionsOptions +// - options *azopenai.GetCompletionsOptions +func (_e *AzureOpenAIClient_Expecter) GetCompletions(ctx interface{}, body interface{}, options interface{}) *AzureOpenAIClient_GetCompletions_Call { + return &AzureOpenAIClient_GetCompletions_Call{Call: _e.mock.On("GetCompletions", ctx, body, options)} +} + +func (_c *AzureOpenAIClient_GetCompletions_Call) Run(run func(ctx context.Context, body azopenai.CompletionsOptions, options *azopenai.GetCompletionsOptions)) *AzureOpenAIClient_GetCompletions_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(azopenai.CompletionsOptions), args[2].(*azopenai.GetCompletionsOptions)) + }) + return _c +} + +func (_c *AzureOpenAIClient_GetCompletions_Call) Return(_a0 azopenai.GetCompletionsResponse, _a1 error) *AzureOpenAIClient_GetCompletions_Call { + _c.Call.Return(_a0, _a1) + return _c +} + +func (_c *AzureOpenAIClient_GetCompletions_Call) RunAndReturn(run func(context.Context, azopenai.CompletionsOptions, *azopenai.GetCompletionsOptions) (azopenai.GetCompletionsResponse, error)) *AzureOpenAIClient_GetCompletions_Call { + _c.Call.Return(run) + return _c +} + +// NewAzureOpenAIClient creates a new instance of AzureOpenAIClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewAzureOpenAIClient(t interface { + mock.TestingT + Cleanup(func()) +}) *AzureOpenAIClient { + mock := &AzureOpenAIClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +}