diff --git a/cmd/apply.go b/cmd/apply.go index b7b3ef3..f4540aa 100644 --- a/cmd/apply.go +++ b/cmd/apply.go @@ -24,9 +24,8 @@ import ( "fmt" "io/ioutil" - "github.com/spf13/cobra" - "github.com/cloudspannerecosystem/wrench/pkg/spanner" + "github.com/spf13/cobra" ) var ( @@ -53,7 +52,7 @@ func apply(c *cobra.Command, _ []string) error { if ddlFile != "" { if dmlFile != "" { - return errors.New("Cannot specify DDL and DML at same time.") + return errors.New("cannot specify DDL and DML at same time") } ddl, err := ioutil.ReadFile(ddlFile) @@ -64,7 +63,7 @@ func apply(c *cobra.Command, _ []string) error { } } - err = client.ApplyDDLFile(ctx, ddl) + err = client.ApplyDDLFile(ctx, ddlFile, ddl) if err != nil { return &Error{ err: err, @@ -96,7 +95,7 @@ func apply(c *cobra.Command, _ []string) error { } } - numAffectedRows, err := client.ApplyDMLFile(ctx, dml, partitioned, p) + numAffectedRows, err := client.ApplyDMLFile(ctx, dmlFile, dml, partitioned, p) if err != nil { return &Error{ err: err, @@ -130,7 +129,6 @@ func priorityTypeOf(prioirty string) (spanner.PriorityType, error) { priority, priorityTypeHigh, priorityTypeMedium, priorityTypeLow, ) } - } func init() { diff --git a/cmd/create.go b/cmd/create.go index 0601400..b747899 100644 --- a/cmd/create.go +++ b/cmd/create.go @@ -40,7 +40,8 @@ func create(c *cobra.Command, _ []string) error { } defer client.Close() - ddl, err := ioutil.ReadFile(schemaFilePath(c)) + filename := schemaFilePath(c) + ddl, err := ioutil.ReadFile(filename) if err != nil { return &Error{ err: err, @@ -48,7 +49,7 @@ func create(c *cobra.Command, _ []string) error { } } - err = client.CreateDatabase(ctx, ddl) + err = client.CreateDatabase(ctx, filename, ddl) if err != nil { return &Error{ err: err, diff --git a/pkg/spanner/BUILD.bazel b/pkg/spanner/BUILD.bazel index 024b9d7..6f2358e 100644 --- a/pkg/spanner/BUILD.bazel +++ b/pkg/spanner/BUILD.bazel @@ -16,6 +16,7 @@ go_library( "@com_google_cloud_go_spanner//:go_default_library", "@com_google_cloud_go_spanner//admin/database/apiv1:go_default_library", "@com_google_cloud_go_spanner//admin/instance/apiv1:go_default_library", + "@com_google_cloud_go_spanner//spansql:go_default_library", "@org_golang_google_api//iterator:go_default_library", "@org_golang_google_api//option:go_default_library", "@org_golang_google_genproto//googleapis/spanner/admin/database/v1:go_default_library", diff --git a/pkg/spanner/client.go b/pkg/spanner/client.go index 81e058c..e72dad5 100644 --- a/pkg/spanner/client.go +++ b/pkg/spanner/client.go @@ -78,8 +78,14 @@ func NewClient(ctx context.Context, config *Config) (*Client, error) { }, nil } -func (c *Client) CreateDatabase(ctx context.Context, ddl []byte) error { - statements := toStatements(ddl) +func (c *Client) CreateDatabase(ctx context.Context, filename string, ddl []byte) error { + statements, err := ddlToStatements(filename, ddl) + if err != nil { + return &Error{ + Code: ErrorCodeLoadSchema, + err: err, + } + } createReq := &databasepb.CreateDatabaseRequest{ Parent: fmt.Sprintf("projects/%s/instances/%s", c.config.Project, c.config.Instance), @@ -189,8 +195,13 @@ func (c *Client) LoadDDL(ctx context.Context) ([]byte, error) { return schema, nil } -func (c *Client) ApplyDDLFile(ctx context.Context, ddl []byte) error { - return c.ApplyDDL(ctx, toStatements(ddl)) +func (c *Client) ApplyDDLFile(ctx context.Context, filename string, ddl []byte) error { + statements, err := ddlToStatements(filename, ddl) + if err != nil { + return err + } + + return c.ApplyDDL(ctx, statements) } func (c *Client) ApplyDDL(ctx context.Context, statements []string) error { @@ -227,8 +238,11 @@ const ( PriorityTypeLow ) -func (c *Client) ApplyDMLFile(ctx context.Context, ddl []byte, partitioned bool, priority PriorityType) (int64, error) { - statements := toStatements(ddl) +func (c *Client) ApplyDMLFile(ctx context.Context, filename string, ddl []byte, partitioned bool, priority PriorityType) (int64, error) { + statements, err := dmlToStatements(filename, ddl) + if err != nil { + return 0, err + } if partitioned { return c.ApplyPartitionedDML(ctx, statements, priority) diff --git a/pkg/spanner/client_test.go b/pkg/spanner/client_test.go index 99fbaa4..f552ddd 100644 --- a/pkg/spanner/client_test.go +++ b/pkg/spanner/client_test.go @@ -93,7 +93,7 @@ func TestApplyDDLFile(t *testing.T) { client, done := testClientWithDatabase(t, ctx) defer done() - if err := client.ApplyDDLFile(ctx, ddl); err != nil { + if err := client.ApplyDDLFile(ctx, "testdata/ddl.sql", ddl); err != nil { t.Fatalf("failed to apply ddl file: %v", err) } @@ -172,7 +172,7 @@ func TestApplyDMLFile(t *testing.T) { t.Fatalf("failed to read dml file: %v", err) } - n, err := client.ApplyDMLFile(ctx, dml, test.partitioned, test.priority) + n, err := client.ApplyDMLFile(ctx, "testdata/dml.sql", dml, test.partitioned, test.priority) if err != nil { t.Fatalf("failed to apply dml file: %v", err) } @@ -470,7 +470,7 @@ func testClientWithDatabase(t *testing.T, ctx context.Context) (*Client, func()) t.Fatalf("failed to read schema file: %v", err) } - if err := client.CreateDatabase(ctx, ddl); err != nil { + if err := client.CreateDatabase(ctx, "testdata/schema.sql", ddl); err != nil { t.Fatalf("failed to create database: %v", err) } diff --git a/pkg/spanner/migration.go b/pkg/spanner/migration.go index fef0804..f4f084f 100644 --- a/pkg/spanner/migration.go +++ b/pkg/spanner/migration.go @@ -20,18 +20,14 @@ package spanner import ( - "bytes" "errors" "fmt" "io/ioutil" "path/filepath" "regexp" "strconv" - "strings" -) -const ( - statementsSeparator = ";" + "cloud.google.com/go/spanner/spansql" ) var ( @@ -115,7 +111,15 @@ func LoadMigrations(dir string) (Migrations, error) { continue } - statements := toStatements(file) + statements, err := ddlToStatements(f.Name(), file) + if err != nil { + nstatements, nerr := dmlToStatements(f.Name(), file) + if nerr != nil { + return nil, errors.New("failed to parse DDL/DML statements") + } + statements = nstatements + } + kind, err := inspectStatementsKind(statements) if err != nil { return nil, err @@ -137,17 +141,32 @@ func LoadMigrations(dir string) (Migrations, error) { return migrations, nil } -func toStatements(file []byte) []string { - contents := bytes.Split(file, []byte(statementsSeparator)) +func ddlToStatements(filename string, data []byte) ([]string, error) { + ddl, err := spansql.ParseDDL(filename, string(data)) + if err != nil { + return nil, err + } - statements := make([]string, 0, len(contents)) - for _, c := range contents { - if statement := strings.TrimSpace(string(c)); statement != "" { - statements = append(statements, statement) - } + var statements []string + for _, stmt := range ddl.List { + statements = append(statements, stmt.SQL()) + } + + return statements, nil +} + +func dmlToStatements(filename string, data []byte) ([]string, error) { + dml, err := spansql.ParseDML(filename, string(data)) + if err != nil { + return nil, err + } + + var statements []string + for _, stmt := range dml.List { + statements = append(statements, stmt.SQL()) } - return statements + return statements, nil } func inspectStatementsKind(statements []string) (statementKind, error) { @@ -166,7 +185,7 @@ func inspectStatementsKind(statements []string) (statementKind, error) { if kindMap[statementKindDML] > 0 { if kindMap[statementKindDDL] > 0 { - return "", errors.New("Cannot specify DDL and DML at same migration file.") + return "", errors.New("cannot specify DDL and DML at same migration file") } return statementKindDML, nil