diff --git a/README.md b/README.md index fbd726a..4e8074e 100644 --- a/README.md +++ b/README.md @@ -94,10 +94,6 @@ make install - Log into the [Supabase](https://supabase.io/) dashboard and go to your project, or create a new one. - Check if logical replication is enabled. This should be the default setting, so you shouldn't have to change anything. You can do this in the `SQL Editor` section on the left hand side of the Supabase dashboard by running `SHOW wal_level;` query, which should log `logical`. - You can find the database connection information on the left hand side under `Project Settings` > `Database`. You will need the `Host`, `Port`, `Database`, `Username`, and `Password` to connect to your database. - - When you create a vault, the `--dburi` should follow this format: - ```sh - postgresql://postgres:[PASSWORD]@db.[PROJECT_ID].supabase.co:5432/postgres - ``` ### Create a vault @@ -114,7 +110,7 @@ A new private key will be written to `FILENAME`. The name of a vault contains a `namespace` (e.g. `my_company`) and the name of an existing database relation (e.g. `my_table`), separated by a period (`.`). Use `vaults create` to create a new vault. See `vaults create --help` for more info. ```bash -vaults create --dburi [DBURI] --account [WALLET_ADDRESS] namespace.relation_name +vaults create --account [WALLET_ADDRESS] namespace.relation_name ``` 🚧 Vaults currently only replicates `INSERT` statements, which means that it only replicates append-only data (e.g., log-style data). Row updates and deletes will be ignored. 🚧 @@ -124,14 +120,18 @@ vaults create --dburi [DBURI] --account [WALLET_ADDRESS] namespace.relation_name Use `vaults stream` to start a daemon that will continuously push changes to the underlying table/view to the network. See `vaults stream --help` for more info. ```bash -vaults stream --private-key [PRIVATE_KEY] namespace.relation_name +vaults stream --dburi [DB_URI] --tables t1,t2 --private-key [PRIVATE_KEY] namespace.relation_name ``` -### Write a Parquet file +The `--dburi` should follow this format: + +```sh +postgresql://[USER]:[PASSWORD]@[HOST]:[PORT]/[DATABASE] +``` -Before writing a Parquet file, you need to [Create a vault](#create-a-vault), if not already created. You can omit the `--dburi` flag, in this case. +### Write a Parquet file -Then, use `vaults write` to write a Parquet file. +Before writing a Parquet file, you need to [Create a vault](#create-a-vault), if not already created. Then, use `vaults write` to write a Parquet file. ```bash vaults write --vault [namespace.relation_name] --private-key [PRIVATE_KEY] filepath @@ -205,7 +205,7 @@ PORT=8888 ./scripts/server.sh ./scripts/run.sh account create pk.out # Start replicating -./scripts/run.sh vaults stream --private-key [PRIVATE_KEY] namespace.relation_name +./scripts/run.sh vaults stream --dburi [DB_URI] --tables t1,t2 --private-key [PRIVATE_KEY] namespace.relation_name ``` ### Run tests diff --git a/cmd/vaults/commands.go b/cmd/vaults/commands.go index 50254bb..064ca8d 100644 --- a/cmd/vaults/commands.go +++ b/cmd/vaults/commands.go @@ -32,8 +32,8 @@ import ( var vaultNameRx = regexp.MustCompile(`^([a-zA-Z_][a-zA-Z0-9_]*)[.]([a-zA-Z_][a-zA-Z0-9_]*$)`) func newVaultCreateCommand() *cli.Command { - var address, dburi, provider string - var winSize, cache int64 + var address, provider string + var cache int64 return &cli.Command{ Name: "create", @@ -68,20 +68,6 @@ func newVaultCreateCommand() *cli.Command { Destination: &cache, Value: 0, }, - &cli.StringFlag{ - Name: "dburi", - Category: "OPTIONAL:", - Usage: "PostgreSQL connection string (e.g., postgresql://postgres:[PASSWORD]@[HOST]:[PORT]/postgres)", - Destination: &dburi, - }, - &cli.Int64Flag{ - Name: "window-size", - Category: "OPTIONAL:", - Usage: "Number of seconds for which WAL updates are buffered before being sent to the provider", - DefaultText: fmt.Sprintf("%d", DefaultWindowSize), - Destination: &winSize, - Value: DefaultWindowSize, - }, }, Action: func(cCtx *cli.Context) error { if cCtx.NArg() != 1 { @@ -98,10 +84,6 @@ func newVaultCreateCommand() *cli.Command { if err != nil { return fmt.Errorf("not a valid account: %s", err) } - pgConfig, err := pgconn.ParseConfig(dburi) - if err != nil { - return fmt.Errorf("parse config: %s", err) - } dir, err := defaultConfigLocation(cCtx.String("dir")) if err != nil { @@ -122,27 +104,22 @@ func newVaultCreateCommand() *cli.Command { } cfg.Vaults[pub] = vault{ - Host: pgConfig.Host, - Port: int(pgConfig.Port), - User: pgConfig.User, - Password: pgConfig.Password, - Database: pgConfig.Database, ProviderHost: provider, - WindowSize: winSize, } if err := yaml.NewEncoder(f).Encode(cfg); err != nil { return fmt.Errorf("encode: %s", err) } - exists, err := createVault(cCtx.Context, dburi, ns, rel, provider, account, cache) - if err != nil { - return fmt.Errorf("failed to create vault: %s", err) + bp := vaultsprovider.New(provider) + req := app.CreateVaultParams{ + Account: account, + Vault: app.Vault(fmt.Sprintf("%s.%s", ns, rel)), + CacheDuration: app.CacheDuration(cache), } - if exists { - fmt.Printf("Vault %s.%s already exists.\n\n", ns, rel) - return nil + if err := bp.CreateVault(cCtx.Context, req); err != nil { + return fmt.Errorf("create vault: %s", err) } if err := os.MkdirAll(path.Join(dir, pub), 0o755); err != nil { @@ -156,7 +133,8 @@ func newVaultCreateCommand() *cli.Command { } func newStreamCommand() *cli.Command { - var privateKey string + var privateKey, dburi, tables string + var winSize int64 return &cli.Command{ Name: "stream", @@ -175,6 +153,29 @@ func newStreamCommand() *cli.Command { Destination: &privateKey, Required: true, }, + &cli.StringFlag{ + Name: "dburi", + Category: "REQUIRED:", + Usage: "PostgreSQL connection string (e.g., postgresql://postgres:[PASSWORD]@[HOST]:[PORT]/postgres)", + Destination: &dburi, + Required: true, + }, + &cli.StringFlag{ + Name: "tables", + Aliases: []string{"t"}, + Category: "REQUIRED:", + Usage: "PostgreSQL tables to be replicated separated by comma (e.g. tbl1,tbl2,tbl3)", + Destination: &tables, + Required: true, + }, + &cli.Int64Flag{ + Name: "window-size", + Category: "OPTIONAL:", + Usage: "Number of seconds for which WAL updates are buffered before being sent to the provider", + DefaultText: fmt.Sprintf("%d", DefaultWindowSize), + Destination: &winSize, + Value: DefaultWindowSize, + }, }, Action: func(cCtx *cli.Context) error { if cCtx.NArg() != 1 { @@ -197,53 +198,39 @@ func newStreamCommand() *cli.Command { return fmt.Errorf("load config: %s", err) } - connString := fmt.Sprintf("postgres://%s:%s@%s:%d/%s", - cfg.Vaults[vault].User, - cfg.Vaults[vault].Password, - cfg.Vaults[vault].Host, - cfg.Vaults[vault].Port, - cfg.Vaults[vault].Database, - ) - - r, err := pgrepl.New(connString, pgrepl.Publication(rel)) + publication := pgrepl.Publication(strings.Replace(vault, ".", "_", -1)) + setup, err := NewDatabaseStreamSetup(cCtx.Context, dburi, publication, tables) if err != nil { - return fmt.Errorf("failed to create replicator: %s", err) + return fmt.Errorf("new database stream setup: %s", err) } + defer func() { + _ = setup.Close(cCtx.Context) + }() - privateKey, err := crypto.HexToECDSA(privateKey) - if err != nil { - return err + if err := setup.CreatePublicationIfNotExists(cCtx.Context); err != nil { + return fmt.Errorf("failed to create database publication: %s", err) } - pgxConn, err := pgx.Connect(cCtx.Context, connString) + tableSchemas, err := setup.TableSchemas(cCtx.Context) if err != nil { - return fmt.Errorf("connect: %s", err) + return fmt.Errorf("getting tables schemas: %s", err) } - defer func() { - _ = pgxConn.Close(cCtx.Context) - }() - tx, err := pgxConn.Begin(cCtx.Context) + r, err := pgrepl.New(dburi, publication) if err != nil { - return fmt.Errorf("failed to begin transaction") + return fmt.Errorf("failed to create replicator: %s", err) } - defer func() { - if err != nil { - _ = tx.Rollback(cCtx.Context) - } - }() - cols, err := inspectTable(cCtx.Context, tx, rel) + privateKey, err := crypto.HexToECDSA(privateKey) if err != nil { - return fmt.Errorf("failed to inspect source table: %s", err) + return err } // Creates a new db manager when replication starts bp := vaultsprovider.New(cfg.Vaults[vault].ProviderHost) uploader := app.NewVaultsUploader(ns, rel, bp, privateKey) dbDir := path.Join(dir, vault) - winSize := time.Duration(cfg.Vaults[vault].WindowSize) * time.Second - dbm := app.NewDBManager(dbDir, rel, cols, winSize, uploader) + dbm := app.NewDBManager(dbDir, tableSchemas, time.Duration(winSize)*time.Second, uploader) // Before starting replication, upload the remaining data if err := dbm.UploadAll(cCtx.Context); err != nil { @@ -788,139 +775,143 @@ func parseVaultName(name string) (ns string, rel string, err error) { return } -func inspectTable(ctx context.Context, tx pgx.Tx, rel string) ([]app.Column, error) { - rows, err := tx.Query(ctx, - ` - WITH primary_key_info AS - (SELECT tc.constraint_schema, - tc.table_name, - ccu.column_name - FROM information_schema.table_constraints tc - JOIN information_schema.constraint_column_usage AS ccu USING (CONSTRAINT_SCHEMA, CONSTRAINT_NAME) - WHERE constraint_type = 'PRIMARY KEY' ), - array_type_info AS - (SELECT c.table_name, - c.column_name, - pg_catalog.format_type(t.oid, NULL) AS full_data_type - FROM information_schema.columns AS c - JOIN pg_catalog.pg_type AS t ON c.udt_name = t.typname - WHERE c.data_type = 'ARRAY') - SELECT - c.column_name, - CASE - WHEN c.data_type = 'ARRAY' THEN ati.full_data_type - ELSE c.data_type - END AS data_type, - c.is_nullable = 'YES' AS is_nullable, - pki.column_name IS NOT NULL AS is_primary - FROM information_schema.columns AS c - LEFT JOIN primary_key_info pki ON c.table_schema = pki.constraint_schema - AND pki.table_name = c.table_name - AND pki.column_name = c.column_name - LEFT JOIN array_type_info ati ON c.table_name = ati.table_name - AND c.column_name = ati.column_name - WHERE c.table_name = $1; - `, rel, - ) - if err != nil { - return []app.Column{}, fmt.Errorf("failed to fetch schema") +func validateBeforeAndAfter(before, after, at string) (app.Timestamp, app.Timestamp, error) { + if !strings.EqualFold(at, "") { + before, after = at, at } - defer rows.Close() - - var colName, typ string - var isNull, isPrimary bool - var columns []app.Column - for rows.Next() { - if err := rows.Scan(&colName, &typ, &isNull, &isPrimary); err != nil { - return []app.Column{}, fmt.Errorf("scan: %s", err) - } - columns = append(columns, app.Column{ - Name: colName, - Typ: typ, - IsNull: isNull, - IsPrimary: isPrimary, - }) + b, err := app.ParseTimestamp(before) + if err != nil { + return app.Timestamp{}, app.Timestamp{}, err } - return columns, nil -} -func createVault( - ctx context.Context, - dburi string, - ns string, - rel string, - provider string, - account *app.Account, - cacheDuration int64, -) (exists bool, err error) { - bp := vaultsprovider.New(provider) - req := app.CreateVaultParams{ - Account: account, - Vault: app.Vault(fmt.Sprintf("%s.%s", ns, rel)), - CacheDuration: app.CacheDuration(cacheDuration), + a, err := app.ParseTimestamp(after) + if err != nil { + return app.Timestamp{}, app.Timestamp{}, err } - if dburi == "" { - if err := bp.CreateVault(ctx, req); err != nil { - return false, fmt.Errorf("create vault: %s", err) - } + return b, a, nil +} - return exists, nil - } +// DatabaseStreamSetup setups the database for streaming. +type DatabaseStreamSetup struct { + publication pgrepl.Publication + tables []string + + // Postgres + pgConfig *pgconn.Config + pgConn *pgx.Conn +} - pgxConn, err := pgx.Connect(ctx, dburi) +// NewDatabaseStreamSetup creates new database stream setup. +func NewDatabaseStreamSetup( + ctx context.Context, dburi string, publication pgrepl.Publication, tables string, +) (*DatabaseStreamSetup, error) { + pgConfig, err := pgconn.ParseConfig(dburi) if err != nil { - return false, fmt.Errorf("connect: %s", err) + return nil, fmt.Errorf("invalid dburi: %s", err) } - defer func() { - _ = pgxConn.Close(ctx) - }() - tx, err := pgxConn.Begin(ctx) + pgConn, err := pgx.Connect(ctx, dburi) if err != nil { - return false, fmt.Errorf("failed to begin transaction") + return nil, fmt.Errorf("connect: %s", err) } - defer func() { - if err != nil { - _ = tx.Rollback(ctx) - } - }() - if _, err := tx.Exec( - ctx, fmt.Sprintf("CREATE PUBLICATION %s FOR TABLE %s", pgrepl.Publication(rel).FullName(), rel), + return &DatabaseStreamSetup{ + publication: publication, + tables: strings.Split(tables, ","), + pgConfig: pgConfig, + pgConn: pgConn, + }, nil +} + +// CreatePublicationIfNotExists creates a database publication if it does not exist. +func (s *DatabaseStreamSetup) CreatePublicationIfNotExists(ctx context.Context) error { + if _, err := s.pgConn.Exec( + ctx, fmt.Sprintf("CREATE PUBLICATION %s FOR TABLE %s", s.publication.FullName(), strings.Join(s.tables, ",")), ); err != nil { - if strings.Contains(err.Error(), "already exists") { - return true, nil + if !strings.Contains(err.Error(), "already exists") { + return fmt.Errorf("failed to create publication: %s", err) } - return false, fmt.Errorf("failed to create publication: %s", err) - } - - if err := bp.CreateVault(ctx, req); err != nil { - return false, fmt.Errorf("create call: %s", err) } - if err := tx.Commit(ctx); err != nil { - return false, fmt.Errorf("commit: %s", err) - } - - return false, nil + return nil } -func validateBeforeAndAfter(before, after, at string) (app.Timestamp, app.Timestamp, error) { - if !strings.EqualFold(at, "") { - before, after = at, at - } +// TableSchemas returns the schema of the tables. +func (s *DatabaseStreamSetup) TableSchemas(ctx context.Context) ([]app.TableSchema, error) { + schemas := []app.TableSchema{} + + for _, table := range s.tables { + rows, err := s.pgConn.Query(ctx, + ` + WITH primary_key_info AS + (SELECT tc.constraint_schema, + tc.table_name, + ccu.column_name + FROM information_schema.table_constraints tc + JOIN information_schema.constraint_column_usage AS ccu USING (CONSTRAINT_SCHEMA, CONSTRAINT_NAME) + WHERE constraint_type = 'PRIMARY KEY' ), + array_type_info AS + (SELECT c.table_name, + c.column_name, + pg_catalog.format_type(t.oid, NULL) AS full_data_type + FROM information_schema.columns AS c + JOIN pg_catalog.pg_type AS t ON c.udt_name = t.typname + WHERE c.data_type = 'ARRAY') + SELECT + c.column_name, + CASE + WHEN c.data_type = 'ARRAY' THEN ati.full_data_type + ELSE c.data_type + END AS data_type, + c.is_nullable = 'YES' AS is_nullable, + pki.column_name IS NOT NULL AS is_primary + FROM information_schema.columns AS c + LEFT JOIN primary_key_info pki ON c.table_schema = pki.constraint_schema + AND pki.table_name = c.table_name + AND pki.column_name = c.column_name + LEFT JOIN array_type_info ati ON c.table_name = ati.table_name + AND c.column_name = ati.column_name + WHERE c.table_name = $1; + `, table, + ) + if err != nil { + return []app.TableSchema{}, fmt.Errorf("failed to fetch schema") + } + defer rows.Close() + + var colName, typ string + var isNull, isPrimary bool + var columns []app.Column + for rows.Next() { + if err := rows.Scan(&colName, &typ, &isNull, &isPrimary); err != nil { + return []app.TableSchema{}, fmt.Errorf("scan: %s", err) + } + + columns = append(columns, app.Column{ + Name: colName, + Typ: typ, + IsNull: isNull, + IsPrimary: isPrimary, + }) + } - b, err := app.ParseTimestamp(before) - if err != nil { - return app.Timestamp{}, app.Timestamp{}, err - } + // this means that the table is not part of the publication. + if len(columns) == 0 { + continue + } - a, err := app.ParseTimestamp(after) - if err != nil { - return app.Timestamp{}, app.Timestamp{}, err + schemas = append(schemas, app.TableSchema{ + Table: table, + Columns: columns, + }) } - return b, a, nil + return schemas, nil +} + +// Close closes db connection. +func (s *DatabaseStreamSetup) Close(ctx context.Context) error { + return s.pgConn.Close(ctx) } diff --git a/internal/app/db.go b/internal/app/db.go index b0af866..1572c93 100644 --- a/internal/app/db.go +++ b/internal/app/db.go @@ -33,8 +33,7 @@ type DBManager struct { db *sql.DB dbDir string dbFname string - table string - cols []Column + schemas []TableSchema // configs windowInterval time.Duration @@ -46,14 +45,19 @@ type DBManager struct { close chan struct{} } +// TableSchema represents a table and its schema. +type TableSchema struct { + Table string + Columns []Column +} + // NewDBManager creates a new DBManager. func NewDBManager( - dbDir, table string, cols []Column, windowInterval time.Duration, uploader *VaultsUploader, + dbDir string, schemas []TableSchema, windowInterval time.Duration, uploader *VaultsUploader, ) *DBManager { return &DBManager{ dbDir: dbDir, - table: table, - cols: cols, + schemas: schemas, windowInterval: windowInterval, uploader: uploader, } @@ -122,52 +126,59 @@ func (dbm *DBManager) Replay(ctx context.Context, tx *pgrepl.Tx) error { } // Export exports the current db to a parquet file at the given path. -func (dbm *DBManager) Export(ctx context.Context, exportPath string) (bool, error) { +func (dbm *DBManager) Export(ctx context.Context, exportPath string) ([]string, error) { var err error db := dbm.db // db is nil before replication starts. // In that case, we open all existing db files // and upload them. if db == nil { + dbm.dbFname = path.Base(exportPath) // convert the export path to a db path: // .db.parquet -> .db dbPath := strings.ReplaceAll(exportPath, ".parquet", "") db, err = sql.Open("duckdb", dbPath) if err != nil { - return true, err + return []string{}, err } defer func() { if err := db.Close(); err != nil { - fmt.Printf("cannot close db %v \n", err) + slog.Error("cannot close db", "error", err) } }() - slog.Info("backing up db", "at", exportPath) + slog.Info("backing up db", "at", dbPath) } else { slog.Info("backing up current db") } - var n int - if err := db.QueryRowContext( - ctx, - "select coalesce(sum(estimated_size), 0) rows_count from duckdb_tables() LIMIT 1", - ).Scan(&n); err != nil { - return true, fmt.Errorf("quering row count: %s", err) - } - if n == 0 { - return true, nil - } + exportedFiles := []string{} + for _, schema := range dbm.schemas { + var n int + if err := db.QueryRowContext( + ctx, + fmt.Sprintf("select count(1) from %s LIMIT 1", schema.Table), + ).Scan(&n); err != nil { + return []string{}, fmt.Errorf("querying row count: %s", err) + } - _, err = db.ExecContext(ctx, - fmt.Sprintf( - `INSTALL parquet; - LOAD parquet; - COPY (SELECT * FROM %s) TO '%s' (FORMAT PARQUET)`, - dbm.table, exportPath)) - if err != nil { - return true, fmt.Errorf("cannot export to parquet file: %s", err) + if n == 0 { + continue + } + + exportedFileName := strings.Replace(exportPath, dbm.dbFname, fmt.Sprintf("%s-%s", schema.Table, dbm.dbFname), -1) + exportedFiles = append(exportedFiles, exportedFileName) + _, err = db.ExecContext(ctx, + fmt.Sprintf( + `INSTALL parquet; + LOAD parquet; + COPY (SELECT * FROM %s) TO '%s' (FORMAT PARQUET)`, + schema.Table, exportedFileName)) + if err != nil { + return []string{}, fmt.Errorf("cannot export to parquet file: %s", err) + } } - return false, nil + return exportedFiles, nil } // UploadAt uploads a db dump at the given path. @@ -214,13 +225,13 @@ func (dbm *DBManager) UploadAll(ctx context.Context) error { if re.MatchString(fname) { dbPath := path.Join(dbm.dbDir, fname) exportAt := dbPath + ".parquet" - isEmpty, err := dbm.Export(ctx, exportAt) + files, err := dbm.Export(ctx, exportAt) if err != nil { return fmt.Errorf("export: %s", err) } - if !isEmpty { - if err := dbm.UploadAt(ctx, exportAt); err != nil { + for _, file := range files { + if err := dbm.UploadAt(ctx, file); err != nil { return fmt.Errorf("upload: %s", err) } } @@ -244,14 +255,14 @@ func (dbm *DBManager) Close() { func (dbm *DBManager) queryFromWAL(tx *pgrepl.Tx) (string, error) { var columnValsStr string - // get column names - cols := []string{} - for _, c := range tx.Records[0].Columns { - cols = append(cols, c.Name) - } - - recordVals := []string{} + // build an insert stmt for each record inside tx + stmts := []string{} for _, r := range tx.Records { + cols := []string{} + for _, c := range r.Columns { + cols = append(cols, c.Name) + } + columnVals := []string{} for _, c := range r.Columns { ddbType, err := dbm.pgToDDBType(c.Type) @@ -262,22 +273,25 @@ func (dbm *DBManager) queryFromWAL(tx *pgrepl.Tx) (string, error) { columnVals = append(columnVals, columnVal) } columnValsStr = strings.Join(columnVals, ", ") - recordVals = append( - recordVals, fmt.Sprintf("(%s)", columnValsStr)) + recordVals := fmt.Sprintf("(%s)", columnValsStr) + + stmt := fmt.Sprintf( + "insert into %s (%s) values %s", + r.Table, + strings.Join(cols, ", "), + recordVals, + ) + + stmts = append(stmts, stmt) } - return fmt.Sprintf( - "insert into %s (%s) values %s", - dbm.table, - strings.Join(cols, ", "), - strings.Join(recordVals, ", "), - ), nil + return strings.Join(stmts, ";"), nil } func (dbm *DBManager) replace(ctx context.Context) error { // Export current db to a parquet file at a given path exportAt := path.Join(dbm.dbDir, dbm.dbFname) + ".parquet" - isEmpty, err := dbm.Export(ctx, exportAt) + files, err := dbm.Export(ctx, exportAt) if err != nil { return err } @@ -286,10 +300,10 @@ func (dbm *DBManager) replace(ctx context.Context) error { slog.Info("closing current db") dbm.Close() - if !isEmpty { + for _, file := range files { // Upload the exported parquet file - if err := dbm.UploadAt(ctx, exportAt); err != nil { - fmt.Println("upload error, skipping", "err", err) + if err := dbm.UploadAt(ctx, file); err != nil { + slog.Error("upload error, skipping", "err", err) } } @@ -345,41 +359,46 @@ func (dbm *DBManager) pgToDDBType(typ string) (duckdbType, error) { } func (dbm *DBManager) genCreateQuery() (string, error) { - var cols, pks string - for i, column := range dbm.cols { - ddbType, err := dbm.pgToDDBType(column.Typ) - if err != nil { - return "", err - } - col := fmt.Sprintf("%s %s", column.Name, ddbType.typeName) - if !column.IsNull { - col = fmt.Sprintf("%s NOT NULL", col) - } - if i == 0 { - cols = col - if column.IsPrimary { - pks = column.Name + stmts := []string{} + for _, schema := range dbm.schemas { + var cols, pks string + for i, column := range schema.Columns { + ddbType, err := dbm.pgToDDBType(column.Typ) + if err != nil { + return "", err } - } else { - cols = fmt.Sprintf("%s,%s", cols, col) - if column.IsPrimary { - pks = fmt.Sprintf("%s,%s", pks, column.Name) + col := fmt.Sprintf("%s %s", column.Name, ddbType.typeName) + if !column.IsNull { + col = fmt.Sprintf("%s NOT NULL", col) + } + if i == 0 { + cols = col + if column.IsPrimary { + pks = column.Name + } + } else { + cols = fmt.Sprintf("%s,%s", cols, col) + if column.IsPrimary { + pks = fmt.Sprintf("%s,%s", pks, column.Name) + } } } - } - if pks != "" { - cols = fmt.Sprintf("%s,PRIMARY KEY (%s)", cols, pks) - } + if pks != "" { + cols = fmt.Sprintf("%s,PRIMARY KEY (%s)", cols, pks) + } + + if cols == "" { + return "", errors.New("schema must have at least one column") + } - if cols == "" { - return "", errors.New("schema must have at least one column") + stmt := fmt.Sprintf( + "CREATE TABLE IF NOT EXISTS %s (%s)", + schema.Table, cols) + stmts = append(stmts, stmt) } - stmt := fmt.Sprintf( - "CREATE TABLE IF NOT EXISTS %s (%s)", - dbm.table, cols) - return stmt, nil + return strings.Join(stmts, ";"), nil } func (dbm *DBManager) cleanup(dbPath string) error { diff --git a/internal/app/db_test.go b/internal/app/db_test.go index 6c7fc5a..729f40a 100644 --- a/internal/app/db_test.go +++ b/internal/app/db_test.go @@ -111,7 +111,7 @@ func TestGenCreateQuery(t *testing.T) { for _, tc := range testCases { t.Run(tc.tableName, func(t *testing.T) { dbm := NewDBManager( - t.TempDir(), tc.tableName, tc.cols, 3*time.Second, nil) + t.TempDir(), []TableSchema{{tc.tableName, tc.cols}}, 3*time.Second, nil) query, err := dbm.genCreateQuery() require.NoError(t, err) @@ -146,7 +146,7 @@ func TestGenCreateQueryUnsupported(t *testing.T) { for _, tc := range testCases { t.Run(tc.tableName, func(t *testing.T) { dbm := NewDBManager( - t.TempDir(), tc.tableName, tc.cols, 3*time.Second, nil) + t.TempDir(), []TableSchema{{tc.tableName, tc.cols}}, 3*time.Second, nil) _, err := dbm.genCreateQuery() require.EqualError(t, err, tc.expectedErr.Error()) }) @@ -571,7 +571,7 @@ func TestQueryFromWAL(t *testing.T) { {Name: "id", Typ: tc.typ, IsNull: valIsNull, IsPrimary: false}, } dbm := NewDBManager( - t.TempDir(), "t", cols, 3*time.Second, nil) + t.TempDir(), []TableSchema{{"t", cols}}, 3*time.Second, nil) insertQuery, err := dbm.queryFromWAL(&tx) require.NoError(t, err) require.Equal(t, tc.expectedInsertStmts[i], insertQuery) @@ -611,7 +611,7 @@ func TestQueryFromWALUnsupported(t *testing.T) { {Name: "id", Typ: tc.typ, IsNull: valIsNull, IsPrimary: false}, } dbm := NewDBManager( - t.TempDir(), "t", cols, 3*time.Second, nil) + t.TempDir(), []TableSchema{{"t", cols}}, 3*time.Second, nil) _, err := dbm.queryFromWAL(&tx) require.EqualError(t, err, tc.expectedErr.Error()) } @@ -638,7 +638,7 @@ func TestReplay(t *testing.T) { } // use a large window for testing dbm := NewDBManager( - t.TempDir(), "t", cols, 3*time.Hour, nil) + t.TempDir(), []TableSchema{{"t", cols}}, 3*time.Hour, nil) // assert new db setup (create queries are correctly applied) ctx := context.Background() @@ -664,7 +664,7 @@ func TestReplayUnsupported(t *testing.T) { {Name: "id", Typ: typ, IsNull: valIsNull, IsPrimary: false}, } dbm := NewDBManager( - t.TempDir(), "t", cols, 3*time.Hour, nil) + t.TempDir(), []TableSchema{{"t", cols}}, 3*time.Hour, nil) // assert new db setup (create queries are correctly applied) ctx := context.Background() err := dbm.NewDB(ctx) diff --git a/internal/app/streamer.go b/internal/app/streamer.go index 0fd49d7..c13e44e 100644 --- a/internal/app/streamer.go +++ b/internal/app/streamer.go @@ -14,7 +14,7 @@ import ( // Replicator replicates Postgres txs into a channel. type Replicator interface { - StartReplication(ctx context.Context) (chan *pgrepl.Tx, string, error) + StartReplication(ctx context.Context) (chan *pgrepl.Tx, []string, error) Commit(ctx context.Context, lsn pglogrepl.LSN) error Shutdown() } diff --git a/internal/app/streamer_test.go b/internal/app/streamer_test.go index ff69905..de1e400 100644 --- a/internal/app/streamer_test.go +++ b/internal/app/streamer_test.go @@ -44,7 +44,7 @@ func TestVaultsStreamerOne(t *testing.T) { } uploader := NewVaultsUploader(testNS, testTable, providerMock, privateKey) dbm := NewDBManager( - testDBDir, testTable, cols, winSize, uploader) + testDBDir, []TableSchema{{testTable, cols}}, winSize, uploader) streamer := NewVaultsStreamer(testNS, &replicatorMock{feed: feed}, dbm) go func() { @@ -145,7 +145,7 @@ func TestVaultsStreamerTwo(t *testing.T) { } uploader := NewVaultsUploader(testNS, testTable, providerMock, privateKey) dbm := NewDBManager( - testDBDir, testTable, cols, winSize, uploader) + testDBDir, []TableSchema{{testTable, cols}}, winSize, uploader) streamer := NewVaultsStreamer(testNS, &replicatorMock{feed: feed}, dbm) go func() { // start listening to WAL records in a separate goroutine @@ -175,21 +175,6 @@ func TestVaultsStreamerTwo(t *testing.T) { // wait for window to pass time.Sleep(winSize + 1) - // nothing should be uploaded because the second tx was received before - // the window closed. the exports should be uploaded - // when we replicator is started again - select { - case <-providerMock.uploaderInputs: - t.FailNow() // should not be reached - default: - // manually trigger uploadAll to simulate - // starting the replication process again - go func() { - require.NoError( - t, dbm.UploadAll(context.Background())) - }() - } - // Assert that the both first and second tx // were replayed by importing the exported parquet file file := <-providerMock.uploaderInputs @@ -221,8 +206,8 @@ type replicatorMock struct { var _ Replicator = (*replicatorMock)(nil) -func (rm *replicatorMock) StartReplication(_ context.Context) (chan *pgrepl.Tx, string, error) { - return rm.feed, "", nil +func (rm *replicatorMock) StartReplication(_ context.Context) (chan *pgrepl.Tx, []string, error) { + return rm.feed, []string{}, nil } func (rm *replicatorMock) Commit(_ context.Context, _ pglogrepl.LSN) error { diff --git a/pkg/pgrepl/conn.go b/pkg/pgrepl/conn.go index 5965288..029b0cc 100644 --- a/pkg/pgrepl/conn.go +++ b/pkg/pgrepl/conn.go @@ -14,54 +14,35 @@ type Conn struct { *pgx.Conn } -// FetchPublicationTables fetches all tables that needs replication from publications. -func (c *Conn) FetchPublicationTables(ctx context.Context) ([]string, error) { +// GetPublicationTables checks if the publication exists for a given table. +func (c *Conn) GetPublicationTables(ctx context.Context, p Publication) ([]string, error) { rows, err := c.Query(ctx, ` SELECT schemaname, tablename FROM pg_publication p JOIN pg_publication_tables pt ON p.pubname = pt.pubname - `, + WHERE p.pubname = $1 + `, p.FullName(), ) if errors.Is(err, pgx.ErrNoRows) { - return []string{}, nil + return []string{}, fmt.Errorf("publications not found") } else if err != nil { return []string{}, fmt.Errorf("query: %s", err) } - defer rows.Close() - var tables []string + tables := []string{} for rows.Next() { var schema, table string if err := rows.Scan(&schema, &table); err != nil { return []string{}, fmt.Errorf("scan: %s", err) } + tables = append(tables, fmt.Sprintf("%s.%s", schema, table)) } return tables, nil } -// GetPublicationTable checks if the publication exists for a given table. -func (c *Conn) GetPublicationTable(ctx context.Context, p Publication) (string, error) { - var schema, table string - err := c.QueryRow(ctx, - ` - SELECT schemaname, tablename - FROM pg_publication p - JOIN pg_publication_tables pt ON p.pubname = pt.pubname - WHERE p.pubname = $1 - `, p.FullName(), - ).Scan(&schema, &table) - if errors.Is(err, pgx.ErrNoRows) { - return "", fmt.Errorf("publication does not exist") - } else if err != nil { - return "", fmt.Errorf("query: %s", err) - } - - return fmt.Sprintf("%s.%s", schema, table), nil -} - // ConfirmedFlushLSN fetches the confirmed flush LSN. func (c *Conn) ConfirmedFlushLSN(ctx context.Context, slot string) (pglogrepl.LSN, error) { var lsn pglogrepl.LSN diff --git a/pkg/pgrepl/replicator.go b/pkg/pgrepl/replicator.go index 7863227..07c61f3 100644 --- a/pkg/pgrepl/replicator.go +++ b/pkg/pgrepl/replicator.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "strings" "sync" "github.com/jackc/pglogrepl" @@ -26,7 +27,7 @@ type Publication string // FullName is the name used to create a publication in Postgres. func (p Publication) FullName() string { - return fmt.Sprintf("pub_basin_%s", p) + return fmt.Sprintf("pub_basin_%s", string(p)) } // PgReplicator is a component that replicates Postgres data. @@ -37,9 +38,9 @@ type PgReplicator struct { // channel of replicated Txs. feed chan *Tx - // The table that will be replicated. + // The tables that will be replicated. // We get them by querying pg_publication. - table string + tables []string // The commitLSN is the LSN used to start the replication. // It either comes from the confirmed_flush_lsn of an existing replication slot @@ -94,12 +95,12 @@ func New(connStr string, publication Publication) (*PgReplicator, error) { return nil, fmt.Errorf("ping: %s", err) } - // Check if publication exists - table, err := conn.GetPublicationTable(ctx, publication) + // Check if publications exist + tables, err := conn.GetPublicationTables(ctx, publication) if err != nil { return nil, err } - r.table = table + r.tables = tables // Fetch the confirmed flush lsn. lsn, err := conn.ConfirmedFlushLSN(ctx, r.slot) @@ -132,7 +133,7 @@ func New(connStr string, publication Publication) (*PgReplicator, error) { } // StartReplication starts replicattion. -func (r *PgReplicator) StartReplication(ctx context.Context) (chan *Tx, string, error) { +func (r *PgReplicator) StartReplication(ctx context.Context) (chan *Tx, []string, error) { if err := pglogrepl.StartReplication( ctx, r.pgConn, @@ -149,9 +150,9 @@ func (r *PgReplicator) StartReplication(ctx context.Context) (chan *Tx, string, "\"include-pk\" 'true'", "\"format-version\" '2'", "\"include-xids\" 'true'", - fmt.Sprintf("\"add-tables\" '%s'", r.table), + fmt.Sprintf("\"add-tables\" '%s'", strings.Join(r.tables, ",")), }}); err != nil { - return nil, r.table, err + return nil, r.tables, err } slog.Info("Logical replication started", "slot", r.slot) @@ -205,7 +206,7 @@ func (r *PgReplicator) StartReplication(ctx context.Context) (chan *Tx, string, } }() - return r.feed, r.table, nil + return r.feed, r.tables, nil } // Commit send a signal to Postgres that the lsn was consumed. diff --git a/pkg/pgrepl/replicator_test.go b/pkg/pgrepl/replicator_test.go index 2ee6cdf..25b7c83 100644 --- a/pkg/pgrepl/replicator_test.go +++ b/pkg/pgrepl/replicator_test.go @@ -35,23 +35,23 @@ func TestMain(m *testing.M) { } func TestReplication(t *testing.T) { - t.Parallel() - _, err := db.ExecContext(context.Background(), ` create table t(id int primary key, name text); - create publication pub_basin_t for table t; + create table t2(id int primary key, name text); + create publication pub_basin_t for table t, t2; `) require.NoError(t, err) replicator, err := New(uri, "t") require.NoError(t, err) - feed, pubName, err := replicator.StartReplication(context.Background()) + feed, tables, err := replicator.StartReplication(context.Background()) require.NoError(t, err) - require.Equal(t, "public.t", pubName) + require.Equal(t, []string{"public.t", "public.t2"}, tables) _, err = db.ExecContext(context.Background(), ` insert into t values (1, 'foo'); insert into t values (2, 'bar'); + insert into t2 values (4, 'foo2'); insert into t values (3, 'baz'); update t set name='quz' where id=3; delete from t where id=2; @@ -59,7 +59,7 @@ func TestReplication(t *testing.T) { require.NoError(t, err) tx := <-feed - require.Equal(t, 5, len(tx.Records)) + require.Equal(t, 6, len(tx.Records)) require.Equal(t, tx.Records[0].Table, "t") require.Equal(t, tx.Records[0].Columns, []Column{ { @@ -74,6 +74,20 @@ func TestReplication(t *testing.T) { }, }) + require.Equal(t, tx.Records[2].Table, "t2") + require.Equal(t, tx.Records[2].Columns, []Column{ + { + Name: "id", + Type: "integer", + Value: toJSON(t, 4), + }, + { + Name: "name", + Type: "text", + Value: toJSON(t, "foo2"), + }, + }) + // TODO: add more assertions replicator.Shutdown()