From 10046479ed97cf4fa10bc97c39cadf68c72fff9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20My=C5=9Bliwiec?= Date: Thu, 23 Jan 2020 15:32:35 +0100 Subject: [PATCH] fix(core): fix lifecycle hooks for middleware, injectables --- packages/core/hooks/on-app-bootstrap.hook.ts | 7 ++++- packages/core/hooks/on-app-shutdown.hook.ts | 7 ++++- packages/core/hooks/on-module-destroy.hook.ts | 7 ++++- packages/core/hooks/on-module-init.hook.ts | 7 ++++- packages/core/injector/module.ts | 9 ++++-- packages/core/middleware/container.ts | 29 +++++++++++-------- packages/core/nest-application.ts | 11 +++---- .../core/test/middleware/container.spec.ts | 16 ++++++++-- ...dule.spec.ts => middleware-module.spec.ts} | 29 ++++++++++++------- .../core/test/middleware/resolver.spec.ts | 3 +- 10 files changed, 87 insertions(+), 38 deletions(-) rename packages/core/test/middleware/{middlewares-module.spec.ts => middleware-module.spec.ts} (90%) diff --git a/packages/core/hooks/on-app-bootstrap.hook.ts b/packages/core/hooks/on-app-bootstrap.hook.ts index 8ade7753d..1accb11fb 100644 --- a/packages/core/hooks/on-app-bootstrap.hook.ts +++ b/packages/core/hooks/on-app-bootstrap.hook.ts @@ -43,7 +43,12 @@ export async function callModuleBootstrapHook(module: Module): Promise { // Module (class) instance is the first element of the providers array // Lifecycle hook has to be called once all classes are properly initialized const [_, { instance: moduleClassInstance }] = providers.shift(); - const instances = [...module.controllers, ...providers]; + const instances = [ + ...module.controllers, + ...providers, + ...module.injectables, + ...module.middlewares, + ]; const nonTransientInstances = getNonTransientInstances(instances); await Promise.all(callOperator(nonTransientInstances)); diff --git a/packages/core/hooks/on-app-shutdown.hook.ts b/packages/core/hooks/on-app-shutdown.hook.ts index ac7cf69f7..8eb2a45ab 100644 --- a/packages/core/hooks/on-app-shutdown.hook.ts +++ b/packages/core/hooks/on-app-shutdown.hook.ts @@ -51,7 +51,12 @@ export async function callAppShutdownHook( // Module (class) instance is the first element of the providers array // Lifecycle hook has to be called once all classes are properly initialized const [_, { instance: moduleClassInstance }] = providers.shift(); - const instances = [...module.controllers, ...providers]; + const instances = [ + ...module.controllers, + ...providers, + ...module.injectables, + ...module.middlewares, + ]; const nonTransientInstances = getNonTransientInstances(instances); await Promise.all(callOperator(nonTransientInstances, signal)); diff --git a/packages/core/hooks/on-module-destroy.hook.ts b/packages/core/hooks/on-module-destroy.hook.ts index 5d0680cac..602e341e1 100644 --- a/packages/core/hooks/on-module-destroy.hook.ts +++ b/packages/core/hooks/on-module-destroy.hook.ts @@ -43,7 +43,12 @@ export async function callModuleDestroyHook(module: Module): Promise { // Module (class) instance is the first element of the providers array // Lifecycle hook has to be called once all classes are properly destroyed const [_, { instance: moduleClassInstance }] = providers.shift(); - const instances = [...module.controllers, ...providers]; + const instances = [ + ...module.controllers, + ...providers, + ...module.injectables, + ...module.middlewares, + ]; const nonTransientInstances = getNonTransientInstances(instances); await Promise.all(callOperator(nonTransientInstances)); diff --git a/packages/core/hooks/on-module-init.hook.ts b/packages/core/hooks/on-module-init.hook.ts index fdd9de39c..c66a7e515 100644 --- a/packages/core/hooks/on-module-init.hook.ts +++ b/packages/core/hooks/on-module-init.hook.ts @@ -39,7 +39,12 @@ export async function callModuleInitHook(module: Module): Promise { // Module (class) instance is the first element of the providers array // Lifecycle hook has to be called once all classes are properly initialized const [_, { instance: moduleClassInstance }] = providers.shift(); - const instances = [...module.controllers, ...providers]; + const instances = [ + ...module.controllers, + ...providers, + ...module.injectables, + ...module.middlewares, + ]; const nonTransientInstances = getNonTransientInstances(instances); await Promise.all(callOperator(nonTransientInstances)); diff --git a/packages/core/injector/module.ts b/packages/core/injector/module.ts index 0d8d4ed01..99eaa1b0d 100644 --- a/packages/core/injector/module.ts +++ b/packages/core/injector/module.ts @@ -39,6 +39,7 @@ export class Module { private readonly _imports = new Set(); private readonly _providers = new Map>(); private readonly _injectables = new Map>(); + private readonly _middlewares = new Map>(); private readonly _controllers = new Map< string, InstanceWrapper @@ -51,7 +52,7 @@ export class Module { private readonly _scope: Type[], private readonly container: NestContainer, ) { - this.addCoreProviders(container); + this.addCoreProviders(); this._id = randomStringGenerator(); } @@ -67,6 +68,10 @@ export class Module { return this._providers; } + get middlewares(): Map> { + return this._middlewares; + } + get imports(): Set { return this._imports; } @@ -124,7 +129,7 @@ export class Module { this._distance = value; } - public addCoreProviders(container: NestContainer) { + public addCoreProviders() { this.addModuleAsProvider(); this.addModuleRef(); this.addApplicationConfig(); diff --git a/packages/core/middleware/container.ts b/packages/core/middleware/container.ts index 824bbfec3..7a835ec4a 100644 --- a/packages/core/middleware/container.ts +++ b/packages/core/middleware/container.ts @@ -1,6 +1,7 @@ import { Scope, Type } from '@nestjs/common'; import { SCOPE_OPTIONS_METADATA } from '@nestjs/common/constants'; import { MiddlewareConfiguration } from '@nestjs/common/interfaces/middleware/middleware-configuration.interface'; +import { NestContainer } from '../injector'; import { InstanceWrapper } from '../injector/instance-wrapper'; export class MiddlewareContainer { @@ -10,17 +11,28 @@ export class MiddlewareContainer { Set >(); - public getMiddlewareCollection(module: string): Map { - return this.middleware.get(module) || new Map(); + constructor(private readonly container: NestContainer) {} + + public getMiddlewareCollection( + moduleKey: string, + ): Map { + if (!this.middleware.has(moduleKey)) { + const moduleRef = this.container.getModuleByKey(moduleKey); + this.middleware.set(moduleKey, moduleRef.middlewares); + } + return this.middleware.get(moduleKey); } public getConfigurations(): Map> { return this.configurationSets; } - public insertConfig(configList: MiddlewareConfiguration[], module: string) { - const middleware = this.getTargetMiddleware(module); - const targetConfig = this.getTargetConfig(module); + public insertConfig( + configList: MiddlewareConfiguration[], + moduleKey: string, + ) { + const middleware = this.getMiddlewareCollection(moduleKey); + const targetConfig = this.getTargetConfig(moduleKey); const configurations = configList || []; const insertMiddleware = >(metatype: T) => { @@ -40,13 +52,6 @@ export class MiddlewareContainer { }); } - private getTargetMiddleware(module: string) { - if (!this.middleware.has(module)) { - this.middleware.set(module, new Map()); - } - return this.middleware.get(module); - } - private getTargetConfig(module: string) { if (!this.configurationSets.has(module)) { this.configurationSets.set(module, new Set()); diff --git a/packages/core/nest-application.ts b/packages/core/nest-application.ts index 1434a124c..c238f1445 100644 --- a/packages/core/nest-application.ts +++ b/packages/core/nest-application.ts @@ -44,11 +44,12 @@ export class NestApplication extends NestApplicationContext implements INestApplication { private readonly logger = new Logger(NestApplication.name, true); private readonly middlewareModule = new MiddlewareModule(); - private readonly middlewareContainer = new MiddlewareContainer(); - private readonly microservicesModule = MicroservicesModule - ? new MicroservicesModule() - : null; - private readonly socketModule = SocketModule ? new SocketModule() : null; + private readonly middlewareContainer = new MiddlewareContainer( + this.container, + ); + private readonly microservicesModule = + MicroservicesModule && new MicroservicesModule(); + private readonly socketModule = SocketModule && new SocketModule(); private readonly routesResolver: Resolver; private readonly microservices: any[] = []; private httpServer: any; diff --git a/packages/core/test/middleware/container.spec.ts b/packages/core/test/middleware/container.spec.ts index 2ee293726..834b025dd 100644 --- a/packages/core/test/middleware/container.spec.ts +++ b/packages/core/test/middleware/container.spec.ts @@ -5,10 +5,14 @@ import { RequestMapping } from '../../../common/decorators/http/request-mapping. import { RequestMethod } from '../../../common/enums/request-method.enum'; import { MiddlewareConfiguration } from '../../../common/interfaces/middleware/middleware-configuration.interface'; import { NestMiddleware } from '../../../common/interfaces/middleware/nest-middleware.interface'; +import { NestContainer } from '../../injector'; import { InstanceWrapper } from '../../injector/instance-wrapper'; +import { Module } from '../../injector/module'; import { MiddlewareContainer } from '../../middleware/container'; describe('MiddlewareContainer', () => { + class ExampleModule {} + @Controller('test') class TestRoute { @RequestMapping({ path: 'test' }) @@ -26,7 +30,13 @@ describe('MiddlewareContainer', () => { let container: MiddlewareContainer; beforeEach(() => { - container = new MiddlewareContainer(); + const nestContainer = new NestContainer(); + const modules = nestContainer.getModules(); + + modules.set('Module', new Module(ExampleModule, [], nestContainer)); + modules.set('Test', new Module(ExampleModule, [], nestContainer)); + + container = new MiddlewareContainer(nestContainer); }); it('should store expected configurations for given module', () => { @@ -36,7 +46,7 @@ describe('MiddlewareContainer', () => { forRoutes: [TestRoute, 'test'], }, ]; - container.insertConfig(config, 'Module' as any); + container.insertConfig(config, 'Module'); expect([...container.getConfigurations().get('Module')]).to.deep.equal( config, ); @@ -50,7 +60,7 @@ describe('MiddlewareContainer', () => { }, ]; - const key = 'Test' as any; + const key = 'Test'; container.insertConfig(config, key); const collection = container.getMiddlewareCollection(key); diff --git a/packages/core/test/middleware/middlewares-module.spec.ts b/packages/core/test/middleware/middleware-module.spec.ts similarity index 90% rename from packages/core/test/middleware/middlewares-module.spec.ts rename to packages/core/test/middleware/middleware-module.spec.ts index 8b07a11d3..03232d33c 100644 --- a/packages/core/test/middleware/middlewares-module.spec.ts +++ b/packages/core/test/middleware/middleware-module.spec.ts @@ -56,7 +56,7 @@ describe('MiddlewareModule', () => { }; await middlewareModule.loadConfiguration( - new MiddlewareContainer(), + new MiddlewareContainer(new NestContainer()), mockModule as any, 'Test' as any, ); @@ -71,27 +71,35 @@ describe('MiddlewareModule', () => { }); describe('registerRouteMiddleware', () => { + class TestModule {} + + let nestContainer: NestContainer; + + beforeEach(() => { + nestContainer = new NestContainer(); + nestContainer + .getModules() + .set('Test', new Module(TestModule, [], nestContainer)); + }); it('should throw "RuntimeException" exception when middleware is not stored in container', () => { const route = { path: 'Test' }; const configuration = { middleware: [TestMiddleware], forRoutes: [BaseController], }; - const useSpy = sinon.spy(); const app = { use: useSpy }; - const nestContainer = new NestContainer(); // tslint:disable-next-line:no-string-literal middlewareModule['container'] = nestContainer; expect( middlewareModule.registerRouteMiddleware( - new MiddlewareContainer(), + new MiddlewareContainer(nestContainer), route as any, configuration, - 'Test' as any, - app as any, + 'Test', + app, ), ).to.eventually.be.rejectedWith(RuntimeException); }); @@ -109,8 +117,8 @@ describe('MiddlewareModule', () => { const useSpy = sinon.spy(); const app = { use: useSpy }; - const container = new MiddlewareContainer(); - const moduleKey = 'Test' as any; + const container = new MiddlewareContainer(nestContainer); + const moduleKey = 'Test'; container.insertConfig([configuration], moduleKey); const instance = new InvalidMiddleware(); @@ -125,7 +133,7 @@ describe('MiddlewareModule', () => { route as any, configuration, moduleKey, - app as any, + app, ), ).to.be.rejectedWith(InvalidMiddlewareException); }); @@ -143,7 +151,7 @@ describe('MiddlewareModule', () => { const app = { createMiddlewareFactory: createMiddlewareFactoryStub, }; - const container = new MiddlewareContainer(); + const container = new MiddlewareContainer(new NestContainer()); const moduleKey = 'Test'; container.insertConfig([configuration], moduleKey); @@ -155,7 +163,6 @@ describe('MiddlewareModule', () => { instance, }), ); - const nestContainer = new NestContainer(); sinon .stub(nestContainer, 'getModuleByKey') .callsFake(() => new Module(class {}, [], nestContainer)); diff --git a/packages/core/test/middleware/resolver.spec.ts b/packages/core/test/middleware/resolver.spec.ts index 16b7e81c1..8aafca528 100644 --- a/packages/core/test/middleware/resolver.spec.ts +++ b/packages/core/test/middleware/resolver.spec.ts @@ -2,6 +2,7 @@ import { expect } from 'chai'; import * as sinon from 'sinon'; import { Injectable } from '../../../common'; import { NestMiddleware } from '../../../common/interfaces/middleware/nest-middleware.interface'; +import { NestContainer } from '../../injector'; import { MiddlewareContainer } from '../../middleware/container'; import { MiddlewareResolver } from '../../middleware/resolver'; @@ -16,7 +17,7 @@ describe('MiddlewareResolver', () => { let mockContainer: sinon.SinonMock; beforeEach(() => { - container = new MiddlewareContainer(); + container = new MiddlewareContainer(new NestContainer()); resolver = new MiddlewareResolver(container); mockContainer = sinon.mock(container); });