Merge pull request #15386 from nestjs/feat/introspection-capabilities

feat: enhance introspection capabilities
This commit is contained in:
Kamil Mysliwiec
2025-07-14 17:29:58 +02:00
committed by GitHub
34 changed files with 580 additions and 159 deletions

View File

@@ -38,6 +38,7 @@ export default tseslint.config(
'@typescript-eslint/no-unused-expressions': 'off',
'@typescript-eslint/no-require-imports': 'off',
'@typescript-eslint/no-unused-vars': 'off',
'@typescript-eslint/no-non-null-asserted-optional-chain': 'warn',
"@typescript-eslint/no-misused-promises": [
"error",
{

View File

@@ -53,4 +53,18 @@ export class NestApplicationContextOptions {
* @default 'reference'
*/
moduleIdGeneratorAlgorithm?: 'deep-hash' | 'reference';
/**
* Instrument the application context.
* This option allows you to add custom instrumentation to the application context.
*/
instrument?: {
/**
* Function that decorates each instance created by the application context.
* This function can be used to add custom properties or methods to the instance.
* @param instance The instance to decorate.
* @returns The decorated instance.
*/
instanceDecorator: (instance: unknown) => unknown;
};
}

View File

@@ -346,6 +346,23 @@ export class ConsoleLogger implements LoggerService {
writeStreamType?: 'stdout' | 'stderr';
errorStack?: unknown;
},
) {
const logObject = this.getJsonLogObject(message, options);
const formattedMessage =
!this.options.colors && this.inspectOptions.compact === true
? JSON.stringify(logObject, this.stringifyReplacer)
: inspect(logObject, this.inspectOptions);
process[options.writeStreamType ?? 'stdout'].write(`${formattedMessage}\n`);
}
protected getJsonLogObject(
message: unknown,
options: {
context: string;
logLevel: LogLevel;
writeStreamType?: 'stdout' | 'stderr';
errorStack?: unknown;
},
) {
type JsonLogObject = {
level: LogLevel;
@@ -370,12 +387,7 @@ export class ConsoleLogger implements LoggerService {
if (options.errorStack) {
logObject.stack = options.errorStack;
}
const formattedMessage =
!this.options.colors && this.inspectOptions.compact === true
? JSON.stringify(logObject, this.stringifyReplacer)
: inspect(logObject, this.inspectOptions);
process[options.writeStreamType ?? 'stdout'].write(`${formattedMessage}\n`);
return logObject;
}
protected formatPid(pid: number) {

View File

@@ -12,6 +12,9 @@ export abstract class AbstractHttpAdapter<
> implements HttpServer<TRequest, TResponse>
{
protected httpServer: TServer;
protected onRouteTriggered:
| ((requestMethod: RequestMethod, path: string) => void)
| undefined;
constructor(protected instance?: any) {}
@@ -143,6 +146,20 @@ export abstract class AbstractHttpAdapter<
return path;
}
public setOnRouteTriggered(
onRouteTriggered: (requestMethod: RequestMethod, path: string) => void,
) {
this.onRouteTriggered = onRouteTriggered;
}
public getOnRouteTriggered() {
return this.onRouteTriggered;
}
public setOnRequestHook(onRequestHook: Function): void {}
public setOnResponseHook(onResponseHook: Function): void {}
abstract close();
abstract initHttpServer(options: NestApplicationOptions);
abstract useStaticAssets(...args: any[]);

View File

@@ -1,4 +1,4 @@
import { Observable, Subject } from 'rxjs';
import { Observable, ReplaySubject, Subject } from 'rxjs';
import { AbstractHttpAdapter } from '../adapters/http-adapter';
/**
@@ -18,6 +18,7 @@ export class HttpAdapterHost<
> {
private _httpAdapter?: T;
private _listen$ = new Subject<void>();
private _init$ = new ReplaySubject<void>();
private isListening = false;
/**
@@ -27,6 +28,9 @@ export class HttpAdapterHost<
*/
set httpAdapter(httpAdapter: T) {
this._httpAdapter = httpAdapter;
this._init$.next();
this._init$.complete();
}
/**
@@ -47,6 +51,14 @@ export class HttpAdapterHost<
return this._listen$.asObservable();
}
/**
* Observable that allows to subscribe to the `init` event.
* This event is emitted when the HTTP application is initialized.
*/
get init$(): Observable<void> {
return this._init$.asObservable();
}
/**
* Sets the listening state of the application.
*/

View File

@@ -67,6 +67,10 @@ export class NestContainer {
return this._applicationConfig;
}
get contextOptions(): NestApplicationContextOptions | undefined {
return this._contextOptions;
}
public setHttpAdapter(httpAdapter: any) {
this.internalProvidersStorage.httpAdapter = httpAdapter;

View File

@@ -84,8 +84,26 @@ export interface InjectorDependencyContext {
export class Injector {
private logger: LoggerService = new Logger('InjectorLogger');
private readonly instanceDecorator: (target: unknown) => unknown = (
target: unknown,
) => target;
constructor(private readonly options?: { preview: boolean }) {}
constructor(
private readonly options?: {
/**
* Whether to enable preview mode.
*/
preview: boolean;
/**
* Function to decorate a freshly created instance.
*/
instanceDecorator?: (target: unknown) => unknown;
},
) {
if (options?.instanceDecorator) {
this.instanceDecorator = options.instanceDecorator;
}
}
public loadPrototype<T>(
{ token }: InstanceWrapper<T>,
@@ -768,11 +786,14 @@ export class Injector {
new (metatype as Type<any>)(...instances),
)
: new (metatype as Type<any>)(...instances);
instanceHost.instance = this.instanceDecorator(instanceHost.instance);
} else if (isInContext) {
const factoryReturnValue = (targetMetatype.metatype as any as Function)(
...instances,
);
instanceHost.instance = await factoryReturnValue;
instanceHost.instance = this.instanceDecorator(instanceHost.instance);
}
instanceHost.isResolved = true;
return instanceHost.instance;

View File

@@ -27,7 +27,11 @@ export class InternalCoreModuleFactory {
const logger = new Logger(LazyModuleLoader.name, {
timestamp: false,
});
const injector = new Injector();
const injector = new Injector({
preview: container.contextOptions?.preview!,
instanceDecorator:
container.contextOptions?.instrument?.instanceDecorator,
});
const instanceLoader = new InstanceLoader(
container,
injector,

View File

@@ -23,7 +23,7 @@ export interface ModuleRefGetOrResolveOpts {
}
export abstract class ModuleRef extends AbstractInstanceResolver {
protected readonly injector = new Injector();
protected readonly injector: Injector;
private _instanceLinksHost: InstanceLinksHost;
protected get instanceLinksHost() {
@@ -35,6 +35,12 @@ export abstract class ModuleRef extends AbstractInstanceResolver {
constructor(protected readonly container: NestContainer) {
super();
this.injector = new Injector({
preview: container.contextOptions?.preview!,
instanceDecorator:
container.contextOptions?.instrument?.instanceDecorator,
});
}
/**

View File

@@ -384,13 +384,16 @@ export class Module {
enhancerSubtype?: EnhancerSubtype,
) {
const { useValue: value, provide: providerToken } = provider;
const instanceDecorator =
this.container.contextOptions?.instrument?.instanceDecorator;
collection.set(
providerToken,
new InstanceWrapper({
token: providerToken,
name: (providerToken as Function)?.name || providerToken,
metatype: null!,
instance: value,
instance: instanceDecorator ? instanceDecorator(value) : value,
isResolved: true,
async: value instanceof Promise,
host: this,

View File

@@ -1,14 +1,41 @@
import { Observable, ReplaySubject } from 'rxjs';
import { uid } from 'uid';
import { Module } from './module';
export class ModulesContainer extends Map<string, Module> {
private readonly _applicationId = uid(21);
private readonly _rpcTargetRegistry$ = new ReplaySubject<any>();
/**
* Unique identifier of the application instance.
*/
get applicationId(): string {
return this._applicationId;
}
/**
* Retrieves a module by its identifier.
* @param id The identifier of the module to retrieve.
* @returns The module instance if found, otherwise undefined.
*/
public getById(id: string): Module | undefined {
return Array.from(this.values()).find(moduleRef => moduleRef.id === id);
}
/**
* Returns the RPC target registry as an observable.
* This registry contains all RPC targets registered in the application.
* @returns An observable that emits the RPC target registry.
*/
public getRpcTargetRegistry<T>(): Observable<T> {
return this._rpcTargetRegistry$.asObservable();
}
/**
* Adds an RPC target to the registry.
* @param target The RPC target to add.
*/
public addRpcTarget<T>(target: T): void {
this._rpcTargetRegistry$.next(target);
}
}

View File

@@ -81,7 +81,10 @@ export class NestApplication
this.selectContextModule();
this.registerHttpServer();
this.injector = new Injector({ preview: this.appOptions.preview! });
this.injector = new Injector({
preview: this.appOptions.preview!,
instanceDecorator: appOptions.instrument?.instanceDecorator,
});
this.middlewareModule = new MiddlewareModule();
this.routesResolver = new RoutesResolver(
this.container,
@@ -452,6 +455,7 @@ export class NestApplication
this.httpAdapter.setViewEngine(engineOrOptions);
return this;
}
private host(): string | undefined {
const address = this.httpServer.address();
if (isString(address)) {

View File

@@ -209,7 +209,10 @@ export class NestFactoryStatic {
? UuidFactoryMode.Deterministic
: UuidFactoryMode.Random;
const injector = new Injector({ preview: options.preview! });
const injector = new Injector({
preview: options.preview!,
instanceDecorator: options.instrument?.instanceDecorator,
});
const instanceLoader = new InstanceLoader(
container,
injector,

View File

@@ -101,14 +101,14 @@ export class RouterExplorer {
public explore<T extends HttpServer = any>(
instanceWrapper: InstanceWrapper,
moduleKey: string,
applicationRef: T,
httpAdapterRef: T,
host: string | RegExp | Array<string | RegExp>,
routePathMetadata: RoutePathMetadata,
) {
const { instance } = instanceWrapper;
const routerPaths = this.pathsExplorer.scanForPaths(instance);
this.applyPathsToRouterProxy(
applicationRef,
httpAdapterRef,
routerPaths,
instanceWrapper,
moduleKey,
@@ -234,7 +234,17 @@ export class RouterExplorer {
const normalizedPath = router.normalizePath
? router.normalizePath(path)
: path;
routerMethodRef(normalizedPath, routeHandler);
const httpAdapter = this.container.getHttpAdapterRef();
const onRouteTriggered = httpAdapter.getOnRouteTriggered?.();
if (onRouteTriggered) {
routerMethodRef(normalizedPath, (...args: unknown[]) => {
onRouteTriggered(requestMethod, path);
return routeHandler(...args);
});
} else {
routerMethodRef(normalizedPath, routeHandler);
}
this.graphInspector.insertEntrypointDefinition<HttpEntrypointMetadata>(
entrypointDefinition,

View File

@@ -50,7 +50,11 @@ export class MicroservicesModule<
new InterceptorsConsumer(),
);
const injector = new Injector();
const injector = new Injector({
preview: container.contextOptions?.preview!,
instanceDecorator:
container.contextOptions?.instrument?.instanceDecorator,
});
this.listenersController = new ListenersController(
this.clientsContainer,
contextCreator,

View File

@@ -64,7 +64,10 @@ export class NestMicroservice
) {
super(container, config);
this.injector = new Injector({ preview: config.preview! });
this.injector = new Injector({
preview: config.preview!,
instanceDecorator: config.instrument?.instanceDecorator,
});
this.microservicesModule.register(
container,
this.graphInspector,
@@ -73,6 +76,9 @@ export class NestMicroservice
);
this.createServer(config);
this.selectContextModule();
const modulesContainer = this.container.getModules();
modulesContainer.addRpcTarget(this.serverInstance);
}
public createServer(config: CompleteMicroserviceOptions) {

View File

@@ -180,6 +180,11 @@ export class ServerGrpc extends Server<never, never> {
if (!methodHandler) {
continue;
}
Object.defineProperty(methodHandler, 'name', {
value: methodName,
writable: false,
});
service[methodName] = this.createServiceMethod(
methodHandler,
grpcService.prototype[methodName],
@@ -263,19 +268,36 @@ export class ServerGrpc extends Server<never, never> {
public createUnaryServiceMethod(methodHandler: Function): Function {
return async (call: GrpcCall, callback: Function) => {
const handler = methodHandler(call.request, call.metadata, call);
this.transformToObservable(await handler).subscribe({
next: async data => callback(null, await data),
error: (err: any) => callback(err),
});
return this.onProcessingStartHook(
this.transportId,
{ ...call, operationId: methodHandler.name } as any,
async () => {
const handler = methodHandler(call.request, call.metadata, call);
this.transformToObservable(await handler).subscribe({
next: async data => callback(null, await data),
error: (err: any) => callback(err),
complete: () => {
this.onProcessingEndHook?.(this.transportId, call.request);
},
});
},
);
};
}
public createStreamServiceMethod(methodHandler: Function): Function {
return async (call: GrpcCall, callback: Function) => {
const handler = methodHandler(call.request, call.metadata, call);
const result$ = this.transformToObservable(await handler);
await this.writeObservableToGrpc(result$, call);
return this.onProcessingStartHook(
this.transportId,
{ ...call, operationId: methodHandler.name } as any,
async () => {
const handler = methodHandler(call.request, call.metadata, call);
const result$ = this.transformToObservable(await handler);
await this.writeObservableToGrpc(result$, call);
this.onProcessingEndHook?.(this.transportId, call.request);
},
);
};
}
@@ -406,52 +428,62 @@ export class ServerGrpc extends Server<never, never> {
call: GrpcCall,
callback: (err: unknown, value: unknown) => void,
) => {
// Needs to be a Proxy in order to buffer messages that come before handler is executed
// This could happen if handler has any async guards or interceptors registered that would delay
// the execution.
const { subject, next, error, complete, cleanup } =
this.bufferUntilDrained();
call.on('data', (m: any) => next(m));
call.on('error', (e: any) => {
// Check if error means that stream ended on other end
const isCancelledError = String(e).toLowerCase().indexOf('cancelled');
return this.onProcessingStartHook(
this.transportId,
{ ...call, operationId: methodHandler.name } as any,
async () => {
// Needs to be a Proxy in order to buffer messages that come before handler is executed
// This could happen if handler has any async guards or interceptors registered that would delay
// the execution.
const { subject, next, error, complete, cleanup } =
this.bufferUntilDrained();
call.on('data', (m: any) => next(m));
call.on('error', (e: any) => {
// Check if error means that stream ended on other end
const isCancelledError = String(e)
.toLowerCase()
.indexOf('cancelled');
if (isCancelledError) {
call.end();
return;
}
// If another error then just pass it along
error(e);
});
call.on('end', () => {
complete();
cleanup();
});
if (isCancelledError) {
call.end();
return;
}
// If another error then just pass it along
error(e);
});
call.on('end', () => {
complete();
cleanup();
const handler = methodHandler(
subject.asObservable(),
call.metadata,
call,
this.onProcessingEndHook?.(this.transportId, call.request);
});
const handler = methodHandler(
subject.asObservable(),
call.metadata,
call,
);
const res = this.transformToObservable(await handler);
if (isResponseStream) {
await this.writeObservableToGrpc(res, call);
} else {
const response = await lastValueFrom(
res.pipe(
takeUntil(fromEvent(call as any, CANCELLED_EVENT)),
catchError(err => {
callback(err, null);
return EMPTY;
}),
defaultIfEmpty(undefined),
),
);
if (!isUndefined(response)) {
callback(null, response);
}
}
},
);
const res = this.transformToObservable(await handler);
if (isResponseStream) {
await this.writeObservableToGrpc(res, call);
} else {
const response = await lastValueFrom(
res.pipe(
takeUntil(fromEvent(call as any, CANCELLED_EVENT)),
catchError(err => {
callback(err, null);
return EMPTY;
}),
defaultIfEmpty(undefined),
),
);
if (!isUndefined(response)) {
callback(null, response);
}
}
};
}
@@ -463,15 +495,25 @@ export class ServerGrpc extends Server<never, never> {
call: GrpcCall,
callback: (err: unknown, value: unknown) => void,
) => {
let handlerStream: Observable<any>;
if (isResponseStream) {
handlerStream = this.transformToObservable(await methodHandler(call));
} else {
handlerStream = this.transformToObservable(
await methodHandler(call, callback),
);
}
await lastValueFrom(handlerStream);
return this.onProcessingStartHook(
this.transportId,
{ ...call, operationId: methodHandler.name } as any,
async () => {
let handlerStream: Observable<any>;
if (isResponseStream) {
handlerStream = this.transformToObservable(
await methodHandler(call),
);
} else {
handlerStream = this.transformToObservable(
await methodHandler(call, callback),
);
}
await lastValueFrom(handlerStream).finally(() => {
this.onProcessingEndHook?.(this.transportId, call.request);
});
},
);
};
}

View File

@@ -187,9 +187,16 @@ export class ServerKafka extends Server<never, KafkaStatus> {
replyTopic: string,
replyPartition: string,
correlationId: string,
context: KafkaContext,
): (data: any) => Promise<RecordMetadata[]> {
return (data: any) =>
this.sendMessage(data, replyTopic, replyPartition, correlationId);
this.sendMessage(
data,
replyTopic,
replyPartition,
correlationId,
context,
);
}
public async handleMessage(payload: EachMessagePayload) {
@@ -225,6 +232,7 @@ export class ServerKafka extends Server<never, KafkaStatus> {
replyTopic,
replyPartition,
correlationId,
kafkaContext,
);
if (!handler) {
@@ -233,15 +241,20 @@ export class ServerKafka extends Server<never, KafkaStatus> {
err: NO_MESSAGE_HANDLER,
});
}
return this.onProcessingStartHook(
this.transportId,
kafkaContext,
async () => {
const response$ = this.transformToObservable(
handler(packet.data, kafkaContext),
);
const response$ = this.transformToObservable(
handler(packet.data, kafkaContext),
const replayStream$ = new ReplaySubject();
await this.combineStreamsAndThrowIfRetriable(response$, replayStream$);
this.send(replayStream$, publish);
},
);
const replayStream$ = new ReplaySubject();
await this.combineStreamsAndThrowIfRetriable(response$, replayStream$);
this.send(replayStream$, publish);
}
public unwrap<T>(): T {
@@ -293,6 +306,7 @@ export class ServerKafka extends Server<never, KafkaStatus> {
replyTopic: string,
replyPartition: string | undefined | null,
correlationId: string,
context: KafkaContext,
): Promise<RecordMetadata[]> {
const outgoingMessage = await this.serializer.serialize(message.response);
this.assignReplyPartition(replyPartition, outgoingMessage);
@@ -307,7 +321,9 @@ export class ServerKafka extends Server<never, KafkaStatus> {
},
this.options.send || {},
);
return this.producer!.send(replyMessage);
return this.producer!.send(replyMessage).finally(() => {
this.onProcessingEndHook?.(this.transportId, context);
});
}
public assignIsDisposedHeader(
@@ -362,10 +378,14 @@ export class ServerKafka extends Server<never, KafkaStatus> {
if (!handler) {
return this.logger.error(NO_EVENT_HANDLER`${pattern}`);
}
const resultOrStream = await handler(packet.data, context);
if (isObservable(resultOrStream)) {
await lastValueFrom(resultOrStream);
}
return this.onProcessingStartHook(this.transportId, context, async () => {
const resultOrStream = await handler(packet.data, context);
if (isObservable(resultOrStream)) {
await lastValueFrom(resultOrStream);
this.onProcessingEndHook?.(this.transportId, context);
}
});
}
protected initializeSerializer(options: KafkaOptions['options']) {

View File

@@ -129,7 +129,7 @@ export class ServerMqtt extends Server<MqttEvents, MqttStatus> {
}
const publish = this.getPublisher(
pub,
channel,
mqttContext,
(packet as IncomingRequest).id,
);
const handler = this.getHandlerByPattern(channel);
@@ -143,13 +143,23 @@ export class ServerMqtt extends Server<MqttEvents, MqttStatus> {
};
return publish(noHandlerPacket);
}
const response$ = this.transformToObservable(
await handler(packet.data, mqttContext),
return this.onProcessingStartHook(
this.transportId,
mqttContext,
async () => {
const response$ = this.transformToObservable(
await handler(packet.data, mqttContext),
);
response$ && this.send(response$, publish);
},
);
response$ && this.send(response$, publish);
}
public getPublisher(client: MqttClient, pattern: any, id: string): any {
public getPublisher(
client: MqttClient,
context: MqttContext,
id: string,
): any {
return (response: any) => {
Object.assign(response, { id });
@@ -161,8 +171,10 @@ export class ServerMqtt extends Server<MqttEvents, MqttStatus> {
const outgoingResponse: string | Buffer =
this.serializer.serialize(response);
this.onProcessingEndHook?.(this.transportId, context);
return client.publish(
this.getReplyPattern(pattern),
this.getReplyPattern(context.getTopic()),
outgoingResponse,
options,
);

View File

@@ -153,7 +153,11 @@ export class ServerNats<
if (isUndefined((message as IncomingRequest).id)) {
return this.handleEvent(channel, message, natsCtx);
}
const publish = this.getPublisher(natsMsg, (message as IncomingRequest).id);
const publish = this.getPublisher(
natsMsg,
(message as IncomingRequest).id,
natsCtx,
);
const handler = this.getHandlerByPattern(channel);
if (!handler) {
const status = 'error';
@@ -164,18 +168,22 @@ export class ServerNats<
};
return publish(noHandlerPacket);
}
const response$ = this.transformToObservable(
await handler(message.data, natsCtx),
);
response$ && this.send(response$, publish);
return this.onProcessingStartHook(this.transportId, natsCtx, async () => {
const response$ = this.transformToObservable(
await handler(message.data, natsCtx),
);
response$ && this.send(response$, publish);
});
}
public getPublisher(natsMsg: NatsMsg, id: string) {
public getPublisher(natsMsg: NatsMsg, id: string, ctx: NatsContext) {
if (natsMsg.reply) {
return (response: any) => {
Object.assign(response, { id });
const outgoingResponse: NatsRecord =
this.serializer.serialize(response);
this.onProcessingEndHook?.(this.transportId, ctx);
return natsMsg.respond(outgoingResponse.data, {
headers: outgoingResponse.headers,
});

View File

@@ -145,6 +145,7 @@ export class ServerRedis extends Server<RedisEvents, RedisStatus> {
pub,
channel,
(packet as IncomingRequest).id,
redisCtx,
);
const handler = this.getHandlerByPattern(channel);
@@ -157,17 +158,24 @@ export class ServerRedis extends Server<RedisEvents, RedisStatus> {
};
return publish(noHandlerPacket);
}
const response$ = this.transformToObservable(
await handler(packet.data, redisCtx),
return this.onProcessingStartHook?.(
this.transportId,
redisCtx,
async () => {
const response$ = this.transformToObservable(
await handler(packet.data, redisCtx),
);
response$ && this.send(response$, publish);
},
);
response$ && this.send(response$, publish);
}
public getPublisher(pub: Redis, pattern: any, id: string) {
public getPublisher(pub: Redis, pattern: any, id: string, ctx: RedisContext) {
return (response: any) => {
Object.assign(response, { id });
const outgoingResponse = this.serializer.serialize(response);
this.onProcessingEndHook?.(this.transportId, ctx);
return pub.publish(
this.getReplyPattern(pattern),
JSON.stringify(outgoingResponse),

View File

@@ -292,16 +292,28 @@ export class ServerRMQ extends Server<RmqEvents, RmqStatus> {
noHandlerPacket,
properties.replyTo,
properties.correlationId,
rmqContext,
);
}
const response$ = this.transformToObservable(
await handler(packet.data, rmqContext),
return this.onProcessingStartHook(
this.transportId,
rmqContext,
async () => {
const response$ = this.transformToObservable(
await handler(packet.data, rmqContext),
);
const publish = <T>(data: T) =>
this.sendMessage(
data,
properties.replyTo,
properties.correlationId,
rmqContext,
);
response$ && this.send(response$, publish);
},
);
const publish = <T>(data: T) =>
this.sendMessage(data, properties.replyTo, properties.correlationId);
response$ && this.send(response$, publish);
}
public async handleEvent(
@@ -321,6 +333,7 @@ export class ServerRMQ extends Server<RmqEvents, RmqStatus> {
message: T,
replyTo: any,
correlationId: string,
context: RmqContext,
): void {
const outgoingResponse = this.serializer.serialize(
message as unknown as OutgoingResponse,
@@ -330,6 +343,8 @@ export class ServerRMQ extends Server<RmqEvents, RmqStatus> {
const buffer = Buffer.from(JSON.stringify(outgoingResponse));
const sendOptions = { correlationId, ...options };
this.onProcessingEndHook?.(this.transportId, context);
this.channel!.sendToQueue(replyTo, buffer, sendOptions);
}

View File

@@ -105,18 +105,26 @@ export class ServerTCP extends Server<TcpEvents, TcpStatus> {
});
return socket.sendMessage(noHandlerPacket);
}
const response$ = this.transformToObservable(
await handler(packet.data, tcpContext),
);
response$ &&
this.send(response$, data => {
Object.assign(data, { id: (packet as IncomingRequest).id });
const outgoingResponse = this.serializer.serialize(
data as WritePacket & PacketId,
return this.onProcessingStartHook(
this.transportId,
tcpContext,
async () => {
const response$ = this.transformToObservable(
await handler(packet.data, tcpContext),
);
socket.sendMessage(outgoingResponse);
});
response$ &&
this.send(response$, data => {
Object.assign(data, { id: (packet as IncomingRequest).id });
const outgoingResponse = this.serializer.serialize(
data as WritePacket & PacketId,
);
this.onProcessingEndHook?.(this.transportId, tcpContext);
socket.sendMessage(outgoingResponse);
});
},
);
}
public handleClose(): undefined | number | NodeJS.Timer {

View File

@@ -57,16 +57,21 @@ export abstract class Server<
protected readonly logger: LoggerService = new Logger(Server.name);
protected serializer: ConsumerSerializer;
protected deserializer: ConsumerDeserializer;
protected onProcessingStartHook: (
transportId: Transport | symbol,
context: BaseRpcContext,
done: () => Promise<any>,
) => void = (
transportId: Transport | symbol,
context: BaseRpcContext,
done: () => Promise<any>,
) => done();
protected onProcessingEndHook: (
transportId: Transport | symbol,
context: BaseRpcContext,
) => void;
protected _status$ = new ReplaySubject<Status>(1);
/**
* Sets the transport identifier.
* @param transportId Unique transport identifier.
*/
public setTransportId(transportId: Transport | symbol): void {
this.transportId = transportId;
}
/**
* Returns an observable that emits status changes.
*/
@@ -83,6 +88,7 @@ export abstract class Server<
EventKey extends keyof EventsMap = keyof EventsMap,
EventCallback extends EventsMap[EventKey] = EventsMap[EventKey],
>(event: EventKey, callback: EventCallback): any;
/**
* Returns an instance of the underlying server/broker instance,
* or a group of servers if there are more than one.
@@ -94,11 +100,42 @@ export abstract class Server<
* @param callback Function to be called upon initialization
*/
public abstract listen(callback: (...optionalParams: unknown[]) => any): any;
/**
* Method called when server is being terminated.
*/
public abstract close(): any;
/**
* Sets the transport identifier.
* @param transportId Unique transport identifier.
*/
public setTransportId(transportId: Transport | symbol): void {
this.transportId = transportId;
}
/**
* Sets a hook that will be called when processing starts.
*/
public setOnProcessingStartHook(
hook: (
transportId: Transport | symbol,
context: unknown,
done: () => Promise<any>,
) => void,
): void {
this.onProcessingStartHook = hook;
}
/**
* Sets a hook that will be called when processing ends.
*/
public setOnProcessingEndHook(
hook: (transportId: Transport | symbol, context: unknown) => void,
): void {
this.onProcessingEndHook = hook;
}
public addHandler(
pattern: any,
callback: MessageHandler,
@@ -177,14 +214,25 @@ export abstract class Server<
if (!handler) {
return this.logger.error(NO_EVENT_HANDLER`${pattern}`);
}
const resultOrStream = await handler(packet.data, context);
if (isObservable(resultOrStream)) {
const connectableSource = connectable(resultOrStream, {
connector: () => new Subject(),
resetOnDisconnect: false,
});
connectableSource.connect();
}
return this.onProcessingStartHook(this.transportId!, context, async () => {
const resultOrStream = await handler(packet.data, context);
if (isObservable(resultOrStream)) {
const connectableSource = connectable(
resultOrStream.pipe(
finalize(() =>
this.onProcessingEndHook?.(this.transportId!, context),
),
),
{
connector: () => new Subject(),
resetOnDisconnect: false,
},
);
connectableSource.connect();
} else {
this.onProcessingEndHook?.(this.transportId!, context);
}
});
}
public transformToObservable<T>(

View File

@@ -1,11 +1,11 @@
import { ApplicationConfig } from '@nestjs/core/application-config';
import { GraphInspector } from '@nestjs/core/inspector/graph-inspector';
import { Transport } from '@nestjs/microservices/enums';
import { AsyncMicroserviceOptions } from '@nestjs/microservices/interfaces';
import { NestMicroservice } from '@nestjs/microservices/nest-microservice';
import { Server, ServerTCP } from '@nestjs/microservices/server';
import { GraphInspector } from '@nestjs/core/inspector/graph-inspector';
import { ApplicationConfig } from '@nestjs/core/application-config';
import { Transport } from '@nestjs/microservices/enums';
import { expect } from 'chai';
import * as sinon from 'sinon';
import { AsyncMicroserviceOptions } from '@nestjs/microservices/interfaces';
const createMockGraphInspector = (): GraphInspector =>
({
@@ -23,7 +23,10 @@ const createMockAppConfig = (): ApplicationConfig =>
const mockContainer = {
getModuleCompiler: sinon.stub(),
getModules: () => new Map(),
getModules: () =>
Object.assign(new Map(), {
addRpcTarget: sinon.spy(),
}),
get: () => null,
getHttpAdapterHost: () => undefined,
} as any;

View File

@@ -2,6 +2,7 @@ import { Logger } from '@nestjs/common';
import { AssertionError, expect } from 'chai';
import * as sinon from 'sinon';
import { NO_MESSAGE_HANDLER } from '../../constants';
import { KafkaContext } from '../../ctx-host';
import { KafkaHeaders } from '../../enums';
import {
EachMessagePayload,
@@ -273,6 +274,7 @@ describe('ServerKafka', () => {
});
describe('getPublisher', () => {
const context = new KafkaContext([] as any);
let sendMessageStub: sinon.SinonStub;
let publisher;
@@ -281,15 +283,16 @@ describe('ServerKafka', () => {
replyTopic,
replyPartition,
correlationId,
context,
);
sendMessageStub = sinon
.stub(server, 'sendMessage')
.callsFake(async () => []);
});
it(`should return function`, () => {
expect(typeof server.getPublisher(null!, null!, correlationId)).to.be.eql(
'function',
);
expect(
typeof server.getPublisher(null!, null!, correlationId, context),
).to.be.eql('function');
});
it(`should call "publish" with expected arguments`, () => {
const data = {
@@ -411,10 +414,11 @@ describe('ServerKafka', () => {
});
describe('sendMessage', () => {
const context = new KafkaContext([] as any);
let sendSpy: sinon.SinonSpy;
beforeEach(() => {
sendSpy = sinon.spy();
sendSpy = sinon.stub().callsFake(() => Promise.resolve());
sinon.stub(server as any, 'producer').value({
send: sendSpy,
});
@@ -429,6 +433,7 @@ describe('ServerKafka', () => {
replyTopic,
replyPartition,
correlationId,
context,
);
expect(
@@ -455,6 +460,7 @@ describe('ServerKafka', () => {
replyTopic,
undefined,
correlationId,
context,
);
expect(
@@ -480,6 +486,7 @@ describe('ServerKafka', () => {
replyTopic,
replyPartition,
correlationId,
context,
);
expect(
@@ -507,6 +514,7 @@ describe('ServerKafka', () => {
replyTopic,
replyPartition,
correlationId,
context,
);
expect(

View File

@@ -1,6 +1,7 @@
import { expect } from 'chai';
import * as sinon from 'sinon';
import { NO_MESSAGE_HANDLER } from '../../constants';
import { MqttContext } from '../../ctx-host';
import { BaseRpcContext } from '../../ctx-host/base-rpc.context';
import { ServerMqtt } from '../../server/server-mqtt';
import { objectToMap } from './utils/object-to-map';
@@ -167,16 +168,19 @@ describe('ServerMqtt', () => {
const id = '1';
const pattern = 'test';
const context = new MqttContext([pattern, {}]);
beforeEach(() => {
publisherSpy = sinon.spy();
pub = {
publish: publisherSpy,
};
publisher = server.getPublisher(pub, pattern, id);
publisher = server.getPublisher(pub, context, id);
});
it(`should return function`, () => {
expect(typeof server.getPublisher(null, null, id)).to.be.eql('function');
expect(typeof server.getPublisher(null, context, id)).to.be.eql(
'function',
);
});
it(`should call "publish" with expected arguments`, () => {
const respond = 'test';

View File

@@ -222,6 +222,7 @@ describe('ServerNats', () => {
});
});
describe('getPublisher', () => {
const context = new NatsContext([] as any);
const id = '1';
it(`should return function`, () => {
@@ -231,7 +232,9 @@ describe('ServerNats', () => {
sid: +id,
respond: sinon.spy(),
};
expect(typeof server.getPublisher(natsMsg, id)).to.be.eql('function');
expect(typeof server.getPublisher(natsMsg, id, context)).to.be.eql(
'function',
);
});
it(`should call "respond" when reply topic provided`, () => {
const replyTo = 'test';
@@ -242,7 +245,7 @@ describe('ServerNats', () => {
respond: sinon.spy(),
reply: replyTo,
};
const publisher = server.getPublisher(natsMsg, id);
const publisher = server.getPublisher(natsMsg, id, context);
const respond = 'test';
publisher({ respond, id });
@@ -258,7 +261,7 @@ describe('ServerNats', () => {
sid: +id,
respond: sinon.spy(),
};
const publisher = server.getPublisher(natsMsg, id);
const publisher = server.getPublisher(natsMsg, id, context);
const respond = 'test';
publisher({ respond, id });

View File

@@ -1,6 +1,7 @@
import { expect } from 'chai';
import * as sinon from 'sinon';
import { NO_MESSAGE_HANDLER } from '../../constants';
import { RedisContext } from '../../ctx-host';
import { BaseRpcContext } from '../../ctx-host/base-rpc.context';
import { ServerRedis } from '../../server/server-redis';
import { objectToMap } from './utils/object-to-map';
@@ -172,16 +173,19 @@ describe('ServerRedis', () => {
const id = '1';
const pattern = 'test';
const context = new RedisContext([] as any);
beforeEach(() => {
publisherSpy = sinon.spy();
pub = {
publish: publisherSpy,
};
publisher = server.getPublisher(pub, pattern, id);
publisher = server.getPublisher(pub, pattern, id, context);
});
it(`should return function`, () => {
expect(typeof server.getPublisher(null, null, id)).to.be.eql('function');
expect(typeof server.getPublisher(null, null, id, context)).to.be.eql(
'function',
);
});
it(`should call "publish" with expected arguments`, () => {
const respond = 'test';

View File

@@ -231,6 +231,8 @@ describe('ServerRMQ', () => {
});
describe('sendMessage', () => {
const context = new RmqContext([] as any);
let channel: any;
beforeEach(() => {
@@ -245,7 +247,7 @@ describe('ServerRMQ', () => {
const replyTo = 'test';
const correlationId = '0';
server.sendMessage(message, replyTo, correlationId);
server.sendMessage(message, replyTo, correlationId, context);
expect(
channel.sendToQueue.calledWith(
replyTo,

View File

@@ -54,9 +54,50 @@ export class ExpressAdapter extends AbstractHttpAdapter<
private readonly routerMethodFactory = new RouterMethodFactory();
private readonly logger = new Logger(ExpressAdapter.name);
private readonly openConnections = new Set<Duplex>();
private onRequestHook?: (
req: express.Request,
res: express.Response,
done: () => void,
) => Promise<void> | void;
private onResponseHook?: (
req: express.Request,
res: express.Response,
) => Promise<void> | void;
constructor(instance?: any) {
super(instance || express());
this.instance!.use((req, res, next) => {
if (this.onResponseHook) {
res.on('finish', () => {
void this.onResponseHook!.apply(this, [req, res]);
});
}
if (this.onRequestHook) {
void this.onRequestHook.apply(this, [req, res, next]);
} else {
next();
}
});
}
public setOnRequestHook(
onRequestHook: (
req: express.Request,
res: express.Response,
done: () => void,
) => Promise<void> | void,
) {
this.onRequestHook = onRequestHook;
}
public setOnResponseHook(
onResponseHook: (
req: express.Request,
res: express.Response,
) => Promise<void> | void,
) {
this.onResponseHook = onResponseHook;
}
public reply(response: any, body: any, statusCode?: number) {

View File

@@ -19,6 +19,7 @@ describe('ExpressAdapter', () => {
.returns(urlencodedInstance);
const useSpy = sinon.spy(expressInstance, 'use');
const expressAdapter = new ExpressAdapter(expressInstance);
useSpy.resetHistory();
expressAdapter.registerParserMiddleware();
@@ -37,6 +38,7 @@ describe('ExpressAdapter', () => {
expressInstance.use(function urlencodedParser() {});
const useSpy = sinon.spy(expressInstance, 'use');
const expressAdapter = new ExpressAdapter(expressInstance);
useSpy.resetHistory();
expressAdapter.registerParserMiddleware();

View File

@@ -145,6 +145,16 @@ export class FastifyAdapter<
protected _pathPrefix?: string;
private _isParserRegistered: boolean;
private onRequestHook?: (
request: TRequest,
reply: TReply,
done: (err?: Error) => void,
) => void | Promise<void>;
private onResponseHook?: (
request: TRequest,
reply: TReply,
done: (err?: Error) => void,
) => void | Promise<void>;
private isMiddieRegistered: boolean;
private versioningOptions?: VersioningOptions;
private readonly versionConstraint = {
@@ -249,6 +259,42 @@ export class FastifyAdapter<
if ((instanceOrOptions as FastifyAdapterBaseOptions)?.skipMiddie) {
this.isMiddieRegistered = true;
}
this.instance.addHook('onRequest', (request, reply, done) => {
if (this.onRequestHook) {
this.onRequestHook(request as TRequest, reply as TReply, done);
} else {
done();
}
});
this.instance.addHook('onResponse', (request, reply, done) => {
if (this.onResponseHook) {
this.onResponseHook(request as TRequest, reply as TReply, done);
} else {
done();
}
});
}
public setOnRequestHook(
hook: (
request: TRequest,
reply: TReply,
done: (err?: Error) => void,
) => void | Promise<void>,
) {
this.onRequestHook = hook;
}
public setOnResponseHook(
hook: (
request: TRequest,
reply: TReply,
done: (err?: Error) => void,
) => void | Promise<void>,
) {
this.onResponseHook = hook;
}
public async init() {

View File

@@ -176,7 +176,6 @@ export class WebSocketsController {
const adapter = this.config.getIoAdapter();
const handlers = subscribersMap.map(({ callback, message }) => ({
message,
callback: callback.bind(instance, client),
}));
adapter.bindMessageHandlers(client, handlers, data =>