diff --git a/src/advisory-lock.test.ts b/src/advisory-lock.test.ts new file mode 100644 index 0000000..690fd99 --- /dev/null +++ b/src/advisory-lock.test.ts @@ -0,0 +1,169 @@ +import { describe, expect } from "vitest"; +import { createPool } from "slonik"; + +import { pgsliceTest as test } from "./testing/index.js"; +import { AdvisoryLock, AdvisoryLockError } from "./advisory-lock.js"; +import { Table } from "./table.js"; + +describe("AdvisoryLock.withLock", () => { + test("executes handler and returns result", async ({ transaction }) => { + const table = Table.parse("test_table"); + const result = await AdvisoryLock.withLock( + transaction, + table, + "test_op", + async () => { + return "success"; + }, + ); + + expect(result).toBe("success"); + }); + + test("releases lock even if handler throws", async ({ transaction }) => { + const table = Table.parse("test_table"); + + await expect( + AdvisoryLock.withLock(transaction, table, "test_op", async () => { + throw new Error("handler error"); + }), + ).rejects.toThrow("handler error"); + + // Should be able to acquire the lock again since it was released + const result = await AdvisoryLock.withLock( + transaction, + table, + "test_op", + async () => "acquired again", + ); + expect(result).toBe("acquired again"); + }); + + test("throws AdvisoryLockError when lock is held by another session", async ({ + databaseUrl, + }) => { + const table = Table.parse("test_table"); + const operation = "test_op"; + + // Create two separate pools - each will hold a separate session + const pool1 = await createPool(databaseUrl.toString(), { + maximumPoolSize: 1, + queryRetryLimit: 0, + }); + const pool2 = await createPool(databaseUrl.toString(), { + maximumPoolSize: 1, + queryRetryLimit: 0, + }); + + try { + // Use a transaction in pool1 to hold the connection open while we hold the lock + await pool1.transaction(async (tx1) => { + // Acquire lock in the first session + const release = await AdvisoryLock.acquire(tx1, table, operation); + + // Try to acquire the same lock in the second session + await pool2.transaction(async (tx2) => { + await expect( + AdvisoryLock.acquire(tx2, table, operation), + ).rejects.toThrow(AdvisoryLockError); + }); + + await release(); + }); + } finally { + await pool1.end(); + await pool2.end(); + } + }); +}); + +describe("AdvisoryLock.acquire", () => { + test("returns a release function", async ({ transaction }) => { + const table = Table.parse("test_table"); + const release = await AdvisoryLock.acquire(transaction, table, "test_op"); + + expect(typeof release).toBe("function"); + await release(); + }); + + test("same table + different operation = different locks", async ({ + databaseUrl, + }) => { + const table = Table.parse("test_table"); + + const pool1 = await createPool(databaseUrl.toString(), { + maximumPoolSize: 1, + queryRetryLimit: 0, + }); + const pool2 = await createPool(databaseUrl.toString(), { + maximumPoolSize: 1, + queryRetryLimit: 0, + }); + + try { + // Use transactions to hold connections open + await pool1.transaction(async (tx1) => { + // Acquire lock for operation1 + const release1 = await AdvisoryLock.acquire(tx1, table, "operation1"); + + // Should be able to acquire lock for operation2 on same table in different session + await pool2.transaction(async (tx2) => { + const release2 = await AdvisoryLock.acquire(tx2, table, "operation2"); + await release2(); + }); + + await release1(); + }); + } finally { + await pool1.end(); + await pool2.end(); + } + }); + + test("different table + same operation = different locks", async ({ + databaseUrl, + }) => { + const table1 = Table.parse("table_one"); + const table2 = Table.parse("table_two"); + + const pool1 = await createPool(databaseUrl.toString(), { + maximumPoolSize: 1, + queryRetryLimit: 0, + }); + const pool2 = await createPool(databaseUrl.toString(), { + maximumPoolSize: 1, + queryRetryLimit: 0, + }); + + try { + // Use transactions to hold connections open + await pool1.transaction(async (tx1) => { + // Acquire lock for table1 + const release1 = await AdvisoryLock.acquire(tx1, table1, "same_op"); + + // Should be able to acquire lock for table2 with same operation + await pool2.transaction(async (tx2) => { + const release2 = await AdvisoryLock.acquire(tx2, table2, "same_op"); + await release2(); + }); + + await release1(); + }); + } finally { + await pool1.end(); + await pool2.end(); + } + }); +}); + +describe("AdvisoryLockError", () => { + test("has descriptive error message", () => { + const table = Table.parse("my_schema.my_table"); + const error = new AdvisoryLockError(table, "prep"); + + expect(error.message).toContain("prep"); + expect(error.message).toContain("my_schema.my_table"); + expect(error.message).toContain("Another pgslice operation"); + expect(error.name).toBe("AdvisoryLockError"); + }); +}); diff --git a/src/advisory-lock.ts b/src/advisory-lock.ts new file mode 100644 index 0000000..e6d5aa2 --- /dev/null +++ b/src/advisory-lock.ts @@ -0,0 +1,96 @@ +import { CommonQueryMethods } from "slonik"; +import { z } from "zod"; +import { Table } from "./table.js"; +import { sql } from "./sql-utils.js"; + +export class AdvisoryLockError extends Error { + override name = "AdvisoryLockError"; + + constructor(table: Table, operation: string) { + super( + `Could not acquire advisory lock for "${operation}" on table "${table.toString()}". ` + + `Another pgslice operation may be in progress.`, + ); + } +} + +export abstract class AdvisoryLock { + /** + * Executes a handler while holding an advisory lock. + * The lock is automatically released when the handler completes or throws. + */ + static async withLock( + connection: CommonQueryMethods, + table: Table, + operation: string, + handler: () => Promise, + ): Promise { + const release = await this.acquire(connection, table, operation); + try { + return await handler(); + } finally { + await release(); + } + } + + /** + * Acquires an advisory lock and returns a release function. + * Use this for generators that need to hold a lock across yields. + */ + static async acquire( + connection: CommonQueryMethods, + table: Table, + operation: string, + ): Promise<() => Promise> { + const key = await this.#getKey(connection, table, operation); + const acquired = await this.#tryAcquire(connection, key); + + if (!acquired) { + throw new AdvisoryLockError(table, operation); + } + + return async () => { + await this.#release(connection, key); + }; + } + + static async #getKey( + connection: CommonQueryMethods, + table: Table, + operation: string, + ): Promise { + const lockName = `${table.toString()}:${operation}`; + const result = await connection.one( + sql.type(z.object({ key: z.coerce.bigint() }))` + SELECT hashtext(${lockName})::bigint AS key + `, + ); + return result.key; + } + + static async #tryAcquire( + connection: CommonQueryMethods, + key: bigint, + ): Promise { + const result = await connection.one( + sql.type(z.object({ acquired: z.boolean() }))` + SELECT pg_try_advisory_lock(${key}) AS acquired + `, + ); + return result.acquired; + } + + static async #release( + connection: CommonQueryMethods, + key: bigint, + ): Promise { + const { acquired } = await connection.one( + sql.type( + z.object({ acquired: z.boolean() }), + )`SELECT pg_advisory_unlock(${key}) AS acquired`, + ); + if (!acquired) { + throw new Error("Attempted to release lock that was never held."); + } + } +} diff --git a/src/commands/disable-mirroring.ts b/src/commands/disable-mirroring.ts index 961d4db..9c8917f 100644 --- a/src/commands/disable-mirroring.ts +++ b/src/commands/disable-mirroring.ts @@ -25,7 +25,7 @@ export class DisableMirroringCommand extends BaseCommand { override async perform(pgslice: Pgslice): Promise { await pgslice.start(async (tx) => { - await this.context.pgslice.disableMirroring(tx, { table: this.table }); + await pgslice.disableMirroring(tx, { table: this.table }); this.context.stdout.write( `Mirroring triggers disabled for ${this.table}\n`, ); diff --git a/src/commands/fill.ts b/src/commands/fill.ts index 2096fd9..9dd560e 100644 --- a/src/commands/fill.ts +++ b/src/commands/fill.ts @@ -56,25 +56,27 @@ export class FillCommand extends BaseCommand { }); async perform(pgslice: Pgslice) { - let hasBatches = false; - for await (const batch of pgslice.fill({ - table: this.table, - swapped: this.swapped, - batchSize: this.batchSize, - start: this.start, - })) { - hasBatches = true; + await pgslice.start(async (conn) => { + let hasBatches = false; + for await (const batch of pgslice.fill(conn, { + table: this.table, + swapped: this.swapped, + batchSize: this.batchSize, + start: this.start, + })) { + hasBatches = true; - this.context.stdout.write(`/* batch ${batch.batchNumber} */\n`); + this.context.stdout.write(`/* batch ${batch.batchNumber} */\n`); - // Sleep between batches if requested - if (this.sleep) { - await sleep(this.sleep * 1000); + // Sleep between batches if requested + if (this.sleep) { + await sleep(this.sleep * 1000); + } } - } - if (!hasBatches) { - this.context.stdout.write("/* nothing to fill */\n"); - } + if (!hasBatches) { + this.context.stdout.write("/* nothing to fill */\n"); + } + }); } } diff --git a/src/commands/synchronize.ts b/src/commands/synchronize.ts index 50d59aa..f03bc8b 100644 --- a/src/commands/synchronize.ts +++ b/src/commands/synchronize.ts @@ -75,36 +75,38 @@ export class SynchronizeCommand extends BaseCommand { let targetName: string | null = null; let headerPrinted = false; - for await (const batch of pgslice.synchronize({ - table: this.table, - start: this.start, - windowSize: this.windowSize, - dryRun: this.dryRun, - })) { - // Print header on first batch (we need synchronizer to know table names) - if (!headerPrinted) { - // Get table names from the batch (inferred from the command options) - sourceName = this.table; - targetName = `${this.table}_intermediate`; - this.#printHeader(sourceName, targetName); - headerPrinted = true; + await pgslice.start(async (conn) => { + for await (const batch of pgslice.synchronize(conn, { + table: this.table, + start: this.start, + windowSize: this.windowSize, + dryRun: this.dryRun, + })) { + // Print header on first batch (we need synchronizer to know table names) + if (!headerPrinted) { + // Get table names from the batch (inferred from the command options) + sourceName = this.table; + targetName = `${this.table}_intermediate`; + this.#printHeader(sourceName, targetName); + headerPrinted = true; + } + + stats.totalBatches++; + stats.totalRowsCompared += batch.rowsCompared; + stats.matchingRows += batch.matchingRows; + stats.rowsWithDifferences += batch.rowsUpdated; + stats.missingRows += batch.rowsInserted; + stats.extraRows += batch.rowsDeleted; + + this.#printBatchResult(batch); + + // Calculate and apply adaptive delay + const sleepTime = this.#calculateSleepTime(batch.batchDurationMs); + if (sleepTime > 0) { + await sleep(sleepTime * 1000); + } } - - stats.totalBatches++; - stats.totalRowsCompared += batch.rowsCompared; - stats.matchingRows += batch.matchingRows; - stats.rowsWithDifferences += batch.rowsUpdated; - stats.missingRows += batch.rowsInserted; - stats.extraRows += batch.rowsDeleted; - - this.#printBatchResult(batch); - - // Calculate and apply adaptive delay - const sleepTime = this.#calculateSleepTime(batch.batchDurationMs); - if (sleepTime > 0) { - await sleep(sleepTime * 1000); - } - } + }); // Print summary this.#printSummary(stats); diff --git a/src/fill.test.ts b/src/fill.test.ts index 2219c41..34d2832 100644 --- a/src/fill.test.ts +++ b/src/fill.test.ts @@ -209,7 +209,10 @@ describe("Pgslice.fill", () => { // Fill data const batches = []; - for await (const batch of pgslice.fill({ table: "posts", batchSize: 10 })) { + for await (const batch of pgslice.fill(transaction, { + table: "posts", + batchSize: 10, + })) { batches.push(batch); } @@ -258,7 +261,10 @@ describe("Pgslice.fill", () => { // Fill data const batches = []; - for await (const batch of pgslice.fill({ table: "posts", batchSize: 10 })) { + for await (const batch of pgslice.fill(transaction, { + table: "posts", + batchSize: 10, + })) { batches.push(batch); } @@ -287,7 +293,7 @@ describe("Pgslice.fill", () => { }); const batches = []; - for await (const batch of pgslice.fill({ table: "posts" })) { + for await (const batch of pgslice.fill(transaction, { table: "posts" })) { batches.push(batch); } @@ -310,7 +316,7 @@ describe("Pgslice.fill", () => { // Start from ID 5 const batches = []; - for await (const batch of pgslice.fill({ + for await (const batch of pgslice.fill(transaction, { table: "posts", start: "5", batchSize: 100, @@ -335,7 +341,9 @@ describe("Pgslice.fill", () => { `); const error = await (async () => { - for await (const _batch of pgslice.fill({ table: "posts" })) { + for await (const _batch of pgslice.fill(transaction, { + table: "posts", + })) { // should not reach here } })().catch((e) => e); @@ -352,7 +360,9 @@ describe("Pgslice.fill", () => { `); const error = await (async () => { - for await (const _batch of pgslice.fill({ table: "posts" })) { + for await (const _batch of pgslice.fill(transaction, { + table: "posts", + })) { // should not reach here } })().catch((e) => e); @@ -373,7 +383,9 @@ describe("Pgslice.fill", () => { `); const error = await (async () => { - for await (const _batch of pgslice.fill({ table: "posts" })) { + for await (const _batch of pgslice.fill(transaction, { + table: "posts", + })) { // should not reach here } })().catch((e) => e); diff --git a/src/index.ts b/src/index.ts index 6da5b0c..ab77bba 100644 --- a/src/index.ts +++ b/src/index.ts @@ -2,6 +2,7 @@ export { createCli } from "./cli.js"; export { Pgslice } from "./pgslice.js"; export { Table } from "./table.js"; export { Filler } from "./filler.js"; +export { AdvisoryLockError } from "./advisory-lock.js"; export type { Period, Cast, diff --git a/src/pgslice.test.ts b/src/pgslice.test.ts index 2510ff1..b661fcd 100644 --- a/src/pgslice.test.ts +++ b/src/pgslice.test.ts @@ -683,7 +683,7 @@ describe("Pgslice.synchronize", () => { }); // Fill initial data - for await (const _batch of pgslice.fill({ table: "posts" })) { + for await (const _batch of pgslice.fill(transaction, { table: "posts" })) { // consume } @@ -694,7 +694,9 @@ describe("Pgslice.synchronize", () => { // Synchronize const batches = []; - for await (const batch of pgslice.synchronize({ table: "posts" })) { + for await (const batch of pgslice.synchronize(transaction, { + table: "posts", + })) { batches.push(batch); } @@ -715,7 +717,9 @@ describe("Pgslice.synchronize", () => { `); const error = await (async () => { - for await (const _batch of pgslice.synchronize({ table: "posts" })) { + for await (const _batch of pgslice.synchronize(transaction, { + table: "posts", + })) { // should not reach here } })().catch((e) => e); @@ -732,7 +736,9 @@ describe("Pgslice.synchronize", () => { `); const error = await (async () => { - for await (const _batch of pgslice.synchronize({ table: "posts" })) { + for await (const _batch of pgslice.synchronize(transaction, { + table: "posts", + })) { // should not reach here } })().catch((e) => e); diff --git a/src/pgslice.ts b/src/pgslice.ts index badc762..2482e41 100644 --- a/src/pgslice.ts +++ b/src/pgslice.ts @@ -1,6 +1,7 @@ import { CommonQueryMethods, createPool, + DatabasePoolConnection, DatabaseTransactionConnection, type DatabasePool, } from "slonik"; @@ -29,21 +30,27 @@ import { Mirroring } from "./mirroring.js"; import { Filler } from "./filler.js"; import { Synchronizer } from "./synchronizer.js"; import { Swapper } from "./swapper.js"; +import { AdvisoryLock } from "./advisory-lock.js"; interface PgsliceOptions { dryRun?: boolean; + + /** + * Whether to use Postgres advisory locks to prevent concurrent operations + * on the same table for the same operation. Defaults to true. + */ + advisoryLocks?: boolean; } export class Pgslice { - #connection: DatabasePool | CommonQueryMethods | null = null; + #pool: DatabasePool | null = null; #dryRun: boolean; + #advisoryLocks: boolean; - constructor( - connection: DatabasePool | CommonQueryMethods, - options: PgsliceOptions, - ) { + constructor(pool: DatabasePool, options: PgsliceOptions) { this.#dryRun = options.dryRun ?? false; - this.#connection = connection; + this.#advisoryLocks = options.advisoryLocks ?? true; + this.#pool = pool; } static async connect( @@ -56,7 +63,7 @@ export class Pgslice { url.searchParams.set("application_name", "pgslice"); } - const connection = await createPool(url.toString(), { + const pool = await createPool(url.toString(), { // We don't want to perform any operations in parallel, and should // only ever need a single connection at a time. maximumPoolSize: 1, @@ -64,34 +71,57 @@ export class Pgslice { // Never retry queries. queryRetryLimit: 0, }); - const instance = new Pgslice(connection, options); + const instance = new Pgslice(pool, options); return instance; } - private get connection() { - if (!this.#connection) { + private get pool() { + if (!this.#pool) { throw new Error("Not connected to the database"); } - return this.#connection; + return this.#pool; } async start( - handler: (transaction: DatabaseTransactionConnection) => Promise, + handler: (transaction: DatabasePoolConnection) => Promise, ): Promise { if (this.#dryRun) { throw new Error("Dry run not yet supported."); } - return this.connection.transaction(handler, 0); + return this.pool.connect(handler); + } + + async #withLock( + tx: CommonQueryMethods, + table: Table, + operation: string, + handler: () => Promise, + ): Promise { + if (!this.#advisoryLocks) { + return handler(); + } + return AdvisoryLock.withLock(tx, table, operation, handler); + } + + async #acquireLock( + connection: CommonQueryMethods, + table: Table, + operation: string, + ): Promise<() => Promise> { + if (!this.#advisoryLocks) { + return async () => {}; + } + return AdvisoryLock.acquire(connection, table, operation); } async close(): Promise { - if (this.#connection) { - if ("end" in this.#connection) { - await this.#connection.end(); + if (this.#pool) { + if ("end" in this.#pool) { + await this.#pool.end(); } - this.#connection = null; + this.#pool = null; } } @@ -103,41 +133,50 @@ export class Pgslice { * with `partition: false`. */ async prep( - tx: DatabaseTransactionConnection, + connection: DatabasePoolConnection, options: PrepOptions, ): Promise { const table = Table.parse(options.table); - const intermediate = table.intermediate; - - if (!(await table.exists(tx))) { - throw new Error(`Table not found: ${table.toString()}`); - } - if (await intermediate.exists(tx)) { - throw new Error(`Table already exists: ${intermediate.toString()}`); - } - - if (options.partition) { - const columns = await table.columns(tx); - const columnInfo = columns.find((c) => c.name === options.column); - if (!columnInfo) { - throw new Error(`Column not found: ${options.column}`); - } - - if (!isPeriod(options.period)) { - throw new Error(`Invalid period: ${options.period}`); - } - - await this.#createPartitionedIntermediateTable( - tx, - table, - intermediate, - columnInfo, - options.period, - ); - } else { - await this.#createUnpartitionedIntermediateTable(tx, table, intermediate); - } + return connection.transaction(async (tx) => + this.#withLock(tx, table, "prep", async () => { + const intermediate = table.intermediate; + + if (!(await table.exists(tx))) { + throw new Error(`Table not found: ${table.toString()}`); + } + + if (await intermediate.exists(tx)) { + throw new Error(`Table already exists: ${intermediate.toString()}`); + } + + if (options.partition) { + const columns = await table.columns(tx); + const columnInfo = columns.find((c) => c.name === options.column); + if (!columnInfo) { + throw new Error(`Column not found: ${options.column}`); + } + + if (!isPeriod(options.period)) { + throw new Error(`Invalid period: ${options.period}`); + } + + await this.#createPartitionedIntermediateTable( + tx, + table, + intermediate, + columnInfo, + options.period, + ); + } else { + await this.#createUnpartitionedIntermediateTable( + tx, + table, + intermediate, + ); + } + }), + ); } async #createPartitionedIntermediateTable( @@ -232,82 +271,87 @@ export class Pgslice { * Adds partitions to a partitioned table. */ async addPartitions( - tx: DatabaseTransactionConnection, + connection: DatabasePoolConnection, options: AddPartitionsOptions, ): Promise { const originalTable = Table.parse(options.table); - const targetTable = options.intermediate - ? originalTable.intermediate - : originalTable; - if (!(await targetTable.exists(tx))) { - throw new Error(`Table not found: ${targetTable.toString()}`); - } - - const settings = await targetTable.fetchSettings(tx); - if (!settings) { - let message = `No settings found: ${targetTable.toString()}`; - if (!options.intermediate) { - message += "\nDid you mean to use --intermediate?"; - } - throw new Error(message); - } - - const past = options.past ?? 0; - const future = options.future ?? 0; - - // Determine which table to get the primary key from. - // For intermediate tables, use the original table. - // For swapped tables, use the last existing partition (if any) or the original. - let schemaTable: Table; - if (options.intermediate) { - schemaTable = originalTable; - } else { - const existingPartitions = await targetTable.partitions(tx); - schemaTable = - existingPartitions.length > 0 - ? existingPartitions[existingPartitions.length - 1] + return connection.transaction(async (tx) => + this.#withLock(tx, originalTable, "add_partitions", async () => { + const targetTable = options.intermediate + ? originalTable.intermediate : originalTable; - } - - const primaryKeyColumn = await schemaTable.primaryKey(tx); - - const dateRanges = new DateRanges({ - period: settings.period, - past, - future, - }); - - for (const range of dateRanges) { - const partitionTable = originalTable.partition(range.suffix); - - if (await partitionTable.exists(tx)) { - continue; - } - - const startDate = formatDateForSql(range.start, settings.cast); - const endDate = formatDateForSql(range.end, settings.cast); - // Build the CREATE TABLE statement - let createSql = sql.fragment` - CREATE TABLE ${partitionTable.sqlIdentifier} - PARTITION OF ${targetTable.sqlIdentifier} - FOR VALUES FROM (${startDate}) TO (${endDate}) - `; - - if (options.tablespace) { - createSql = sql.fragment`${createSql} TABLESPACE ${sql.identifier([options.tablespace])}`; - } - - await tx.query(sql.typeAlias("void")`${createSql}`); - - await tx.query( - sql.typeAlias("void")` - ALTER TABLE ${partitionTable.sqlIdentifier} - ADD PRIMARY KEY (${sql.identifier([primaryKeyColumn])}) - `, - ); - } + if (!(await targetTable.exists(tx))) { + throw new Error(`Table not found: ${targetTable.toString()}`); + } + + const settings = await targetTable.fetchSettings(tx); + if (!settings) { + let message = `No settings found: ${targetTable.toString()}`; + if (!options.intermediate) { + message += "\nDid you mean to use --intermediate?"; + } + throw new Error(message); + } + + const past = options.past ?? 0; + const future = options.future ?? 0; + + // Determine which table to get the primary key from. + // For intermediate tables, use the original table. + // For swapped tables, use the last existing partition (if any) or the original. + let schemaTable: Table; + if (options.intermediate) { + schemaTable = originalTable; + } else { + const existingPartitions = await targetTable.partitions(tx); + schemaTable = + existingPartitions.length > 0 + ? existingPartitions[existingPartitions.length - 1] + : originalTable; + } + + const primaryKeyColumn = await schemaTable.primaryKey(tx); + + const dateRanges = new DateRanges({ + period: settings.period, + past, + future, + }); + + for (const range of dateRanges) { + const partitionTable = originalTable.partition(range.suffix); + + if (await partitionTable.exists(tx)) { + continue; + } + + const startDate = formatDateForSql(range.start, settings.cast); + const endDate = formatDateForSql(range.end, settings.cast); + + // Build the CREATE TABLE statement + let createSql = sql.fragment` + CREATE TABLE ${partitionTable.sqlIdentifier} + PARTITION OF ${targetTable.sqlIdentifier} + FOR VALUES FROM (${startDate}) TO (${endDate}) + `; + + if (options.tablespace) { + createSql = sql.fragment`${createSql} TABLESPACE ${sql.identifier([options.tablespace])}`; + } + + await tx.query(sql.typeAlias("void")`${createSql}`); + + await tx.query( + sql.typeAlias("void")` + ALTER TABLE ${partitionTable.sqlIdentifier} + ADD PRIMARY KEY (${sql.identifier([primaryKeyColumn])}) + `, + ); + } + }), + ); } /** @@ -316,21 +360,26 @@ export class Pgslice { * table are automatically replicated to the target table. */ async enableMirroring( - tx: DatabaseTransactionConnection, + connection: DatabasePoolConnection, options: EnableMirroringOptions, ): Promise { const table = Table.parse(options.table); - const targetType = options.targetType ?? "intermediate"; - const target = table[targetType]; - if (!(await table.exists(tx))) { - throw new Error(`Table not found: ${table.toString()}`); - } - if (!(await target.exists(tx))) { - throw new Error(`Table not found: ${target.toString()}`); - } + return connection.transaction(async (tx) => + this.#withLock(tx, table, "enable_mirroring", async () => { + const targetType = options.targetType ?? "intermediate"; + const target = table[targetType]; + + if (!(await table.exists(tx))) { + throw new Error(`Table not found: ${table.toString()}`); + } + if (!(await target.exists(tx))) { + throw new Error(`Table not found: ${target.toString()}`); + } - await new Mirroring({ source: table, targetType }).enable(tx, target); + await new Mirroring({ source: table, targetType }).enable(tx, target); + }), + ); } /** @@ -338,17 +387,22 @@ export class Pgslice { * This removes the triggers that were created by enableMirroring. */ async disableMirroring( - tx: DatabaseTransactionConnection, + connection: DatabasePoolConnection, options: DisableMirroringOptions, ): Promise { const table = Table.parse(options.table); - const targetType = options.targetType ?? "intermediate"; - if (!(await table.exists(tx))) { - throw new Error(`Table not found: ${table.toString()}`); - } + return connection.transaction(async (tx) => + this.#withLock(tx, table, "disable_mirroring", async () => { + const targetType = options.targetType ?? "intermediate"; + + if (!(await table.exists(tx))) { + throw new Error(`Table not found: ${table.toString()}`); + } - await new Mirroring({ source: table, targetType }).disable(tx); + await new Mirroring({ source: table, targetType }).disable(tx); + }), + ); } /** @@ -358,11 +412,24 @@ export class Pgslice { * @param options - Fill options including table names and batch configuration * @yields FillBatchResult after each batch is processed */ - async *fill(options: FillOptions): AsyncGenerator { - const filler = await this.start((tx) => Filler.init(tx, options)); + async *fill( + connection: DatabasePoolConnection, + options: FillOptions, + ): AsyncGenerator { + const releaseLock = await this.#acquireLock( + connection, + Table.parse(options.table), + "fill", + ); + + try { + const filler = await Filler.init(connection, options); - for await (const batch of filler.fill(this.connection)) { - yield batch; + for await (const batch of filler.fill(connection)) { + yield batch; + } + } finally { + await releaseLock(); } } @@ -374,14 +441,22 @@ export class Pgslice { * @yields SynchronizeBatchResult after each batch is processed */ async *synchronize( + connection: DatabasePoolConnection, options: SynchronizeOptions, ): AsyncGenerator { - const synchronizer = await this.start((tx) => - Synchronizer.init(tx, options), + const releaseLock = await this.#acquireLock( + connection, + Table.parse(options.table), + "synchronize", ); + try { + const synchronizer = await Synchronizer.init(connection, options); - for await (const batch of synchronizer.synchronize(this.connection)) { - yield batch; + for await (const batch of synchronizer.synchronize(connection)) { + yield batch; + } + } finally { + await releaseLock(); } } @@ -395,15 +470,21 @@ export class Pgslice { * - A retired mirroring trigger is enabled to keep the retired table in sync */ async swap( - tx: DatabaseTransactionConnection, + connection: DatabasePoolConnection, options: SwapOptions, ): Promise { - const swapper = new Swapper({ - table: options.table, - direction: "forward", - lockTimeout: options.lockTimeout, - }); - await swapper.execute(tx); + const table = Table.parse(options.table); + + return connection.transaction(async (tx) => + this.#withLock(tx, table, "swap", async () => { + const swapper = new Swapper({ + table, + direction: "forward", + lockTimeout: options.lockTimeout, + }); + await swapper.execute(tx); + }), + ); } /** @@ -416,15 +497,21 @@ export class Pgslice { * - An intermediate mirroring trigger is enabled to keep the intermediate table in sync */ async unswap( - tx: DatabaseTransactionConnection, + connection: DatabasePoolConnection, options: UnswapOptions, ): Promise { - const swapper = new Swapper({ - table: options.table, - direction: "reverse", - lockTimeout: options.lockTimeout, - }); - await swapper.execute(tx); + const table = Table.parse(options.table); + + return connection.transaction(async (tx) => + this.#withLock(tx, table, "unswap", async () => { + const swapper = new Swapper({ + table, + direction: "reverse", + lockTimeout: options.lockTimeout, + }); + await swapper.execute(tx); + }), + ); } /** @@ -439,11 +526,11 @@ export class Pgslice { const table = Table.parse(options.table); const targetTable = options.swapped ? table : table.intermediate; - if (!(await targetTable.exists(this.connection))) { + if (!(await targetTable.exists(this.pool))) { throw new Error(`Table not found: ${targetTable.toString()}`); } - await this.connection.query( + await this.pool.query( sql.typeAlias("void")`ANALYZE VERBOSE ${targetTable.sqlIdentifier}`, ); @@ -457,20 +544,25 @@ export class Pgslice { * with CASCADE, which also removes any dependent objects like partitions. */ async unprep( - tx: DatabaseTransactionConnection, + connection: DatabasePoolConnection, options: UnprepOptions, ): Promise { const table = Table.parse(options.table); - const intermediate = table.intermediate; - if (!(await intermediate.exists(tx))) { - throw new Error(`Table not found: ${intermediate.toString()}`); - } + return connection.transaction(async (tx) => + this.#withLock(tx, table, "unprep", async () => { + const intermediate = table.intermediate; - await tx.query( - sql.typeAlias("void")` - DROP TABLE ${intermediate.sqlIdentifier} CASCADE - `, + if (!(await intermediate.exists(tx))) { + throw new Error(`Table not found: ${intermediate.toString()}`); + } + + await tx.query( + sql.typeAlias("void")` + DROP TABLE ${intermediate.sqlIdentifier} CASCADE + `, + ); + }), ); } } diff --git a/src/swapper.test.ts b/src/swapper.test.ts index 12a8c0e..8a9f078 100644 --- a/src/swapper.test.ts +++ b/src/swapper.test.ts @@ -13,7 +13,7 @@ describe("Swapper", () => { transaction, }) => { const swapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "forward", }); @@ -33,7 +33,7 @@ describe("Swapper", () => { `); const swapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "forward", }); @@ -65,7 +65,7 @@ describe("Swapper", () => { `); const swapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "forward", }); @@ -93,7 +93,7 @@ describe("Swapper", () => { test("renames original table to retired", async ({ transaction }) => { const swapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "forward", }); @@ -107,7 +107,7 @@ describe("Swapper", () => { transaction, }) => { const swapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "forward", }); @@ -121,7 +121,7 @@ describe("Swapper", () => { transaction, }) => { const swapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "forward", }); @@ -140,7 +140,7 @@ describe("Swapper", () => { `); const swapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "forward", }); @@ -186,7 +186,7 @@ describe("Swapper", () => { `); const swapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "forward", }); @@ -237,7 +237,7 @@ describe("Swapper", () => { expect(beforeResult).not.toBeNull(); const swapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "forward", }); @@ -254,7 +254,7 @@ describe("Swapper", () => { test("creates retired mirroring trigger", async ({ transaction }) => { const swapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "forward", }); @@ -273,7 +273,7 @@ describe("Swapper", () => { transaction, }) => { const swapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "forward", }); @@ -310,7 +310,7 @@ describe("Swapper", () => { test("uses default lock timeout of 5s", async ({ transaction }) => { const swapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "forward", }); @@ -328,7 +328,7 @@ describe("Swapper", () => { transaction, }) => { const swapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "forward", lockTimeout: "10s", }); @@ -365,7 +365,7 @@ describe("Swapper", () => { transaction, }) => { const swapper = new Swapper({ - table: "myschema.posts", + table: Table.parse("myschema.posts"), direction: "forward", }); @@ -388,7 +388,7 @@ describe("Swapper", () => { transaction, }) => { const swapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "reverse", }); @@ -408,7 +408,7 @@ describe("Swapper", () => { `); const swapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "reverse", }); @@ -440,7 +440,7 @@ describe("Swapper", () => { `); const swapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "reverse", }); @@ -470,7 +470,7 @@ describe("Swapper", () => { transaction, }) => { const swapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "reverse", }); @@ -484,7 +484,7 @@ describe("Swapper", () => { transaction, }) => { const swapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "reverse", }); @@ -498,7 +498,7 @@ describe("Swapper", () => { transaction, }) => { const swapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "reverse", }); @@ -517,7 +517,7 @@ describe("Swapper", () => { `); const swapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "reverse", }); @@ -563,7 +563,7 @@ describe("Swapper", () => { `); const swapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "reverse", }); @@ -631,7 +631,7 @@ describe("Swapper", () => { expect(beforeResult).not.toBeNull(); const swapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "reverse", }); @@ -650,7 +650,7 @@ describe("Swapper", () => { transaction, }) => { const swapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "reverse", }); @@ -669,7 +669,7 @@ describe("Swapper", () => { transaction, }) => { const swapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "reverse", }); @@ -706,7 +706,7 @@ describe("Swapper", () => { test("uses default lock timeout of 5s", async ({ transaction }) => { const swapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "reverse", }); @@ -724,7 +724,7 @@ describe("Swapper", () => { transaction, }) => { const swapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "reverse", lockTimeout: "15s", }); @@ -761,7 +761,7 @@ describe("Swapper", () => { transaction, }) => { const swapper = new Swapper({ - table: "myschema.posts", + table: Table.parse("myschema.posts"), direction: "reverse", }); @@ -804,7 +804,7 @@ describe("Swapper", () => { expect(await retired.exists(transaction)).toBe(false); const forwardSwapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "forward", }); await forwardSwapper.execute(transaction); @@ -814,7 +814,7 @@ describe("Swapper", () => { expect(await retired.exists(transaction)).toBe(true); const reverseSwapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "reverse", }); await reverseSwapper.execute(transaction); @@ -833,13 +833,13 @@ describe("Swapper", () => { `); const forwardSwapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "forward", }); await forwardSwapper.execute(transaction); const reverseSwapper = new Swapper({ - table: "posts", + table: Table.parse("posts"), direction: "reverse", }); await reverseSwapper.execute(transaction); diff --git a/src/swapper.ts b/src/swapper.ts index f0669f4..a8e4776 100644 --- a/src/swapper.ts +++ b/src/swapper.ts @@ -6,7 +6,7 @@ import type { SwapDirection } from "./types.js"; import { sql } from "./sql-utils.js"; interface SwapperOptions { - table: string; + table: Table; direction: SwapDirection; lockTimeout?: string; } @@ -37,7 +37,7 @@ export class Swapper { readonly #lockTimeout: string; constructor(options: SwapperOptions) { - this.#table = Table.parse(options.table); + this.#table = options.table; this.#direction = options.direction; this.#lockTimeout = options.lockTimeout ?? "5s"; } diff --git a/src/testing/pgslice.ts b/src/testing/pgslice.ts index 0141021..48742eb 100644 --- a/src/testing/pgslice.ts +++ b/src/testing/pgslice.ts @@ -1,6 +1,10 @@ -import { test as baseTest } from "vitest"; +import { test as baseTest, vi } from "vitest"; import { Pgslice } from "../pgslice.js"; -import { createPool, DatabaseTransactionConnection } from "slonik"; +import { + createPool, + DatabasePool, + DatabaseTransactionConnection, +} from "slonik"; class TestRollbackError extends Error { constructor() { @@ -8,24 +12,35 @@ class TestRollbackError extends Error { } } -function getTestDatabaseUrl(): string { +function getTestDatabaseUrl(): URL { const url = process.env.PGSLICE_URL; if (!url) { throw new Error("PGSLICE_URL environment variable must be set for tests"); } - return url; + return new URL(url); } export const pgsliceTest = baseTest.extend<{ + databaseUrl: URL; pgslice: Pgslice; + pool: DatabasePool; transaction: DatabaseTransactionConnection; }>({ - transaction: async ({}, use) => { - const connection = await createPool(getTestDatabaseUrl().toString()); + databaseUrl: getTestDatabaseUrl(), + pool: async ({ databaseUrl }, use) => { + const pool = await createPool(databaseUrl.toString()); try { - await connection.transaction(async (transaction) => { + await use(pool); + } finally { + await pool.end(); + } + }, + + transaction: async ({ pool }, use) => { + try { + await pool.transaction(async (transaction) => { await use(transaction); throw new TestRollbackError(); }); @@ -34,8 +49,38 @@ export const pgsliceTest = baseTest.extend<{ } }, - pgslice: async ({ transaction }, use) => { - const pgslice = new Pgslice(transaction, {}); + pgslice: async ({ pool, transaction }, use) => { + const transactionalizedPool = { + ...transaction, + configuration: pool.configuration, + connect: vi.fn().mockImplementation((handler) => handler(transaction)), + end: vi.fn().mockResolvedValue(undefined), + state: vi.fn().mockReturnValue(pool.state), + + // A bunch of event emitter stuff that we don't use but having this + // makes the compiler helper. + addListener: vi.fn().mockReturnThis(), + emit: vi.fn().mockReturnValue(false), + eventNames: vi.fn().mockReturnValue([]), + getMaxListeners: vi.fn().mockReturnValue(0), + listenerCount: vi.fn().mockReturnValue(0), + listeners: vi.fn().mockReturnThis(), + off: vi.fn().mockReturnThis(), + on: vi.fn().mockReturnThis(), + once: vi.fn().mockReturnThis(), + prependListener: vi.fn().mockReturnThis(), + prependOnceListener: vi.fn().mockReturnThis(), + rawListeners: vi.fn().mockReturnThis(), + removeAllListeners: vi.fn().mockReturnThis(), + removeListener: vi.fn().mockReturnThis(), + setMaxListeners: vi.fn().mockReturnThis(), + } satisfies DatabasePool; + + const pgslice = new Pgslice(transactionalizedPool, { + // Disable advisory locks since we run tests both transactionally + // and concurrently, which these would otherwise interfere with. + advisoryLocks: false, + }); await use(pgslice);