mirror of
https://github.com/nestjs/nest.git
synced 2026-02-21 23:11:44 +00:00
Merge pull request #15543 from kim-sung-jee/feature/websockets-manual-ack
feat(websockets): allow manual acknowledgement handling with @Ack() decorator
This commit is contained in:
@@ -41,5 +41,38 @@ describe('WebSocketGateway (ack)', () => {
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle manual ack for async operations when @Ack() is used (success case)', async () => {
|
||||
app = await createNestApp(AckGateway);
|
||||
await app.listen(3000);
|
||||
|
||||
ws = io('http://localhost:8080');
|
||||
const payload = { shouldSucceed: true };
|
||||
|
||||
await new Promise<void>(resolve =>
|
||||
ws.emit('manual-ack', payload, response => {
|
||||
expect(response).to.eql({ status: 'success', data: payload });
|
||||
resolve();
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
it('should handle manual ack for async operations when @Ack() is used (error case)', async () => {
|
||||
app = await createNestApp(AckGateway);
|
||||
await app.listen(3000);
|
||||
|
||||
ws = io('http://localhost:8080');
|
||||
const payload = { shouldSucceed: false };
|
||||
|
||||
await new Promise<void>(resolve =>
|
||||
ws.emit('manual-ack', payload, response => {
|
||||
expect(response).to.eql({
|
||||
status: 'error',
|
||||
message: 'Operation failed',
|
||||
});
|
||||
resolve();
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
afterEach(() => app.close());
|
||||
});
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
import { SubscribeMessage, WebSocketGateway } from '@nestjs/websockets';
|
||||
import {
|
||||
Ack,
|
||||
MessageBody,
|
||||
SubscribeMessage,
|
||||
WebSocketGateway,
|
||||
} from '@nestjs/websockets';
|
||||
|
||||
@WebSocketGateway(8080)
|
||||
export class AckGateway {
|
||||
@@ -6,4 +11,19 @@ export class AckGateway {
|
||||
onPush() {
|
||||
return 'pong';
|
||||
}
|
||||
|
||||
@SubscribeMessage('manual-ack')
|
||||
async handleManualAck(
|
||||
@MessageBody() data: any,
|
||||
@Ack() ack: (response: any) => void,
|
||||
) {
|
||||
await new Promise(resolve => setTimeout(resolve, 20));
|
||||
|
||||
if (data.shouldSucceed) {
|
||||
ack({ status: 'success', data });
|
||||
} else {
|
||||
ack({ status: 'error', message: 'Operation failed' });
|
||||
}
|
||||
return { status: 'ignored' };
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,4 +12,5 @@ export enum RouteParamtypes {
|
||||
HOST = 10,
|
||||
IP = 11,
|
||||
RAW_BODY = 12,
|
||||
ACK = 13,
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import { Observable } from 'rxjs';
|
||||
export interface WsMessageHandler<T = string> {
|
||||
message: T;
|
||||
callback: (...args: any[]) => Observable<any> | Promise<any>;
|
||||
isAckHandledManually: boolean;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -44,22 +44,24 @@ export class IoAdapter extends AbstractWsAdapter {
|
||||
first(),
|
||||
);
|
||||
|
||||
handlers.forEach(({ message, callback }) => {
|
||||
handlers.forEach(({ message, callback, isAckHandledManually }) => {
|
||||
const source$ = fromEvent(socket, message).pipe(
|
||||
mergeMap((payload: any) => {
|
||||
const { data, ack } = this.mapPayload(payload);
|
||||
return transform(callback(data, ack)).pipe(
|
||||
filter((response: any) => !isNil(response)),
|
||||
map((response: any) => [response, ack]),
|
||||
map((response: any) => [response, ack, isAckHandledManually]),
|
||||
);
|
||||
}),
|
||||
takeUntil(disconnect$),
|
||||
);
|
||||
source$.subscribe(([response, ack]) => {
|
||||
source$.subscribe(([response, ack, isAckHandledManually]) => {
|
||||
if (response.event) {
|
||||
return socket.emit(response.event, response.data);
|
||||
}
|
||||
isFunction(ack) && ack(response);
|
||||
if (!isAckHandledManually && isFunction(ack)) {
|
||||
ack(response);
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import { WsParamtype } from '../enums/ws-paramtype.enum';
|
||||
|
||||
export const DEFAULT_CALLBACK_METADATA = {
|
||||
[`${WsParamtype.ACK}:2`]: { index: 2, data: undefined, pipes: [] },
|
||||
[`${WsParamtype.PAYLOAD}:1`]: { index: 1, data: undefined, pipes: [] },
|
||||
[`${WsParamtype.SOCKET}:0`]: { index: 0, data: undefined, pipes: [] },
|
||||
};
|
||||
|
||||
28
packages/websockets/decorators/ack.decorator.ts
Normal file
28
packages/websockets/decorators/ack.decorator.ts
Normal file
@@ -0,0 +1,28 @@
|
||||
import { WsParamtype } from '../enums/ws-paramtype.enum';
|
||||
import { createPipesWsParamDecorator } from '../utils/param.utils';
|
||||
|
||||
/**
|
||||
* WebSockets `ack` parameter decorator.
|
||||
* Extracts the `ack` callback function from the arguments of a ws event.
|
||||
*
|
||||
* This decorator signals to the framework that the `ack` callback will be
|
||||
* handled manually within the method, preventing the framework from
|
||||
* automatically sending an acknowledgement based on the return value.
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* @SubscribeMessage('events')
|
||||
* onEvent(
|
||||
* @MessageBody() data: string,
|
||||
* @Ack() ack: (response: any) => void
|
||||
* ) {
|
||||
* // Manually call the ack callback
|
||||
* ack({ status: 'ok' });
|
||||
* }
|
||||
* ```
|
||||
*
|
||||
* @publicApi
|
||||
*/
|
||||
export function Ack(): ParameterDecorator {
|
||||
return createPipesWsParamDecorator(WsParamtype.ACK)();
|
||||
}
|
||||
@@ -3,3 +3,4 @@ export * from './gateway-server.decorator';
|
||||
export * from './message-body.decorator';
|
||||
export * from './socket-gateway.decorator';
|
||||
export * from './subscribe-message.decorator';
|
||||
export * from './ack.decorator';
|
||||
|
||||
@@ -3,4 +3,5 @@ import { RouteParamtypes } from '@nestjs/common/enums/route-paramtypes.enum';
|
||||
export enum WsParamtype {
|
||||
SOCKET = RouteParamtypes.REQUEST,
|
||||
PAYLOAD = RouteParamtypes.BODY,
|
||||
ACK = RouteParamtypes.ACK,
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { isFunction } from '@nestjs/common/utils/shared.utils';
|
||||
import { WsParamtype } from '../enums/ws-paramtype.enum';
|
||||
|
||||
export class WsParamsFactory {
|
||||
@@ -14,6 +15,9 @@ export class WsParamsFactory {
|
||||
return args[0];
|
||||
case WsParamtype.PAYLOAD:
|
||||
return data ? args[1]?.[data] : args[1];
|
||||
case WsParamtype.ACK: {
|
||||
return args.find(arg => isFunction(arg));
|
||||
}
|
||||
default:
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -5,16 +5,22 @@ import {
|
||||
GATEWAY_SERVER_METADATA,
|
||||
MESSAGE_MAPPING_METADATA,
|
||||
MESSAGE_METADATA,
|
||||
PARAM_ARGS_METADATA,
|
||||
} from './constants';
|
||||
import { NestGateway } from './interfaces/nest-gateway.interface';
|
||||
import { ParamsMetadata } from '@nestjs/core/helpers/interfaces';
|
||||
import { WsParamtype } from './enums/ws-paramtype.enum';
|
||||
import { ContextUtils } from '@nestjs/core/helpers/context-utils';
|
||||
|
||||
export interface MessageMappingProperties {
|
||||
message: any;
|
||||
methodName: string;
|
||||
callback: (...args: any[]) => Observable<any> | Promise<any>;
|
||||
isAckHandledManually: boolean;
|
||||
}
|
||||
|
||||
export class GatewayMetadataExplorer {
|
||||
private readonly contextUtils = new ContextUtils();
|
||||
constructor(private readonly metadataScanner: MetadataScanner) {}
|
||||
|
||||
public explore(instance: NestGateway): MessageMappingProperties[] {
|
||||
@@ -38,13 +44,40 @@ export class GatewayMetadataExplorer {
|
||||
return null;
|
||||
}
|
||||
const message = Reflect.getMetadata(MESSAGE_METADATA, callback);
|
||||
const isAckHandledManually = this.hasAckDecorator(
|
||||
instancePrototype,
|
||||
methodName,
|
||||
);
|
||||
|
||||
return {
|
||||
callback,
|
||||
message,
|
||||
methodName,
|
||||
isAckHandledManually,
|
||||
};
|
||||
}
|
||||
|
||||
private hasAckDecorator(
|
||||
instancePrototype: object,
|
||||
methodName: string,
|
||||
): boolean {
|
||||
const paramsMetadata: ParamsMetadata = Reflect.getMetadata(
|
||||
PARAM_ARGS_METADATA,
|
||||
instancePrototype.constructor,
|
||||
methodName,
|
||||
);
|
||||
|
||||
if (!paramsMetadata) {
|
||||
return false;
|
||||
}
|
||||
const metadataKeys = Object.keys(paramsMetadata);
|
||||
return metadataKeys.some(key => {
|
||||
const type = this.contextUtils.mapParamType(key);
|
||||
|
||||
return (Number(type) as WsParamtype) === WsParamtype.ACK;
|
||||
});
|
||||
}
|
||||
|
||||
public *scanForServerHooks(instance: NestGateway): IterableIterator<string> {
|
||||
for (const propertyKey in instance) {
|
||||
if (isFunction(propertyKey)) {
|
||||
|
||||
28
packages/websockets/test/decorators/ack.decorator.spec.ts
Normal file
28
packages/websockets/test/decorators/ack.decorator.spec.ts
Normal file
@@ -0,0 +1,28 @@
|
||||
import 'reflect-metadata';
|
||||
import { expect } from 'chai';
|
||||
import { PARAM_ARGS_METADATA } from '../../constants';
|
||||
import { Ack } from '../../decorators/ack.decorator';
|
||||
import { WsParamtype } from '../../enums/ws-paramtype.enum';
|
||||
|
||||
class AckTest {
|
||||
public test(@Ack() ack: Function) {}
|
||||
}
|
||||
|
||||
describe('@Ack', () => {
|
||||
it('should enhance class with expected request metadata', () => {
|
||||
const argsMetadata = Reflect.getMetadata(
|
||||
PARAM_ARGS_METADATA,
|
||||
AckTest,
|
||||
'test',
|
||||
);
|
||||
|
||||
const expectedMetadata = {
|
||||
[`${WsParamtype.ACK}:0`]: {
|
||||
index: 0,
|
||||
data: undefined,
|
||||
pipes: [],
|
||||
},
|
||||
};
|
||||
expect(argsMetadata).to.be.eql(expectedMetadata);
|
||||
});
|
||||
});
|
||||
@@ -4,11 +4,13 @@ import { MetadataScanner } from '../../core/metadata-scanner';
|
||||
import { WebSocketServer } from '../decorators/gateway-server.decorator';
|
||||
import { WebSocketGateway } from '../decorators/socket-gateway.decorator';
|
||||
import { SubscribeMessage } from '../decorators/subscribe-message.decorator';
|
||||
import { Ack } from '../decorators/ack.decorator';
|
||||
import { GatewayMetadataExplorer } from '../gateway-metadata-explorer';
|
||||
|
||||
describe('GatewayMetadataExplorer', () => {
|
||||
const message = 'test';
|
||||
const secMessage = 'test2';
|
||||
const ackMessage = 'ack-test';
|
||||
|
||||
@WebSocketGateway()
|
||||
class Test {
|
||||
@@ -28,6 +30,9 @@ describe('GatewayMetadataExplorer', () => {
|
||||
@SubscribeMessage(secMessage)
|
||||
public testSec() {}
|
||||
|
||||
@SubscribeMessage(ackMessage)
|
||||
public testWithAck(@Ack() ack: Function) {}
|
||||
|
||||
public noMessage() {}
|
||||
}
|
||||
let instance: GatewayMetadataExplorer;
|
||||
@@ -61,9 +66,22 @@ describe('GatewayMetadataExplorer', () => {
|
||||
});
|
||||
it(`should return message mapping properties when "isMessageMapping" metadata is not undefined`, () => {
|
||||
const metadata = instance.exploreMethodMetadata(test, 'test')!;
|
||||
expect(metadata).to.have.keys(['callback', 'message', 'methodName']);
|
||||
expect(metadata).to.have.keys([
|
||||
'callback',
|
||||
'message',
|
||||
'methodName',
|
||||
'isAckHandledManually',
|
||||
]);
|
||||
expect(metadata.message).to.eql(message);
|
||||
});
|
||||
it('should set "isAckHandledManually" property to true when @Ack decorator is used', () => {
|
||||
const metadata = instance.exploreMethodMetadata(test, 'testWithAck')!;
|
||||
expect(metadata.isAckHandledManually).to.be.true;
|
||||
});
|
||||
it('should set "isAckHandledManually" property to false when @Ack decorator is not used', () => {
|
||||
const metadata = instance.exploreMethodMetadata(test, 'test')!;
|
||||
expect(metadata.isAckHandledManually).to.be.false;
|
||||
});
|
||||
});
|
||||
describe('scanForServerHooks', () => {
|
||||
it(`should return properties with @Client decorator`, () => {
|
||||
|
||||
@@ -135,6 +135,7 @@ describe('WebSocketsController', () => {
|
||||
message: 'message',
|
||||
methodName: 'methodName',
|
||||
callback: handlerCallback,
|
||||
isAckHandledManually: false,
|
||||
},
|
||||
];
|
||||
server = { server: 'test' };
|
||||
@@ -173,6 +174,7 @@ describe('WebSocketsController', () => {
|
||||
message: 'message',
|
||||
methodName: 'methodName',
|
||||
callback: messageHandlerCallback,
|
||||
isAckHandledManually: false,
|
||||
},
|
||||
]);
|
||||
});
|
||||
@@ -188,11 +190,13 @@ describe('WebSocketsController', () => {
|
||||
methodName: 'findOne',
|
||||
message: 'find',
|
||||
callback: null!,
|
||||
isAckHandledManually: false,
|
||||
},
|
||||
{
|
||||
methodName: 'create',
|
||||
message: 'insert',
|
||||
callback: null!,
|
||||
isAckHandledManually: false,
|
||||
},
|
||||
];
|
||||
const insertEntrypointDefinitionSpy = sinon.spy(
|
||||
@@ -423,14 +427,40 @@ describe('WebSocketsController', () => {
|
||||
client = { on: onSpy, off: onSpy };
|
||||
|
||||
handlers = [
|
||||
{ message: 'test', callback: { bind: () => 'testCallback' } },
|
||||
{ message: 'test2', callback: { bind: () => 'testCallback2' } },
|
||||
{
|
||||
message: 'test',
|
||||
callback: { bind: () => 'testCallback' },
|
||||
isAckHandledManually: true,
|
||||
},
|
||||
{
|
||||
message: 'test2',
|
||||
callback: { bind: () => 'testCallback2' },
|
||||
isAckHandledManually: false,
|
||||
},
|
||||
];
|
||||
});
|
||||
it('should bind each handler to client', () => {
|
||||
instance.subscribeMessages(handlers, client, gateway);
|
||||
expect(onSpy.calledTwice).to.be.true;
|
||||
});
|
||||
it('should pass "isAckHandledManually" flag to the adapter', () => {
|
||||
const adapter = config.getIoAdapter();
|
||||
const bindMessageHandlersSpy = sinon.spy(adapter, 'bindMessageHandlers');
|
||||
|
||||
instance.subscribeMessages(handlers, client, gateway);
|
||||
|
||||
const handlersPassedToAdapter = bindMessageHandlersSpy.firstCall.args[1];
|
||||
|
||||
expect(handlersPassedToAdapter[0].message).to.equal(handlers[0].message);
|
||||
expect(handlersPassedToAdapter[0].isAckHandledManually).to.equal(
|
||||
handlers[0].isAckHandledManually,
|
||||
);
|
||||
|
||||
expect(handlersPassedToAdapter[1].message).to.equal(handlers[1].message);
|
||||
expect(handlersPassedToAdapter[1].isAckHandledManually).to.equal(
|
||||
handlers[1].isAckHandledManually,
|
||||
);
|
||||
});
|
||||
});
|
||||
describe('pickResult', () => {
|
||||
describe('when deferredResult contains value which', () => {
|
||||
|
||||
@@ -72,7 +72,7 @@ export class WebSocketsController {
|
||||
) {
|
||||
const nativeMessageHandlers = this.metadataExplorer.explore(instance);
|
||||
const messageHandlers = nativeMessageHandlers.map(
|
||||
({ callback, message, methodName }) => ({
|
||||
({ callback, isAckHandledManually, message, methodName }) => ({
|
||||
message,
|
||||
methodName,
|
||||
callback: this.contextCreator.create(
|
||||
@@ -81,6 +81,7 @@ export class WebSocketsController {
|
||||
moduleKey,
|
||||
methodName,
|
||||
),
|
||||
isAckHandledManually,
|
||||
}),
|
||||
);
|
||||
|
||||
@@ -174,10 +175,13 @@ export class WebSocketsController {
|
||||
instance: NestGateway,
|
||||
) {
|
||||
const adapter = this.config.getIoAdapter();
|
||||
const handlers = subscribersMap.map(({ callback, message }) => ({
|
||||
message,
|
||||
callback: callback.bind(instance, client),
|
||||
}));
|
||||
const handlers = subscribersMap.map(
|
||||
({ callback, message, isAckHandledManually }) => ({
|
||||
message,
|
||||
callback: callback.bind(instance, client),
|
||||
isAckHandledManually,
|
||||
}),
|
||||
);
|
||||
adapter.bindMessageHandlers(client, handlers, data =>
|
||||
fromPromise(this.pickResult(data)).pipe(mergeAll()),
|
||||
);
|
||||
|
||||
Reference in New Issue
Block a user