Merge pull request #15543 from kim-sung-jee/feature/websockets-manual-ack

feat(websockets): allow manual acknowledgement handling with @Ack() decorator
This commit is contained in:
Kamil Mysliwiec
2025-10-21 10:32:04 +02:00
committed by GitHub
15 changed files with 218 additions and 13 deletions

View File

@@ -41,5 +41,38 @@ describe('WebSocketGateway (ack)', () => {
);
});
it('should handle manual ack for async operations when @Ack() is used (success case)', async () => {
app = await createNestApp(AckGateway);
await app.listen(3000);
ws = io('http://localhost:8080');
const payload = { shouldSucceed: true };
await new Promise<void>(resolve =>
ws.emit('manual-ack', payload, response => {
expect(response).to.eql({ status: 'success', data: payload });
resolve();
}),
);
});
it('should handle manual ack for async operations when @Ack() is used (error case)', async () => {
app = await createNestApp(AckGateway);
await app.listen(3000);
ws = io('http://localhost:8080');
const payload = { shouldSucceed: false };
await new Promise<void>(resolve =>
ws.emit('manual-ack', payload, response => {
expect(response).to.eql({
status: 'error',
message: 'Operation failed',
});
resolve();
}),
);
});
afterEach(() => app.close());
});

View File

@@ -1,4 +1,9 @@
import { SubscribeMessage, WebSocketGateway } from '@nestjs/websockets';
import {
Ack,
MessageBody,
SubscribeMessage,
WebSocketGateway,
} from '@nestjs/websockets';
@WebSocketGateway(8080)
export class AckGateway {
@@ -6,4 +11,19 @@ export class AckGateway {
onPush() {
return 'pong';
}
@SubscribeMessage('manual-ack')
async handleManualAck(
@MessageBody() data: any,
@Ack() ack: (response: any) => void,
) {
await new Promise(resolve => setTimeout(resolve, 20));
if (data.shouldSucceed) {
ack({ status: 'success', data });
} else {
ack({ status: 'error', message: 'Operation failed' });
}
return { status: 'ignored' };
}
}

View File

@@ -12,4 +12,5 @@ export enum RouteParamtypes {
HOST = 10,
IP = 11,
RAW_BODY = 12,
ACK = 13,
}

View File

@@ -6,6 +6,7 @@ import { Observable } from 'rxjs';
export interface WsMessageHandler<T = string> {
message: T;
callback: (...args: any[]) => Observable<any> | Promise<any>;
isAckHandledManually: boolean;
}
/**

View File

@@ -44,22 +44,24 @@ export class IoAdapter extends AbstractWsAdapter {
first(),
);
handlers.forEach(({ message, callback }) => {
handlers.forEach(({ message, callback, isAckHandledManually }) => {
const source$ = fromEvent(socket, message).pipe(
mergeMap((payload: any) => {
const { data, ack } = this.mapPayload(payload);
return transform(callback(data, ack)).pipe(
filter((response: any) => !isNil(response)),
map((response: any) => [response, ack]),
map((response: any) => [response, ack, isAckHandledManually]),
);
}),
takeUntil(disconnect$),
);
source$.subscribe(([response, ack]) => {
source$.subscribe(([response, ack, isAckHandledManually]) => {
if (response.event) {
return socket.emit(response.event, response.data);
}
isFunction(ack) && ack(response);
if (!isAckHandledManually && isFunction(ack)) {
ack(response);
}
});
});
}

View File

@@ -1,6 +1,7 @@
import { WsParamtype } from '../enums/ws-paramtype.enum';
export const DEFAULT_CALLBACK_METADATA = {
[`${WsParamtype.ACK}:2`]: { index: 2, data: undefined, pipes: [] },
[`${WsParamtype.PAYLOAD}:1`]: { index: 1, data: undefined, pipes: [] },
[`${WsParamtype.SOCKET}:0`]: { index: 0, data: undefined, pipes: [] },
};

View File

@@ -0,0 +1,28 @@
import { WsParamtype } from '../enums/ws-paramtype.enum';
import { createPipesWsParamDecorator } from '../utils/param.utils';
/**
* WebSockets `ack` parameter decorator.
* Extracts the `ack` callback function from the arguments of a ws event.
*
* This decorator signals to the framework that the `ack` callback will be
* handled manually within the method, preventing the framework from
* automatically sending an acknowledgement based on the return value.
*
* @example
* ```typescript
* @SubscribeMessage('events')
* onEvent(
* @MessageBody() data: string,
* @Ack() ack: (response: any) => void
* ) {
* // Manually call the ack callback
* ack({ status: 'ok' });
* }
* ```
*
* @publicApi
*/
export function Ack(): ParameterDecorator {
return createPipesWsParamDecorator(WsParamtype.ACK)();
}

View File

@@ -3,3 +3,4 @@ export * from './gateway-server.decorator';
export * from './message-body.decorator';
export * from './socket-gateway.decorator';
export * from './subscribe-message.decorator';
export * from './ack.decorator';

View File

@@ -3,4 +3,5 @@ import { RouteParamtypes } from '@nestjs/common/enums/route-paramtypes.enum';
export enum WsParamtype {
SOCKET = RouteParamtypes.REQUEST,
PAYLOAD = RouteParamtypes.BODY,
ACK = RouteParamtypes.ACK,
}

View File

@@ -1,3 +1,4 @@
import { isFunction } from '@nestjs/common/utils/shared.utils';
import { WsParamtype } from '../enums/ws-paramtype.enum';
export class WsParamsFactory {
@@ -14,6 +15,9 @@ export class WsParamsFactory {
return args[0];
case WsParamtype.PAYLOAD:
return data ? args[1]?.[data] : args[1];
case WsParamtype.ACK: {
return args.find(arg => isFunction(arg));
}
default:
return null;
}

View File

@@ -5,16 +5,22 @@ import {
GATEWAY_SERVER_METADATA,
MESSAGE_MAPPING_METADATA,
MESSAGE_METADATA,
PARAM_ARGS_METADATA,
} from './constants';
import { NestGateway } from './interfaces/nest-gateway.interface';
import { ParamsMetadata } from '@nestjs/core/helpers/interfaces';
import { WsParamtype } from './enums/ws-paramtype.enum';
import { ContextUtils } from '@nestjs/core/helpers/context-utils';
export interface MessageMappingProperties {
message: any;
methodName: string;
callback: (...args: any[]) => Observable<any> | Promise<any>;
isAckHandledManually: boolean;
}
export class GatewayMetadataExplorer {
private readonly contextUtils = new ContextUtils();
constructor(private readonly metadataScanner: MetadataScanner) {}
public explore(instance: NestGateway): MessageMappingProperties[] {
@@ -38,13 +44,40 @@ export class GatewayMetadataExplorer {
return null;
}
const message = Reflect.getMetadata(MESSAGE_METADATA, callback);
const isAckHandledManually = this.hasAckDecorator(
instancePrototype,
methodName,
);
return {
callback,
message,
methodName,
isAckHandledManually,
};
}
private hasAckDecorator(
instancePrototype: object,
methodName: string,
): boolean {
const paramsMetadata: ParamsMetadata = Reflect.getMetadata(
PARAM_ARGS_METADATA,
instancePrototype.constructor,
methodName,
);
if (!paramsMetadata) {
return false;
}
const metadataKeys = Object.keys(paramsMetadata);
return metadataKeys.some(key => {
const type = this.contextUtils.mapParamType(key);
return (Number(type) as WsParamtype) === WsParamtype.ACK;
});
}
public *scanForServerHooks(instance: NestGateway): IterableIterator<string> {
for (const propertyKey in instance) {
if (isFunction(propertyKey)) {

View File

@@ -0,0 +1,28 @@
import 'reflect-metadata';
import { expect } from 'chai';
import { PARAM_ARGS_METADATA } from '../../constants';
import { Ack } from '../../decorators/ack.decorator';
import { WsParamtype } from '../../enums/ws-paramtype.enum';
class AckTest {
public test(@Ack() ack: Function) {}
}
describe('@Ack', () => {
it('should enhance class with expected request metadata', () => {
const argsMetadata = Reflect.getMetadata(
PARAM_ARGS_METADATA,
AckTest,
'test',
);
const expectedMetadata = {
[`${WsParamtype.ACK}:0`]: {
index: 0,
data: undefined,
pipes: [],
},
};
expect(argsMetadata).to.be.eql(expectedMetadata);
});
});

View File

@@ -4,11 +4,13 @@ import { MetadataScanner } from '../../core/metadata-scanner';
import { WebSocketServer } from '../decorators/gateway-server.decorator';
import { WebSocketGateway } from '../decorators/socket-gateway.decorator';
import { SubscribeMessage } from '../decorators/subscribe-message.decorator';
import { Ack } from '../decorators/ack.decorator';
import { GatewayMetadataExplorer } from '../gateway-metadata-explorer';
describe('GatewayMetadataExplorer', () => {
const message = 'test';
const secMessage = 'test2';
const ackMessage = 'ack-test';
@WebSocketGateway()
class Test {
@@ -28,6 +30,9 @@ describe('GatewayMetadataExplorer', () => {
@SubscribeMessage(secMessage)
public testSec() {}
@SubscribeMessage(ackMessage)
public testWithAck(@Ack() ack: Function) {}
public noMessage() {}
}
let instance: GatewayMetadataExplorer;
@@ -61,9 +66,22 @@ describe('GatewayMetadataExplorer', () => {
});
it(`should return message mapping properties when "isMessageMapping" metadata is not undefined`, () => {
const metadata = instance.exploreMethodMetadata(test, 'test')!;
expect(metadata).to.have.keys(['callback', 'message', 'methodName']);
expect(metadata).to.have.keys([
'callback',
'message',
'methodName',
'isAckHandledManually',
]);
expect(metadata.message).to.eql(message);
});
it('should set "isAckHandledManually" property to true when @Ack decorator is used', () => {
const metadata = instance.exploreMethodMetadata(test, 'testWithAck')!;
expect(metadata.isAckHandledManually).to.be.true;
});
it('should set "isAckHandledManually" property to false when @Ack decorator is not used', () => {
const metadata = instance.exploreMethodMetadata(test, 'test')!;
expect(metadata.isAckHandledManually).to.be.false;
});
});
describe('scanForServerHooks', () => {
it(`should return properties with @Client decorator`, () => {

View File

@@ -135,6 +135,7 @@ describe('WebSocketsController', () => {
message: 'message',
methodName: 'methodName',
callback: handlerCallback,
isAckHandledManually: false,
},
];
server = { server: 'test' };
@@ -173,6 +174,7 @@ describe('WebSocketsController', () => {
message: 'message',
methodName: 'methodName',
callback: messageHandlerCallback,
isAckHandledManually: false,
},
]);
});
@@ -188,11 +190,13 @@ describe('WebSocketsController', () => {
methodName: 'findOne',
message: 'find',
callback: null!,
isAckHandledManually: false,
},
{
methodName: 'create',
message: 'insert',
callback: null!,
isAckHandledManually: false,
},
];
const insertEntrypointDefinitionSpy = sinon.spy(
@@ -423,14 +427,40 @@ describe('WebSocketsController', () => {
client = { on: onSpy, off: onSpy };
handlers = [
{ message: 'test', callback: { bind: () => 'testCallback' } },
{ message: 'test2', callback: { bind: () => 'testCallback2' } },
{
message: 'test',
callback: { bind: () => 'testCallback' },
isAckHandledManually: true,
},
{
message: 'test2',
callback: { bind: () => 'testCallback2' },
isAckHandledManually: false,
},
];
});
it('should bind each handler to client', () => {
instance.subscribeMessages(handlers, client, gateway);
expect(onSpy.calledTwice).to.be.true;
});
it('should pass "isAckHandledManually" flag to the adapter', () => {
const adapter = config.getIoAdapter();
const bindMessageHandlersSpy = sinon.spy(adapter, 'bindMessageHandlers');
instance.subscribeMessages(handlers, client, gateway);
const handlersPassedToAdapter = bindMessageHandlersSpy.firstCall.args[1];
expect(handlersPassedToAdapter[0].message).to.equal(handlers[0].message);
expect(handlersPassedToAdapter[0].isAckHandledManually).to.equal(
handlers[0].isAckHandledManually,
);
expect(handlersPassedToAdapter[1].message).to.equal(handlers[1].message);
expect(handlersPassedToAdapter[1].isAckHandledManually).to.equal(
handlers[1].isAckHandledManually,
);
});
});
describe('pickResult', () => {
describe('when deferredResult contains value which', () => {

View File

@@ -72,7 +72,7 @@ export class WebSocketsController {
) {
const nativeMessageHandlers = this.metadataExplorer.explore(instance);
const messageHandlers = nativeMessageHandlers.map(
({ callback, message, methodName }) => ({
({ callback, isAckHandledManually, message, methodName }) => ({
message,
methodName,
callback: this.contextCreator.create(
@@ -81,6 +81,7 @@ export class WebSocketsController {
moduleKey,
methodName,
),
isAckHandledManually,
}),
);
@@ -174,10 +175,13 @@ export class WebSocketsController {
instance: NestGateway,
) {
const adapter = this.config.getIoAdapter();
const handlers = subscribersMap.map(({ callback, message }) => ({
message,
callback: callback.bind(instance, client),
}));
const handlers = subscribersMap.map(
({ callback, message, isAckHandledManually }) => ({
message,
callback: callback.bind(instance, client),
isAckHandledManually,
}),
);
adapter.bindMessageHandlers(client, handlers, data =>
fromPromise(this.pickResult(data)).pipe(mergeAll()),
);