diff --git a/fixtures/postgres/postgres.go b/fixtures/postgres/postgres.go index dd2ff9f..3983d59 100644 --- a/fixtures/postgres/postgres.go +++ b/fixtures/postgres/postgres.go @@ -210,48 +210,54 @@ func (f *LoaderPostgres) loadTables(ctx *loadContext) error { defer func() { _ = tx.Rollback() }() // truncate first - truncatedTables := make(map[string]bool) - for _, lt := range ctx.tables { - if _, ok := truncatedTables[lt.name.getFullName()]; ok { - // already truncated - continue - } - if err := f.truncateTable(lt.name); err != nil { - return err - } - truncatedTables[lt.name.getFullName()] = true + if err := f.truncateTables(tx, ctx.tables...); err != nil { + return err } + // then load data for _, lt := range ctx.tables { if len(lt.rows) == 0 { continue } - if err := f.loadTable(ctx, lt.name, lt.rows); err != nil { - return fmt.Errorf("failed to load table '%s' because:\n%s", lt.name, err) + if err := f.loadTable(ctx, tx, lt.name, lt.rows); err != nil { + return fmt.Errorf("failed to load table '%s' because:\n%s", lt.name.getFullName(), err) } } // alter the sequences so they contain max id + 1 - if err := f.fixSequences(); err != nil { + if err := f.fixSequences(tx); err != nil { return err } return tx.Commit() } -// truncateTable truncates table -func (f *LoaderPostgres) truncateTable(name tableName) error { - query := fmt.Sprintf("TRUNCATE TABLE %s CASCADE", name.getFullName()) +// truncateTables truncates table +func (f *LoaderPostgres) truncateTables(tx *sql.Tx, tables ...loadedTable) error { + set := make(map[string]struct{}) + tablesToTruncate := make([]string, 0, len(tables)) + for _, t := range tables { + tableName := t.name.getFullName() + if _, ok := set[tableName]; ok { + // already truncated + continue + } + + tablesToTruncate = append(tablesToTruncate, tableName) + set[tableName] = struct{}{} + } + + query := fmt.Sprintf("TRUNCATE TABLE %s CASCADE", strings.Join(tablesToTruncate, ",")) if f.debug { fmt.Println("Issuing SQL:", query) } - _, err := f.db.Exec(query) + _, err := tx.Exec(query) if err != nil { return err } return nil } -func (f *LoaderPostgres) loadTable(ctx *loadContext, t tableName, rows table) error { +func (f *LoaderPostgres) loadTable(ctx *loadContext, tx *sql.Tx, t tableName, rows table) error { // $extend keyword allows to import values from a named row for i, row := range rows { if base, ok := row["$extend"]; ok { @@ -275,7 +281,7 @@ func (f *LoaderPostgres) loadTable(ctx *loadContext, t tableName, rows table) er fmt.Println("Issuing SQL:", query) } // issuing query - insertedRows, err := f.db.Query(query) + insertedRows, err := tx.Query(query) if err != nil { return err } @@ -326,7 +332,7 @@ func (f *LoaderPostgres) loadTable(ctx *loadContext, t tableName, rows table) er return err } -func (f *LoaderPostgres) fixSequences() error { +func (f *LoaderPostgres) fixSequences(tx *sql.Tx) error { query := ` DO $$ DECLARE @@ -353,7 +359,7 @@ END$$ if f.debug { fmt.Println("Issuing SQL:", query) } - _, err := f.db.Exec(query) + _, err := tx.Exec(query) return err } diff --git a/fixtures/postgres/postgres_test.go b/fixtures/postgres/postgres_test.go index e03c32c..985e535 100644 --- a/fixtures/postgres/postgres_test.go +++ b/fixtures/postgres/postgres_test.go @@ -30,16 +30,9 @@ func TestBuildInsertQuery(t *testing.T) { require.NoError(t, err) query, err := l.buildInsertQuery(&ctx, newTableName("table"), ctx.tables[0].rows) + require.NoError(t, err) - if err != nil { - t.Error("must not produce error, error:", err.Error()) - t.Fail() - } - - if query != expected { - t.Error("must generate proper SQL, got result:", query) - t.Fail() - } + require.Equal(t, expected, query) } func TestLoadTablesShouldResolveSchema(t *testing.T) { @@ -47,9 +40,7 @@ func TestLoadTablesShouldResolveSchema(t *testing.T) { require.NoError(t, err) db, mock, err := sqlmock.New() - if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) - } + require.NoError(t, err) defer func() { _ = db.Close() }() ctx := loadContext{ @@ -59,21 +50,12 @@ func TestLoadTablesShouldResolveSchema(t *testing.T) { l := New(db, "", true) - err = l.loadYml([]byte(yml), &ctx) - if err != nil { - t.Error(err) - t.Fail() - } + err = l.loadYml(yml, &ctx) + require.NoError(t, err) mock.ExpectBegin() - mock.ExpectExec("^TRUNCATE TABLE \"schema1\".\"table1\" CASCADE$"). - WillReturnResult(sqlmock.NewResult(0, 0)) - - mock.ExpectExec("^TRUNCATE TABLE \"schema2\".\"table2\" CASCADE$"). - WillReturnResult(sqlmock.NewResult(0, 0)) - - mock.ExpectExec("^TRUNCATE TABLE \"public\".\"table3\" CASCADE$"). + mock.ExpectExec("^TRUNCATE TABLE \"schema1\".\"table1\",\"schema2\".\"table2\",\"public\".\"table3\" CASCADE$"). WillReturnResult(sqlmock.NewResult(0, 0)) q := `^INSERT INTO "schema1"."table1" AS row \("f1", "f2"\) VALUES ` + @@ -112,15 +94,10 @@ func TestLoadTablesShouldResolveSchema(t *testing.T) { mock.ExpectCommit() err = l.loadTables(&ctx) - if err != nil { - t.Error(err) - t.Fail() - } + require.NoError(t, err) - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expectations: %s", err) - t.Fail() - } + err = mock.ExpectationsWereMet() + require.NoError(t, err) } func TestLoadTablesShouldResolveRefs(t *testing.T) { @@ -128,9 +105,7 @@ func TestLoadTablesShouldResolveRefs(t *testing.T) { require.NoError(t, err) db, mock, err := sqlmock.New() - if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) - } + require.NoError(t, err) defer func() { _ = db.Close() }() ctx := loadContext{ @@ -140,21 +115,12 @@ func TestLoadTablesShouldResolveRefs(t *testing.T) { l := New(db, "", true) - err = l.loadYml([]byte(yml), &ctx) - if err != nil { - t.Error(err) - t.Fail() - } + err = l.loadYml(yml, &ctx) + require.NoError(t, err) mock.ExpectBegin() - mock.ExpectExec("^TRUNCATE TABLE \"public\".\"table1\" CASCADE$"). - WillReturnResult(sqlmock.NewResult(0, 0)) - - mock.ExpectExec("^TRUNCATE TABLE \"public\".\"table2\" CASCADE$"). - WillReturnResult(sqlmock.NewResult(0, 0)) - - mock.ExpectExec("^TRUNCATE TABLE \"public\".\"table3\" CASCADE$"). + mock.ExpectExec("^TRUNCATE TABLE \"public\".\"table1\",\"public\".\"table2\",\"public\".\"table3\" CASCADE$"). WillReturnResult(sqlmock.NewResult(0, 0)) q := `^INSERT INTO "public"."table1" AS row \("f1", "f2"\) VALUES ` + @@ -193,15 +159,10 @@ func TestLoadTablesShouldResolveRefs(t *testing.T) { mock.ExpectCommit() err = l.loadTables(&ctx) - if err != nil { - t.Error(err) - t.Fail() - } + require.NoError(t, err) - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expectations: %s", err) - t.Fail() - } + err = mock.ExpectationsWereMet() + require.NoError(t, err) } func TestLoadTablesShouldExtendRows(t *testing.T) { @@ -209,9 +170,7 @@ func TestLoadTablesShouldExtendRows(t *testing.T) { require.NoError(t, err) db, mock, err := sqlmock.New() - if err != nil { - t.Fatalf("an error '%s' was not expected when opening a stub database connection", err) - } + require.NoError(t, err) defer func() { _ = db.Close() }() ctx := loadContext{ @@ -221,21 +180,12 @@ func TestLoadTablesShouldExtendRows(t *testing.T) { l := New(db, "", true) - err = l.loadYml([]byte(yml), &ctx) - if err != nil { - t.Error(err) - t.Fail() - } + err = l.loadYml(yml, &ctx) + require.NoError(t, err) mock.ExpectBegin() - mock.ExpectExec("^TRUNCATE TABLE \"public\".\"table1\" CASCADE$"). - WillReturnResult(sqlmock.NewResult(0, 0)) - - mock.ExpectExec("^TRUNCATE TABLE \"public\".\"table2\" CASCADE$"). - WillReturnResult(sqlmock.NewResult(0, 0)) - - mock.ExpectExec("^TRUNCATE TABLE \"public\".\"table3\" CASCADE$"). + mock.ExpectExec("^TRUNCATE TABLE \"public\".\"table1\",\"public\".\"table2\",\"public\".\"table3\" CASCADE$"). WillReturnResult(sqlmock.NewResult(0, 0)) q := `^INSERT INTO "public"."table1" AS row \("f1", "f2"\) VALUES ` + @@ -276,13 +226,8 @@ func TestLoadTablesShouldExtendRows(t *testing.T) { mock.ExpectCommit() err = l.loadTables(&ctx) - if err != nil { - t.Error(err) - t.Fail() - } + require.NoError(t, err) - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("there were unfulfilled expectations: %s", err) - t.Fail() - } + err = mock.ExpectationsWereMet() + require.NoError(t, err) }