Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions integration/websockets/e2e/ws-error-gateway.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import { INestApplication } from '@nestjs/common';
import { WsAdapter } from '@nestjs/platform-ws';
import { Test } from '@nestjs/testing';
import WebSocket from 'ws';
import { WsErrorGateway } from '../src/ws-error.gateway.js';

async function createNestApp(...gateways: any[]): Promise<INestApplication> {
const testingModule = await Test.createTestingModule({
providers: gateways,
}).compile();
const app = testingModule.createNestApplication();
app.useWebSocketAdapter(new WsAdapter(app) as any);
return app;
}

describe('WebSocketGateway (WsAdapter) - Error Handling', () => {
let ws: WebSocket, app: INestApplication;

it('should send WsException error to client via native WebSocket', async () => {
app = await createNestApp(WsErrorGateway);
await app.listen(3000);

ws = new WebSocket('ws://localhost:8085');
await new Promise(resolve => ws.on('open', resolve));

ws.send(
JSON.stringify({
event: 'push',
data: {
test: 'test',
},
}),
);

await new Promise<void>(resolve =>
ws.on('message', data => {
const response = JSON.parse(data.toString());
expect(response).toEqual({
event: 'exception',
data: {
status: 'error',
message: 'test',
cause: {
pattern: 'push',
data: {
test: 'test',
},
},
},
});
ws.close();
resolve();
}),
);
});

afterEach(async function () {
await app.close();
});
});
14 changes: 14 additions & 0 deletions integration/websockets/src/ws-error.gateway.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import {
SubscribeMessage,
WebSocketGateway,
WsException,
} from '@nestjs/websockets';
import { throwError } from 'rxjs';

@WebSocketGateway(8085)
export class WsErrorGateway {
@SubscribeMessage('push')
onPush() {
return throwError(() => new WsException('test'));
}
}
56 changes: 46 additions & 10 deletions packages/websockets/exceptions/base-ws-exception-filter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import {
type WsExceptionFilter,
} from '@nestjs/common';
import { WsException } from '../errors/ws-exception.js';
import { isObject } from '@nestjs/common/internal';
import { isFunction, isNumber, isObject } from '@nestjs/common/internal';
import { MESSAGES } from '@nestjs/core/internal';

export interface ErrorPayload<Cause = { pattern: string; data: unknown }> {
Expand Down Expand Up @@ -64,7 +64,7 @@ export class BaseWsExceptionFilter<
});
}

public handleError<TClient extends { emit: Function }>(
public handleError<TClient extends { emit?: Function; send?: Function }>(
client: TClient,
exception: TError,
cause: ErrorPayload['cause'],
Expand All @@ -77,7 +77,7 @@ export class BaseWsExceptionFilter<
const result = exception.getError();

if (isObject(result)) {
return client.emit('exception', result);
return this.emitMessage(client, 'exception', result);
}

const payload: ErrorPayload<unknown> = {
Expand All @@ -89,14 +89,12 @@ export class BaseWsExceptionFilter<
payload.cause = this.options.causeFactory!(cause.pattern, cause.data);
}

client.emit('exception', payload);
this.emitMessage(client, 'exception', payload);
}

public handleUnknownError<TClient extends { emit: Function }>(
exception: TError,
client: TClient,
data: ErrorPayload['cause'],
) {
public handleUnknownError<
TClient extends { emit?: Function; send?: Function },
>(exception: TError, client: TClient, data: ErrorPayload['cause']) {
const status = 'error';
const payload: ErrorPayload<unknown> = {
status,
Expand All @@ -107,7 +105,7 @@ export class BaseWsExceptionFilter<
payload.cause = this.options.causeFactory!(data.pattern, data.data);
}

client.emit('exception', payload);
this.emitMessage(client, 'exception', payload);

if (!(exception instanceof IntrinsicException)) {
const logger = BaseWsExceptionFilter.logger;
Expand All @@ -118,4 +116,42 @@ export class BaseWsExceptionFilter<
public isExceptionObject(err: any): err is Error {
return isObject(err) && !!(err as Error).message;
}

/**
* Sends an error message to the client. Supports both Socket.IO clients
* (which use `emit`) and native WebSocket clients (which use `send`).
*
* Native WebSocket clients (e.g. from the `ws` package) inherit from
* EventEmitter and therefore also have an `emit` method, but that method
* only dispatches events locally. To distinguish native WebSocket clients
* from Socket.IO clients, we check for a numeric `readyState` property
* (part of the WebSocket specification) before falling back to `emit`.
*/
protected emitMessage<TClient extends { emit?: Function; send?: Function }>(
client: TClient,
event: string,
payload: unknown,
): void {
if (this.isNativeWebSocket(client)) {
client.send(
JSON.stringify({
event,
data: payload,
}),
);
} else if (isFunction(client.emit)) {
client.emit(event, payload);
}
}

/**
* Determines whether the given client is a native WebSocket (e.g. from the
* `ws` package) as opposed to a Socket.IO socket. Native WebSocket objects
* expose a numeric `readyState` property per the WebSocket specification.
*/
private isNativeWebSocket(
client: Record<string, any>,
): client is { send: Function; readyState: number } {
return isNumber(client.readyState) && isFunction(client.send);
}
}
2 changes: 1 addition & 1 deletion packages/websockets/exceptions/ws-exceptions-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export class WsExceptionsHandler extends BaseWsExceptionFilter {

public handle(exception: Error | WsException, host: ArgumentsHost) {
const client = host.switchToWs().getClient();
if (this.invokeCustomFilters(exception, host) || !client.emit) {
if (this.invokeCustomFilters(exception, host) || !client) {
return;
}
super.catch(exception, host);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ describe('WsExceptionsHandler', () => {
const message = 'Unauthorized';

handler.handle(new WsException(message), executionContextHost);
console.log(emitStub.mock.calls[0]);
expect(emitStub).toHaveBeenCalledWith('exception', {
message,
status: 'error',
Expand Down Expand Up @@ -98,6 +97,89 @@ describe('WsExceptionsHandler', () => {
});
});

describe('when client uses "send" instead of "emit" (native WebSocket)', () => {
let sendStub: ReturnType<typeof vi.fn>;
let wsClient: { send: ReturnType<typeof vi.fn>; readyState: number };
let wsExecutionContextHost: ExecutionContextHost;

beforeEach(() => {
handler = new WsExceptionsHandler();
sendStub = vi.fn();
wsClient = { send: sendStub, readyState: 1 };
wsExecutionContextHost = new ExecutionContextHost([
wsClient,
data,
pattern,
]);
});

it('should send JSON-stringified error via "send" when exception is unknown', () => {
handler.handle(new Error(), wsExecutionContextHost);
expect(sendStub).toHaveBeenCalledTimes(1);
const sent = JSON.parse(sendStub.mock.calls[0][0]);
expect(sent).toEqual({
event: 'exception',
data: {
status: 'error',
message: 'Internal server error',
cause: {
pattern,
data,
},
},
});
});

it('should send JSON-stringified error via "send" for WsException with object', () => {
const message = { custom: 'Unauthorized' };
handler.handle(new WsException(message), wsExecutionContextHost);
expect(sendStub).toHaveBeenCalledTimes(1);
const sent = JSON.parse(sendStub.mock.calls[0][0]);
expect(sent).toEqual({
event: 'exception',
data: message,
});
});

it('should send JSON-stringified error via "send" for WsException with string', () => {
const message = 'Unauthorized';
handler.handle(new WsException(message), wsExecutionContextHost);
expect(sendStub).toHaveBeenCalledTimes(1);
const sent = JSON.parse(sendStub.mock.calls[0][0]);
expect(sent).toEqual({
event: 'exception',
data: {
message,
status: 'error',
cause: {
pattern,
data,
},
},
});
});

describe('when "includeCause" is set to false', () => {
beforeEach(() => {
handler = new WsExceptionsHandler({ includeCause: false });
});

it('should send error without cause via "send"', () => {
const message = 'Unauthorized';
handler.handle(new WsException(message), wsExecutionContextHost);
expect(sendStub).toHaveBeenCalledTimes(1);
const sent = JSON.parse(sendStub.mock.calls[0][0]);
expect(sent).toEqual({
event: 'exception',
data: {
message,
status: 'error',
},
});
});
});
});

describe('when "invokeCustomFilters" returns true', () => {
beforeEach(() => {
vi.spyOn(handler, 'invokeCustomFilters').mockReturnValue(true);
Expand Down