Skip to content

Commit

Permalink
Only track DDLs for table that part of the publication (#44)
Browse files Browse the repository at this point in the history
This way, we don't accidentally end up tracking DDLs from other tables, including temp
  • Loading branch information
shayonj authored Nov 12, 2024
1 parent cff3f38 commit 27ea37c
Show file tree
Hide file tree
Showing 8 changed files with 365 additions and 156 deletions.
197 changes: 156 additions & 41 deletions internal/scripts/e2e_ddl.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@ set -euo pipefail

source "$(dirname "$0")/e2e_common.sh"

create_users() {
log "Creating initial test table..."
run_sql "DROP TABLE IF EXISTS public.users;"
run_sql "CREATE TABLE public.users (id serial PRIMARY KEY, data text);"
success "Initial test table created"
create_test_tables() {
log "Creating test schemas and tables..."
run_sql "DROP SCHEMA IF EXISTS app CASCADE; CREATE SCHEMA app;"
run_sql "DROP SCHEMA IF EXISTS public CASCADE; CREATE SCHEMA public;"

run_sql "CREATE TABLE app.users (id serial PRIMARY KEY, data text);"
run_sql "CREATE TABLE app.posts (id serial PRIMARY KEY, content text);"

run_sql "CREATE TABLE app.comments (id serial PRIMARY KEY, text text);"
run_sql "CREATE TABLE public.metrics (id serial PRIMARY KEY, value numeric);"
success "Test tables created"
}

start_pg_flo_replication() {
Expand All @@ -23,8 +29,8 @@ start_pg_flo_replication() {
--user "$PG_USER" \
--password "$PG_PASSWORD" \
--group "group_ddl" \
--tables "users" \
--schema "public" \
--schema "app" \
--tables "users,posts" \
--nats-url "$NATS_URL" \
--track-ddl \
>"$pg_flo_LOG" 2>&1 &
Expand Down Expand Up @@ -61,60 +67,169 @@ start_pg_flo_worker() {

perform_ddl_operations() {
log "Performing DDL operations..."
run_sql "ALTER TABLE users ADD COLUMN new_column int;"
run_sql "CREATE INDEX CONCURRENTLY idx_users_data ON users (data);"
run_sql "ALTER TABLE users RENAME COLUMN data TO old_data;"
run_sql "DROP INDEX idx_users_data;"
run_sql "ALTER TABLE users ADD COLUMN new_column_one int;"
run_sql "ALTER TABLE users ALTER COLUMN old_data TYPE varchar(255);"

# Column operations on tracked tables
run_sql "ALTER TABLE app.users ADD COLUMN email text;"
run_sql "ALTER TABLE app.users ADD COLUMN status varchar(50) DEFAULT 'active';"
run_sql "ALTER TABLE app.posts ADD COLUMN category text;"

# Index operations on tracked tables
run_sql "CREATE INDEX CONCURRENTLY idx_users_email ON app.users (email);"
run_sql "CREATE UNIQUE INDEX idx_posts_unique ON app.posts (content) WHERE content IS NOT NULL;"

# Column modifications on tracked tables
run_sql "ALTER TABLE app.users ALTER COLUMN status SET DEFAULT 'pending';"
run_sql "ALTER TABLE app.posts ALTER COLUMN category TYPE varchar(100);"

# Rename operations on tracked tables
run_sql "ALTER TABLE app.users RENAME COLUMN data TO profile;"

# Drop operations on tracked tables
run_sql "DROP INDEX CONCURRENTLY IF EXISTS idx_users_email;"
run_sql "ALTER TABLE app.posts DROP COLUMN IF EXISTS category;"

# Operations on non-tracked tables (should be ignored)
run_sql "ALTER TABLE app.comments ADD COLUMN author text;"
run_sql "CREATE INDEX idx_comments_text ON app.comments (text);"
run_sql "ALTER TABLE public.metrics ADD COLUMN timestamp timestamptz;"

success "DDL operations performed"
}

verify_ddl_changes() {
log "Verifying DDL changes..."
log "Verifying DDL changes in target database..."
local failures=0

# Check table structure in target database
local new_column_exists=$(run_sql_target "SELECT COUNT(*) FROM information_schema.columns WHERE table_name = 'users' AND column_name = 'new_column';")
local new_column_one_exists=$(run_sql_target "SELECT COUNT(*) FROM information_schema.columns WHERE table_name = 'users' AND column_name = 'new_column_one';")
local old_data_type=$(run_sql_target "SELECT data_type FROM information_schema.columns WHERE table_name = 'users' AND column_name = 'old_data';")
old_data_type=$(echo "$old_data_type" | xargs)
check_column() {
local table=$1
local column=$2
local expected_exists=$3
local expected_type=${4:-""}
local expected_default=${5:-""}
local query="
SELECT COUNT(*),
data_type,
character_maximum_length,
column_default
FROM information_schema.columns
WHERE table_schema='app'
AND table_name='$table'
AND column_name='$column'
GROUP BY data_type, character_maximum_length, column_default;"

if [ "$new_column_exists" -eq 1 ]; then
success "new_column exists in target database"
else
error "new_column does not exist in target database"
return 1
fi
local result
result=$(run_sql_target "$query")

if [ "$new_column_one_exists" -eq 1 ]; then
success "new_column_one exists in target database"
else
error "new_column_one does not exist in target database"
return 1
fi
if [ -z "$result" ]; then
exists=0
data_type=""
char_length=""
default_value=""
else
read exists data_type char_length default_value < <(echo "$result" | tr '|' ' ')
fi

if [ "$old_data_type" = "character varying" ]; then
success "old_data column type is character varying"
else
error "old_data column type is not character varying (got: '$old_data_type')"
return 1
fi
exists=${exists:-0}

if [ "$exists" -eq "$expected_exists" ]; then
if [ "$expected_exists" -eq 1 ]; then
local type_ok=true
local default_ok=true

if [ -n "$expected_type" ]; then
# Handle character varying type specifically
if [ "$expected_type" = "character varying" ]; then
if [ "$data_type" = "character varying" ] || [ "$data_type" = "varchar" ] || [ "$data_type" = "character" ]; then
type_ok=true
else
type_ok=false
fi
elif [ "$data_type" != "$expected_type" ]; then
type_ok=false
fi
fi

if [ -n "$expected_default" ]; then
if [[ "$default_value" == *"$expected_default"* ]]; then
default_ok=true
else
default_ok=false
fi
fi

if [ "$type_ok" = true ] && [ "$default_ok" = true ]; then
if [[ "$expected_type" == "character varying" && -n "$char_length" ]]; then
success "Column app.$table.$column verification passed (type: $data_type($char_length), default: $default_value)"
else
success "Column app.$table.$column verification passed (type: $data_type, default: $default_value)"
fi
else
if [ "$type_ok" = false ]; then
error "Column app.$table.$column type mismatch (expected: $expected_type, got: $data_type)"
failures=$((failures + 1))
fi
if [ "$default_ok" = false ]; then
error "Column app.$table.$column default value mismatch (expected: $expected_default, got: $default_value)"
failures=$((failures + 1))
fi
fi
else
success "Column app.$table.$column verification passed (not exists)"
fi
else
error "Column app.$table.$column verification failed (expected: $expected_exists, got: $exists)"
failures=$((failures + 1))
fi
}

check_index() {
local index=$1
local expected=$2
local exists=$(run_sql_target "SELECT COUNT(*) FROM pg_indexes WHERE schemaname='app' AND indexname='$index';")

if [ "$exists" -eq "$expected" ]; then
success "Index app.$index verification passed (expected: $expected)"
else
error "Index app.$index verification failed (expected: $expected, got: $exists)"
failures=$((failures + 1))
fi
}

# Verify app.users changes
check_column "users" "email" 1 "text"
check_column "users" "status" 1 "character varying" "'pending'"
check_column "users" "data" 0
check_column "users" "profile" 1 "text"

# Verify app.posts changes
check_column "posts" "category" 0
check_column "posts" "content" 1 "text"
check_index "idx_posts_unique" 1 "unique"

# Verify non-tracked tables
check_column "comments" "author" 0
check_index "idx_comments_text" 0

# Check if internal table is empty
local remaining_rows=$(run_sql "SELECT COUNT(*) FROM internal_pg_flo.ddl_log;")
if [ "$remaining_rows" -eq 0 ]; then
success "internal_pg_flo.ddl_log table is empty"
else
error "internal_pg_flo.ddl_log table is not empty. Remaining rows: $remaining_rows"
return 1
failures=$((failures + 1))
fi

return 0
if [ "$failures" -eq 0 ]; then
success "All DDL changes verified successfully"
return 0
else
error "DDL verification failed with $failures errors"
return 1
fi
}

test_pg_flo_ddl() {
setup_postgres
create_users
create_test_tables
start_pg_flo_worker
sleep 5
start_pg_flo_replication
Expand Down
6 changes: 3 additions & 3 deletions internal/scripts/e2e_test_local.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ make build

setup_docker

log "Running e2e routing tests..."
if CI=false ./internal/scripts/e2e_routing.sh; then
success "Original e2e tests completed successfully"
log "Running e2e ddl tests..."
if CI=false ./internal/scripts/e2e_ddl.sh; then
success "e2e ddl tests completed successfully"
else
error "Original e2e tests failed"
exit 1
Expand Down
82 changes: 63 additions & 19 deletions pkg/replicator/base_replicator.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ type BaseReplicator struct {

// NewBaseReplicator creates a new BaseReplicator instance
func NewBaseReplicator(config Config, replicationConn ReplicationConnection, standardConn StandardConnection, natsClient NATSClient) *BaseReplicator {
if config.Schema == "" {
config.Schema = "public"
}

logger := log.With().Str("component", "replicator").Logger()

br := &BaseReplicator{
Expand Down Expand Up @@ -78,6 +82,26 @@ func NewBaseReplicator(config Config, replicationConn ReplicationConnection, sta
return br
}

// buildCreatePublicationQuery constructs the SQL query for creating a publication
func (r *BaseReplicator) buildCreatePublicationQuery() (string, error) {
publicationName := GeneratePublicationName(r.Config.Group)

tables, err := r.GetConfiguredTables(context.Background())
if err != nil {
return "", fmt.Errorf("failed to get configured tables: %w", err)
}

sanitizedTables := make([]string, len(tables))
for i, table := range tables {
parts := strings.Split(table, ".")
sanitizedTables[i] = pgx.Identifier{parts[0], parts[1]}.Sanitize()
}

return fmt.Sprintf("CREATE PUBLICATION %s FOR TABLE %s",
pgx.Identifier{publicationName}.Sanitize(),
strings.Join(sanitizedTables, ", ")), nil
}

// CreatePublication creates a new publication if it doesn't exist
func (r *BaseReplicator) CreatePublication() error {
publicationName := GeneratePublicationName(r.Config.Group)
Expand All @@ -91,7 +115,11 @@ func (r *BaseReplicator) CreatePublication() error {
return nil
}

query := r.buildCreatePublicationQuery()
query, err := r.buildCreatePublicationQuery()
if err != nil {
return fmt.Errorf("failed to build publication query: %w", err)
}

_, err = r.StandardConn.Exec(context.Background(), query)
if err != nil {
return fmt.Errorf("failed to create publication: %w", err)
Expand All @@ -101,24 +129,6 @@ func (r *BaseReplicator) CreatePublication() error {
return nil
}

// buildCreatePublicationQuery constructs the SQL query for creating a publication
func (r *BaseReplicator) buildCreatePublicationQuery() string {
publicationName := GeneratePublicationName(r.Config.Group)
if len(r.Config.Tables) == 0 {
return fmt.Sprintf("CREATE PUBLICATION %s FOR ALL TABLES",
pgx.Identifier{publicationName}.Sanitize())
}

fullyQualifiedTables := make([]string, len(r.Config.Tables))
for i, table := range r.Config.Tables {
fullyQualifiedTables[i] = pgx.Identifier{r.Config.Schema, table}.Sanitize()
}

return fmt.Sprintf("CREATE PUBLICATION %s FOR TABLE %s",
pgx.Identifier{publicationName}.Sanitize(),
strings.Join(fullyQualifiedTables, ", "))
}

// checkPublicationExists checks if a publication with the given name exists
func (r *BaseReplicator) checkPublicationExists(publicationName string) (bool, error) {
var exists bool
Expand Down Expand Up @@ -559,3 +569,37 @@ func (r *BaseReplicator) CheckReplicationSlotStatus(ctx context.Context) error {
r.Logger.Info().Str("slotName", publicationName).Str("restartLSN", restartLSN).Msg("Replication slot status")
return nil
}

// GetConfiguredTables returns all tables based on configuration
// If no specific tables are configured, returns all tables from the configured schema
func (r *BaseReplicator) GetConfiguredTables(ctx context.Context) ([]string, error) {
if len(r.Config.Tables) > 0 {
fullyQualifiedTables := make([]string, len(r.Config.Tables))
for i, table := range r.Config.Tables {
fullyQualifiedTables[i] = fmt.Sprintf("%s.%s", r.Config.Schema, table)
}
return fullyQualifiedTables, nil
}

rows, err := r.StandardConn.Query(ctx, `
SELECT schemaname || '.' || tablename
FROM pg_tables
WHERE schemaname = $1
AND schemaname NOT IN ('pg_catalog', 'information_schema', 'internal_pg_flo')
`, r.Config.Schema)
if err != nil {
return nil, fmt.Errorf("failed to query tables: %v", err)
}
defer rows.Close()

var tables []string
for rows.Next() {
var tableName string
if err := rows.Scan(&tableName); err != nil {
return nil, fmt.Errorf("failed to scan table name: %v", err)
}
tables = append(tables, tableName)
}

return tables, nil
}
Loading

0 comments on commit 27ea37c

Please sign in to comment.