diff --git a/packages/common/interfaces/websockets/web-socket-adapter.interface.ts b/packages/common/interfaces/websockets/web-socket-adapter.interface.ts index e3baa1272..fe08dbe77 100644 --- a/packages/common/interfaces/websockets/web-socket-adapter.interface.ts +++ b/packages/common/interfaces/websockets/web-socket-adapter.interface.ts @@ -6,6 +6,7 @@ import { Observable } from 'rxjs'; export interface WsMessageHandler { message: T; callback: (...args: any[]) => Observable | Promise; + isAckHandledManually: boolean; } /** diff --git a/packages/platform-socket.io/adapters/io-adapter.ts b/packages/platform-socket.io/adapters/io-adapter.ts index ab08fde13..109041daf 100644 --- a/packages/platform-socket.io/adapters/io-adapter.ts +++ b/packages/platform-socket.io/adapters/io-adapter.ts @@ -44,22 +44,24 @@ export class IoAdapter extends AbstractWsAdapter { first(), ); - handlers.forEach(({ message, callback }) => { + handlers.forEach(({ message, callback, isAckHandledManually }) => { const source$ = fromEvent(socket, message).pipe( mergeMap((payload: any) => { const { data, ack } = this.mapPayload(payload); return transform(callback(data, ack)).pipe( filter((response: any) => !isNil(response)), - map((response: any) => [response, ack]), + map((response: any) => [response, ack, isAckHandledManually]), ); }), takeUntil(disconnect$), ); - source$.subscribe(([response, ack]) => { + source$.subscribe(([response, ack, isAckHandledManually]) => { if (response.event) { return socket.emit(response.event, response.data); } - isFunction(ack) && ack(response); + if (!isAckHandledManually && isFunction(ack)) { + ack(response); + } }); }); } diff --git a/packages/websockets/decorators/ack.decorator.ts b/packages/websockets/decorators/ack.decorator.ts new file mode 100644 index 000000000..728bc0964 --- /dev/null +++ b/packages/websockets/decorators/ack.decorator.ts @@ -0,0 +1,28 @@ +import { WsParamtype } from '../enums/ws-paramtype.enum'; +import { createPipesWsParamDecorator } from '../utils/param.utils'; + +/** + * WebSockets `ack` parameter decorator. + * Extracts the `ack` callback function from the arguments of a ws event. + * + * This decorator signals to the framework that the `ack` callback will be + * handled manually within the method, preventing the framework from + * automatically sending an acknowledgement based on the return value. + * + * @example + * ```typescript + * @SubscribeMessage('events') + * onEvent( + * @MessageBody() data: string, + * @Ack() ack: (response: any) => void + * ) { + * // Manually call the ack callback + * ack({ status: 'ok' }); + * } + * ``` + * + * @publicApi + */ +export function Ack(): ParameterDecorator { + return createPipesWsParamDecorator(WsParamtype.ACK)(); +} diff --git a/packages/websockets/decorators/index.ts b/packages/websockets/decorators/index.ts index 51d5d6da3..ee6f88bcd 100644 --- a/packages/websockets/decorators/index.ts +++ b/packages/websockets/decorators/index.ts @@ -3,3 +3,4 @@ export * from './gateway-server.decorator'; export * from './message-body.decorator'; export * from './socket-gateway.decorator'; export * from './subscribe-message.decorator'; +export * from './ack.decorator'; diff --git a/packages/websockets/factories/ws-params-factory.ts b/packages/websockets/factories/ws-params-factory.ts index 649226d82..2ef6fefa0 100644 --- a/packages/websockets/factories/ws-params-factory.ts +++ b/packages/websockets/factories/ws-params-factory.ts @@ -1,3 +1,4 @@ +import { isFunction } from '@nestjs/common/utils/shared.utils'; import { WsParamtype } from '../enums/ws-paramtype.enum'; export class WsParamsFactory { @@ -14,6 +15,9 @@ export class WsParamsFactory { return args[0]; case WsParamtype.PAYLOAD: return data ? args[1]?.[data] : args[1]; + case WsParamtype.ACK: { + return args.find(arg => isFunction(arg)); + } default: return null; } diff --git a/packages/websockets/gateway-metadata-explorer.ts b/packages/websockets/gateway-metadata-explorer.ts index a81e45780..66f44cc99 100644 --- a/packages/websockets/gateway-metadata-explorer.ts +++ b/packages/websockets/gateway-metadata-explorer.ts @@ -10,6 +10,7 @@ import { import { NestGateway } from './interfaces/nest-gateway.interface'; import { ParamsMetadata } from '@nestjs/core/helpers/interfaces'; import { WsParamtype } from './enums/ws-paramtype.enum'; +import { ContextUtils } from '@nestjs/core/helpers/context-utils'; export interface MessageMappingProperties { message: any; @@ -19,6 +20,7 @@ export interface MessageMappingProperties { } export class GatewayMetadataExplorer { + private readonly contextUtils = new ContextUtils(); constructor(private readonly metadataScanner: MetadataScanner) {} public explore(instance: NestGateway): MessageMappingProperties[] { @@ -68,9 +70,12 @@ export class GatewayMetadataExplorer { if (!paramsMetadata) { return false; } + const metadataKeys = Object.keys(paramsMetadata); + return metadataKeys.some(key => { + const type = this.contextUtils.mapParamType(key); - const params = Object.values(paramsMetadata); - return params.some((param: any) => param.type === WsParamtype.ACK); + return (Number(type) as WsParamtype) === WsParamtype.ACK; + }); } public *scanForServerHooks(instance: NestGateway): IterableIterator { diff --git a/packages/websockets/web-sockets-controller.ts b/packages/websockets/web-sockets-controller.ts index bcacf245f..37e21efef 100644 --- a/packages/websockets/web-sockets-controller.ts +++ b/packages/websockets/web-sockets-controller.ts @@ -72,7 +72,7 @@ export class WebSocketsController { ) { const nativeMessageHandlers = this.metadataExplorer.explore(instance); const messageHandlers = nativeMessageHandlers.map( - ({ callback, message, methodName }) => ({ + ({ callback, isAckHandledManually, message, methodName }) => ({ message, methodName, callback: this.contextCreator.create( @@ -81,6 +81,7 @@ export class WebSocketsController { moduleKey, methodName, ), + isAckHandledManually, }), ); @@ -174,10 +175,13 @@ export class WebSocketsController { instance: NestGateway, ) { const adapter = this.config.getIoAdapter(); - const handlers = subscribersMap.map(({ callback, message }) => ({ - message, - callback: callback.bind(instance, client), - })); + const handlers = subscribersMap.map( + ({ callback, message, isAckHandledManually }) => ({ + message, + callback: callback.bind(instance, client), + isAckHandledManually, + }), + ); adapter.bindMessageHandlers(client, handlers, data => fromPromise(this.pickResult(data)).pipe(mergeAll()), );