feat(ws): add ack decorator

This commit is contained in:
Seongjee Kim
2025-08-17 15:27:07 +09:00
parent 4d9c188016
commit 6559c97d12
7 changed files with 56 additions and 11 deletions

View File

@@ -6,6 +6,7 @@ import { Observable } from 'rxjs';
export interface WsMessageHandler<T = string> {
message: T;
callback: (...args: any[]) => Observable<any> | Promise<any>;
isAckHandledManually: boolean;
}
/**

View File

@@ -44,22 +44,24 @@ export class IoAdapter extends AbstractWsAdapter {
first(),
);
handlers.forEach(({ message, callback }) => {
handlers.forEach(({ message, callback, isAckHandledManually }) => {
const source$ = fromEvent(socket, message).pipe(
mergeMap((payload: any) => {
const { data, ack } = this.mapPayload(payload);
return transform(callback(data, ack)).pipe(
filter((response: any) => !isNil(response)),
map((response: any) => [response, ack]),
map((response: any) => [response, ack, isAckHandledManually]),
);
}),
takeUntil(disconnect$),
);
source$.subscribe(([response, ack]) => {
source$.subscribe(([response, ack, isAckHandledManually]) => {
if (response.event) {
return socket.emit(response.event, response.data);
}
isFunction(ack) && ack(response);
if (!isAckHandledManually && isFunction(ack)) {
ack(response);
}
});
});
}

View File

@@ -0,0 +1,28 @@
import { WsParamtype } from '../enums/ws-paramtype.enum';
import { createPipesWsParamDecorator } from '../utils/param.utils';
/**
* WebSockets `ack` parameter decorator.
* Extracts the `ack` callback function from the arguments of a ws event.
*
* This decorator signals to the framework that the `ack` callback will be
* handled manually within the method, preventing the framework from
* automatically sending an acknowledgement based on the return value.
*
* @example
* ```typescript
* @SubscribeMessage('events')
* onEvent(
* @MessageBody() data: string,
* @Ack() ack: (response: any) => void
* ) {
* // Manually call the ack callback
* ack({ status: 'ok' });
* }
* ```
*
* @publicApi
*/
export function Ack(): ParameterDecorator {
return createPipesWsParamDecorator(WsParamtype.ACK)();
}

View File

@@ -3,3 +3,4 @@ export * from './gateway-server.decorator';
export * from './message-body.decorator';
export * from './socket-gateway.decorator';
export * from './subscribe-message.decorator';
export * from './ack.decorator';

View File

@@ -1,3 +1,4 @@
import { isFunction } from '@nestjs/common/utils/shared.utils';
import { WsParamtype } from '../enums/ws-paramtype.enum';
export class WsParamsFactory {
@@ -14,6 +15,9 @@ export class WsParamsFactory {
return args[0];
case WsParamtype.PAYLOAD:
return data ? args[1]?.[data] : args[1];
case WsParamtype.ACK: {
return args.find(arg => isFunction(arg));
}
default:
return null;
}

View File

@@ -10,6 +10,7 @@ import {
import { NestGateway } from './interfaces/nest-gateway.interface';
import { ParamsMetadata } from '@nestjs/core/helpers/interfaces';
import { WsParamtype } from './enums/ws-paramtype.enum';
import { ContextUtils } from '@nestjs/core/helpers/context-utils';
export interface MessageMappingProperties {
message: any;
@@ -19,6 +20,7 @@ export interface MessageMappingProperties {
}
export class GatewayMetadataExplorer {
private readonly contextUtils = new ContextUtils();
constructor(private readonly metadataScanner: MetadataScanner) {}
public explore(instance: NestGateway): MessageMappingProperties[] {
@@ -68,9 +70,12 @@ export class GatewayMetadataExplorer {
if (!paramsMetadata) {
return false;
}
const metadataKeys = Object.keys(paramsMetadata);
return metadataKeys.some(key => {
const type = this.contextUtils.mapParamType(key);
const params = Object.values(paramsMetadata);
return params.some((param: any) => param.type === WsParamtype.ACK);
return (Number(type) as WsParamtype) === WsParamtype.ACK;
});
}
public *scanForServerHooks(instance: NestGateway): IterableIterator<string> {

View File

@@ -72,7 +72,7 @@ export class WebSocketsController {
) {
const nativeMessageHandlers = this.metadataExplorer.explore(instance);
const messageHandlers = nativeMessageHandlers.map(
({ callback, message, methodName }) => ({
({ callback, isAckHandledManually, message, methodName }) => ({
message,
methodName,
callback: this.contextCreator.create(
@@ -81,6 +81,7 @@ export class WebSocketsController {
moduleKey,
methodName,
),
isAckHandledManually,
}),
);
@@ -174,10 +175,13 @@ export class WebSocketsController {
instance: NestGateway,
) {
const adapter = this.config.getIoAdapter();
const handlers = subscribersMap.map(({ callback, message }) => ({
message,
callback: callback.bind(instance, client),
}));
const handlers = subscribersMap.map(
({ callback, message, isAckHandledManually }) => ({
message,
callback: callback.bind(instance, client),
isAckHandledManually,
}),
);
adapter.bindMessageHandlers(client, handlers, data =>
fromPromise(this.pickResult(data)).pipe(mergeAll()),
);