Skip to content

Commit

Permalink
Enable nested beforeTemplateIsBaked calls
Browse files Browse the repository at this point in the history
  • Loading branch information
codetheweb committed Apr 8, 2024
1 parent 4f8443a commit 12fbf7c
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 48 deletions.
120 changes: 74 additions & 46 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ import type {
import { Pool } from "pg"
import type { Jsonifiable } from "type-fest"
import type { ExecutionContext } from "ava"
import { once } from "node:events"
import { createBirpc } from "birpc"
import { BirpcReturn, createBirpc } from "birpc"
import { ExecResult } from "testcontainers"
import isPlainObject from "lodash/isPlainObject"

Expand Down Expand Up @@ -136,57 +135,86 @@ export const getTestPostgresDatabaseFactory = <
}

let rpcCallback: (data: any) => void
const rpc = createBirpc<SharedWorkerFunctions, TestWorkerFunctions>(
{
runBeforeTemplateIsBakedHook: async (connection, params) => {
if (options?.beforeTemplateIsBaked) {
const connectionDetails =
mapWorkerConnectionDetailsToConnectionDetails(connection)

// Ignore if the pool is terminated by the shared worker
// (This happens in CI for some reason even though we drain the pool first.)
connectionDetails.pool.on("error", (error) => {
if (
error.message.includes(
"terminating connection due to administrator command"
)
) {
return
}
const rpc: BirpcReturn<SharedWorkerFunctions, TestWorkerFunctions> =
createBirpc<SharedWorkerFunctions, TestWorkerFunctions>(
{
runBeforeTemplateIsBakedHook: async (connection, params) => {
if (options?.beforeTemplateIsBaked) {
const connectionDetails =
mapWorkerConnectionDetailsToConnectionDetails(connection)

// Ignore if the pool is terminated by the shared worker
// (This happens in CI for some reason even though we drain the pool first.)
connectionDetails.pool.on("error", (error) => {
if (
error.message.includes(
"terminating connection due to administrator command"
)
) {
return
}

throw error
})

const createdNestedConnections: ConnectionDetails[] = []
const hookResult = await options.beforeTemplateIsBaked({
params: params as any,
connection: connectionDetails,
containerExec: async (command): Promise<ExecResult> =>
rpc.execCommandInContainer(command),
// This is what allows a consumer to get a "nested" database from within their beforeTemplateIsBaked hook
beforeTemplateIsBaked: async (options) => {
const { connectionDetails, beforeTemplateIsBakedResult } =
await rpc.getTestDatabase({
params: options.params,
databaseDedupeKey: options.databaseDedupeKey,
})

throw error
})
const mappedConnection =
mapWorkerConnectionDetailsToConnectionDetails(
connectionDetails
)

const hookResult = await options.beforeTemplateIsBaked({
params: params as any,
connection: connectionDetails,
containerExec: async (command): Promise<ExecResult> =>
rpc.execCommandInContainer(command),
})
createdNestedConnections.push(mappedConnection)

await teardownConnection(connectionDetails)
return {
...mappedConnection,
beforeTemplateIsBakedResult,
}
},
})

if (hookResult && !isSerializable(hookResult)) {
throw new TypeError(
"Return value of beforeTemplateIsBaked() hook could not be serialized. Make sure it returns only JSON-serializable values."
await Promise.all(
createdNestedConnections.map(async (connection) => {
await teardownConnection(connection)
await rpc.dropDatabase(connection.database)
})
)
}

return hookResult
}
},
},
{
post: async (data) => {
const worker = await workerPromise
await worker.available
worker.publish(data)
},
on: (data) => {
rpcCallback = data
await teardownConnection(connectionDetails)

if (hookResult && !isSerializable(hookResult)) {
throw new TypeError(
"Return value of beforeTemplateIsBaked() hook could not be serialized. Make sure it returns only JSON-serializable values."
)
}

return hookResult
}
},
},
}
)
{
post: async (data) => {
const worker = await workerPromise
await worker.available
worker.publish(data)
},
on: (data) => {
rpcCallback = data
},
}
)

// Automatically cleaned up by AVA since each test file runs in a separate worker
const _messageHandlerPromise = (async () => {
Expand Down
1 change: 1 addition & 0 deletions src/internal-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,5 @@ export interface SharedWorkerFunctions {
beforeTemplateIsBakedResult: unknown
}>
execCommandInContainer: (command: string[]) => Promise<ExecResult>
dropDatabase: (databaseName: string) => Promise<void>
}
5 changes: 5 additions & 0 deletions src/public-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ export interface GetTestPostgresDatabaseFactoryOptions<
connection: ConnectionDetails
params: Params
containerExec: (command: string[]) => Promise<ExecResult>
beforeTemplateIsBaked: (
options: {
params: Params
} & Pick<GetTestPostgresDatabaseOptions, "databaseDedupeKey">
) => Promise<GetTestPostgresDatabaseResult>
}) => Promise<any>
}

Expand Down
46 changes: 46 additions & 0 deletions src/tests/hooks.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import test from "ava"
import { getTestPostgresDatabaseFactory } from "~/index"
import { countDatabaseTemplates } from "./utils/count-database-templates"
import { doesDatabaseExist } from "./utils/does-database-exist"

test("beforeTemplateIsBaked", async (t) => {
let wasHookCalled = false
Expand Down Expand Up @@ -145,3 +146,48 @@ test("beforeTemplateIsBaked (result isn't serializable)", async (t) => {
}
)
})

test("beforeTemplateIsBaked, get nested database", async (t) => {
type DatabaseParams = {
type: "foo" | "bar"
}

let nestedDatabaseName: string | undefined = undefined

const getTestServer = getTestPostgresDatabaseFactory<DatabaseParams>({
postgresVersion: process.env.POSTGRES_VERSION,
workerDedupeKey: "beforeTemplateIsBakedHookNestedDatabase",
beforeTemplateIsBaked: async ({
params,
connection: { pool },
beforeTemplateIsBaked,
}) => {
if (params.type === "foo") {
await pool.query(`CREATE TABLE "foo" ("id" SERIAL PRIMARY KEY)`)
return { createdFoo: true }
}

await pool.query(`CREATE TABLE "bar" ("id" SERIAL PRIMARY KEY)`)
const fooDatabase = await beforeTemplateIsBaked({
params: { type: "foo" },
})
t.deepEqual(fooDatabase.beforeTemplateIsBakedResult, { createdFoo: true })

nestedDatabaseName = fooDatabase.database

await t.notThrowsAsync(async () => {
await fooDatabase.pool.query(`INSERT INTO "foo" DEFAULT VALUES`)
})

return { createdBar: true }
},
})

const database = await getTestServer(t, { type: "bar" })
t.deepEqual(database.beforeTemplateIsBakedResult, { createdBar: true })

t.false(
await doesDatabaseExist(database.pool, nestedDatabaseName!),
"Nested database should have been cleaned up after the parent hook completed"
)
})
12 changes: 12 additions & 0 deletions src/tests/utils/does-database-exist.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import { Pool } from "pg"

export const doesDatabaseExist = async (pool: Pool, databaseName: string) => {
const {
rows: [{ count }],
} = await pool.query(
'SELECT COUNT(*) FROM "pg_database" WHERE "datname" = $1',
[databaseName]
)

return count > 0
}
17 changes: 15 additions & 2 deletions src/worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ export class Worker {
const container = (await this.startContainerPromise).container
return container.exec(command)
},
dropDatabase: async (databaseName) => {
const { postgresClient } = await this.startContainerPromise
await postgresClient.query(`DROP DATABASE ${databaseName}`)
},
},
rpcChannel
)
Expand Down Expand Up @@ -148,8 +152,17 @@ export class Worker {
return
}

await this.forceDisconnectClientsFrom(databaseName!)
await postgresClient.query(`DROP DATABASE ${databaseName}`)
try {
await this.forceDisconnectClientsFrom(databaseName!)
await postgresClient.query(`DROP DATABASE ${databaseName}`)
} catch (error) {
if ((error as Error)?.message?.includes("does not exist")) {
// Database was likely a nested database and manually dropped by the test worker, ignore
return
}

throw error
}
})

return {
Expand Down

0 comments on commit 12fbf7c

Please sign in to comment.