diff --git a/integration/websockets/e2e/ws-error-gateway.spec.ts b/integration/websockets/e2e/ws-error-gateway.spec.ts new file mode 100644 index 00000000000..293af95c504 --- /dev/null +++ b/integration/websockets/e2e/ws-error-gateway.spec.ts @@ -0,0 +1,60 @@ +import { INestApplication } from '@nestjs/common'; +import { WsAdapter } from '@nestjs/platform-ws'; +import { Test } from '@nestjs/testing'; +import WebSocket from 'ws'; +import { WsErrorGateway } from '../src/ws-error.gateway.js'; + +async function createNestApp(...gateways: any[]): Promise { + const testingModule = await Test.createTestingModule({ + providers: gateways, + }).compile(); + const app = testingModule.createNestApplication(); + app.useWebSocketAdapter(new WsAdapter(app) as any); + return app; +} + +describe('WebSocketGateway (WsAdapter) - Error Handling', () => { + let ws: WebSocket, app: INestApplication; + + it('should send WsException error to client via native WebSocket', async () => { + app = await createNestApp(WsErrorGateway); + await app.listen(3000); + + ws = new WebSocket('ws://localhost:8085'); + await new Promise(resolve => ws.on('open', resolve)); + + ws.send( + JSON.stringify({ + event: 'push', + data: { + test: 'test', + }, + }), + ); + + await new Promise(resolve => + ws.on('message', data => { + const response = JSON.parse(data.toString()); + expect(response).toEqual({ + event: 'exception', + data: { + status: 'error', + message: 'test', + cause: { + pattern: 'push', + data: { + test: 'test', + }, + }, + }, + }); + ws.close(); + resolve(); + }), + ); + }); + + afterEach(async function () { + await app.close(); + }); +}); diff --git a/integration/websockets/src/ws-error.gateway.ts b/integration/websockets/src/ws-error.gateway.ts new file mode 100644 index 00000000000..f72c69f5b52 --- /dev/null +++ b/integration/websockets/src/ws-error.gateway.ts @@ -0,0 +1,14 @@ +import { + SubscribeMessage, + WebSocketGateway, + WsException, +} from '@nestjs/websockets'; +import { throwError } from 'rxjs'; + +@WebSocketGateway(8085) +export class WsErrorGateway { + @SubscribeMessage('push') + onPush() { + return throwError(() => new WsException('test')); + } +} diff --git a/packages/websockets/exceptions/base-ws-exception-filter.ts b/packages/websockets/exceptions/base-ws-exception-filter.ts index d006286fc63..cac1f78b61c 100644 --- a/packages/websockets/exceptions/base-ws-exception-filter.ts +++ b/packages/websockets/exceptions/base-ws-exception-filter.ts @@ -5,7 +5,7 @@ import { type WsExceptionFilter, } from '@nestjs/common'; import { WsException } from '../errors/ws-exception.js'; -import { isObject } from '@nestjs/common/internal'; +import { isFunction, isNumber, isObject } from '@nestjs/common/internal'; import { MESSAGES } from '@nestjs/core/internal'; export interface ErrorPayload { @@ -64,7 +64,7 @@ export class BaseWsExceptionFilter< }); } - public handleError( + public handleError( client: TClient, exception: TError, cause: ErrorPayload['cause'], @@ -77,7 +77,7 @@ export class BaseWsExceptionFilter< const result = exception.getError(); if (isObject(result)) { - return client.emit('exception', result); + return this.emitMessage(client, 'exception', result); } const payload: ErrorPayload = { @@ -89,14 +89,12 @@ export class BaseWsExceptionFilter< payload.cause = this.options.causeFactory!(cause.pattern, cause.data); } - client.emit('exception', payload); + this.emitMessage(client, 'exception', payload); } - public handleUnknownError( - exception: TError, - client: TClient, - data: ErrorPayload['cause'], - ) { + public handleUnknownError< + TClient extends { emit?: Function; send?: Function }, + >(exception: TError, client: TClient, data: ErrorPayload['cause']) { const status = 'error'; const payload: ErrorPayload = { status, @@ -107,7 +105,7 @@ export class BaseWsExceptionFilter< payload.cause = this.options.causeFactory!(data.pattern, data.data); } - client.emit('exception', payload); + this.emitMessage(client, 'exception', payload); if (!(exception instanceof IntrinsicException)) { const logger = BaseWsExceptionFilter.logger; @@ -118,4 +116,42 @@ export class BaseWsExceptionFilter< public isExceptionObject(err: any): err is Error { return isObject(err) && !!(err as Error).message; } + + /** + * Sends an error message to the client. Supports both Socket.IO clients + * (which use `emit`) and native WebSocket clients (which use `send`). + * + * Native WebSocket clients (e.g. from the `ws` package) inherit from + * EventEmitter and therefore also have an `emit` method, but that method + * only dispatches events locally. To distinguish native WebSocket clients + * from Socket.IO clients, we check for a numeric `readyState` property + * (part of the WebSocket specification) before falling back to `emit`. + */ + protected emitMessage( + client: TClient, + event: string, + payload: unknown, + ): void { + if (this.isNativeWebSocket(client)) { + client.send( + JSON.stringify({ + event, + data: payload, + }), + ); + } else if (isFunction(client.emit)) { + client.emit(event, payload); + } + } + + /** + * Determines whether the given client is a native WebSocket (e.g. from the + * `ws` package) as opposed to a Socket.IO socket. Native WebSocket objects + * expose a numeric `readyState` property per the WebSocket specification. + */ + private isNativeWebSocket( + client: Record, + ): client is { send: Function; readyState: number } { + return isNumber(client.readyState) && isFunction(client.send); + } } diff --git a/packages/websockets/exceptions/ws-exceptions-handler.ts b/packages/websockets/exceptions/ws-exceptions-handler.ts index 4d71bdfe586..e6e4acab543 100644 --- a/packages/websockets/exceptions/ws-exceptions-handler.ts +++ b/packages/websockets/exceptions/ws-exceptions-handler.ts @@ -16,7 +16,7 @@ export class WsExceptionsHandler extends BaseWsExceptionFilter { public handle(exception: Error | WsException, host: ArgumentsHost) { const client = host.switchToWs().getClient(); - if (this.invokeCustomFilters(exception, host) || !client.emit) { + if (this.invokeCustomFilters(exception, host) || !client) { return; } super.catch(exception, host); diff --git a/packages/websockets/test/exceptions/ws-exceptions-handler.spec.ts b/packages/websockets/test/exceptions/ws-exceptions-handler.spec.ts index 3df991bafaa..48909d796fb 100644 --- a/packages/websockets/test/exceptions/ws-exceptions-handler.spec.ts +++ b/packages/websockets/test/exceptions/ws-exceptions-handler.spec.ts @@ -50,7 +50,6 @@ describe('WsExceptionsHandler', () => { const message = 'Unauthorized'; handler.handle(new WsException(message), executionContextHost); - console.log(emitStub.mock.calls[0]); expect(emitStub).toHaveBeenCalledWith('exception', { message, status: 'error', @@ -98,6 +97,89 @@ describe('WsExceptionsHandler', () => { }); }); + describe('when client uses "send" instead of "emit" (native WebSocket)', () => { + let sendStub: ReturnType; + let wsClient: { send: ReturnType; readyState: number }; + let wsExecutionContextHost: ExecutionContextHost; + + beforeEach(() => { + handler = new WsExceptionsHandler(); + sendStub = vi.fn(); + wsClient = { send: sendStub, readyState: 1 }; + wsExecutionContextHost = new ExecutionContextHost([ + wsClient, + data, + pattern, + ]); + }); + + it('should send JSON-stringified error via "send" when exception is unknown', () => { + handler.handle(new Error(), wsExecutionContextHost); + expect(sendStub).toHaveBeenCalledTimes(1); + const sent = JSON.parse(sendStub.mock.calls[0][0]); + expect(sent).toEqual({ + event: 'exception', + data: { + status: 'error', + message: 'Internal server error', + cause: { + pattern, + data, + }, + }, + }); + }); + + it('should send JSON-stringified error via "send" for WsException with object', () => { + const message = { custom: 'Unauthorized' }; + handler.handle(new WsException(message), wsExecutionContextHost); + expect(sendStub).toHaveBeenCalledTimes(1); + const sent = JSON.parse(sendStub.mock.calls[0][0]); + expect(sent).toEqual({ + event: 'exception', + data: message, + }); + }); + + it('should send JSON-stringified error via "send" for WsException with string', () => { + const message = 'Unauthorized'; + handler.handle(new WsException(message), wsExecutionContextHost); + expect(sendStub).toHaveBeenCalledTimes(1); + const sent = JSON.parse(sendStub.mock.calls[0][0]); + expect(sent).toEqual({ + event: 'exception', + data: { + message, + status: 'error', + cause: { + pattern, + data, + }, + }, + }); + }); + + describe('when "includeCause" is set to false', () => { + beforeEach(() => { + handler = new WsExceptionsHandler({ includeCause: false }); + }); + + it('should send error without cause via "send"', () => { + const message = 'Unauthorized'; + handler.handle(new WsException(message), wsExecutionContextHost); + expect(sendStub).toHaveBeenCalledTimes(1); + const sent = JSON.parse(sendStub.mock.calls[0][0]); + expect(sent).toEqual({ + event: 'exception', + data: { + message, + status: 'error', + }, + }); + }); + }); + }); + describe('when "invokeCustomFilters" returns true', () => { beforeEach(() => { vi.spyOn(handler, 'invokeCustomFilters').mockReturnValue(true);