diff --git a/packages/bsky/src/api/app/bsky/notification/listNotifications.ts b/packages/bsky/src/api/app/bsky/notification/listNotifications.ts index 106cdafd490..31dd5d87d48 100644 --- a/packages/bsky/src/api/app/bsky/notification/listNotifications.ts +++ b/packages/bsky/src/api/app/bsky/notification/listNotifications.ts @@ -1,5 +1,6 @@ import { mapDefined } from '@atproto/common' import { InvalidRequestError } from '@atproto/xrpc-server' +import { ServerConfig } from '../../../../config' import { AppContext } from '../../../../context' import { HydrateCtx, Hydrator } from '../../../../hydration/hydrator' import { Server } from '../../../../lexicon' @@ -15,7 +16,7 @@ import { import { Notification } from '../../../../proto/bsky_pb' import { uriToDid as didFromUri } from '../../../../util/uris' import { Views } from '../../../../views' -import { clearlyBadCursor, resHeaders } from '../../../util' +import { resHeaders } from '../../../util' export default function (server: Server, ctx: AppContext) { const listNotifications = createPipeline( @@ -93,6 +94,23 @@ const paginateNotifications = async (opts: { } } +/** + * Applies a configurable delay to the datetime string of a cursor, + * effectively allowing for a delay on listing the notifications. + * This is useful to allow time for services to process notifications + * before they are listed to the user. + */ +export const delayCursor = ( + cursorStr: string | undefined, + delayMs: number, +): string => { + const nowMinusDelay = Date.now() - delayMs + if (cursorStr === undefined) return new Date(nowMinusDelay).toISOString() + const cursor = new Date(cursorStr).getTime() + if (isNaN(cursor)) return cursorStr + return new Date(Math.min(cursor, nowMinusDelay)).toISOString() +} + const skeleton = async ( input: SkeletonFnInput, ): Promise => { @@ -100,17 +118,20 @@ const skeleton = async ( if (params.seenAt) { throw new InvalidRequestError('The seenAt parameter is unsupported') } + + const originalCursor = params.cursor + const delayedCursor = delayCursor( + originalCursor, + ctx.cfg.notificationsDelayMs, + ) const viewer = params.hydrateCtx.viewer const priority = params.priority ?? (await getPriority(ctx, viewer)) - if (clearlyBadCursor(params.cursor)) { - return { notifs: [], priority } - } const [res, lastSeenRes] = await Promise.all([ paginateNotifications({ ctx, priority, reasons: params.reasons, - cursor: params.cursor, + cursor: delayedCursor, limit: params.limit, viewer, }), @@ -122,7 +143,7 @@ const skeleton = async ( // @NOTE for the first page of results if there's no last-seen time, consider top notification unread // rather than all notifications. bit of a hack to be more graceful when seen times are out of sync. let lastSeenDate = lastSeenRes.timestamp?.toDate() - if (!lastSeenDate && !params.cursor) { + if (!lastSeenDate && !originalCursor) { lastSeenDate = res.notifications.at(0)?.timestamp?.toDate() } return { @@ -210,6 +231,7 @@ const presentation = ( type Context = { hydrator: Hydrator views: Views + cfg: ServerConfig } type Params = QueryParams & { diff --git a/packages/bsky/src/config.ts b/packages/bsky/src/config.ts index a94dde58bc3..e273fdd840b 100644 --- a/packages/bsky/src/config.ts +++ b/packages/bsky/src/config.ts @@ -49,6 +49,8 @@ export interface ServerConfigValues { bigThreadUris: Set bigThreadDepth?: number maxThreadDepth?: number + // notifications + notificationsDelayMs?: number // client config clientCheckEmailConfirmed?: boolean topicsEnabled?: boolean @@ -170,6 +172,10 @@ export class ServerConfig { ? parseInt(process.env.BSKY_MAX_THREAD_DEPTH || '', 10) : undefined + const notificationsDelayMs = process.env.BSKY_NOTIFICATIONS_DELAY_MS + ? parseInt(process.env.BSKY_NOTIFICATIONS_DELAY_MS || '', 10) + : 0 + const disableSsrfProtection = process.env.BSKY_DISABLE_SSRF_PROTECTION ? process.env.BSKY_DISABLE_SSRF_PROTECTION === 'true' : debugMode @@ -231,6 +237,7 @@ export class ServerConfig { bigThreadUris, bigThreadDepth, maxThreadDepth, + notificationsDelayMs, disableSsrfProtection, proxyAllowHTTP2, proxyHeadersTimeout, @@ -426,6 +433,10 @@ export class ServerConfig { return this.cfg.maxThreadDepth } + get notificationsDelayMs() { + return this.cfg.notificationsDelayMs ?? 0 + } + get disableSsrfProtection(): boolean { return this.cfg.disableSsrfProtection ?? false } diff --git a/packages/bsky/src/data-plane/server/db/pagination.ts b/packages/bsky/src/data-plane/server/db/pagination.ts index 90d7be0b84d..acbc47163fb 100644 --- a/packages/bsky/src/data-plane/server/db/pagination.ts +++ b/packages/bsky/src/data-plane/server/db/pagination.ts @@ -2,8 +2,8 @@ import { sql } from 'kysely' import { InvalidRequestError } from '@atproto/xrpc-server' import { AnyQb, DbRef } from './util' -export type Cursor = { primary: string; secondary: string } -export type LabeledResult = { +type KeysetCursor = { primary: string; secondary: string } +type KeysetLabeledResult = { primary: string | number secondary: string | number } @@ -22,14 +22,14 @@ export type LabeledResult = { * Result -*-> LabeledResult <-*-> Cursor <--> packed/string cursor * ↳ SQL Condition */ -export abstract class GenericKeyset { +export abstract class GenericKeyset { constructor( public primary: DbRef, public secondary: DbRef, ) {} abstract labelResult(result: R): LR - abstract labeledResultToCursor(labeled: LR): Cursor - abstract cursorToLabeledResult(cursor: Cursor): LR + abstract labeledResultToCursor(labeled: LR): KeysetCursor + abstract cursorToLabeledResult(cursor: KeysetCursor): LR packFromResult(results: R | R[]): string | undefined { const result = Array.isArray(results) ? results.at(-1) : results if (!result) return @@ -45,11 +45,11 @@ export abstract class GenericKeyset { if (!cursor) return return this.cursorToLabeledResult(cursor) } - packCursor(cursor?: Cursor): string | undefined { + packCursor(cursor?: KeysetCursor): string | undefined { if (!cursor) return return `${cursor.primary}__${cursor.secondary}` } - unpackCursor(cursorStr?: string): Cursor | undefined { + unpackCursor(cursorStr?: string): KeysetCursor | undefined { if (!cursorStr) return const result = cursorStr.split('__') const [primary, secondary, ...others] = result @@ -79,10 +79,43 @@ export abstract class GenericKeyset { } } } + paginate( + qb: QB, + opts: { + limit?: number + cursor?: string + direction?: 'asc' | 'desc' + tryIndex?: boolean + // By default, pg does nullsFirst + nullsLast?: boolean + }, + ): QB { + const { limit, cursor, direction = 'desc', tryIndex, nullsLast } = opts + const keysetSql = this.getSql(this.unpack(cursor), direction, tryIndex) + return qb + .if(!!limit, (q) => q.limit(limit as number)) + .if(!nullsLast, (q) => + q.orderBy(this.primary, direction).orderBy(this.secondary, direction), + ) + .if(!!nullsLast, (q) => + q + .orderBy( + direction === 'asc' + ? sql`${this.primary} asc nulls last` + : sql`${this.primary} desc nulls last`, + ) + .orderBy( + direction === 'asc' + ? sql`${this.secondary} asc nulls last` + : sql`${this.secondary} desc nulls last`, + ), + ) + .if(!!keysetSql, (qb) => (keysetSql ? qb.where(keysetSql) : qb)) as QB + } } type SortAtCidResult = { sortAt: string; cid: string } -type TimeCidLabeledResult = Cursor +type TimeCidLabeledResult = KeysetCursor export class TimeCidKeyset< TimeCidResult = SortAtCidResult, @@ -97,7 +130,7 @@ export class TimeCidKeyset< secondary: labeled.secondary, } } - cursorToLabeledResult(cursor: Cursor) { + cursorToLabeledResult(cursor: KeysetCursor) { const primaryDate = new Date(parseInt(cursor.primary, 10)) if (isNaN(primaryDate.getTime())) { throw new InvalidRequestError('Malformed cursor') @@ -127,6 +160,9 @@ export class IndexedAtDidKeyset extends TimeCidKeyset<{ } } +/** + * This is being deprecated. Use {@link GenericKeyset#paginate} instead. + */ export const paginate = < QB extends AnyQb, K extends GenericKeyset, @@ -142,32 +178,119 @@ export const paginate = < nullsLast?: boolean }, ): QB => { - const { - limit, - cursor, - keyset, - direction = 'desc', - tryIndex, - nullsLast, - } = opts - const keysetSql = keyset.getSql(keyset.unpack(cursor), direction, tryIndex) - return qb - .if(!!limit, (q) => q.limit(limit as number)) - .if(!nullsLast, (q) => - q.orderBy(keyset.primary, direction).orderBy(keyset.secondary, direction), - ) - .if(!!nullsLast, (q) => - q - .orderBy( - direction === 'asc' - ? sql`${keyset.primary} asc nulls last` - : sql`${keyset.primary} desc nulls last`, - ) - .orderBy( + return opts.keyset.paginate(qb, opts) +} + +type SingleKeyCursor = { + primary: string +} + +type SingleKeyLabeledResult = { + primary: string | number +} + +/** + * GenericSingleKey is similar to {@link GenericKeyset} but for a single key cursor. + */ +export abstract class GenericSingleKey { + constructor(public primary: DbRef) {} + abstract labelResult(result: R): LR + abstract labeledResultToCursor(labeled: LR): SingleKeyCursor + abstract cursorToLabeledResult(cursor: SingleKeyCursor): LR + packFromResult(results: R | R[]): string | undefined { + const result = Array.isArray(results) ? results.at(-1) : results + if (!result) return + return this.pack(this.labelResult(result)) + } + pack(labeled?: LR): string | undefined { + if (!labeled) return + const cursor = this.labeledResultToCursor(labeled) + return this.packCursor(cursor) + } + unpack(cursorStr?: string): LR | undefined { + const cursor = this.unpackCursor(cursorStr) + if (!cursor) return + return this.cursorToLabeledResult(cursor) + } + packCursor(cursor?: SingleKeyCursor): string | undefined { + if (!cursor) return + return cursor.primary + } + unpackCursor(cursorStr?: string): SingleKeyCursor | undefined { + if (!cursorStr) return + const result = cursorStr.split('__') + const [primary, ...others] = result + if (!primary || others.length > 0) { + throw new InvalidRequestError('Malformed cursor') + } + return { + primary, + } + } + getSql(labeled?: LR, direction?: 'asc' | 'desc') { + if (labeled === undefined) return + if (direction === 'asc') { + return sql`${this.primary} > ${labeled.primary}` + } + return sql`${this.primary} < ${labeled.primary}` + } + paginate( + qb: QB, + opts: { + limit?: number + cursor?: string + direction?: 'asc' | 'desc' + // By default, pg does nullsFirst + nullsLast?: boolean + }, + ): QB { + const { limit, cursor, direction = 'desc', nullsLast } = opts + const keySql = this.getSql(this.unpack(cursor), direction) + return qb + .if(!!limit, (q) => q.limit(limit as number)) + .if(!nullsLast, (q) => q.orderBy(this.primary, direction)) + .if(!!nullsLast, (q) => + q.orderBy( direction === 'asc' - ? sql`${keyset.secondary} asc nulls last` - : sql`${keyset.secondary} desc nulls last`, + ? sql`${this.primary} asc nulls last` + : sql`${this.primary} desc nulls last`, ), - ) - .if(!!keysetSql, (qb) => (keysetSql ? qb.where(keysetSql) : qb)) as QB + ) + .if(!!keySql, (qb) => (keySql ? qb.where(keySql) : qb)) as QB + } +} + +type SortAtResult = { sortAt: string } +type TimeLabeledResult = SingleKeyCursor + +export class IsoTimeKey extends GenericSingleKey< + TimeResult, + TimeLabeledResult +> { + labelResult(result: TimeResult): TimeLabeledResult + labelResult(result: TimeResult) { + return { primary: result.sortAt } + } + labeledResultToCursor(labeled: TimeLabeledResult) { + return { + primary: new Date(labeled.primary).toISOString(), + } + } + cursorToLabeledResult(cursor: SingleKeyCursor) { + const primaryDate = new Date(cursor.primary) + if (isNaN(primaryDate.getTime())) { + throw new InvalidRequestError('Malformed cursor') + } + return { + primary: primaryDate.toISOString(), + } + } +} + +export class IsoSortAtKey extends IsoTimeKey<{ + sortAt: string +}> { + labelResult(result: { sortAt: string }) { + return { primary: result.sortAt } + } } diff --git a/packages/bsky/src/data-plane/server/routes/notifs.ts b/packages/bsky/src/data-plane/server/routes/notifs.ts index f8a53efca03..6af5ba004e3 100644 --- a/packages/bsky/src/data-plane/server/routes/notifs.ts +++ b/packages/bsky/src/data-plane/server/routes/notifs.ts @@ -3,7 +3,7 @@ import { ServiceImpl } from '@connectrpc/connect' import { sql } from 'kysely' import { Service } from '../../../proto/bsky_connect' import { Database } from '../db' -import { TimeCidKeyset, paginate } from '../db/pagination' +import { IsoSortAtKey } from '../db/pagination' import { countAll, notSoftDeletedClause } from '../db/util' export default (db: Database): Partial> => ({ @@ -41,15 +41,10 @@ export default (db: Database): Partial> => ({ ]) .select(priorityFollowQb.as('priority')) - const keyset = new TimeCidKeyset( - ref('notif.sortAt'), - ref('notif.recordCid'), - ) - builder = paginate(builder, { + const key = new IsoSortAtKey(ref('notif.sortAt')) + builder = key.paginate(builder, { cursor, limit, - keyset, - tryIndex: true, }) const notifsRes = await builder.execute() @@ -63,7 +58,7 @@ export default (db: Database): Partial> => ({ })) return { notifications, - cursor: keyset.packFromResult(notifsRes), + cursor: key.packFromResult(notifsRes), } }, diff --git a/packages/bsky/tests/views/__snapshots__/notifications.test.ts.snap b/packages/bsky/tests/views/__snapshots__/notifications.test.ts.snap index 112b8c7ad8e..68753f6f26b 100644 --- a/packages/bsky/tests/views/__snapshots__/notifications.test.ts.snap +++ b/packages/bsky/tests/views/__snapshots__/notifications.test.ts.snap @@ -384,7 +384,7 @@ Array [ exports[`notification views fetches notifications with default priority 1`] = ` Object { - "cursor": "0000000000000__bafycid", + "cursor": "1970-01-01T00:00:00.000Z", "notifications": Array [ Object { "author": Object { @@ -486,7 +486,7 @@ Object { exports[`notification views fetches notifications with explicit priority 1`] = ` Object { - "cursor": "0000000000000__bafycid", + "cursor": "1970-01-01T00:00:00.000Z", "notifications": Array [ Object { "author": Object { @@ -588,7 +588,7 @@ Object { exports[`notification views fetches notifications with explicit priority 2`] = ` Object { - "cursor": "0000000000000__bafycid", + "cursor": "1970-01-01T00:00:00.000Z", "notifications": Array [ Object { "author": Object { diff --git a/packages/bsky/tests/views/notifications.test.ts b/packages/bsky/tests/views/notifications.test.ts index 94250ae1bb3..31f462eed16 100644 --- a/packages/bsky/tests/views/notifications.test.ts +++ b/packages/bsky/tests/views/notifications.test.ts @@ -1,5 +1,6 @@ import { AtpAgent } from '@atproto/api' import { SeedClient, TestNetwork, basicSeed } from '@atproto/dev-env' +import { delayCursor } from '../../src/api/app/bsky/notification/listNotifications' import { ids } from '../../src/lexicon/lexicons' import { Notification } from '../../src/lexicon/types/app/bsky/notification/listNotifications' import { forSnapshot, paginateAll } from '../_util' @@ -497,17 +498,196 @@ describe('notification views', () => { expect(results(paginatedAll)).toEqual(results([full.data])) }) - it('fails open on clearly bad cursor.', async () => { - const { data: notifs } = - await agent.api.app.bsky.notification.listNotifications( - { cursor: '90210::bafycid' }, - { - headers: await network.serviceHeaders( - alice, - ids.AppBskyNotificationListNotifications, - ), + describe('notifications delay', () => { + const notificationsDelayMs = 5_000 + + let delayNetwork: TestNetwork + let delayAgent: AtpAgent + let delaySc: SeedClient + let delayAlice: string + + beforeAll(async () => { + delayNetwork = await TestNetwork.create({ + bsky: { + notificationsDelayMs, }, + dbPostgresSchema: 'bsky_views_notifications_delay', + }) + delayAgent = delayNetwork.bsky.getClient() + delaySc = delayNetwork.getSeedClient() + await basicSeed(delaySc) + await delayNetwork.processAll() + delayAlice = delaySc.dids.alice + + // Add to reply chain, post ancestors: alice -> bob -> alice -> carol. + // Should have added one notification for each of alice and bob. + await delaySc.reply( + delaySc.dids.carol, + delaySc.posts[delayAlice][1].ref, + delaySc.replies[delayAlice][0].ref, + 'indeed', + ) + await delayNetwork.processAll() + + // @NOTE: Use fake timers after inserting seed data, + // to avoid inserting all notifications with the same timestamp. + jest.useFakeTimers({ + doNotFake: [ + 'nextTick', + 'performance', + 'setImmediate', + 'setInterval', + 'setTimeout', + ], + }) + }) + + afterAll(async () => { + jest.useRealTimers() + await delayNetwork.close() + }) + + it('paginates', async () => { + const firstNotification = await delayNetwork.bsky.db.db + .selectFrom('notification') + .selectAll() + .limit(1) + .orderBy('sortAt', 'asc') + .executeTakeFirstOrThrow() + // Sets the system time to when the first notification happened. + // At this point we won't have any notifications that already crossed the delay threshold. + jest.setSystemTime(new Date(firstNotification.sortAt)) + + const results = (results) => + sort(results.flatMap((res) => res.notifications)) + const paginator = async (cursor?: string) => { + const res = + await delayAgent.api.app.bsky.notification.listNotifications( + { cursor, limit: 6 }, + { + headers: await delayNetwork.serviceHeaders( + delayAlice, + ids.AppBskyNotificationListNotifications, + ), + }, + ) + return res.data + } + + const paginatedAllBeforeDelay = await paginateAll(paginator) + paginatedAllBeforeDelay.forEach((res) => + expect(res.notifications.length).toBe(0), + ) + const fullBeforeDelay = + await delayAgent.api.app.bsky.notification.listNotifications( + {}, + { + headers: await delayNetwork.serviceHeaders( + delayAlice, + ids.AppBskyNotificationListNotifications, + ), + }, + ) + + expect(fullBeforeDelay.data.notifications.length).toEqual(0) + expect(results(paginatedAllBeforeDelay)).toEqual( + results([fullBeforeDelay.data]), + ) + + const lastNotification = await delayNetwork.bsky.db.db + .selectFrom('notification') + .selectAll() + .limit(1) + .orderBy('sortAt', 'desc') + .executeTakeFirstOrThrow() + // Sets the system time to when the last notification happened and the delay has elapsed. + // At this point we all notifications already crossed the delay threshold. + jest.setSystemTime( + new Date( + new Date(lastNotification.sortAt).getTime() + + notificationsDelayMs + + 1, + ), ) - expect(notifs).toMatchObject({ notifications: [] }) + + const paginatedAllAfterDelay = await paginateAll(paginator) + paginatedAllAfterDelay.forEach((res) => + expect(res.notifications.length).toBeLessThanOrEqual(6), + ) + const fullAfterDelay = + await delayAgent.api.app.bsky.notification.listNotifications( + {}, + { + headers: await delayNetwork.serviceHeaders( + delayAlice, + ids.AppBskyNotificationListNotifications, + ), + }, + ) + + expect(fullAfterDelay.data.notifications.length).toEqual(13) + expect(results(paginatedAllAfterDelay)).toEqual( + results([fullAfterDelay.data]), + ) + }) + + describe('cursor delay', () => { + const delay0s = 0 + const delay5s = 5_000 + + const now = '2021-01-01T01:00:00.000Z' + const nowMinus2s = '2021-01-01T00:59:58.000Z' + const nowMinus5s = '2021-01-01T00:59:55.000Z' + const nowMinus8s = '2021-01-01T00:59:52.000Z' + + beforeAll(async () => { + jest.useFakeTimers({ doNotFake: ['performance'] }) + jest.setSystemTime(new Date(now)) + }) + + afterAll(async () => { + jest.useRealTimers() + }) + + describe('for undefined cursor', () => { + it('returns now minus delay', async () => { + const delayedCursor = delayCursor(undefined, delay5s) + expect(delayedCursor).toBe(nowMinus5s) + }) + + it('returns now if delay is 0', async () => { + const delayedCursor = delayCursor(undefined, delay0s) + expect(delayedCursor).toBe(now) + }) + }) + + describe('for defined cursor', () => { + it('returns original cursor if delay is 0', async () => { + const originalCursor = nowMinus2s + const delayedCursor = delayCursor(originalCursor, delay0s) + expect(delayedCursor).toBe(originalCursor) + }) + + it('returns "now minus delay" for cursor that is after that', async () => { + // Cursor is "now - 2s", should become "now - 5s" + const originalCursor = nowMinus2s + const cursor = delayCursor(originalCursor, delay5s) + expect(cursor).toBe(nowMinus5s) + }) + + it('returns original cursor for cursor that is before "now minus delay"', async () => { + // Cursor is "now - 8s", should stay like that. + const originalCursor = nowMinus8s + const cursor = delayCursor(originalCursor, delay5s) + expect(cursor).toBe(originalCursor) + }) + + it('passes through a non-date cursor', async () => { + const originalCursor = '123_abc' + const cursor = delayCursor(originalCursor, delay5s) + expect(cursor).toBe(originalCursor) + }) + }) + }) }) })