Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
"export-data": "echo 'Use the /export endpoint: curl http://localhost:8787/export?table=tablename > export.sql'",
"validate": "wrangler d1 list && echo 'Validation: Check that your D1 database ID matches wrangler.toml'",
"get-schema": "echo 'Get database schema: curl http://localhost:8787/schema'",
"get-migration": "echo 'Generate migration script: curl http://localhost:8787/migration-script > postgres-migration.sql'"
"get-migration": "echo 'Generate migration script: curl http://localhost:8787/migration-script > postgres-migration.sql'",
"test": "tsx tests/auto-mirror.test.ts"
},
"keywords": [
"cloudflare",
Expand Down
151 changes: 123 additions & 28 deletions src/auto-mirror.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { Env } from './worker';
import { convertSqlitePlaceholdersToPostgres, PlaceholderMappingEntry, normalizeNamedBindings, buildPostgresParameterArray } from './sql-utils';

export class AutoMirrorDB {
constructor(private env: Env) { }
Expand All @@ -8,48 +9,69 @@ export class AutoMirrorDB {

this.env.DB.prepare = (sql: string) => {
const stmt = originalPrepare(sql);
const originalBind = stmt.bind.bind(stmt);

stmt.bind = (...params: unknown[]) => {
const bound = originalBind(...params);

// Patch all methods that can execute write operations
const originalRun = bound.run.bind(bound);
const originalAll = bound.all.bind(bound);
const originalFirst = bound.first.bind(bound);
return this.patchStatement(stmt, sql);
};
}

// Only mirror if this is a write operation (INSERT, UPDATE, DELETE)
const isWriteOperation = this.isWriteSQL(sql);
private patchStatement<T extends D1PreparedStatement>(stmt: T, sql: string): T {
const { sql: pgSql, mapping } = convertSqlitePlaceholdersToPostgres(sql);
const isWriteOperation = this.isWriteSQL(sql);

bound.run = async <T = Record<string, unknown>>(): Promise<D1Result<T>> => {
const result = await originalRun<T>();
const applyExecutionPatches = (statement: any, boundParamsProvider: () => unknown[]) => {
const baseRun = captureOriginalMethod(statement, 'run');
if (baseRun) {
statement.run = async <T = Record<string, unknown>>(...args: unknown[]): Promise<D1Result<T>> => {
const result = await baseRun.apply(statement, args) as D1Result<T>;
if (isWriteOperation) {
await this.mirrorToPostgres(sql, params);
const params = args.length > 0
? this.normalizeParams(args, mapping)
: boundParamsProvider();
await this.mirrorToPostgres(pgSql, [...params]);
}
return result;
};
}

bound.all = async <T = Record<string, unknown>>(): Promise<D1Result<T>> => {
const result = await originalAll<T>();
const baseAll = captureOriginalMethod(statement, 'all');
if (baseAll) {
statement.all = async <T = Record<string, unknown>>(...args: unknown[]): Promise<D1Result<T>> => {
const result = await baseAll.apply(statement, args) as D1Result<T>;
if (isWriteOperation) {
await this.mirrorToPostgres(sql, params);
const params = boundParamsProvider();
await this.mirrorToPostgres(pgSql, [...params]);
}
return result;
};
}

bound.first = async <T = Record<string, unknown>>(colName?: string): Promise<T | null> => {
const result = await (colName ? originalFirst<T>(colName) : originalFirst<T>());
const baseFirst = captureOriginalMethod(statement, 'first');
if (baseFirst) {
statement.first = async <T = Record<string, unknown>>(...args: unknown[]): Promise<T | null> => {
const result = await baseFirst.apply(statement, args) as T | null;
if (isWriteOperation) {
await this.mirrorToPostgres(sql, params);
const params = boundParamsProvider();
await this.mirrorToPostgres(pgSql, [...params]);
}
return result;
};
}

return statement;
};

applyExecutionPatches(stmt, () => []);

if (typeof stmt.bind === 'function') {
const originalBind = stmt.bind.bind(stmt);
stmt.bind = (...bindArgs: unknown[]) => {
const normalized = this.normalizeParams(bindArgs, mapping);
const bound = originalBind(...bindArgs);
applyExecutionPatches(bound, () => normalized);
return bound;
};
}

return stmt;
};
return stmt;
}

private isWriteSQL(sql: string): boolean {
Expand All @@ -75,11 +97,10 @@ export class AutoMirrorDB {

private async mirrorToPostgres(sql: string, params: unknown[]) {
try {
const pgSql = this.convertPlaceholders(sql, params.length);
const opId = crypto.randomUUID();

await this.env.MIRROR_QUEUE.send({
sql: pgSql,
sql,
params,
opId
});
Expand All @@ -89,8 +110,82 @@ export class AutoMirrorDB {
}
}

private convertPlaceholders(sql: string, count: number): string {
let i = 1;
return sql.replace(/\?/g, () => `$${i++}`);
private normalizeParams(rawParams: unknown[], mapping: PlaceholderMappingEntry[]): unknown[] {
if (mapping.length === 0) {
return [];
}

const positional: unknown[] = [];
const named = new Map<string, unknown>();

const addNamed = (key: string, value: unknown) => {
normalizeNamedBindings(named, key, value);
};

for (const param of rawParams) {
if (Array.isArray(param)) {
positional.push(...param);
continue;
}

if (param instanceof Map) {
for (const [key, value] of param.entries()) {
if (typeof key === 'string') {
addNamed(key, value);
}
}
continue;
}

if (isPlainObject(param)) {
for (const [key, value] of Object.entries(param)) {
addNamed(key, value);
}
continue;
}

positional.push(param);
}

return buildPostgresParameterArray(mapping, positional, named);
}
}

function isPlainObject(value: unknown): value is Record<string, unknown> {
if (value === null || typeof value !== 'object') {
return false;
}

const prototype = Object.getPrototypeOf(value);
return prototype === Object.prototype || prototype === null;
}

type StatementMethod = (...args: unknown[]) => Promise<unknown>;

type OriginalMethodRegistry = {
run?: StatementMethod;
all?: StatementMethod;
first?: StatementMethod;
};

const originalMethodRegistry = new WeakMap<object, OriginalMethodRegistry>();

function captureOriginalMethod(statement: any, method: 'run' | 'all' | 'first'): StatementMethod | undefined {
let registry = originalMethodRegistry.get(statement);
if (!registry) {
registry = {};
originalMethodRegistry.set(statement, registry);
}

if (registry[method]) {
return registry[method];
}

const current = statement[method];
if (typeof current === 'function') {
registry[method] = current as StatementMethod;
return registry[method];
}
}

return undefined;
}
15 changes: 7 additions & 8 deletions src/db-router.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { Env } from './worker';
import { buildPostgresParameterArray, convertSqlitePlaceholdersToPostgres } from './sql-utils';

export async function executeQuery(env: Env, sql: string, params: unknown[] = []) {
if (env.PRIMARY_DB === "pg") {
Expand All @@ -16,8 +17,11 @@ export async function executeQuery(env: Env, sql: string, params: unknown[] = []
await client.connect();

// Convert D1 placeholders to Postgres placeholders
const pgSql = convertPlaceholders(sql, params.length);
const result = await client.query(pgSql, params);
const { sql: pgSql, mapping } = convertSqlitePlaceholdersToPostgres(sql);
const normalizedParams = mapping.length > 0
? buildPostgresParameterArray(mapping, params)
: params;
const result = await client.query(pgSql, normalizedParams);

// Return in D1-compatible format
return {
Expand Down Expand Up @@ -180,9 +184,4 @@ export async function getAllTables(env: Env): Promise<string[]> {
throw new Error('Failed to get table list');
}
}
}

function convertPlaceholders(sql: string, count: number): string {
let i = 1;
return sql.replace(/\?/g, () => `$${i++}`);
}
}
Loading