diff --git a/.github/workflows/rego.yaml b/.github/workflows/rego.yaml index d29fe70..e553bc3 100644 --- a/.github/workflows/rego.yaml +++ b/.github/workflows/rego.yaml @@ -29,4 +29,4 @@ jobs: version: latest - name: Run OPA Tests - run: opa test classification/labels/*.rego -v + run: opa test ./classification/labels/*.rego -v diff --git a/Makefile b/Makefile index 569d168..0eba7a7 100644 --- a/Makefile +++ b/Makefile @@ -15,3 +15,12 @@ integration-test: clean: go clean -i ./... + +opt-fmt: + opa fmt --write ./classification/labels + +opa-lint: + regal lint --disable=line-length ./classification/labels/ + +opa-test: + opa test ./classification/labels/*.rego -v diff --git a/classification/label.go b/classification/label.go index 7aa38b5..8f5d566 100644 --- a/classification/label.go +++ b/classification/label.go @@ -50,7 +50,10 @@ func NewLabel(name, description, classificationRule string, tags ...string) (Lab // GetEmbeddedLabels returns the predefined embedded labels and their // classification rules. The labels are read from the embedded labels.yaml file -// and the classification rules are read from the embedded Rego files. +// and the classification rules are read from the embedded Rego files. If there +// is an error unmarshalling the labels file, it is returned. If there is an +// error reading or parsing a classification rule for a label, a warning is +// logged and that label is skipped. func GetEmbeddedLabels() ([]Label, error) { labels := struct { Labels map[string]Label `yaml:"labels"` @@ -62,11 +65,13 @@ func GetEmbeddedLabels() ([]Label, error) { fname := "labels/" + strings.ReplaceAll(strings.ToLower(name), " ", "_") + ".rego" b, err := regoFs.ReadFile(fname) if err != nil { - return nil, fmt.Errorf("error reading rego file %s: %w", fname, err) + log.WithError(err).Warnf("error reading rego file %s", fname) + continue } rule, err := parseRego(string(b)) if err != nil { - return nil, fmt.Errorf("error preparing classification rule for label %s: %w", lbl.Name, err) + log.WithError(err).Warnf("error parsing classification rule for label %s", lbl.Name) + continue } lbl.Name = name lbl.ClassificationRule = rule diff --git a/classification/label_classifier.go b/classification/label_classifier.go index c7a23b9..0bbdb46 100644 --- a/classification/label_classifier.go +++ b/classification/label_classifier.go @@ -18,7 +18,6 @@ type LabelClassifier struct { var _ Classifier = (*LabelClassifier)(nil) // NewLabelClassifier creates a new LabelClassifier with the provided labels. -// func NewLabelClassifier(labels ...Label) (*LabelClassifier, error) { if len(labels) == 0 { return nil, fmt.Errorf("labels cannot be empty") @@ -27,7 +26,7 @@ func NewLabelClassifier(labels ...Label) (*LabelClassifier, error) { for _, lbl := range labels { queries[lbl.Name] = rego.New( // We only care about the 'output' variable. - rego.Query(lbl.ClassificationRule.Package.Path.String() + ".output"), + rego.Query(lbl.ClassificationRule.Package.Path.String()+".output"), rego.ParsedModule(lbl.ClassificationRule), ) } diff --git a/cmd/main.go b/cmd/main.go index fee5fd2..f2038ca 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -64,7 +64,7 @@ func main() { ctx := kong.Parse( &cli, kong.Name("dmap"), - kong.Description("Assess your data security posture in AWS."), + kong.Description("Discover your data repositories and classify their sensitive data."), kong.UsageOnError(), kong.ConfigureHelp( kong.HelpOptions{ diff --git a/cmd/repo_scan.go b/cmd/repo_scan.go index ce4f327..a691598 100644 --- a/cmd/repo_scan.go +++ b/cmd/repo_scan.go @@ -53,6 +53,7 @@ func (g GlobFlag) Decode(ctx *kong.DecodeContext) error { func (cmd *RepoScanCmd) Run(_ *Globals) error { ctx := context.Background() + // Configure and instantiate the scanner. cfg := sql.ScannerConfig{ RepoType: cmd.Type, RepoConfig: sql.RepoConfig{ @@ -73,15 +74,17 @@ func (cmd *RepoScanCmd) Run(_ *Globals) error { if err != nil { return fmt.Errorf("error creating new scanner: %w", err) } + // Scan the repository. results, err := scanner.Scan(ctx) if err != nil { return fmt.Errorf("error scanning repository: %w", err) } + // Print the results to stdout. jsonResults, err := json.MarshalIndent(results, "", " ") if err != nil { return fmt.Errorf("error marshalling results: %w", err) } fmt.Println(string(jsonResults)) - // TODO: publish results to API -ccampo 2024-04-03 + // TODO: publish results to the API -ccampo 2024-04-03 return nil } diff --git a/scan/scanner.go b/scan/scanner.go index 8a39f4f..642ed4d 100644 --- a/scan/scanner.go +++ b/scan/scanner.go @@ -27,7 +27,7 @@ type RepoScanResults struct { Classifications []Classification `json:"classifications"` } -// TODO: godoc -ccampo 2024-04-03 +// Classification represents the classification of a data repository attribute. type Classification struct { // AttributePath is the full path of the data repository attribute // (e.g. the column). Each element corresponds to a component, in increasing diff --git a/sql/classify.go b/sql/classify.go deleted file mode 100644 index 30dd914..0000000 --- a/sql/classify.go +++ /dev/null @@ -1,61 +0,0 @@ -package sql - -import ( - "context" - "fmt" - "maps" - "strings" - - "github.com/cyralinc/dmap/classification" - "github.com/cyralinc/dmap/scan" -) - -// classifySamples uses the provided classifiers to classify the sample data -// passed via the "samples" parameter. It is mostly a helper function which -// loops through each repository.Sample, retrieves the attribute names and -// values of that sample, passes them to Classifier.Classify, and then -// aggregates the results. Please see the documentation for Classifier and its -// Classify method for more details. The returned slice represents all the -// unique classification results for a given sample set. -func classifySamples( - ctx context.Context, - samples []Sample, - classifier classification.Classifier, -) ([]scan.Classification, error) { - uniqueResults := make(map[string]scan.Classification) - for _, sample := range samples { - // Classify each sampled row and combine the results. - for _, sampleResult := range sample.Results { - res, err := classifier.Classify(ctx, sampleResult) - if err != nil { - return nil, fmt.Errorf("error classifying sample: %w", err) - } - for attr, labels := range res { - attrPath := append(sample.TablePath, attr) - key := pathKey(attrPath) - result, ok := uniqueResults[key] - if !ok { - uniqueResults[key] = scan.Classification{ - AttributePath: attrPath, - Labels: labels, - } - } else { - // Merge the labels from the new result into the existing result. - maps.Copy(result.Labels, labels) - } - } - } - } - // Convert the map of unique results to a slice. - results := make([]scan.Classification, 0, len(uniqueResults)) - for _, result := range uniqueResults { - results = append(results, result) - } - return results, nil -} - -func pathKey(path []string) string { - // U+2063 is an invisible separator. It is used here to ensure that the - // pathKey is unique and does not conflict with any of the path elements. - return strings.Join(path, "\u2063") -} diff --git a/sql/classify_test.go b/sql/classify_test.go deleted file mode 100644 index abce877..0000000 --- a/sql/classify_test.go +++ /dev/null @@ -1,158 +0,0 @@ -package sql - -import ( - "context" - "testing" - - "github.com/stretchr/testify/require" - - "github.com/cyralinc/dmap/classification" - "github.com/cyralinc/dmap/scan" -) - -func Test_classifySamples_SingleSample(t *testing.T) { - ctx := context.Background() - sample := Sample{ - TablePath: []string{"db", "schema", "table"}, - Results: []SampleResult{ - { - "age": "52", - "social_sec_num": "512-23-4258", - "credit_card_num": "4111111111111111", - }, - { - "age": "101", - "social_sec_num": "foobarbaz", - "credit_card_num": "4111111111111111", - }, - }, - } - classifier := NewMockClassifier(t) - // Need to explicitly convert it to a map because Mockery isn't smart enough - // to infer the type. - classifier.EXPECT().Classify(ctx, map[string]any(sample.Results[0])).Return( - classification.Result{ - "age": lblSet("AGE"), - "social_sec_num": lblSet("SSN"), - "credit_card_num": lblSet("CCN"), - }, - nil, - ) - classifier.EXPECT().Classify(ctx, map[string]any(sample.Results[1])).Return( - classification.Result{ - "age": lblSet("AGE", "CVV"), - "credit_card_num": lblSet("CCN"), - }, - nil, - ) - - expected := []scan.Classification{ - { - AttributePath: append(sample.TablePath, "age"), - Labels: lblSet("AGE", "CVV"), - }, - { - AttributePath: append(sample.TablePath, "social_sec_num"), - Labels: lblSet("SSN"), - }, - { - AttributePath: append(sample.TablePath, "credit_card_num"), - Labels: lblSet("CCN"), - }, - } - actual, err := classifySamples(ctx, []Sample{sample}, classifier) - require.NoError(t, err) - require.ElementsMatch(t, expected, actual) -} - -func Test_classifySamples_MultipleSamples(t *testing.T) { - ctx := context.Background() - samples := []Sample{ - { - TablePath: []string{"db1", "schema1", "table1"}, - Results: []SampleResult{ - { - "age": "52", - "social_sec_num": "512-23-4258", - "credit_card_num": "4111111111111111", - }, - { - "age": "101", - "social_sec_num": "foobarbaz", - "credit_card_num": "4111111111111111", - }, - }, - }, - { - TablePath: []string{"db2", "schema2", "table2"}, - Results: []SampleResult{ - { - "fullname": "John Doe", - "dob": "2000-01-01", - "random": "foobarbaz", - }, - }, - }, - } - - classifier := NewMockClassifier(t) - // Need to explicitly convert it to a map because Mockery isn't smart enough - // to infer the type. - classifier.EXPECT().Classify(ctx, map[string]any(samples[0].Results[0])).Return( - classification.Result{ - "age": lblSet("AGE"), - "social_sec_num": lblSet("SSN"), - "credit_card_num": lblSet("CCN"), - }, - nil, - ) - classifier.EXPECT().Classify(ctx, map[string]any(samples[0].Results[1])).Return( - classification.Result{ - "age": lblSet("AGE", "CVV"), - "credit_card_num": lblSet("CCN"), - }, - nil, - ) - classifier.EXPECT().Classify(ctx, map[string]any(samples[1].Results[0])).Return( - classification.Result{ - "fullname": lblSet("FULL_NAME"), - "dob": lblSet("DOB"), - }, - nil, - ) - - expected := []scan.Classification{ - { - AttributePath: append(samples[0].TablePath, "age"), - Labels: lblSet("AGE", "CVV"), - }, - { - AttributePath: append(samples[0].TablePath, "social_sec_num"), - Labels: lblSet("SSN"), - }, - { - AttributePath: append(samples[0].TablePath, "credit_card_num"), - Labels: lblSet("CCN"), - }, - { - AttributePath: append(samples[1].TablePath, "fullname"), - Labels: lblSet("FULL_NAME"), - }, - { - AttributePath: append(samples[1].TablePath, "dob"), - Labels: lblSet("DOB"), - }, - } - actual, err := classifySamples(ctx, samples, classifier) - require.NoError(t, err) - require.ElementsMatch(t, expected, actual) -} - -func lblSet(labels ...string) classification.LabelSet { - set := make(classification.LabelSet) - for _, label := range labels { - set[label] = struct { - }{} - } - return set -} diff --git a/sql/config.go b/sql/config.go index a8785b9..648fa4a 100644 --- a/sql/config.go +++ b/sql/config.go @@ -4,8 +4,6 @@ import ( "fmt" ) -const configConnOpts = "connection-string-args" - // RepoConfig is the necessary configuration to connect to a data sql. type RepoConfig struct { // Host is the hostname of the database. @@ -24,61 +22,17 @@ type RepoConfig struct { Advanced map[string]any } -// FetchAdvancedConfigString fetches a map in the repo advanced configuration, -// for a given repo and set of parameters. Example: -// -// repo-advanced: -// -// snowflake: -// account: exampleAccount -// role: exampleRole -// warehouse: exampleWarehouse -// -// Calling FetchAdvancedMapConfig(, "snowflake", -// []string{"account", "role", "warehouse"}) returns the map -// -// {"account": "exampleAccount", "role": "exampleRole", "warehouse": -// "exampleWarehouse"} -// -// The suffix 'String' means that the values of the map are strings. This gives -// room to have FetchAdvancedConfigList or FetchAdvancedConfigMap, for example, -// without name conflicts. -func FetchAdvancedConfigString( - cfg RepoConfig, - repo string, - parameters []string, -) (map[string]string, error) { - advancedCfg, err := getAdvancedConfig(cfg, repo) - if err != nil { - return nil, err - } - repoSpecificMap := make(map[string]string) - for _, key := range parameters { - var valInterface any - var val string - var ok bool - if valInterface, ok = advancedCfg[key]; !ok { - return nil, fmt.Errorf("unable to find '%s' in %s advanced config", key, repo) - } - if val, ok = valInterface.(string); !ok { - return nil, fmt.Errorf("'%s' in %s config must be a string", key, repo) - } - repoSpecificMap[key] = val - } - return repoSpecificMap, nil -} - -// getAdvancedConfig gets the Advanced field in a repo config and converts it to -// a map[string]any. In every step, it checks for error and generates -// nice messages. -func getAdvancedConfig(cfg RepoConfig, repo string) (map[string]any, error) { - advancedCfgInterface, ok := cfg.Advanced[repo] +// keyAsString returns the value of the given key as a string from the given +// configuration map. It returns an error if the key does not exist or if the +// value is not a string. +func keyAsString(cfg map[string]any, key string) (string, error) { + val, ok := cfg[key] if !ok { - return nil, fmt.Errorf("unable to find '%s' in advanced config", repo) + return "", fmt.Errorf("%s key does not exist", key) } - advancedCfg, ok := advancedCfgInterface.(map[string]any) + valStr, ok := val.(string) if !ok { - return nil, fmt.Errorf("'%s' in advanced config is not a map", repo) + return "", fmt.Errorf("%s key must be a string", key) } - return advancedCfg, nil + return valStr, nil } diff --git a/sql/config_test.go b/sql/config_test.go deleted file mode 100644 index e92a5c8..0000000 --- a/sql/config_test.go +++ /dev/null @@ -1,77 +0,0 @@ -package sql - -import ( - "testing" - - "github.com/stretchr/testify/require" -) - -func TestAdvancedConfigSucc(t *testing.T) { - sampleCfg := RepoConfig{ - Advanced: map[string]any{ - "snowflake": map[string]any{ - "account": "exampleAccount", - "role": "exampleRole", - "warehouse": "exampleWarehouse", - }, - }, - } - repoSpecificMap, err := FetchAdvancedConfigString( - sampleCfg, - "snowflake", []string{"account", "role", "warehouse"}, - ) - require.NoError(t, err) - require.EqualValues( - t, repoSpecificMap, map[string]string{ - "account": "exampleAccount", - "role": "exampleRole", - "warehouse": "exampleWarehouse", - }, - ) -} - -func TestAdvancedConfigMissing(t *testing.T) { - // Without the snowflake config at all - sampleCfg := RepoConfig{ - Advanced: map[string]any{}, - } - _, err := FetchAdvancedConfigString( - sampleCfg, - "snowflake", []string{"account", "role", "warehouse"}, - ) - require.Error(t, err) - - sampleCfg = RepoConfig{ - Advanced: map[string]any{ - "snowflake": map[string]any{ - // Missing account - - "role": "exampleRole", - "warehouse": "exampleWarehouse", - }, - }, - } - _, err = FetchAdvancedConfigString( - sampleCfg, - "snowflake", []string{"account", "role", "warehouse"}, - ) - require.Error(t, err) -} - -func TestAdvancedConfigMalformed(t *testing.T) { - sampleCfg := RepoConfig{ - Advanced: map[string]any{ - "snowflake": map[string]any{ - // Let's give a _list_ of things - "account": []string{"account1", "account2"}, - "role": []string{"role1", "role2"}, - "warehouse": []string{"warehouse1", "warehouse2"}, - }, - }, - } - _, err := FetchAdvancedConfigString( - sampleCfg, - "snowflake", []string{"account", "role", "warehouse"}, - ) - require.Error(t, err) -} diff --git a/sql/denodo.go b/sql/denodo.go index 9e7a730..2e11248 100644 --- a/sql/denodo.go +++ b/sql/denodo.go @@ -11,11 +11,11 @@ import ( const ( RepoTypeDenodo = "denodo" - // DenodoIntrospectQuery is the SQL query used to introspect the database. For - // Denodo, the object hierarchy is (database > views). When querying - // Denodo, the database corresponds to a schema, and the view corresponds - // to a table (see SampleTable). - DenodoIntrospectQuery = "SELECT " + + // denodoIntrospectQuery is the SQL query used to introspect the database. + // For Denodo, the object hierarchy is (database > views). When querying + // Denodo, the database corresponds to a schema, and the view corresponds to + // a table. + denodoIntrospectQuery = "SELECT " + "database_name AS table_schema, " + "view_name AS table_name, " + "column_name, " + @@ -36,21 +36,16 @@ var _ Repository = (*DenodoRepository)(nil) // NewDenodoRepository is the constructor for sql. func NewDenodoRepository(cfg RepoConfig) (*DenodoRepository, error) { - pgCfg, err := parsePostgresConfig(cfg) - if err != nil { - return nil, fmt.Errorf("unable to parse postgres config: %w", err) - } if cfg.Database == "" { return nil, errors.New("database name is mandatory for Denodo repositories") } connStr := fmt.Sprintf( - "postgresql://%s:%s@%s:%d/%s%s", + "postgresql://%s:%s@%s:%d/%s", cfg.User, cfg.Password, cfg.Host, cfg.Port, cfg.Database, - pgCfg.ConnOptsStr, ) generic, err := NewGenericRepository(RepoTypePostgres, cfg.Database, connStr, cfg.MaxOpenConns) if err != nil { @@ -69,7 +64,7 @@ func (r *DenodoRepository) ListDatabases(_ context.Context) ([]string, error) { // Repository.Introspect and GenericRepository.IntrospectWithQuery for more // details. func (r *DenodoRepository) Introspect(ctx context.Context, params IntrospectParameters) (*Metadata, error) { - return r.generic.IntrospectWithQuery(ctx, DenodoIntrospectQuery, params) + return r.generic.IntrospectWithQuery(ctx, denodoIntrospectQuery, params) } // SampleTable delegates sampling to GenericRepository, using a Denodo-specific diff --git a/sql/doc.go b/sql/doc.go index e1f8447..80ec5d7 100644 --- a/sql/doc.go +++ b/sql/doc.go @@ -1,14 +1,15 @@ -// Package sql provides mechanisms to perform database introspection and -// data discovery on various SQL data repositories. -// -// Additionally, the Repository interface provides an API for performing -// database introspection and data discovery on SQL databases. It encapsulates -// the concept of a Dmap data SQL repository. All out-of-the-box Repository -// implementations are included in their own files named after the repository -// type, e.g. mysql.go, postgres.go, etc. +// Package sql provides mechanisms to perform database introspection, sampling, +// and classification on various SQL data repositories. The Repository interface +// provides the API for performing database introspection and sampling. It +// encapsulates the concept of a Dmap data SQL repository. All out-of-the-box +// Repository implementations are included in their own files named after the +// repository type, e.g. mysql.go, postgres.go, etc. // // Registry provides an API for registering and constructing Repository // implementations within an application. There is a global DefaultRegistry // which has all-out-of-the-box Repository implementations registered to it // by default. +// +// Scanner is a scan.RepoScanner implementation that can be used to perform +// data discovery and classification on SQL repositories. package sql diff --git a/sql/generic.go b/sql/generic.go index 1c870fd..3d5e8f4 100644 --- a/sql/generic.go +++ b/sql/generic.go @@ -12,7 +12,7 @@ import ( ) const ( - GenericIntrospectQuery = "SELECT " + + genericIntrospectQuery = "SELECT " + "table_schema, " + "table_name, " + "column_name, " + @@ -29,8 +29,8 @@ const ( "'performance_schema', " + "'pg_catalog'" + ")" - GenericPingQuery = "SELECT 1" - GenericSampleQueryTemplate = "SELECT %s FROM %s.%s LIMIT ? OFFSET ?" + genericPingQuery = "SELECT 1" + genericSampleQueryTemplate = "SELECT %s FROM %s.%s LIMIT ? OFFSET ?" ) // GenericRepository implements generic SQL functionalities that work for a @@ -128,7 +128,7 @@ func (r *GenericRepository) Introspect( ctx context.Context, params IntrospectParameters, ) (*Metadata, error) { - return r.IntrospectWithQuery(ctx, GenericIntrospectQuery, params) + return r.IntrospectWithQuery(ctx, genericIntrospectQuery, params) } // IntrospectWithQuery executes a query against the information_schema table in @@ -165,7 +165,7 @@ func (r *GenericRepository) SampleTable( ) (Sample, error) { // ANSI SQL uses double-quotes to quote identifiers attrStr := params.Metadata.QuotedAttributeNamesString("\"") - query := fmt.Sprintf(GenericSampleQueryTemplate, attrStr, params.Metadata.Schema, params.Metadata.Name) + query := fmt.Sprintf(genericSampleQueryTemplate, attrStr, params.Metadata.Schema, params.Metadata.Name) return r.SampleTableWithQuery(ctx, query, params) } @@ -213,8 +213,8 @@ func (r *GenericRepository) SampleTableWithQuery( // Ping verifies the connection to the database used by this repository by // executing a simple query. If the query fails, an error is returned. func (r *GenericRepository) Ping(ctx context.Context) error { - log.Tracef("Query: %s", GenericPingQuery) - rows, err := r.db.QueryContext(ctx, GenericPingQuery) + log.Tracef("Query: %s", genericPingQuery) + rows, err := r.db.QueryContext(ctx, genericPingQuery) if err != nil { return err } diff --git a/sql/metadata.go b/sql/metadata.go index 622e8ef..5564661 100644 --- a/sql/metadata.go +++ b/sql/metadata.go @@ -87,7 +87,6 @@ func newMetadataFromQueryResult( if err := rows.Scan(&attr.Schema, &attr.Table, &attr.Name, &attr.DataType); err != nil { return nil, fmt.Errorf("error scanning metadata query result row: %w", err) } - // Skip tables that match excludePaths or does not match includePaths. log.Tracef("checking if %s.%s.%s matches excludePaths %s\n", db, attr.Schema, attr.Table, excludePaths) if matchPathPatterns(db, attr.Schema, attr.Table, excludePaths) { @@ -97,7 +96,6 @@ func newMetadataFromQueryResult( if !matchPathPatterns(db, attr.Schema, attr.Table, includePaths) { continue } - // SchemaMetadata exists - add a table if necessary. if schema, ok := repo.Schemas[attr.Schema]; ok { // TableMetadata exists - just append the attribute. @@ -116,7 +114,6 @@ func newMetadataFromQueryResult( repo.Schemas[attr.Schema] = schema } } - // Something broke while iterating the row set. if err := rows.Err(); err != nil { return nil, fmt.Errorf("error iterating metadata query rows: %w", err) diff --git a/sql/mysql.go b/sql/mysql.go index 0973c8b..c53f49b 100644 --- a/sql/mysql.go +++ b/sql/mysql.go @@ -10,7 +10,7 @@ import ( const ( RepoTypeMysql = "mysql" - MySqlDatabaseQuery = ` + mySqlDatabaseQuery = ` SELECT schema_name FROM @@ -55,7 +55,7 @@ func NewMySqlRepository(cfg RepoConfig) (*MySqlRepository, error) { // using a MySQL-specific database query. It delegates the actual work to // GenericRepository.ListDatabasesWithQuery - see that method for more details. func (r *MySqlRepository) ListDatabases(ctx context.Context) ([]string, error) { - return r.generic.ListDatabasesWithQuery(ctx, MySqlDatabaseQuery) + return r.generic.ListDatabasesWithQuery(ctx, mySqlDatabaseQuery) } // Introspect delegates introspection to GenericRepository. See @@ -76,7 +76,7 @@ func (r *MySqlRepository) SampleTable( attrStr := params.Metadata.QuotedAttributeNamesString("`") // The generic select/limit/offset query and ? placeholders work fine with // MySQL. - query := fmt.Sprintf(GenericSampleQueryTemplate, attrStr, params.Metadata.Schema, params.Metadata.Name) + query := fmt.Sprintf(genericSampleQueryTemplate, attrStr, params.Metadata.Schema, params.Metadata.Name) return r.generic.SampleTableWithQuery(ctx, query, params) } diff --git a/sql/mysql_test.go b/sql/mysql_test.go index 346d05b..e3be384 100644 --- a/sql/mysql_test.go +++ b/sql/mysql_test.go @@ -13,7 +13,7 @@ func TestMySqlRepository_ListDatabases(t *testing.T) { ctx, db, mock, r := initMySqlRepoTest(t) defer func() { _ = db.Close() }() dbRows := sqlmock.NewRows([]string{"name"}).AddRow("db1").AddRow("db2") - mock.ExpectQuery(MySqlDatabaseQuery).WillReturnRows(dbRows) + mock.ExpectQuery(mySqlDatabaseQuery).WillReturnRows(dbRows) dbs, err := r.ListDatabases(ctx) require.NoError(t, err) require.ElementsMatch(t, []string{"db1", "db2"}, dbs) diff --git a/sql/oracle.go b/sql/oracle.go index e4c0394..7290298 100644 --- a/sql/oracle.go +++ b/sql/oracle.go @@ -11,7 +11,7 @@ import ( const ( RepoTypeOracle = "oracle" - OracleIntrospectQuery = ` + oracleIntrospectQuery = ` WITH users AS ( SELECT username @@ -48,7 +48,7 @@ var _ Repository = (*OracleRepository)(nil) // NewOracleRepository creates a new Oracle repository. func NewOracleRepository(cfg RepoConfig) (*OracleRepository, error) { - oracleCfg, err := ParseOracleConfig(cfg) + oracleCfg, err := NewOracleConfigFromMap(cfg.Advanced) if err != nil { return nil, fmt.Errorf("unable to parse oracle config: %w", err) } @@ -78,7 +78,7 @@ func (r *OracleRepository) ListDatabases(_ context.Context) ([]string, error) { // Oracle-specific introspection query. See Repository.Introspect and // GenericRepository.IntrospectWithQuery for more details. func (r *OracleRepository) Introspect(ctx context.Context, params IntrospectParameters) (*Metadata, error) { - return r.generic.IntrospectWithQuery(ctx, OracleIntrospectQuery, params) + return r.generic.IntrospectWithQuery(ctx, oracleIntrospectQuery, params) } // SampleTable delegates sampling to GenericRepository, using an Oracle-specific @@ -119,17 +119,13 @@ type OracleConfig struct { ServiceName string } -// ParseOracleConfig parses the Oracle-specific configuration from the -// given The Oracle configuration is expected to be in the -// config's "advanced config" property. -func ParseOracleConfig(cfg RepoConfig) (*OracleConfig, error) { - oracleCfg, err := FetchAdvancedConfigString( - cfg, - RepoTypeOracle, - []string{configServiceName}, - ) +// NewOracleConfigFromMap creates a new OracleConfig from the given map. This is +// useful for parsing the Oracle-specific configuration from the +// RepoConfig.Advanced map, for example. +func NewOracleConfigFromMap(cfg map[string]any) (OracleConfig, error) { + serviceName, err := keyAsString(cfg, configServiceName) if err != nil { - return nil, fmt.Errorf("error fetching advanced oracle config: %w", err) + return OracleConfig{}, err } - return &OracleConfig{ServiceName: oracleCfg[configServiceName]}, nil + return OracleConfig{ServiceName: serviceName}, nil } diff --git a/sql/oracle_test.go b/sql/oracle_test.go new file mode 100644 index 0000000..081e0b6 --- /dev/null +++ b/sql/oracle_test.go @@ -0,0 +1,44 @@ +package sql + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewOracleConfigFromMap(t *testing.T) { + tests := []struct { + name string + cfg map[string]any + want OracleConfig + wantErr require.ErrorAssertionFunc + }{ + { + name: "Returns config when ServiceName key is present", + cfg: map[string]any{ + configServiceName: "testServiceName", + }, + want: OracleConfig{ + ServiceName: "testServiceName", + }, + }, + { + name: "Returns error when ServiceName key is missing", + cfg: map[string]any{}, + wantErr: require.Error, + }, + } + + for _, tt := range tests { + t.Run( + tt.name, func(t *testing.T) { + got, err := NewOracleConfigFromMap(tt.cfg) + if tt.wantErr == nil { + tt.wantErr = require.NoError + } + tt.wantErr(t, err) + require.Equal(t, tt.want, got) + }, + ) + } +} diff --git a/sql/postgres.go b/sql/postgres.go index e5a81ee..f96632b 100644 --- a/sql/postgres.go +++ b/sql/postgres.go @@ -3,16 +3,13 @@ package sql import ( "context" "fmt" - "strings" - // Postgresql DB driver _ "github.com/lib/pq" ) const ( RepoTypePostgres = "postgres" - - PostgresDatabaseQuery = ` + postgresDatabaseQuery = ` SELECT datname FROM @@ -36,23 +33,18 @@ var _ Repository = (*PostgresRepository)(nil) // NewPostgresRepository creates a new PostgresRepository. func NewPostgresRepository(cfg RepoConfig) (*PostgresRepository, error) { - pgCfg, err := parsePostgresConfig(cfg) - if err != nil { - return nil, fmt.Errorf("error parsing postgres config: %w", err) - } database := cfg.Database // Connect to the default database, if unspecified. if database == "" { database = "postgres" } connStr := fmt.Sprintf( - "postgresql://%s:%s@%s:%d/%s%s", + "postgresql://%s:%s@%s:%d/%s", cfg.User, cfg.Password, cfg.Host, cfg.Port, database, - pgCfg.ConnOptsStr, ) generic, err := NewGenericRepository(RepoTypePostgres, cfg.Database, connStr, cfg.MaxOpenConns) if err != nil { @@ -65,7 +57,7 @@ func NewPostgresRepository(cfg RepoConfig) (*PostgresRepository, error) { // using a Postgres-specific database query. It delegates the actual work to // GenericRepository.ListDatabasesWithQuery - see that method for more details. func (r *PostgresRepository) ListDatabases(ctx context.Context) ([]string, error) { - return r.generic.ListDatabasesWithQuery(ctx, PostgresDatabaseQuery) + return r.generic.ListDatabasesWithQuery(ctx, postgresDatabaseQuery) } // Introspect delegates introspection to GenericRepository. See @@ -105,87 +97,3 @@ func (r *PostgresRepository) Ping(ctx context.Context) error { func (r *PostgresRepository) Close() error { return r.generic.Close() } - -// PostgresConfig contains Postgres-specific configuration parameters. -type PostgresConfig struct { - // ConnOptsStr is a string containing Postgres-specific connection options. - ConnOptsStr string -} - -// parsePostgresConfig parses the Postgres-specific configuration parameters -// from the given The Postgres connection options are built from the -// config and stored in the ConnOptsStr field of the returned Postgres -func parsePostgresConfig(cfg RepoConfig) (*PostgresConfig, error) { - connOptsStr, err := buildConnOptsStr(cfg) - if err != nil { - return nil, fmt.Errorf("error building connection options string: %w", err) - } - return &PostgresConfig{ConnOptsStr: connOptsStr}, nil -} - -// buildConnOptsStr parses the repo config to produce a string in the format -// "?option=value&option2=value2". Example: -// -// buildConnOptsStr(RepoConfig{ -// Advanced: map[string]any{ -// "connection-string-args": []any{"sslmode=disable"}, -// }, -// }) -// -// returns ("?sslmode=disable", nil). -func buildConnOptsStr(cfg RepoConfig) (string, error) { - connOptsMap, err := mapFromConnOpts(cfg) - if err != nil { - return "", fmt.Errorf("connection options: %w", err) - } - connOptsStr := "" - for key, val := range connOptsMap { - // Don't add if the value is empty, since that would make the - // string malformed. - if val != "" { - if connOptsStr == "" { - connOptsStr += fmt.Sprintf("%s=%s", key, val) - } else { - // Need & for subsequent options - connOptsStr += fmt.Sprintf("&%s=%s", key, val) - } - } - } - // Only add ? if connection string is not empty - if connOptsStr != "" { - connOptsStr = "?" + connOptsStr - } - return connOptsStr, nil -} - -// mapFromConnOpts builds a map from the list of connection options given. Each -// option has the format 'option=value'. An error is returned if the config is -// malformed. -func mapFromConnOpts(cfg RepoConfig) (map[string]string, error) { - m := make(map[string]string) - connOptsInterface, ok := cfg.Advanced[configConnOpts] - if !ok { - return nil, nil - } - connOpts, ok := connOptsInterface.([]any) - if !ok { - return nil, fmt.Errorf("'%s' is not a list", configConnOpts) - } - for _, optInterface := range connOpts { - opt, ok := optInterface.(string) - if !ok { - return nil, fmt.Errorf("'%v' is not a string", optInterface) - } - splitOpt := strings.Split(opt, "=") - if len(splitOpt) != 2 { - return nil, fmt.Errorf( - "malformed '%s'. "+ - "Please follow the format 'option=value'", configConnOpts, - ) - } - key := splitOpt[0] - val := splitOpt[1] - m[key] = val - } - return m, nil -} diff --git a/sql/postgres_test.go b/sql/postgres_test.go index 84a1974..e46a2f4 100644 --- a/sql/postgres_test.go +++ b/sql/postgres_test.go @@ -13,76 +13,12 @@ func TestPostgresRepository_ListDatabases(t *testing.T) { ctx, db, mock, r := initPostgresRepoTest(t) defer func() { _ = db.Close() }() dbRows := sqlmock.NewRows([]string{"name"}).AddRow("db1").AddRow("db2") - mock.ExpectQuery(PostgresDatabaseQuery).WillReturnRows(dbRows) + mock.ExpectQuery(postgresDatabaseQuery).WillReturnRows(dbRows) dbs, err := r.ListDatabases(ctx) require.NoError(t, err) require.ElementsMatch(t, []string{"db1", "db2"}, dbs) } -func TestBuildConnOptionsSucc(t *testing.T) { - sampleRepoCfg := getSampleRepoConfig() - connOptsStr, err := buildConnOptsStr(sampleRepoCfg) - require.NoError(t, err) - require.Equal(t, connOptsStr, "?sslmode=disable") -} - -func TestBuildConnOptionsFail(t *testing.T) { - invalidRepoCfg := RepoConfig{ - Advanced: map[string]any{ - // Invalid: map instead of string - configConnOpts: []any{ - map[string]string{"sslmode": "disable"}, - }, - }, - } - connOptsStr, err := buildConnOptsStr(invalidRepoCfg) - require.Error(t, err) - require.Empty(t, connOptsStr) -} - -func TestMapConnOptionsSucc(t *testing.T) { - sampleRepoCfg := getSampleRepoConfig() - connOptsMap, err := mapFromConnOpts(sampleRepoCfg) - require.NoError(t, err) - require.EqualValues( - t, connOptsMap, map[string]string{ - "sslmode": "disable", - }, - ) -} - -// The mapping should only fail if the config is malformed, not if it is missing -func TestMapConnOptionsMissing(t *testing.T) { - sampleCfg := RepoConfig{} - optsMap, err := mapFromConnOpts(sampleCfg) - require.NoError(t, err) - require.Empty(t, optsMap) -} - -func TestMapConnOptionsMalformedMap(t *testing.T) { - sampleCfg := RepoConfig{ - Advanced: map[string]any{ - // Let's put a map instead of the required list - configConnOpts: map[string]any{ - "testKey": "testValue", - }, - }, - } - _, err := mapFromConnOpts(sampleCfg) - require.Error(t, err) -} - -func TestMapConnOptionsMalformedColon(t *testing.T) { - sampleCfg := RepoConfig{ - Advanced: map[string]any{ - // Let's use a colon instead of '=' to divide options - configConnOpts: []string{"sslmode:disable"}, - }, - } - _, err := mapFromConnOpts(sampleCfg) - require.Error(t, err) -} - func initPostgresRepoTest(t *testing.T) (context.Context, *sql.DB, sqlmock.Sqlmock, *PostgresRepository) { ctx := context.Background() db, mock, err := sqlmock.New() @@ -91,12 +27,3 @@ func initPostgresRepoTest(t *testing.T) (context.Context, *sql.DB, sqlmock.Sqlmo generic: NewGenericRepositoryFromDB(RepoTypePostgres, "dbName", db), } } - -// Returns a correct repo config -func getSampleRepoConfig() RepoConfig { - return RepoConfig{ - Advanced: map[string]any{ - configConnOpts: []any{"sslmode=disable"}, - }, - } -} diff --git a/sql/redshift.go b/sql/redshift.go index ae847e8..2f7cf93 100644 --- a/sql/redshift.go +++ b/sql/redshift.go @@ -24,23 +24,18 @@ var _ Repository = (*RedshiftRepository)(nil) // NewRedshiftRepository creates a new RedshiftRepository. func NewRedshiftRepository(cfg RepoConfig) (*RedshiftRepository, error) { - pgCfg, err := parsePostgresConfig(cfg) - if err != nil { - return nil, fmt.Errorf("unable to parse postgres config: %w", err) - } database := cfg.Database // Connect to the default database, if unspecified. if database == "" { database = "dev" } connStr := fmt.Sprintf( - "postgresql://%s:%s@%s:%d/%s%s", + "postgresql://%s:%s@%s:%d/%s", cfg.User, cfg.Password, cfg.Host, cfg.Port, database, - pgCfg.ConnOptsStr, ) generic, err := NewGenericRepository(RepoTypePostgres, cfg.Database, connStr, cfg.MaxOpenConns) if err != nil { @@ -54,7 +49,7 @@ func NewRedshiftRepository(cfg RepoConfig) (*RedshiftRepository, error) { // GenericRepository.ListDatabasesWithQuery - see that method for more details. func (r *RedshiftRepository) ListDatabases(ctx context.Context) ([]string, error) { // Redshift and Postgres use the same query to list the server databases. - return r.generic.ListDatabasesWithQuery(ctx, PostgresDatabaseQuery) + return r.generic.ListDatabasesWithQuery(ctx, postgresDatabaseQuery) } // Introspect delegates introspection to GenericRepository. See diff --git a/sql/redshift_test.go b/sql/redshift_test.go index 3b45d75..4f1fadd 100644 --- a/sql/redshift_test.go +++ b/sql/redshift_test.go @@ -13,7 +13,7 @@ func TestRedshiftRepository_ListDatabases(t *testing.T) { ctx, db, mock, r := initRedshiftRepoTest(t) defer func() { _ = db.Close() }() dbRows := sqlmock.NewRows([]string{"name"}).AddRow("db1").AddRow("db2") - mock.ExpectQuery(PostgresDatabaseQuery).WillReturnRows(dbRows) + mock.ExpectQuery(postgresDatabaseQuery).WillReturnRows(dbRows) dbs, err := r.ListDatabases(ctx) require.NoError(t, err) require.ElementsMatch(t, []string{"db1", "db2"}, dbs) diff --git a/sql/sample.go b/sql/sample.go deleted file mode 100644 index 5497f7c..0000000 --- a/sql/sample.go +++ /dev/null @@ -1,190 +0,0 @@ -package sql - -import ( - "context" - "fmt" - "sync" - - "github.com/hashicorp/go-multierror" - log "github.com/sirupsen/logrus" - "golang.org/x/sync/semaphore" -) - -// sampleAndErr is a "pair" type intended to be passed to a channel (see -// sampleDb) -type sampleAndErr struct { - sample Sample - err error -} - -// samplesAndErr is a "pair" type intended to be passed to a channel (see -// sampleAllDbs) -type samplesAndErr struct { - samples []Sample - err error -} - -// sampleAllDbs uses the given Repository to list all the -// databases on the server, and samples each one in parallel by calling -// sampleDb for each database. The repository is intended to be -// configured to connect to the default database on the server, or at least some -// database which can be used to enumerate the full set of databases on the -// server. An error will be returned if the set of databases cannot be listed. -// If there is an error connecting to or sampling a database, the error will be -// logged and no samples will be returned for that database. Therefore, the -// returned slice of samples contains samples for only the databases which could -// be discovered and successfully sampled, and could potentially be empty if no -// databases were sampled. -func sampleAllDbs( - ctx context.Context, - ctor RepoConstructor, - cfg RepoConfig, - introspectParams IntrospectParameters, - sampleSize, offset uint, -) ( - []Sample, - error, -) { - // Create a repository instance that will be used to list all the databases - // on the server. - repo, err := ctor(ctx, cfg) - if err != nil { - return nil, fmt.Errorf("error creating repository instance: %w", err) - } - defer func() { _ = repo.Close() }() - - // We assume that this repository will be connected to the default database - // (or at least some database that can discover all the other databases), - // and we use that to discover all other databases. - dbs, err := repo.ListDatabases(ctx) - if err != nil { - return nil, fmt.Errorf("error listing databases: %w", err) - } - - // Sample each database on a separate goroutine, and send the samples to - // the 'out' channel. Each slice of samples will be aggregated below on the - // main goroutine and returned. - var wg sync.WaitGroup - out := make(chan samplesAndErr) - wg.Add(len(dbs)) - // Ensures that we avoid opening more than the specified number of - // connections. - var sema *semaphore.Weighted - if cfg.MaxOpenConns > 0 { - sema = semaphore.NewWeighted(int64(cfg.MaxOpenConns)) - } - for _, db := range dbs { - go func(db string, cfg RepoConfig) { - defer wg.Done() - if sema != nil { - _ = sema.Acquire(ctx, 1) - defer sema.Release(1) - } - // Create a repository instance for this specific database. It will - // be used to connect to and sample the database. - cfg.Database = db - repo, err := ctor(ctx, cfg) - if err != nil { - log.WithError(err).Errorf("error creating repository instance for database %s", db) - return - } - defer func() { _ = repo.Close() }() - // Sample the database. - s, err := sampleDb(ctx, repo, introspectParams, sampleSize, offset) - if err != nil && len(s) == 0 { - log.WithError(err).Errorf("error gathering repository data samples for database %s", db) - return - } - // Send the samples for this database to the 'out' channel. The - // samples for each database will be aggregated into a single slice - // on the main goroutine and returned. - out <- samplesAndErr{samples: s, err: err} - }(db, cfg) - } - - // Start a goroutine to close the 'out' channel once all the goroutines - // we launched above are done. This will allow the aggregation range loop - // below to terminate properly. Note that this must start after the wg.Add - // call. See https://go.dev/blog/pipelines ("Fan-out, fan-in" section). - go func() { - wg.Wait() - close(out) - }() - - // Aggregate and return the results. - var ret []Sample - var errs error - for res := range out { - ret = append(ret, res.samples...) - if res.err != nil { - errs = multierror.Append(errs, res.err) - } - } - return ret, errs -} - -// sampleDb is a helper function which will sample every table in a -// given repository and return them as a collection of Sample. First the -// repository is introspected by calling Introspect to return the -// repository metadata (Metadata). Then, for each schema and table in the -// metadata, it calls SampleTable in a new goroutine. Once all the -// sampling goroutines are finished, their results are collected and returned -// as a slice of Sample. -func sampleDb( - ctx context.Context, - repo Repository, - introspectParams IntrospectParameters, - sampleSize, offset uint, -) ( - []Sample, - error, -) { - // Introspect the repository to get the metadata. - meta, err := repo.Introspect(ctx, introspectParams) - if err != nil { - return nil, fmt.Errorf("error introspecting repository: %w", err) - } - - // Fan out sample executions. - out := make(chan sampleAndErr) - numTables := 0 - for _, schemaMeta := range meta.Schemas { - for _, tableMeta := range schemaMeta.Tables { - numTables++ - go func(meta *TableMetadata) { - params := SampleParameters{ - Metadata: meta, - SampleSize: sampleSize, - Offset: offset, - } - sample, err := repo.SampleTable(ctx, params) - select { - case <-ctx.Done(): - return - case out <- sampleAndErr{sample: sample, err: err}: - } - }(tableMeta) - } - } - - // Aggregate and return the results. - var samples []Sample - var errs error - for i := 0; i < numTables; i++ { - select { - case <-ctx.Done(): - return samples, ctx.Err() - case res:= <-out: - if res.err != nil { - errs = multierror.Append(errs, res.err) - } else { - samples = append(samples, res.sample) - } - } - } - close(out) - if errs != nil { - return samples, fmt.Errorf("error(s) while sampling repository: %w", errs) - } - return samples, nil -} diff --git a/sql/sample_test.go b/sql/sample_test.go deleted file mode 100644 index 811cfa5..0000000 --- a/sql/sample_test.go +++ /dev/null @@ -1,205 +0,0 @@ -package sql - -import ( - "context" - "errors" - "testing" - - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" -) - -var ( - table1Sample = Sample{ - TablePath: []string{"database", "schema1", "table1"}, - Results: []SampleResult{ - { - "name1": "foo", - "name2": "bar", - }, - { - "name1": "baz", - "name2": "qux", - }, - }, - } - - table2Sample = Sample{ - TablePath: []string{"database", "schema2", "table2"}, - Results: []SampleResult{ - { - "name3": "foo1", - "name4": "bar1", - }, - { - "name3": "baz1", - "name4": "qux1", - }, - }, - } -) - -func Test_sampleDb_Success(t *testing.T) { - ctx := context.Background() - repo := NewMockRepository(t) - meta := Metadata{ - Database: "database", - Schemas: map[string]*SchemaMetadata{ - "schema1": { - Name: "", - Tables: map[string]*TableMetadata{ - "table1": { - Schema: "schema1", - Name: "table1", - Attributes: []*AttributeMetadata{ - { - Schema: "schema1", - Table: "table1", - Name: "name1", - DataType: "varchar", - }, - { - Schema: "schema1", - Table: "table1", - Name: "name2", - DataType: "decimal", - }, - }, - }, - }, - }, - "schema2": { - Name: "", - Tables: map[string]*TableMetadata{ - "table2": { - Schema: "schema2", - Name: "table2", - Attributes: []*AttributeMetadata{ - { - Schema: "schema2", - Table: "table2", - Name: "name3", - DataType: "int", - }, - { - Schema: "schema2", - Table: "table2", - Name: "name4", - DataType: "timestamp", - }, - }, - }, - }, - }, - }, - } - repo.EXPECT().Introspect(ctx, mock.Anything).Return(&meta, nil) - sampleParams1 := SampleParameters{ - Metadata: meta.Schemas["schema1"].Tables["table1"], - } - sampleParams2 := SampleParameters{ - Metadata: meta.Schemas["schema2"].Tables["table2"], - } - repo.EXPECT().SampleTable(ctx, sampleParams1).Return(table1Sample, nil) - repo.EXPECT().SampleTable(ctx, sampleParams2).Return(table2Sample, nil) - samples, err := sampleDb(ctx, repo, IntrospectParameters{}, 0, 0) - require.NoError(t, err) - // Order is not important and is actually non-deterministic due to concurrency - expected := []Sample{table1Sample, table2Sample} - require.ElementsMatch(t, expected, samples) -} - -func Test_sampleDb_PartialError(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - repo := NewMockRepository(t) - meta := Metadata{ - Database: "database", - Schemas: map[string]*SchemaMetadata{ - "schema1": { - Name: "", - Tables: map[string]*TableMetadata{ - "table1": { - Schema: "schema1", - Name: "table1", - Attributes: []*AttributeMetadata{ - { - Schema: "schema1", - Table: "table1", - Name: "name1", - DataType: "varchar", - }, - { - Schema: "schema1", - Table: "table1", - Name: "name2", - DataType: "decimal", - }, - }, - }, - "forbidden": { - Schema: "schema1", - Name: "forbidden", - Attributes: []*AttributeMetadata{ - { - Schema: "schema1", - Table: "forbidden", - Name: "name1", - DataType: "varchar", - }, - { - Schema: "schema1", - Table: "forbidden", - Name: "name2", - DataType: "decimal", - }, - }, - }, - }, - }, - "schema2": { - Name: "", - Tables: map[string]*TableMetadata{ - "table2": { - Schema: "schema2", - Name: "table2", - Attributes: []*AttributeMetadata{ - { - Schema: "schema2", - Table: "table2", - Name: "name3", - DataType: "int", - }, - { - Schema: "schema2", - Table: "table2", - Name: "name4", - DataType: "timestamp", - }, - }, - }, - }, - }, - }, - } - repo.EXPECT().Introspect(ctx, mock.Anything).Return(&meta, nil) - sampleParams1 := SampleParameters{ - Metadata: meta.Schemas["schema1"].Tables["table1"], - } - sampleParams2 := SampleParameters{ - Metadata: meta.Schemas["schema1"].Tables["forbidden"], - } - sampleParamsForbidden := SampleParameters{ - Metadata: meta.Schemas["schema2"].Tables["table2"], - } - repo.EXPECT().SampleTable(ctx, sampleParams1).Return(table1Sample, nil) - errForbidden := errors.New("forbidden table") - repo.EXPECT().SampleTable(ctx, sampleParamsForbidden).Return(Sample{}, errForbidden) - repo.EXPECT().SampleTable(ctx, sampleParams2).Return(table2Sample, nil) - - samples, err := sampleDb(ctx, repo, IntrospectParameters{}, 0, 0) - require.ErrorIs(t, err, errForbidden) - // Order is not important and is actually non-deterministic due to concurrency - expected := []Sample{table1Sample, table2Sample} - require.ElementsMatch(t, expected, samples) -} diff --git a/sql/samplealldatabases_test.go b/sql/samplealldatabases_test.go deleted file mode 100644 index 831c7df..0000000 --- a/sql/samplealldatabases_test.go +++ /dev/null @@ -1,182 +0,0 @@ -package sql - -import ( - "context" - "errors" - "testing" - - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" -) - -func Test_sampleAllDbs_Error(t *testing.T) { - ctx := context.Background() - listDbErr := errors.New("error listing databases") - repo := NewMockRepository(t) - repo.EXPECT().ListDatabases(ctx).Return(nil, listDbErr) - repo.EXPECT().Close().Return(nil) - ctor := func(ctx context.Context, cfg RepoConfig) (Repository, error) { - return repo, nil - } - cfg := RepoConfig{} - samples, err := sampleAllDbs(ctx, ctor, cfg, IntrospectParameters{}, 0, 0) - require.Nil(t, samples) - require.ErrorIs(t, err, listDbErr) -} - -func Test_sampleAllDbs_Successful_TwoDatabases(t *testing.T) { - ctx := context.Background() - dbs := []string{"db1", "db2"} - // Dummy metadata returned for each Introspect call - meta := Metadata{ - Database: "db", - Schemas: map[string]*SchemaMetadata{ - "schema": { - Name: "schema", - Tables: map[string]*TableMetadata{ - "table": { - Schema: "schema", - Name: "table", - Attributes: []*AttributeMetadata{ - { - Schema: "schema", - Table: "table", - Name: "attr", - DataType: "string", - }, - }, - }, - }, - }, - }, - } - sample := Sample{ - TablePath: []string{"db", "schema", "table"}, - Results: []SampleResult{ - { - "attr": "foo", - }, - }, - } - repo := NewMockRepository(t) - repo.EXPECT().ListDatabases(ctx).Return(dbs, nil) - repo.EXPECT().Introspect(ctx, mock.Anything).Return(&meta, nil) - repo.EXPECT().SampleTable(ctx, mock.Anything).Return(sample, nil) - repo.EXPECT().Close().Return(nil) - ctor := func(ctx context.Context, cfg RepoConfig) (Repository, error) { - return repo, nil - } - samples, err := sampleAllDbs(ctx, ctor, RepoConfig{}, IntrospectParameters{}, 0, 0) - require.NoError(t, err) - // Two databases should be sampled, and our mock will return the sample for - // each sample call. This really just asserts that we've sampled the correct - // number of times. - require.ElementsMatch(t, samples, []Sample{sample, sample}) -} - -func Test_sampleAllDbs_IntrospectError(t *testing.T) { - ctx := context.Background() - dbs := []string{"db1", "db2"} - introspectErr := errors.New("introspect error") - repo := NewMockRepository(t) - repo.EXPECT().ListDatabases(ctx).Return(dbs, nil) - repo.EXPECT().Introspect(ctx, mock.Anything).Return(nil, introspectErr) - repo.EXPECT().Close().Return(nil) - ctor := func(ctx context.Context, cfg RepoConfig) (Repository, error) { - return repo, nil - } - samples, err := sampleAllDbs(ctx, ctor, RepoConfig{}, IntrospectParameters{}, 0, 0) - require.Empty(t, samples) - require.NoError(t, err) -} - -func Test_sampleAllDbs_SampleError(t *testing.T) { - ctx := context.Background() - dbs := []string{"db1", "db2"} - // Dummy metadata returned for each Introspect call - meta := Metadata{ - Database: "db", - Schemas: map[string]*SchemaMetadata{ - "schema": { - Name: "schema", - Tables: map[string]*TableMetadata{ - "table": { - Schema: "schema", - Name: "table", - Attributes: []*AttributeMetadata{ - { - Schema: "schema", - Table: "table", - Name: "attr", - DataType: "string", - }, - }, - }, - }, - }, - }, - } - sampleErr := errors.New("sample error") - repo := NewMockRepository(t) - repo.EXPECT().ListDatabases(ctx).Return(dbs, nil) - repo.EXPECT().Introspect(ctx, mock.Anything).Return(&meta, nil) - repo.EXPECT().SampleTable(ctx, mock.Anything).Return(Sample{}, sampleErr) - repo.EXPECT().Close().Return(nil) - ctor := func(ctx context.Context, cfg RepoConfig) (Repository, error) { - return repo, nil - } - samples, err := sampleAllDbs(ctx, ctor, RepoConfig{}, IntrospectParameters{}, 0, 0) - require.NoError(t, err) - require.Empty(t, samples) -} - -func Test_sampleAllDbs_TwoDatabases_OneSampleError(t *testing.T) { - ctx := context.Background() - dbs := []string{"db1", "db2"} - // Dummy metadata returned for each Introspect call - meta := Metadata{ - Database: "db", - Schemas: map[string]*SchemaMetadata{ - "schema": { - Name: "schema", - Tables: map[string]*TableMetadata{ - "table": { - Schema: "schema", - Name: "table", - Attributes: []*AttributeMetadata{ - { - Schema: "schema", - Table: "table", - Name: "attr", - DataType: "string", - }, - }, - }, - }, - }, - }, - } - sample := Sample{ - TablePath: []string{"db", "schema", "table"}, - Results: []SampleResult{ - { - "attr": "foo", - }, - }, - } - sampleErr := errors.New("sample error") - repo := NewMockRepository(t) - repo.EXPECT().ListDatabases(ctx).Return(dbs, nil) - repo.EXPECT().Introspect(ctx, mock.Anything).Return(&meta, nil) - repo.EXPECT().SampleTable(ctx, mock.Anything).Return(sample, nil).Once() - repo.EXPECT().SampleTable(ctx, mock.Anything).Return(Sample{}, sampleErr).Once() - repo.EXPECT().Close().Return(nil) - ctor := func(ctx context.Context, cfg RepoConfig) (Repository, error) { - return repo, nil - } - samples, err := sampleAllDbs(ctx, ctor, RepoConfig{}, IntrospectParameters{}, 0, 0) - require.NoError(t, err) - // Because of a single sample error, we expect only one database was - // sampled. - require.ElementsMatch(t, samples, []Sample{sample}) -} diff --git a/sql/scanner.go b/sql/scanner.go index 72e7710..e14454a 100644 --- a/sql/scanner.go +++ b/sql/scanner.go @@ -2,10 +2,15 @@ package sql import ( "context" + "errors" "fmt" + "maps" + "strings" + "sync" "github.com/gobwas/glob" log "github.com/sirupsen/logrus" + "golang.org/x/sync/semaphore" "github.com/cyralinc/dmap/classification" "github.com/cyralinc/dmap/scan" @@ -58,8 +63,23 @@ func NewScanner(cfg ScannerConfig) (*Scanner, error) { // repository, classifies the sampled data, and publishes the results to the // configured classification publisher. func (s *Scanner) Scan(ctx context.Context) (*scan.RepoScanResults, error) { - // Introspect and sample the data repository. - samples, err := s.sample(ctx) + // First introspect and sample the data repository. + var ( + samples []Sample + err error + ) + // Check if the user specified a single database, or told us to scan an + // Oracle DB. In that case, therefore we only need to sample that single + // database. Note that Oracle doesn't really have the concept of + // "databases", therefore a single repository instance will always scan the + // entire database. + if s.Config.RepoConfig.Database != "" || s.Config.RepoType == RepoTypeOracle { + samples, err = s.sampleDb(ctx, s.Config.RepoConfig.Database) + } else { + // The name of the database to connect to has been left unspecified by + // the user, so we try to connect and sample all databases instead. + samples, err = s.sampleAllDbs(ctx) + } if err != nil { msg := "error sampling repository" // If we didn't get any samples, just return the error. @@ -71,7 +91,7 @@ func (s *Scanner) Scan(ctx context.Context) (*scan.RepoScanResults, error) { log.WithError(err).Warn(msg) } // Classify the sampled data. - classifications, err := classifySamples(ctx, samples, s.classifier) + classifications, err := s.classifySamples(ctx, samples) if err != nil { return nil, fmt.Errorf("error classifying samples: %w", err) } @@ -81,41 +101,232 @@ func (s *Scanner) Scan(ctx context.Context) (*scan.RepoScanResults, error) { }, nil } -func (s *Scanner) sample(ctx context.Context) ([]Sample, error) { - // This closure is used to create a new repository instance for each - // database that is sampled. When there are multiple databases to sample, - // it is passed to sampleAllDbs to create the necessary repository instances - // for each database. When there is only a single database to sample, it is - // used directly below to create the repository instance for that database, - // which is passed to sampleDb to sample the database. - newRepo := func(ctx context.Context, cfg RepoConfig) (Repository, error) { - return s.Config.Registry.NewRepository(ctx, s.Config.RepoType, cfg) +// sampleDb is samples every table in a given database and returns the samples. +// The repository instance is created with the provided database name by +// newRepository. The database is then introspected by calling +// Repository.Introspect to return the repository metadata (Metadata). Then, for +// each schema and table in the metadata, it calls Repository.SampleTable in a +// new goroutine to sample all tables concurrently. Note however that the level +// of concurrency should be limited by the max number of open connections +// specified for the scanner, since the underlying repository should respect +// this across goroutines. This of course depends on the implementation, however +// for all the out-of-the-box Repository implementations, this applies. Once all +// the sampling goroutines are finished, their results are collected and +// returned as a slice of Sample. +func (s *Scanner) sampleDb(ctx context.Context, db string) ([]Sample, error) { + // Create the repository instance that will be used to sample the database. + cfg := s.Config.RepoConfig + cfg.Database = db + repo, err := s.newRepository(ctx, cfg) + if err != nil { + return nil, fmt.Errorf("error creating repository instance: %w", err) } + defer func() { _ = repo.Close() }() + // Introspect the repository to get the metadata. introspectParams := IntrospectParameters{ IncludePaths: s.Config.IncludePaths, ExcludePaths: s.Config.ExcludePaths, } - // Check if the user specified a single database, or told us to scan an - // Oracle DB. In that case, therefore we only need to sample that single - // database. Note that Oracle doesn't really have the concept of - // "databases", therefore a single repository instance will always scan the - // entire database. - if s.Config.RepoConfig.Database != "" || s.Config.RepoType == RepoTypeOracle { - repo, err := newRepo(ctx, s.Config.RepoConfig) - if err != nil { - return nil, fmt.Errorf("error creating repository: %w", err) + meta, err := repo.Introspect(ctx, introspectParams) + if err != nil { + return nil, fmt.Errorf("error introspecting repository: %w", err) + } + // This is a "pair" type intended to be passed to the channel below. + type sampleAndErr struct { + sample Sample + err error + } + // Fan out sample executions. + out := make(chan sampleAndErr) + numTables := 0 + for _, schemaMeta := range meta.Schemas { + for _, tableMeta := range schemaMeta.Tables { + numTables++ + go func(meta *TableMetadata) { + params := SampleParameters{ + Metadata: meta, + SampleSize: s.Config.SampleSize, + Offset: s.Config.Offset, + } + sample, err := repo.SampleTable(ctx, params) + select { + case <-ctx.Done(): + return + case out <- sampleAndErr{sample: sample, err: err}: + } + }(tableMeta) } - defer func() { _ = repo.Close() }() - return sampleDb(ctx, repo, introspectParams, s.Config.SampleSize, s.Config.Offset) - } - // The name of the database to connect to has been left unspecified by the - // user, so we try to connect and sample all databases instead. - return sampleAllDbs( - ctx, - newRepo, - s.Config.RepoConfig, - introspectParams, - s.Config.SampleSize, - s.Config.Offset, - ) + } + + // Aggregate and return the results. + var samples []Sample + var errs error + for i := 0; i < numTables; i++ { + select { + case <-ctx.Done(): + return samples, ctx.Err() + case res := <-out: + if res.err != nil { + errs = errors.Join(errs, res.err) + } else { + samples = append(samples, res.sample) + } + } + } + close(out) + if errs != nil { + return samples, fmt.Errorf("error(s) while sampling repository: %w", errs) + } + return samples, nil +} + +// sampleAllDbs samples all the databases on the server. It samples each +// database concurrently by calling sampleDb for each database on a new +// goroutine. It first creates a new Repository instance by calling +// newRepository. This repository is intended to be configured to connect to the +// default database on the server, or at least some database which can be used +// to enumerate the full set of databases on the server. An error will be +// returned if the set of databases cannot be listed. If there is an error +// connecting to or sampling a database, the error will be logged and no samples +// will be returned for that database. Therefore, the returned slice of samples +// contains samples for only the databases which could be discovered and +// successfully sampled, and could potentially be empty if no databases were +// sampled. +func (s *Scanner) sampleAllDbs(ctx context.Context) ([]Sample, error) { + // Create a repository instance that will be used to list all the databases + // on the server. + repo, err := s.newRepository(ctx, s.Config.RepoConfig) + if err != nil { + return nil, fmt.Errorf("error creating repository instance: %w", err) + } + defer func() { _ = repo.Close() }() + + // We assume that this repository will be connected to the default database + // (or at least some database that can discover all the other databases). + // Use it to discover all the other databases on the server. + dbs, err := repo.ListDatabases(ctx) + if err != nil { + return nil, fmt.Errorf("error listing databases: %w", err) + } + + // Sample each database on a separate goroutine, and send the samples to + // the 'out' channel. Each slice of samples will be aggregated below on this + // goroutine and returned. + var wg sync.WaitGroup + // This is a "pair" type intended to be passed to the channel below. + type samplesAndErr struct { + samples []Sample + err error + } + out := make(chan samplesAndErr) + wg.Add(len(dbs)) + // Using a semaphore here ensures that we avoid opening more than the + // specified total number of connections, since we end up creating multiple + // database handles (one per database). + var sema *semaphore.Weighted + if s.Config.RepoConfig.MaxOpenConns > 0 { + sema = semaphore.NewWeighted(int64(s.Config.RepoConfig.MaxOpenConns)) + } + for _, db := range dbs { + go func(db string, cfg RepoConfig) { + defer wg.Done() + if sema != nil { + _ = sema.Acquire(ctx, 1) + defer sema.Release(1) + } + // Sample this specific database. + samples, err := s.sampleDb(ctx, db) + if err != nil && len(samples) == 0 { + log.WithError(err).Errorf("error gathering repository data samples for database %s", db) + return + } + // Send the samples for this database to the 'out' channel. The + // samples for each database will be aggregated into a single slice + // on the main goroutine and returned. + select { + case <-ctx.Done(): + return + case out <- samplesAndErr{samples: samples, err: err}: + } + }(db, s.Config.RepoConfig) + } + + // Start a goroutine to close the 'out' channel once all the goroutines we + // launched above are done. This will allow the aggregation range loop below + // to terminate properly. Note that this must start after the wg.Add call. + // See https://go.dev/blog/pipelines ("Fan-out, fan-in" section). + go func() { wg.Wait(); close(out) }() + + // Aggregate and return the results. + var ret []Sample + var errs error + for { + select { + case <-ctx.Done(): + return ret, errors.Join(errs, ctx.Err()) + case res, ok := <-out: + if !ok { + // The 'out' channel has been closed, so we're done. + return ret, errs + } + ret = append(ret, res.samples...) + if res.err != nil { + errs = errors.Join(errs, res.err) + } + } + } +} + +// classifySamples uses the scanner's classifier to classify the provided slice +// of samples. Each sampled row is individually classified. The returned slice +// of classifications represents all the UNIQUE classifications for a given +// sample set. +func (s *Scanner) classifySamples(ctx context.Context, samples []Sample) ([]scan.Classification, error) { + uniqueClassifications := make(map[string]scan.Classification) + for _, sample := range samples { + // Classify each sampled row and combine the classifications. + for _, sampleResult := range sample.Results { + res, err := s.classifier.Classify(ctx, sampleResult) + if err != nil { + return nil, fmt.Errorf("error classifying sample: %w", err) + } + for attr, labels := range res { + attrPath := append(sample.TablePath, attr) + // U+2063 is an invisible separator. It is used here to ensure + // that the path key is unique and does not conflict with any of + // the path elements. + key := strings.Join(attrPath, "\u2063") + result, ok := uniqueClassifications[key] + if !ok { + uniqueClassifications[key] = scan.Classification{ + AttributePath: attrPath, + Labels: labels, + } + } else { + // Merge the labels from the new result into the existing result. + maps.Copy(result.Labels, labels) + } + } + } + } + // Convert the map of unique classifications to a slice. + classifications := make([]scan.Classification, 0, len(uniqueClassifications)) + for _, result := range uniqueClassifications { + classifications = append(classifications, result) + } + return classifications, nil +} + +// newRepository creates a new Repository instance with the provided +// configuration. It delegates the actual creation of the repository to the +// scanner's Registry.NewRepository method, using the scanner's RepoType and +// the provided configuration. You may wonder why we just don't use the +// scanner's repo configuration directly (i.e. s.Config.RepoConfig) instead of +// passing it as an argument. The reason is that we want to be able to create +// a new repository instance with a different configuration than the one +// specified in the scanner's configuration. This is useful when we want to +// sample a specific database, for example, and we want to create a new +// repository instance with the database name set to that specific database. +func (s *Scanner) newRepository(ctx context.Context, cfg RepoConfig) (Repository, error) { + return s.Config.Registry.NewRepository(ctx, s.Config.RepoType, cfg) } diff --git a/sql/scanner_test.go b/sql/scanner_test.go index e4b317b..d4d86a4 100644 --- a/sql/scanner_test.go +++ b/sql/scanner_test.go @@ -1 +1,643 @@ package sql + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/cyralinc/dmap/classification" + "github.com/cyralinc/dmap/scan" +) + +func TestScanner_sampleDb_Success(t *testing.T) { + ctx := context.Background() + repo := NewMockRepository(t) + meta := Metadata{ + Database: "database", + Schemas: map[string]*SchemaMetadata{ + "schema1": { + Name: "", + Tables: map[string]*TableMetadata{ + "table1": { + Schema: "schema1", + Name: "table1", + Attributes: []*AttributeMetadata{ + { + Schema: "schema1", + Table: "table1", + Name: "name1", + DataType: "varchar", + }, + { + Schema: "schema1", + Table: "table1", + Name: "name2", + DataType: "decimal", + }, + }, + }, + }, + }, + "schema2": { + Name: "", + Tables: map[string]*TableMetadata{ + "table2": { + Schema: "schema2", + Name: "table2", + Attributes: []*AttributeMetadata{ + { + Schema: "schema2", + Table: "table2", + Name: "name3", + DataType: "int", + }, + { + Schema: "schema2", + Table: "table2", + Name: "name4", + DataType: "timestamp", + }, + }, + }, + }, + }, + }, + } + table1Sample := Sample{ + TablePath: []string{"database", "schema1", "table1"}, + Results: []SampleResult{ + { + "name1": "foo", + "name2": "bar", + }, + { + "name1": "baz", + "name2": "qux", + }, + }, + } + table2Sample := Sample{ + TablePath: []string{"database", "schema2", "table2"}, + Results: []SampleResult{ + { + "name3": "foo1", + "name4": "bar1", + }, + { + "name3": "baz1", + "name4": "qux1", + }, + }, + } + + repo.EXPECT().Introspect(ctx, mock.Anything).Return(&meta, nil) + sampleParams1 := SampleParameters{ + Metadata: meta.Schemas["schema1"].Tables["table1"], + } + sampleParams2 := SampleParameters{ + Metadata: meta.Schemas["schema2"].Tables["table2"], + } + repo.EXPECT().SampleTable(ctx, sampleParams1).Return(table1Sample, nil) + repo.EXPECT().SampleTable(ctx, sampleParams2).Return(table2Sample, nil) + repo.EXPECT().Close().Return(nil) + repoType := "mock" + reg := NewRegistry() + reg.MustRegister( + repoType, + func(ctx context.Context, cfg RepoConfig) (Repository, error) { + return repo, nil + }, + ) + s := Scanner{ + Config: ScannerConfig{ + RepoType: repoType, + RepoConfig: RepoConfig{}, + Registry: reg, + }, + } + samples, err := s.sampleDb(ctx, meta.Database) + require.NoError(t, err) + // Order is not important and is actually non-deterministic due to concurrency + expected := []Sample{table1Sample, table2Sample} + require.ElementsMatch(t, expected, samples) +} + +func TestScanner_sampleDb_PartialError(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + repo := NewMockRepository(t) + meta := Metadata{ + Database: "database", + Schemas: map[string]*SchemaMetadata{ + "schema1": { + Name: "", + Tables: map[string]*TableMetadata{ + "table1": { + Schema: "schema1", + Name: "table1", + Attributes: []*AttributeMetadata{ + { + Schema: "schema1", + Table: "table1", + Name: "name1", + DataType: "varchar", + }, + { + Schema: "schema1", + Table: "table1", + Name: "name2", + DataType: "decimal", + }, + }, + }, + "forbidden": { + Schema: "schema1", + Name: "forbidden", + Attributes: []*AttributeMetadata{ + { + Schema: "schema1", + Table: "forbidden", + Name: "name1", + DataType: "varchar", + }, + { + Schema: "schema1", + Table: "forbidden", + Name: "name2", + DataType: "decimal", + }, + }, + }, + }, + }, + "schema2": { + Name: "", + Tables: map[string]*TableMetadata{ + "table2": { + Schema: "schema2", + Name: "table2", + Attributes: []*AttributeMetadata{ + { + Schema: "schema2", + Table: "table2", + Name: "name3", + DataType: "int", + }, + { + Schema: "schema2", + Table: "table2", + Name: "name4", + DataType: "timestamp", + }, + }, + }, + }, + }, + }, + } + table1Sample := Sample{ + TablePath: []string{"database", "schema1", "table1"}, + Results: []SampleResult{ + { + "name1": "foo", + "name2": "bar", + }, + { + "name1": "baz", + "name2": "qux", + }, + }, + } + table2Sample := Sample{ + TablePath: []string{"database", "schema2", "table2"}, + Results: []SampleResult{ + { + "name3": "foo1", + "name4": "bar1", + }, + { + "name3": "baz1", + "name4": "qux1", + }, + }, + } + + repo.EXPECT().Introspect(ctx, mock.Anything).Return(&meta, nil) + sampleParams1 := SampleParameters{ + Metadata: meta.Schemas["schema1"].Tables["table1"], + } + sampleParams2 := SampleParameters{ + Metadata: meta.Schemas["schema1"].Tables["forbidden"], + } + sampleParamsForbidden := SampleParameters{ + Metadata: meta.Schemas["schema2"].Tables["table2"], + } + repo.EXPECT().SampleTable(ctx, sampleParams1).Return(table1Sample, nil) + errForbidden := errors.New("forbidden table") + repo.EXPECT().SampleTable(ctx, sampleParamsForbidden).Return(Sample{}, errForbidden) + repo.EXPECT().SampleTable(ctx, sampleParams2).Return(table2Sample, nil) + repo.EXPECT().Close().Return(nil) + repoType := "mock" + reg := NewRegistry() + reg.MustRegister( + repoType, + func(ctx context.Context, cfg RepoConfig) (Repository, error) { + return repo, nil + }, + ) + s := Scanner{ + Config: ScannerConfig{ + RepoType: repoType, + RepoConfig: RepoConfig{}, + Registry: reg, + }, + } + samples, err := s.sampleDb(ctx, meta.Database) + require.ErrorIs(t, err, errForbidden) + // Order is not important and is actually non-deterministic due to concurrency + expected := []Sample{table1Sample, table2Sample} + require.ElementsMatch(t, expected, samples) +} + +func TestScanner_sampleAllDbs_Error(t *testing.T) { + ctx := context.Background() + listDbErr := errors.New("error listing databases") + repo := NewMockRepository(t) + repo.EXPECT().ListDatabases(ctx).Return(nil, listDbErr) + repo.EXPECT().Close().Return(nil) + repoType := "mock" + reg := NewRegistry() + reg.MustRegister( + repoType, + func(ctx context.Context, cfg RepoConfig) (Repository, error) { + return repo, nil + }, + ) + s := Scanner{ + Config: ScannerConfig{ + RepoType: repoType, + RepoConfig: RepoConfig{}, + Registry: reg, + }, + } + samples, err := s.sampleAllDbs(ctx) + require.Nil(t, samples) + require.ErrorIs(t, err, listDbErr) +} + +func TestScanner_sampleAllDbs_Successful_TwoDatabases(t *testing.T) { + ctx := context.Background() + dbs := []string{"db1", "db2"} + // Dummy metadata returned for each Introspect call + meta := Metadata{ + Database: "db", + Schemas: map[string]*SchemaMetadata{ + "schema": { + Name: "schema", + Tables: map[string]*TableMetadata{ + "table": { + Schema: "schema", + Name: "table", + Attributes: []*AttributeMetadata{ + { + Schema: "schema", + Table: "table", + Name: "attr", + DataType: "string", + }, + }, + }, + }, + }, + }, + } + sample := Sample{ + TablePath: []string{"db", "schema", "table"}, + Results: []SampleResult{ + { + "attr": "foo", + }, + }, + } + repo := NewMockRepository(t) + repo.EXPECT().ListDatabases(ctx).Return(dbs, nil) + repo.EXPECT().Introspect(ctx, mock.Anything).Return(&meta, nil) + repo.EXPECT().SampleTable(ctx, mock.Anything).Return(sample, nil) + repo.EXPECT().Close().Return(nil) + repoType := "mock" + reg := NewRegistry() + reg.MustRegister( + repoType, + func(ctx context.Context, cfg RepoConfig) (Repository, error) { + return repo, nil + }, + ) + s := Scanner{ + Config: ScannerConfig{ + RepoType: repoType, + RepoConfig: RepoConfig{}, + Registry: reg, + }, + } + samples, err := s.sampleAllDbs(ctx) + require.NoError(t, err) + // Two databases should be sampled, and our mock will return the sample for + // each sample call. This really just asserts that we've sampled the correct + // number of times. + require.ElementsMatch(t, samples, []Sample{sample, sample}) +} + +func TestScanner_sampleAllDbs_IntrospectError(t *testing.T) { + ctx := context.Background() + dbs := []string{"db1", "db2"} + introspectErr := errors.New("introspect error") + repo := NewMockRepository(t) + repo.EXPECT().ListDatabases(ctx).Return(dbs, nil) + repo.EXPECT().Introspect(ctx, mock.Anything).Return(nil, introspectErr) + repo.EXPECT().Close().Return(nil) + repoType := "mock" + reg := NewRegistry() + reg.MustRegister( + repoType, + func(ctx context.Context, cfg RepoConfig) (Repository, error) { + return repo, nil + }, + ) + s := Scanner{ + Config: ScannerConfig{ + RepoType: repoType, + RepoConfig: RepoConfig{}, + Registry: reg, + }, + } + samples, err := s.sampleAllDbs(ctx) + require.Empty(t, samples) + require.NoError(t, err) +} + +func TestScanner_sampleAllDbs_SampleError(t *testing.T) { + ctx := context.Background() + dbs := []string{"db1", "db2"} + // Dummy metadata returned for each Introspect call + meta := Metadata{ + Database: "db", + Schemas: map[string]*SchemaMetadata{ + "schema": { + Name: "schema", + Tables: map[string]*TableMetadata{ + "table": { + Schema: "schema", + Name: "table", + Attributes: []*AttributeMetadata{ + { + Schema: "schema", + Table: "table", + Name: "attr", + DataType: "string", + }, + }, + }, + }, + }, + }, + } + sampleErr := errors.New("sample error") + repo := NewMockRepository(t) + repo.EXPECT().ListDatabases(ctx).Return(dbs, nil) + repo.EXPECT().Introspect(ctx, mock.Anything).Return(&meta, nil) + repo.EXPECT().SampleTable(ctx, mock.Anything).Return(Sample{}, sampleErr) + repo.EXPECT().Close().Return(nil) + repoType := "mock" + reg := NewRegistry() + reg.MustRegister( + repoType, + func(ctx context.Context, cfg RepoConfig) (Repository, error) { + return repo, nil + }, + ) + s := Scanner{ + Config: ScannerConfig{ + RepoType: repoType, + RepoConfig: RepoConfig{}, + Registry: reg, + }, + } + samples, err := s.sampleAllDbs(ctx) + require.NoError(t, err) + require.Empty(t, samples) +} + +func TestScanner_sampleAllDbs_TwoDatabases_OneSampleError(t *testing.T) { + ctx := context.Background() + dbs := []string{"db1", "db2"} + // Dummy metadata returned for each Introspect call + meta := Metadata{ + Database: "db", + Schemas: map[string]*SchemaMetadata{ + "schema": { + Name: "schema", + Tables: map[string]*TableMetadata{ + "table": { + Schema: "schema", + Name: "table", + Attributes: []*AttributeMetadata{ + { + Schema: "schema", + Table: "table", + Name: "attr", + DataType: "string", + }, + }, + }, + }, + }, + }, + } + sample := Sample{ + TablePath: []string{"db", "schema", "table"}, + Results: []SampleResult{ + { + "attr": "foo", + }, + }, + } + sampleErr := errors.New("sample error") + repo := NewMockRepository(t) + repo.EXPECT().ListDatabases(ctx).Return(dbs, nil) + repo.EXPECT().Introspect(ctx, mock.Anything).Return(&meta, nil) + repo.EXPECT().SampleTable(ctx, mock.Anything).Return(sample, nil).Once() + repo.EXPECT().SampleTable(ctx, mock.Anything).Return(Sample{}, sampleErr).Once() + repo.EXPECT().Close().Return(nil) + repoType := "mock" + reg := NewRegistry() + reg.MustRegister( + repoType, + func(ctx context.Context, cfg RepoConfig) (Repository, error) { + return repo, nil + }, + ) + s := Scanner{ + Config: ScannerConfig{ + RepoType: repoType, + RepoConfig: RepoConfig{}, + Registry: reg, + }, + } + samples, err := s.sampleAllDbs(ctx) + require.NoError(t, err) + // Because of a single sample error, we expect only one database was + // sampled. + require.ElementsMatch(t, samples, []Sample{sample}) +} + +func TestScanner_classifySamples_SingleSample(t *testing.T) { + ctx := context.Background() + sample := Sample{ + TablePath: []string{"db", "schema", "table"}, + Results: []SampleResult{ + { + "age": "52", + "social_sec_num": "512-23-4258", + "credit_card_num": "4111111111111111", + }, + { + "age": "101", + "social_sec_num": "foobarbaz", + "credit_card_num": "4111111111111111", + }, + }, + } + classifier := NewMockClassifier(t) + // Need to explicitly convert it to a map because Mockery isn't smart enough + // to infer the type. + classifier.EXPECT().Classify(ctx, map[string]any(sample.Results[0])).Return( + classification.Result{ + "age": lblSet("AGE"), + "social_sec_num": lblSet("SSN"), + "credit_card_num": lblSet("CCN"), + }, + nil, + ) + classifier.EXPECT().Classify(ctx, map[string]any(sample.Results[1])).Return( + classification.Result{ + "age": lblSet("AGE", "CVV"), + "credit_card_num": lblSet("CCN"), + }, + nil, + ) + + expected := []scan.Classification{ + { + AttributePath: append(sample.TablePath, "age"), + Labels: lblSet("AGE", "CVV"), + }, + { + AttributePath: append(sample.TablePath, "social_sec_num"), + Labels: lblSet("SSN"), + }, + { + AttributePath: append(sample.TablePath, "credit_card_num"), + Labels: lblSet("CCN"), + }, + } + s := Scanner{classifier: classifier} + actual, err := s.classifySamples(ctx, []Sample{sample}) + require.NoError(t, err) + require.ElementsMatch(t, expected, actual) +} + +func TestScanner_classifySamples_MultipleSamples(t *testing.T) { + ctx := context.Background() + samples := []Sample{ + { + TablePath: []string{"db1", "schema1", "table1"}, + Results: []SampleResult{ + { + "age": "52", + "social_sec_num": "512-23-4258", + "credit_card_num": "4111111111111111", + }, + { + "age": "101", + "social_sec_num": "foobarbaz", + "credit_card_num": "4111111111111111", + }, + }, + }, + { + TablePath: []string{"db2", "schema2", "table2"}, + Results: []SampleResult{ + { + "fullname": "John Doe", + "dob": "2000-01-01", + "random": "foobarbaz", + }, + }, + }, + } + + classifier := NewMockClassifier(t) + // Need to explicitly convert it to a map because Mockery isn't smart enough + // to infer the type. + classifier.EXPECT().Classify(ctx, map[string]any(samples[0].Results[0])).Return( + classification.Result{ + "age": lblSet("AGE"), + "social_sec_num": lblSet("SSN"), + "credit_card_num": lblSet("CCN"), + }, + nil, + ) + classifier.EXPECT().Classify(ctx, map[string]any(samples[0].Results[1])).Return( + classification.Result{ + "age": lblSet("AGE", "CVV"), + "credit_card_num": lblSet("CCN"), + }, + nil, + ) + classifier.EXPECT().Classify(ctx, map[string]any(samples[1].Results[0])).Return( + classification.Result{ + "fullname": lblSet("FULL_NAME"), + "dob": lblSet("DOB"), + }, + nil, + ) + + expected := []scan.Classification{ + { + AttributePath: append(samples[0].TablePath, "age"), + Labels: lblSet("AGE", "CVV"), + }, + { + AttributePath: append(samples[0].TablePath, "social_sec_num"), + Labels: lblSet("SSN"), + }, + { + AttributePath: append(samples[0].TablePath, "credit_card_num"), + Labels: lblSet("CCN"), + }, + { + AttributePath: append(samples[1].TablePath, "fullname"), + Labels: lblSet("FULL_NAME"), + }, + { + AttributePath: append(samples[1].TablePath, "dob"), + Labels: lblSet("DOB"), + }, + } + s := Scanner{classifier: classifier} + actual, err := s.classifySamples(ctx, samples) + require.NoError(t, err) + require.ElementsMatch(t, expected, actual) +} + +func lblSet(labels ...string) classification.LabelSet { + set := make(classification.LabelSet) + for _, label := range labels { + set[label] = struct { + }{} + } + return set +} diff --git a/sql/snowflake.go b/sql/snowflake.go index d788013..848f888 100644 --- a/sql/snowflake.go +++ b/sql/snowflake.go @@ -10,7 +10,7 @@ import ( const ( RepoTypeSnowflake = "snowflake" - SnowflakeDatabaseQuery = ` + snowflakeDatabaseQuery = ` SELECT DATABASE_NAME FROM @@ -35,7 +35,7 @@ var _ Repository = (*SnowflakeRepository)(nil) // NewSnowflakeRepository creates a new SnowflakeRepository. func NewSnowflakeRepository(cfg RepoConfig) (*SnowflakeRepository, error) { - snowflakeCfg, err := ParseSnowflakeConfig(cfg) + snowflakeCfg, err := NewSnowflakeConfigFromMap(cfg.Advanced) if err != nil { return nil, fmt.Errorf("error parsing snowflake config: %w", err) } @@ -64,7 +64,7 @@ func NewSnowflakeRepository(cfg RepoConfig) (*SnowflakeRepository, error) { // using a Snowflake-specific database query. It delegates the actual work to // GenericRepository.ListDatabasesWithQuery - see that method for more details. func (r *SnowflakeRepository) ListDatabases(ctx context.Context) ([]string, error) { - return r.generic.ListDatabasesWithQuery(ctx, SnowflakeDatabaseQuery) + return r.generic.ListDatabasesWithQuery(ctx, snowflakeDatabaseQuery) } // Introspect delegates introspection to GenericRepository. See @@ -105,21 +105,25 @@ type SnowflakeConfig struct { Warehouse string } -// ParseSnowflakeConfig produces a config structure with Snowflake-specific -// parameters found in the repo The Snowflake account, role, and -// warehouse are required in the advanced -func ParseSnowflakeConfig(cfg RepoConfig) (*SnowflakeConfig, error) { - snowflakeCfg, err := FetchAdvancedConfigString( - cfg, - RepoTypeSnowflake, - []string{configAccount, configRole, configWarehouse}, - ) +// NewSnowflakeConfigFromMap creates a new SnowflakeConfig from the given map. +// This is useful for parsing the Snowflake-specific configuration from the +// RepoConfig.Advanced map, for example. +func NewSnowflakeConfigFromMap(cfg map[string]any) (SnowflakeConfig, error) { + acct, err := keyAsString(cfg, configAccount) + if err != nil { + return SnowflakeConfig{}, err + } + role, err := keyAsString(cfg, configRole) + if err != nil { + return SnowflakeConfig{}, err + } + warehouse, err := keyAsString(cfg, configWarehouse) if err != nil { - return nil, fmt.Errorf("error fetching advanced config string: %w", err) + return SnowflakeConfig{}, err } - return &SnowflakeConfig{ - Account: snowflakeCfg[configAccount], - Role: snowflakeCfg[configRole], - Warehouse: snowflakeCfg[configWarehouse], + return SnowflakeConfig{ + Account: acct, + Role: role, + Warehouse: warehouse, }, nil } diff --git a/sql/snowflake_test.go b/sql/snowflake_test.go index bfbc97d..f5a4f70 100644 --- a/sql/snowflake_test.go +++ b/sql/snowflake_test.go @@ -13,12 +13,71 @@ func TestSnowflakeRepository_ListDatabases(t *testing.T) { ctx, db, mock, r := initSnowflakeRepoTest(t) defer func() { _ = db.Close() }() dbRows := sqlmock.NewRows([]string{"name"}).AddRow("db1").AddRow("db2") - mock.ExpectQuery(SnowflakeDatabaseQuery).WillReturnRows(dbRows) + mock.ExpectQuery(snowflakeDatabaseQuery).WillReturnRows(dbRows) dbs, err := r.ListDatabases(ctx) require.NoError(t, err) require.ElementsMatch(t, []string{"db1", "db2"}, dbs) } +func TestNewSnowflakeConfigFromMap(t *testing.T) { + tests := []struct { + name string + cfg map[string]any + want SnowflakeConfig + wantErr require.ErrorAssertionFunc + }{ + { + name: "Returns config when all keys are present", + cfg: map[string]any{ + configAccount: "testAccount", + configRole: "testRole", + configWarehouse: "testWarehouse", + }, + want: SnowflakeConfig{ + Account: "testAccount", + Role: "testRole", + Warehouse: "testWarehouse", + }, + }, + { + name: "Returns error when account key is missing", + cfg: map[string]any{ + configRole: "testRole", + configWarehouse: "testWarehouse", + }, + wantErr: require.Error, + }, + { + name: "Returns error when role key is missing", + cfg: map[string]any{ + configAccount: "testAccount", + configWarehouse: "testWarehouse", + }, + wantErr: require.Error, + }, + { + name: "Returns error when warehouse key is missing", + cfg: map[string]any{ + configAccount: "testAccount", + configRole: "testRole", + }, + wantErr: require.Error, + }, + } + for _, tt := range tests { + t.Run( + tt.name, func(t *testing.T) { + got, err := NewSnowflakeConfigFromMap(tt.cfg) + if tt.wantErr == nil { + tt.wantErr = require.NoError + } + tt.wantErr(t, err) + require.Equal(t, tt.want, got) + }, + ) + } +} + func initSnowflakeRepoTest(t *testing.T) (context.Context, *sql.DB, sqlmock.Sqlmock, *SnowflakeRepository) { ctx := context.Background() db, mock, err := sqlmock.New() diff --git a/sql/sqlserver.go b/sql/sqlserver.go index de8d501..b866689 100644 --- a/sql/sqlserver.go +++ b/sql/sqlserver.go @@ -10,30 +10,30 @@ import ( const ( RepoTypeSqlServer = "sqlserver" - // SqlServerSampleQueryTemplate is the string template for the SQL query used to + // sqlServerSampleQueryTemplate is the string template for the SQL query used to // sample a SQL Server database. SQL Server doesn't support limit/offset, so // we use top. It also uses the @ symbol for statement parameter // placeholders. It is intended to be templated by the database name to // query. - SqlServerSampleQueryTemplate = `SELECT TOP (@p1) %s FROM "%s"."%s"` - // SqlServerDatabaseQuery is the query to list all the databases on the server, minus + sqlServerSampleQueryTemplate = `SELECT TOP (@p1) %s FROM "%s"."%s"` + // sqlServerDatabaseQuery is the query to list all the databases on the server, minus // the system default databases 'model' and 'tempdb'. - SqlServerDatabaseQuery = "SELECT name FROM sys.databases WHERE name != 'model' AND name != 'tempdb'" + sqlServerDatabaseQuery = "SELECT name FROM sys.databases WHERE name != 'model' AND name != 'tempdb'" ) -// SQLServerRepository is a Repository implementation for MS SQL Server +// SqlServerRepository is a Repository implementation for MS SQL Server // databases. -type SQLServerRepository struct { +type SqlServerRepository struct { // The majority of the Repository functionality is delegated to a generic // SQL repository instance. generic *GenericRepository } -// SQLServerRepository implements Repository -var _ Repository = (*SQLServerRepository)(nil) +// SqlServerRepository implements Repository +var _ Repository = (*SqlServerRepository)(nil) // NewSqlServerRepository creates a new MS SQL Server sql. -func NewSqlServerRepository(cfg RepoConfig) (*SQLServerRepository, error) { +func NewSqlServerRepository(cfg RepoConfig) (*SqlServerRepository, error) { connStr := fmt.Sprintf( "sqlserver://%s:%s@%s:%d", cfg.User, @@ -49,44 +49,44 @@ func NewSqlServerRepository(cfg RepoConfig) (*SQLServerRepository, error) { if err != nil { return nil, fmt.Errorf("could not instantiate generic sql repository: %w", err) } - return &SQLServerRepository{generic: generic}, nil + return &SqlServerRepository{generic: generic}, nil } // ListDatabases returns a list of the names of all databases on the server by // using a SQL Server-specific database query. It delegates the actual work to // GenericRepository.ListDatabasesWithQuery - see that method for more details. -func (r *SQLServerRepository) ListDatabases(ctx context.Context) ([]string, error) { - return r.generic.ListDatabasesWithQuery(ctx, SqlServerDatabaseQuery) +func (r *SqlServerRepository) ListDatabases(ctx context.Context) ([]string, error) { + return r.generic.ListDatabasesWithQuery(ctx, sqlServerDatabaseQuery) } // Introspect delegates introspection to GenericRepository. See // Repository.Introspect and GenericRepository.IntrospectWithQuery for more // details. -func (r *SQLServerRepository) Introspect(ctx context.Context, params IntrospectParameters) (*Metadata, error) { +func (r *SqlServerRepository) Introspect(ctx context.Context, params IntrospectParameters) (*Metadata, error) { return r.generic.Introspect(ctx, params) } // SampleTable delegates sampling to GenericRepository, using a // SQL Server-specific table sample query. See Repository.SampleTable and // GenericRepository.SampleTableWithQuery for more details. -func (r *SQLServerRepository) SampleTable( +func (r *SqlServerRepository) SampleTable( ctx context.Context, params SampleParameters, ) (Sample, error) { // Sqlserver uses double-quotes to quote identifiers attrStr := params.Metadata.QuotedAttributeNamesString("\"") - query := fmt.Sprintf(SqlServerSampleQueryTemplate, attrStr, params.Metadata.Schema, params.Metadata.Name) + query := fmt.Sprintf(sqlServerSampleQueryTemplate, attrStr, params.Metadata.Schema, params.Metadata.Name) return r.generic.SampleTableWithQuery(ctx, query, params) } // Ping delegates the ping to GenericRepository. See Repository.Ping and // GenericRepository.Ping for more details. -func (r *SQLServerRepository) Ping(ctx context.Context) error { +func (r *SqlServerRepository) Ping(ctx context.Context) error { return r.generic.Ping(ctx) } // Close delegates the close to GenericRepository. See Repository.Close and // GenericRepository.Close for more details. -func (r *SQLServerRepository) Close() error { +func (r *SqlServerRepository) Close() error { return r.generic.Close() } diff --git a/sql/sqlserver_test.go b/sql/sqlserver_test.go index 223dfec..7fbeb4f 100644 --- a/sql/sqlserver_test.go +++ b/sql/sqlserver_test.go @@ -13,17 +13,17 @@ func TestSqlServerRepository_ListDatabases(t *testing.T) { ctx, db, mock, r := initSqlServerRepoTest(t) defer func() { _ = db.Close() }() dbRows := sqlmock.NewRows([]string{"name"}).AddRow("db1").AddRow("db2") - mock.ExpectQuery(SqlServerDatabaseQuery).WillReturnRows(dbRows) + mock.ExpectQuery(sqlServerDatabaseQuery).WillReturnRows(dbRows) dbs, err := r.ListDatabases(ctx) require.NoError(t, err) require.ElementsMatch(t, []string{"db1", "db2"}, dbs) } -func initSqlServerRepoTest(t *testing.T) (context.Context, *sql.DB, sqlmock.Sqlmock, *SQLServerRepository) { +func initSqlServerRepoTest(t *testing.T) (context.Context, *sql.DB, sqlmock.Sqlmock, *SqlServerRepository) { ctx := context.Background() db, mock, err := sqlmock.New() require.NoError(t, err) - return ctx, db, mock, &SQLServerRepository{ + return ctx, db, mock, &SqlServerRepository{ generic: NewGenericRepositoryFromDB(RepoTypeSqlServer, "dbName", db), } }