From cf3c0d1c1ecb0b452fbae2973df4dd963ca409b4 Mon Sep 17 00:00:00 2001 From: Anas Muhammed <102966891+anasmuhmd@users.noreply.github.com> Date: Tue, 24 Sep 2024 17:42:23 +0400 Subject: [PATCH] add microsoft graph data source (#242) --- .mockery.yaml | 1 + .../microsoft/data-sources/microsoft_graph.md | 105 ++++++++++++ docs/plugins/plugins.json | 17 ++ .../microsoft/graph_data_source.fabric | 44 +++++ go.mod | 3 + internal/microsoft/client/client.go | 20 +++ internal/microsoft/client/graph_client.go | 57 +++++++ internal/microsoft/cmd/main.go | 3 +- .../content_azure_openai_text_test.go | 2 +- internal/microsoft/data_microsoft_graph.go | 124 ++++++++++++++ .../microsoft/data_microsoft_graph_test.go | 155 ++++++++++++++++++ internal/microsoft/plugin.go | 74 ++++++++- internal/microsoft/plugin_test.go | 91 +++++++++- internal/plugin_validity_test.go | 2 +- .../microsoft/microsoft_graph_client.go | 98 +++++++++++ tools/docgen/main.go | 2 +- 16 files changed, 792 insertions(+), 6 deletions(-) create mode 100644 docs/plugins/microsoft/data-sources/microsoft_graph.md create mode 100644 examples/templates/microsoft/graph_data_source.fabric create mode 100644 internal/microsoft/client/graph_client.go create mode 100644 internal/microsoft/data_microsoft_graph.go create mode 100644 internal/microsoft/data_microsoft_graph_test.go create mode 100644 mocks/internalpkg/microsoft/microsoft_graph_client.go diff --git a/.mockery.yaml b/.mockery.yaml index 30cac1b8..3f8064c1 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -51,6 +51,7 @@ packages: config: interfaces: AzureOpenaiClient: + MicrosoftGraphClient: github.com/blackstork-io/fabric/plugin/resolver: config: inpackage: true diff --git a/docs/plugins/microsoft/data-sources/microsoft_graph.md b/docs/plugins/microsoft/data-sources/microsoft_graph.md new file mode 100644 index 00000000..efd72c0d --- /dev/null +++ b/docs/plugins/microsoft/data-sources/microsoft_graph.md @@ -0,0 +1,105 @@ +--- +title: "`microsoft_graph` data source" +plugin: + name: blackstork/microsoft + description: "The `microsoft_graph` data source queries Microsoft Graph" + tags: [] + version: "v0.4.2" + source_github: "https://github.com/blackstork-io/fabric/tree/main/internal/microsoft/" +resource: + type: data-source +type: docs +--- + +{{< breadcrumbs 2 >}} + +{{< plugin-resource-header "blackstork/microsoft" "microsoft" "v0.4.2" "microsoft_graph" "data source" >}} + +## Description +The `microsoft_graph` data source queries Microsoft Graph. + +## Installation + +To use `microsoft_graph` data source, you must install the plugin `blackstork/microsoft`. + +To install the plugin, add the full plugin name to the `plugin_versions` map in the Fabric global configuration block (see [Global configuration]({{< ref "configs.md#global-configuration" >}}) for more details), as shown below: + +```hcl +fabric { + plugin_versions = { + "blackstork/microsoft" = ">= v0.4.2" + } +} +``` + +Note the version constraint set for the plugin. + +## Configuration + +The data source supports the following configuration arguments: + +```hcl +config data microsoft_graph { + # The Azure client ID + # + # Required string. + # For example: + client_id = "some string" + + # The Azure client secret. Required if private_key_file/privat_key/cert_thumbprint is not provided. + # + # Optional string. + # Default value: + client_secret = null + + # The Azure tenant ID + # + # Required string. + # For example: + tenant_id = "some string" + + # The path to the private key file. Ignored if private_key/client_secret is provided. + # + # Optional string. + # Default value: + private_key_file = null + + # The private key contents. Ignored if client_secret is provided. + # + # Optional string. + # Default value: + private_key = null + + # The key passphrase. Ignored if client_secret is provided. + # + # Optional string. + # Default value: + key_passphrase = null +} +``` + +## Usage + +The data source supports the following execution arguments: + +```hcl +data microsoft_graph { + # The API version + # + # Optional string. + # Default value: + api_version = "beta" + + # The endpoint to query + # + # Required string. + # For example: + endpoint = "/security/incidents" + + # The query parameters + # + # Optional map of string. + # Default value: + query_params = null +} +``` \ No newline at end of file diff --git a/docs/plugins/plugins.json b/docs/plugins/plugins.json index e0193dd6..c55d36e6 100644 --- a/docs/plugins/plugins.json +++ b/docs/plugins/plugins.json @@ -333,6 +333,23 @@ "top_p" ] }, + { + "name": "microsoft_graph", + "type": "data-source", + "config_params": [ + "client_id", + "client_secret", + "key_passphrase", + "private_key", + "private_key_file", + "tenant_id" + ], + "arguments": [ + "api_version", + "endpoint", + "query_params" + ] + }, { "name": "microsoft_sentinel_incidents", "type": "data-source", diff --git a/examples/templates/microsoft/graph_data_source.fabric b/examples/templates/microsoft/graph_data_source.fabric new file mode 100644 index 00000000..2d45e9e4 --- /dev/null +++ b/examples/templates/microsoft/graph_data_source.fabric @@ -0,0 +1,44 @@ +fabric { + plugin_versions = { + "blackstork/microsoft" = ">= 0.4 < 1.0 || 0.4.0-rev0" + } +} + +document "example" { + meta { + name = "example_document" + } + + data microsoft_graph "mygraph" { + config { + client_id = "" + client_secret = "" + tenant_id = "" + # private_key_file = "" + } + api_version = "v1.0" + endpoint = "/security/incidents" + query_params = { + "$top" = "10" + } + } + + title = "List of Security Incidents" + + content table { + rows = query_jq(".data.microsoft_graph.mygraph.value") + columns = [ + { + "header" = "Severity" + "value" = "{{.row.value.severity}}" + }, + { + "header" = "Display Name" + "value" = "{{.row.value.displayName}}" + } + ] + } + + +} + diff --git a/go.mod b/go.mod index 1189a642..6b28ed09 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.22 require ( github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai v0.6.0 github.com/Azure/azure-sdk-for-go/sdk/azcore v1.11.1 + github.com/AzureAD/microsoft-authentication-library-for-go v1.2.2 github.com/Masterminds/semver/v3 v3.2.1 github.com/Masterminds/sprig/v3 v3.2.3 github.com/TylerBrock/colorjson v0.0.0-20200706003622-8a50f05110d2 @@ -89,6 +90,7 @@ require ( github.com/go-ole/go-ole v1.2.6 // indirect github.com/go-swiss/fonts v0.0.0-20221219152310-0b267088f53d // indirect github.com/gogo/protobuf v1.3.2 // indirect + github.com/golang-jwt/jwt/v5 v5.2.1 // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/google/uuid v1.6.0 // indirect @@ -102,6 +104,7 @@ require ( github.com/jellydator/ttlcache/v3 v3.1.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.17.8 // indirect + github.com/kylelemons/godebug v1.1.0 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/magiconair/properties v1.8.7 // indirect github.com/mattn/go-isatty v0.0.20 // indirect diff --git a/internal/microsoft/client/client.go b/internal/microsoft/client/client.go index a56d6f1e..d8f2de83 100644 --- a/internal/microsoft/client/client.go +++ b/internal/microsoft/client/client.go @@ -9,6 +9,7 @@ import ( "strings" "time" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" "github.com/google/go-querystring/query" ) @@ -18,6 +19,8 @@ const ( version = "2023-11-01" ) +var scopes = []string{"https://graph.microsoft.com/.default"} + func String(s string) *string { return &s } @@ -113,6 +116,23 @@ func (c *client) GetClientCredentialsToken(ctx context.Context, req *GetClientCr return &data, nil } +func AcquireToken(ctx context.Context, tenantId string, clientId string, cred confidential.Credential) (accessToken string, err error) { + confidentialClient, err := confidential.New(authURL+"/"+tenantId, clientId, cred) + if err != nil { + return + } + result, err := confidentialClient.AcquireTokenSilent(ctx, scopes) + if err != nil { + // cache miss, authenticate with another AcquireToken... method + result, err = confidentialClient.AcquireTokenByCredential(ctx, scopes) + if err != nil { + return + } + } + accessToken = result.AccessToken + return +} + func (c *client) ListIncidents(ctx context.Context, req *ListIncidentsReq) (*ListIncidentsRes, error) { format := "/subscriptions/%s/resourceGroups/%s/providers/Microsoft.OperationalInsights/workspaces/%s/providers/Microsoft.SecurityInsights/incidents" u, err := url.Parse(c.url + fmt.Sprintf(format, req.SubscriptionID, req.ResourceGroupName, req.WorkspaceName)) diff --git a/internal/microsoft/client/graph_client.go b/internal/microsoft/client/graph_client.go new file mode 100644 index 00000000..3bf3a0f9 --- /dev/null +++ b/internal/microsoft/client/graph_client.go @@ -0,0 +1,57 @@ +package client + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" +) + +const graphUrl = "https://graph.microsoft.com" + +type graphClient struct { + accessToken string + apiVersion string + client *http.Client +} + +func NewGraphClient(accessToken string, apiVersion string) *graphClient { + return &graphClient{ + accessToken: accessToken, + apiVersion: apiVersion, + client: &http.Client{}, + } +} + +func (cli *graphClient) prepare(r *http.Request) { + r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", cli.accessToken)) +} + +func (cli *graphClient) QueryGraph(ctx context.Context, endpoint string, queryParams url.Values) (result interface{}, err error) { + requestUrl, err := url.Parse(graphUrl + fmt.Sprintf("/%s%s", cli.apiVersion, endpoint)) + if err != nil { + return + } + if queryParams != nil { + requestUrl.RawQuery = queryParams.Encode() + } + r, err := http.NewRequestWithContext(ctx, http.MethodGet, requestUrl.String(), nil) + if err != nil { + return + } + cli.prepare(r) + res, err := cli.client.Do(r) + if err != nil { + return + } + if res.StatusCode != http.StatusOK { + err = fmt.Errorf("microsoft graph client returned status code: %d", res.StatusCode) + return + } + defer res.Body.Close() + if err := json.NewDecoder(res.Body).Decode(&result); err != nil { + return nil, err + } + return +} diff --git a/internal/microsoft/cmd/main.go b/internal/microsoft/cmd/main.go index 9fb41629..e9525183 100644 --- a/internal/microsoft/cmd/main.go +++ b/internal/microsoft/cmd/main.go @@ -2,6 +2,7 @@ package main import ( "github.com/blackstork-io/fabric/internal/microsoft" + "github.com/blackstork-io/fabric/internal/microsoft/client" pluginapiv1 "github.com/blackstork-io/fabric/plugin/pluginapi/v1" ) @@ -9,6 +10,6 @@ var version string func main() { pluginapiv1.Serve( - microsoft.Plugin(version, microsoft.DefaultClientLoader, microsoft.DefaultAzureOpenAIClientLoader), + microsoft.Plugin(version, microsoft.DefaultClientLoader, microsoft.DefaultAzureOpenAIClientLoader, microsoft.MakeDefaultMicrosoftGraphClientLoader(client.AcquireToken)), ) } diff --git a/internal/microsoft/content_azure_openai_text_test.go b/internal/microsoft/content_azure_openai_text_test.go index cf3e35ca..dc3eb6a5 100644 --- a/internal/microsoft/content_azure_openai_text_test.go +++ b/internal/microsoft/content_azure_openai_text_test.go @@ -37,7 +37,7 @@ func TestAzureOpenAITextContentSuite(t *testing.T) { func (s *AzureOpenAITextContentTestSuite) SetupSuite() { s.plugin = microsoft.Plugin("1.0.0", nil, (func(apiKey string, endPoint string) (cli microsoft.AzureOpenaiClient, err error) { return s.cli, nil - })) + }), nil) s.schema = s.plugin.ContentProviders["azure_openai_text"] } diff --git a/internal/microsoft/data_microsoft_graph.go b/internal/microsoft/data_microsoft_graph.go new file mode 100644 index 00000000..28435614 --- /dev/null +++ b/internal/microsoft/data_microsoft_graph.go @@ -0,0 +1,124 @@ +package microsoft + +import ( + "context" + "net/url" + + "github.com/hashicorp/hcl/v2" + "github.com/zclconf/go-cty/cty" + + "github.com/blackstork-io/fabric/pkg/diagnostics" + "github.com/blackstork-io/fabric/plugin" + "github.com/blackstork-io/fabric/plugin/dataspec" + "github.com/blackstork-io/fabric/plugin/dataspec/constraint" + "github.com/blackstork-io/fabric/plugin/plugindata" +) + +func makeMicrosoftGraphDataSource(loader MicrosoftGraphClientLoadFn) *plugin.DataSource { + return &plugin.DataSource{ + Doc: "The `microsoft_graph` data source queries Microsoft Graph.", + DataFunc: fetchMicrosoftGraph(loader), + Config: &dataspec.RootSpec{ + Attrs: []*dataspec.AttrSpec{ + { + Doc: "The Azure client ID", + Name: "client_id", + Type: cty.String, + Constraints: constraint.RequiredNonNull, + }, + { + Doc: "The Azure client secret. Required if private_key_file/privat_key/cert_thumbprint is not provided.", + Name: "client_secret", + Type: cty.String, + Secret: true, + }, + { + Doc: "The Azure tenant ID", + Name: "tenant_id", + Type: cty.String, + Constraints: constraint.RequiredNonNull, + }, + { + Doc: "The path to the private key file. Ignored if private_key/client_secret is provided.", + Name: "private_key_file", + Type: cty.String, + }, + { + Doc: "The private key contents. Ignored if client_secret is provided.", + Name: "private_key", + Type: cty.String, + }, + { + Doc: "The key passphrase. Ignored if client_secret is provided.", + Name: "key_passphrase", + Type: cty.String, + }, + }, + }, + Args: &dataspec.RootSpec{ + Attrs: []*dataspec.AttrSpec{ + { + Doc: "The API version", + Name: "api_version", + Type: cty.String, + DefaultVal: cty.StringVal("beta"), + }, + { + Doc: "The endpoint to query", + Name: "endpoint", + Type: cty.String, + Constraints: constraint.RequiredNonNull, + ExampleVal: cty.StringVal("/security/incidents"), + }, + { + Doc: "The query parameters", + Name: "query_params", + Type: cty.Map(cty.String), + }, + }, + }, + } +} + +func fetchMicrosoftGraph(loader MicrosoftGraphClientLoadFn) plugin.RetrieveDataFunc { + return func(ctx context.Context, params *plugin.RetrieveDataParams) (plugindata.Data, diagnostics.Diag) { + apiVersion := params.Args.GetAttrVal("api_version").AsString() + cli, err := loader(ctx, apiVersion, params.Config) + if err != nil { + return nil, diagnostics.Diag{{ + Severity: hcl.DiagError, + Summary: "Unable to create microsoft graph client", + Detail: err.Error(), + }} + } + endPoint := params.Args.GetAttrVal("endpoint").AsString() + queryParamsAttr := params.Args.GetAttrVal("query_params") + var queryParams url.Values + + if !queryParamsAttr.IsNull() { + queryParams = url.Values{} + queryMap := queryParamsAttr.AsValueMap() + for k, v := range queryMap { + queryParams.Add(k, v.AsString()) + } + } + + response, err := cli.QueryGraph(ctx, endPoint, queryParams) + if err != nil { + return nil, diagnostics.Diag{{ + Severity: hcl.DiagError, + Summary: "Failed to query microsoft graph", + Detail: err.Error(), + }} + } + data, err := plugindata.ParseAny(response) + if err != nil { + return nil, diagnostics.Diag{{ + Severity: hcl.DiagError, + Summary: "Failed to parse response", + Detail: err.Error(), + }} + } + return data, nil + } +} diff --git a/internal/microsoft/data_microsoft_graph_test.go b/internal/microsoft/data_microsoft_graph_test.go new file mode 100644 index 00000000..3d5c61c6 --- /dev/null +++ b/internal/microsoft/data_microsoft_graph_test.go @@ -0,0 +1,155 @@ +package microsoft_test + +import ( + "context" + "errors" + "net/url" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "github.com/zclconf/go-cty/cty" + + "github.com/blackstork-io/fabric/internal/microsoft" + client_mocks "github.com/blackstork-io/fabric/mocks/internalpkg/microsoft" + "github.com/blackstork-io/fabric/pkg/diagnostics/diagtest" + "github.com/blackstork-io/fabric/plugin" + "github.com/blackstork-io/fabric/plugin/dataspec" + "github.com/blackstork-io/fabric/plugin/plugindata" + "github.com/blackstork-io/fabric/plugin/plugintest" +) + +type MicrosoftGraphDataSourceTestSuite struct { + suite.Suite + plugin *plugin.Schema + schema *plugin.DataSource + cli *client_mocks.MicrosoftGraphClient +} + +func TestMicrosoftGraphDataSourceTestSuite(t *testing.T) { + suite.Run(t, &MicrosoftGraphDataSourceTestSuite{}) +} + +func (s *MicrosoftGraphDataSourceTestSuite) SetupSuite() { + s.plugin = microsoft.Plugin("1.0.0", nil, nil, (func(ctx context.Context, apiVersion string, cfg *dataspec.Block) (client microsoft.MicrosoftGraphClient, err error) { + return s.cli, nil + })) + s.schema = s.plugin.DataSources["microsoft_graph"] +} + +func (s *MicrosoftGraphDataSourceTestSuite) SetupTest() { + s.cli = &client_mocks.MicrosoftGraphClient{} +} + +func (s *MicrosoftGraphDataSourceTestSuite) TearDownTest() { + s.cli.AssertExpectations(s.T()) +} + +func (s *MicrosoftGraphDataSourceTestSuite) TestSchema() { + s.Require().NotNil(s.plugin) + s.Require().NotNil(s.schema) + s.NotNil(s.schema.Args) + s.NotNil(s.schema.DataFunc) + s.NotNil(s.schema.Config) +} + +func (s *MicrosoftGraphDataSourceTestSuite) TestBasic() { + expectedData := map[string]interface{}{ + "value": []interface{}{map[string]interface{}{ + "severity": "High", + "displayName": "Incident 1", + }}, + } + s.cli.On("QueryGraph", mock.Anything, "/security/incidents", url.Values{"$top": []string{"10"}}).Return(expectedData, nil) + ctx := context.Background() + result, diags := s.schema.DataFunc(ctx, &plugin.RetrieveDataParams{ + Config: plugintest.NewTestDecoder(s.T(), s.schema.Config). + SetAttr("client_id", cty.StringVal("cid")). + SetAttr("tenant_id", cty.StringVal("tid")). + SetAttr("client_secret", cty.StringVal("csecret")). + Decode(), + Args: plugintest.NewTestDecoder(s.T(), s.schema.Args). + SetAttr("endpoint", cty.StringVal("/security/incidents")). + SetAttr("api_version", cty.StringVal("v1")). + SetAttr("query_params", cty.MapVal(map[string]cty.Value{"$top": cty.StringVal("10")})). + Decode(), + }) + s.Nil(diags) + parsedData, err := plugindata.ParseAny(expectedData) + s.Nil(err) + s.Equal(parsedData, result.AsPluginData()) +} + +func (s *MicrosoftGraphDataSourceTestSuite) TestClientError() { + s.cli.On("QueryGraph", mock.Anything, "/security/incidents", url.Values{"$top": []string{"10"}}).Return(nil, errors.New("microsoft graph client returned status code: 400")) + ctx := context.Background() + result, diags := s.schema.DataFunc(ctx, &plugin.RetrieveDataParams{ + Config: plugintest.NewTestDecoder(s.T(), s.schema.Config). + SetAttr("client_id", cty.StringVal("cid")). + SetAttr("tenant_id", cty.StringVal("tid")). + SetAttr("client_secret", cty.StringVal("csecret")). + Decode(), + Args: plugintest.NewTestDecoder(s.T(), s.schema.Args). + SetAttr("endpoint", cty.StringVal("/security/incidents")). + SetAttr("api_version", cty.StringVal("v1")). + SetAttr("query_params", cty.MapVal(map[string]cty.Value{"$top": cty.StringVal("10")})). + Decode(), + }) + s.Nil(result) + diagtest.Asserts{{ + diagtest.IsError, + diagtest.DetailContains("microsoft graph client returned status code: 400"), + }}.AssertMatch(s.T(), diags, nil) +} + +func (s *MicrosoftGraphDataSourceTestSuite) TestMissingArgs() { + plugintest.NewTestDecoder( + s.T(), + s.schema.Args, + ).Decode([]diagtest.Assert{ + diagtest.IsError, + diagtest.SummaryEquals("Missing required attribute"), + diagtest.DetailContains("endpoint"), + }) +} + +func (s *MicrosoftGraphDataSourceTestSuite) TestMissingConfig() { + plugintest.NewTestDecoder( + s.T(), + s.schema.Config, + ).Decode([]diagtest.Assert{ + diagtest.IsError, + diagtest.SummaryEquals("Missing required attribute"), + diagtest.DetailContains("client_id"), + }, []diagtest.Assert{ + diagtest.IsError, + diagtest.SummaryEquals("Missing required attribute"), + diagtest.DetailContains("tenant_id"), + }) +} + +func (s *MicrosoftGraphDataSourceTestSuite) TestMissingCredentials() { + expectedData := map[string]interface{}{ + "value": []interface{}{map[string]interface{}{ + "severity": "High", + "displayName": "Incident 1", + }}, + } + s.cli.On("QueryGraph", mock.Anything, "/security/incidents", url.Values{"$top": []string{"10"}}).Return(expectedData, nil) + ctx := context.Background() + result, diags := s.schema.DataFunc(ctx, &plugin.RetrieveDataParams{ + Config: plugintest.NewTestDecoder(s.T(), s.schema.Config). + SetAttr("client_id", cty.StringVal("cid")). + SetAttr("tenant_id", cty.StringVal("tid")). + Decode(), + Args: plugintest.NewTestDecoder(s.T(), s.schema.Args). + SetAttr("endpoint", cty.StringVal("/security/incidents")). + SetAttr("api_version", cty.StringVal("v1")). + SetAttr("query_params", cty.MapVal(map[string]cty.Value{"$top": cty.StringVal("10")})). + Decode(), + }) + s.Nil(diags) + parsedData, err := plugindata.ParseAny(expectedData) + s.Nil(err) + s.Equal(parsedData, result.AsPluginData()) +} diff --git a/internal/microsoft/plugin.go b/internal/microsoft/plugin.go index 5cbc28c7..94435d87 100644 --- a/internal/microsoft/plugin.go +++ b/internal/microsoft/plugin.go @@ -3,9 +3,12 @@ package microsoft import ( "context" "fmt" + "net/url" + "os" "github.com/Azure/azure-sdk-for-go/sdk/ai/azopenai" "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" "github.com/blackstork-io/fabric/internal/microsoft/client" "github.com/blackstork-io/fabric/plugin" @@ -28,7 +31,75 @@ var DefaultAzureOpenAIClientLoader AzureOpenaiClientLoadFn = func(azureOpenAIKey return } -func Plugin(version string, loader ClientLoadFn, openAiClientLoader AzureOpenaiClientLoadFn) *plugin.Schema { +type MicrosoftGraphClient interface { + QueryGraph(ctx context.Context, endpoint string, queryParams url.Values) (result interface{}, err error) +} + +type AcquireTokenFn func(ctx context.Context, tenantId string, clientId string, cred confidential.Credential) (string, error) + +type MicrosoftGraphClientLoadFn func(ctx context.Context, apiVersion string, cfg *dataspec.Block) (client MicrosoftGraphClient, err error) + +func MakeDefaultMicrosoftGraphClientLoader(tokenFn AcquireTokenFn) MicrosoftGraphClientLoadFn { + return func(ctx context.Context, apiVersion string, cfg *dataspec.Block) (cli MicrosoftGraphClient, err error) { + if cfg == nil { + return nil, fmt.Errorf("configuration is required") + } + tenantId := cfg.GetAttrVal("tenant_id").AsString() + clientId := cfg.GetAttrVal("client_id").AsString() + clientSecretAttr := cfg.GetAttrVal("client_secret") + if !clientSecretAttr.IsNull() { + cred, err := confidential.NewCredFromSecret(clientSecretAttr.AsString()) + if err != nil { + return nil, err + } + accessToken, err := tokenFn(ctx, tenantId, clientId, cred) + if err != nil { + return nil, err + } + return client.NewGraphClient(accessToken, apiVersion), nil + } + + // if client_secret is not provided, try to use private_key + privateKeyFileAttr := cfg.GetAttrVal("private_key_file") + privateKeyAttr := cfg.GetAttrVal("private_key") + + if !privateKeyFileAttr.IsNull() || !privateKeyAttr.IsNull() { + var pemData []byte + if !privateKeyAttr.IsNull() { + pemData = []byte(privateKeyAttr.AsString()) + } else { + pemData, err = os.ReadFile(privateKeyFileAttr.AsString()) + if err != nil { + return nil, fmt.Errorf("failed to read private key file: %w", err) + } + } + + keyPassphrase := "" + keyPassphraseAttr := cfg.GetAttrVal("key_passphrase") + if !keyPassphraseAttr.IsNull() { + keyPassphrase = keyPassphraseAttr.AsString() + } + + certs, privateKey, err := confidential.CertFromPEM(pemData, keyPassphrase) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %w", err) + } + cred, err := confidential.NewCredFromCert(certs, privateKey) + if err != nil { + return nil, fmt.Errorf("failed to create credential from cert: %w", err) + } + accessToken, err := tokenFn(ctx, tenantId, clientId, cred) + if err != nil { + return nil, err + } + return client.NewGraphClient(accessToken, apiVersion), nil + } + + return nil, fmt.Errorf("missing credentials to authenticate. client_secret or private_key is required") + } +} + +func Plugin(version string, loader ClientLoadFn, openAiClientLoader AzureOpenaiClientLoadFn, graphClientLoader MicrosoftGraphClientLoadFn) *plugin.Schema { if loader == nil { loader = DefaultClientLoader } @@ -38,6 +109,7 @@ func Plugin(version string, loader ClientLoadFn, openAiClientLoader AzureOpenaiC Version: version, DataSources: plugin.DataSources{ "microsoft_sentinel_incidents": makeMicrosoftSentinelIncidentsDataSource(loader), + "microsoft_graph": makeMicrosoftGraphDataSource(graphClientLoader), }, ContentProviders: plugin.ContentProviders{ "azure_openai_text": makeAzureOpenAITextContentSchema(openAiClientLoader), diff --git a/internal/microsoft/plugin_test.go b/internal/microsoft/plugin_test.go index c9aaa301..c414a557 100644 --- a/internal/microsoft/plugin_test.go +++ b/internal/microsoft/plugin_test.go @@ -1,15 +1,104 @@ package microsoft import ( + "bytes" + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" "testing" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential" "github.com/stretchr/testify/assert" + "github.com/zclconf/go-cty/cty" + + "github.com/blackstork-io/fabric/plugin/dataspec" + "github.com/blackstork-io/fabric/plugin/plugintest" ) func TestPlugin_Schema(t *testing.T) { - schema := Plugin("1.2.3", nil, nil) + schema := Plugin("1.2.3", nil, nil, nil) assert.Equal(t, "blackstork/microsoft", schema.Name) assert.Equal(t, "1.2.3", schema.Version) assert.NotNil(t, schema.DataSources["microsoft_sentinel_incidents"]) assert.NotNil(t, schema.ContentProviders["azure_openai_text"]) + assert.NotNil(t, schema.DataSources["microsoft_graph"]) +} + +func TestMakeDefaultGraphClientLoader(t *testing.T) { + loader := MakeDefaultMicrosoftGraphClientLoader(func(ctx context.Context, tenantId, clientId string, cred confidential.Credential) (string, error) { + return "test-token", nil + }) + assert.NotNil(t, loader) + + plugin := Plugin("1.0.0", nil, nil, (func(ctx context.Context, apiVersion string, cfg *dataspec.Block) (client MicrosoftGraphClient, err error) { + return nil, nil + })) + t.Run("with client secret", func(t *testing.T) { + result := plugintest.NewTestDecoder(t, plugin.DataSources["microsoft_graph"].Config). + SetAttr("client_id", cty.StringVal("cid")). + SetAttr("tenant_id", cty.StringVal("tid")). + SetAttr("client_secret", cty.StringVal("csecret")). + Decode() + client, err := loader(context.Background(), "2023-11-01", result) + assert.Nil(t, err) + assert.NotNil(t, client) + }) + + t.Run("should create with client secret", func(t *testing.T) { + result := plugintest.NewTestDecoder(t, plugin.DataSources["microsoft_graph"].Config). + SetAttr("client_id", cty.StringVal("cid")). + SetAttr("tenant_id", cty.StringVal("tid")). + SetAttr("client_secret", cty.StringVal("csecret")). + Decode() + client, err := loader(context.Background(), "2023-11-01", result) + assert.Nil(t, err) + assert.NotNil(t, client) + }) + + t.Run("should fail if no auth specified", func(t *testing.T) { + result := plugintest.NewTestDecoder(t, plugin.DataSources["microsoft_graph"].Config). + SetAttr("client_id", cty.StringVal("cid")). + SetAttr("tenant_id", cty.StringVal("tid")). + Decode() + _, err := loader(context.Background(), "2023-11-01", result) + assert.NotNil(t, err) + assert.EqualError(t, err, "missing credentials to authenticate. client_secret or private_key is required") + }) + + t.Run("should use private key contents if specified", func(t *testing.T) { + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + assert.Nil(t, err) + key, err := x509.MarshalPKCS8PrivateKey(privateKey) + assert.Nil(t, err) + + block := &pem.Block{ + Type: "PRIVATE KEY", + Bytes: key, + } + buff := bytes.NewBuffer([]byte{}) + err = pem.Encode(buff, block) + assert.Nil(t, err) + // add certificate + cert := x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "test"}, + } + certBytes, err := x509.CreateCertificate(rand.Reader, &cert, &cert, &privateKey.PublicKey, privateKey) + assert.Nil(t, err) + certBlock := &pem.Block{Type: "CERTIFICATE", Bytes: certBytes} + err = pem.Encode(buff, certBlock) + assert.Nil(t, err) + result := plugintest.NewTestDecoder(t, plugin.DataSources["microsoft_graph"].Config). + SetAttr("client_id", cty.StringVal("cid")). + SetAttr("tenant_id", cty.StringVal("tid")). + SetAttr("private_key", cty.StringVal(buff.String())). + Decode() + client, err := loader(context.Background(), "2023-11-01", result) + assert.Nil(t, err) + assert.NotNil(t, client) + }) } diff --git a/internal/plugin_validity_test.go b/internal/plugin_validity_test.go index de2e0616..6366ce65 100644 --- a/internal/plugin_validity_test.go +++ b/internal/plugin_validity_test.go @@ -45,7 +45,7 @@ func TestAllPluginSchemaValidity(t *testing.T) { splunk.Plugin(ver, nil), nistnvd.Plugin(ver, nil), snyk.Plugin(ver, nil), - microsoft.Plugin(ver, nil, nil), + microsoft.Plugin(ver, nil, nil, nil), } for _, p := range plugins { p := p diff --git a/mocks/internalpkg/microsoft/microsoft_graph_client.go b/mocks/internalpkg/microsoft/microsoft_graph_client.go new file mode 100644 index 00000000..3bfe9916 --- /dev/null +++ b/mocks/internalpkg/microsoft/microsoft_graph_client.go @@ -0,0 +1,98 @@ +// Code generated by mockery v2.42.1. DO NOT EDIT. + +package microsoft_mocks + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + + url "net/url" +) + +// MicrosoftGraphClient is an autogenerated mock type for the MicrosoftGraphClient type +type MicrosoftGraphClient struct { + mock.Mock +} + +type MicrosoftGraphClient_Expecter struct { + mock *mock.Mock +} + +func (_m *MicrosoftGraphClient) EXPECT() *MicrosoftGraphClient_Expecter { + return &MicrosoftGraphClient_Expecter{mock: &_m.Mock} +} + +// QueryGraph provides a mock function with given fields: ctx, endpoint, queryParams +func (_m *MicrosoftGraphClient) QueryGraph(ctx context.Context, endpoint string, queryParams url.Values) (interface{}, error) { + ret := _m.Called(ctx, endpoint, queryParams) + + if len(ret) == 0 { + panic("no return value specified for QueryGraph") + } + + var r0 interface{} + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, url.Values) (interface{}, error)); ok { + return rf(ctx, endpoint, queryParams) + } + if rf, ok := ret.Get(0).(func(context.Context, string, url.Values) interface{}); ok { + r0 = rf(ctx, endpoint, queryParams) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(interface{}) + } + } + + if rf, ok := ret.Get(1).(func(context.Context, string, url.Values) error); ok { + r1 = rf(ctx, endpoint, queryParams) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// MicrosoftGraphClient_QueryGraph_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'QueryGraph' +type MicrosoftGraphClient_QueryGraph_Call struct { + *mock.Call +} + +// QueryGraph is a helper method to define mock.On call +// - ctx context.Context +// - endpoint string +// - queryParams url.Values +func (_e *MicrosoftGraphClient_Expecter) QueryGraph(ctx interface{}, endpoint interface{}, queryParams interface{}) *MicrosoftGraphClient_QueryGraph_Call { + return &MicrosoftGraphClient_QueryGraph_Call{Call: _e.mock.On("QueryGraph", ctx, endpoint, queryParams)} +} + +func (_c *MicrosoftGraphClient_QueryGraph_Call) Run(run func(ctx context.Context, endpoint string, queryParams url.Values)) *MicrosoftGraphClient_QueryGraph_Call { + _c.Call.Run(func(args mock.Arguments) { + run(args[0].(context.Context), args[1].(string), args[2].(url.Values)) + }) + return _c +} + +func (_c *MicrosoftGraphClient_QueryGraph_Call) Return(result interface{}, err error) *MicrosoftGraphClient_QueryGraph_Call { + _c.Call.Return(result, err) + return _c +} + +func (_c *MicrosoftGraphClient_QueryGraph_Call) RunAndReturn(run func(context.Context, string, url.Values) (interface{}, error)) *MicrosoftGraphClient_QueryGraph_Call { + _c.Call.Return(run) + return _c +} + +// NewMicrosoftGraphClient creates a new instance of MicrosoftGraphClient. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMicrosoftGraphClient(t interface { + mock.TestingT + Cleanup(func()) +}) *MicrosoftGraphClient { + mock := &MicrosoftGraphClient{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/tools/docgen/main.go b/tools/docgen/main.go index 88aa18f6..c99b5b2e 100644 --- a/tools/docgen/main.go +++ b/tools/docgen/main.go @@ -277,7 +277,7 @@ func main() { stixview.Plugin(version), nistnvd.Plugin(version, nil), snyk.Plugin(version, nil), - microsoft.Plugin(version, nil, nil), + microsoft.Plugin(version, nil, nil, nil), } // generate markdown for each plugin for _, p := range plugins {