Merge pull request #10390 from CodyTseng/fix-middleware-global-prefix

fix(core): let the middleware can  get the params in the global prefix
This commit is contained in:
Kamil Mysliwiec
2023-02-03 10:24:47 +01:00
committed by GitHub
19 changed files with 300 additions and 87 deletions

View File

@@ -77,6 +77,26 @@ describe('Middleware', () => {
});
});
describe('when using default URI versioning with the global prefix', () => {
beforeEach(async () => {
app = await createAppWithVersioning(
{
type: VersioningType.URI,
defaultVersion: VERSION_NEUTRAL,
},
async (app: INestApplication) => {
app.setGlobalPrefix('api');
},
);
});
it(`forRoutes({ path: '/versioned', version: '1', method: RequestMethod.ALL })`, () => {
return request(app.getHttpServer())
.get('/api/v1/versioned')
.expect(200, VERSIONED_VALUE);
});
});
describe('when using HEADER versioning', () => {
beforeEach(async () => {
app = await createAppWithVersioning({
@@ -133,6 +153,7 @@ describe('Middleware', () => {
async function createAppWithVersioning(
versioningOptions: VersioningOptions,
beforeInit?: (app: INestApplication) => Promise<void>,
): Promise<INestApplication> {
const app = (
await Test.createTestingModule({
@@ -141,6 +162,9 @@ async function createAppWithVersioning(
).createNestApplication();
app.enableVersioning(versioningOptions);
if (beforeInit) {
await beforeInit(app);
}
await app.init();
return app;

View File

@@ -119,6 +119,17 @@ describe('Global prefix', () => {
await request(server).get('/api/v1/middleware/foo').expect(404);
});
it(`should get the params in the global prefix`, async () => {
app.setGlobalPrefix('/api/:tenantId');
server = app.getHttpServer();
await app.init();
await request(server)
.get('/api/test/params')
.expect(200, { '0': 'params', tenantId: 'test' });
});
afterEach(async () => {
await app.close();
});

View File

@@ -7,6 +7,11 @@ export class AppController {
return 'Hello: ' + req.extras?.data;
}
@Get('params')
getParams(@Req() req): any {
return req.middlewareParams;
}
@Get('health')
getHealth(): string {
return 'up';

View File

@@ -22,6 +22,11 @@ export class AppModule {
req.extras = { data: 'Data attached in middleware' };
next();
})
.forRoutes({ path: '*', method: RequestMethod.GET })
.apply((req, res, next) => {
req.middlewareParams = req.params;
next();
})
.forRoutes({ path: '*', method: RequestMethod.GET });
}
}

View File

@@ -35,7 +35,7 @@ export const ENHANCER_KEY_TO_SUBTYPE_MAP = {
} as const;
export type EnhancerSubtype =
typeof ENHANCER_KEY_TO_SUBTYPE_MAP[keyof typeof ENHANCER_KEY_TO_SUBTYPE_MAP];
(typeof ENHANCER_KEY_TO_SUBTYPE_MAP)[keyof typeof ENHANCER_KEY_TO_SUBTYPE_MAP];
export const RENDER_METADATA = '__renderTemplate__';
export const HTTP_CODE_METADATA = '__httpCode__';

View File

@@ -99,9 +99,9 @@ export class HttpException extends Error {
) {
this.message = (this.response as Record<string, any>).message;
} else if (this.constructor) {
this.message = this.constructor.name
.match(/[A-Z][a-z]+|[0-9]+/g)
?.join(' ') ?? 'Error';
this.message =
this.constructor.name.match(/[A-Z][a-z]+|[0-9]+/g)?.join(' ') ??
'Error';
}
}

View File

@@ -353,7 +353,9 @@ export class Injector {
}
public reflectConstructorParams<T>(type: Type<T>): any[] {
const paramtypes = [...(Reflect.getMetadata(PARAMTYPES_METADATA, type) || [])];
const paramtypes = [
...(Reflect.getMetadata(PARAMTYPES_METADATA, type) || []),
];
const selfParams = this.reflectSelfParams<T>(type);
selfParams.forEach(({ index, param }) => (paramtypes[index] = param));

View File

@@ -1,4 +1,4 @@
import { HttpServer, Logger, VersioningType } from '@nestjs/common';
import { HttpServer, Logger } from '@nestjs/common';
import { RequestMethod } from '@nestjs/common/enums/request-method.enum';
import {
MiddlewareConfiguration,
@@ -6,10 +6,7 @@ import {
RouteInfo,
} from '@nestjs/common/interfaces/middleware';
import { NestApplicationContextOptions } from '@nestjs/common/interfaces/nest-application-context-options.interface';
import {
addLeadingSlash,
isUndefined,
} from '@nestjs/common/utils/shared.utils';
import { isUndefined } from '@nestjs/common/utils/shared.utils';
import { ApplicationConfig } from '../application-config';
import { InvalidMiddlewareException } from '../errors/exceptions/invalid-middleware.exception';
import { RuntimeException } from '../errors/exceptions/runtime.exception';
@@ -26,13 +23,13 @@ import {
MiddlewareEntrypointMetadata,
} from '../inspector/interfaces/entrypoint.interface';
import { REQUEST_CONTEXT_ID } from '../router/request/request-constants';
import { RoutePathFactory } from '../router/route-path-factory';
import { RouterExceptionFilters } from '../router/router-exception-filters';
import { RouterProxy } from '../router/router-proxy';
import { isRequestMethodAll, isRouteExcluded } from '../router/utils';
import { isRequestMethodAll } from '../router/utils';
import { MiddlewareBuilder } from './builder';
import { MiddlewareContainer } from './container';
import { MiddlewareResolver } from './resolver';
import { RouteInfoPathExtractor } from './route-info-path-extractor';
import { RoutesMapper } from './routes-mapper';
export class MiddlewareModule<
@@ -46,13 +43,11 @@ export class MiddlewareModule<
private routerExceptionFilter: RouterExceptionFilters;
private routesMapper: RoutesMapper;
private resolver: MiddlewareResolver;
private config: ApplicationConfig;
private container: NestContainer;
private httpAdapter: HttpServer;
private graphInspector: GraphInspector;
private appOptions: TAppOptions;
constructor(private readonly routePathFactory: RoutePathFactory) {}
private routeInfoPathExtractor: RouteInfoPathExtractor;
public async register(
middlewareContainer: MiddlewareContainer,
@@ -73,8 +68,7 @@ export class MiddlewareModule<
);
this.routesMapper = new RoutesMapper(container);
this.resolver = new MiddlewareResolver(middlewareContainer, injector);
this.config = config;
this.routeInfoPathExtractor = new RouteInfoPathExtractor(config);
this.injector = injector;
this.container = container;
this.httpAdapter = httpAdapter;
@@ -307,44 +301,18 @@ export class MiddlewareModule<
private async registerHandler(
applicationRef: HttpServer,
{ path, method, version }: RouteInfo,
routeInfo: RouteInfo,
proxy: <TRequest, TResponse>(
req: TRequest,
res: TResponse,
next: () => void,
) => void,
) {
const prefix = this.config.getGlobalPrefix();
const excludedRoutes = this.config.getGlobalPrefixOptions().exclude;
const isAWildcard = ['*', '/*', '(.*)', '/(.*)'].includes(path);
if (
(Array.isArray(excludedRoutes) &&
isRouteExcluded(excludedRoutes, path, method)) ||
isAWildcard
) {
path = addLeadingSlash(path);
} else {
const basePath = addLeadingSlash(prefix);
if (basePath?.endsWith('/') && path?.startsWith('/')) {
// strip slash when a wildcard is being used
// and global prefix has been set
path = path?.slice(1);
}
path = basePath + path;
}
const applicationVersioningConfig = this.config.getVersioning();
if (version && applicationVersioningConfig.type === VersioningType.URI) {
const versionPrefix = this.routePathFactory.getVersionPrefix(
applicationVersioningConfig,
);
path = `/${versionPrefix}${version.toString()}${path}`;
}
const { method } = routeInfo;
const paths = this.routeInfoPathExtractor.extractPathsFrom(routeInfo);
const isMethodAll = isRequestMethodAll(method);
const requestMethod = RequestMethod[method];
const router = await applicationRef.createMiddlewareFactory(method);
const middlewareFunction = isMethodAll
? proxy
: <TRequest, TResponse>(
@@ -357,8 +325,7 @@ export class MiddlewareModule<
}
return next();
};
router(path, middlewareFunction);
paths.forEach(path => router(path, middlewareFunction));
}
private getContextId(request: unknown, isTreeDurable: boolean): ContextId {

View File

@@ -0,0 +1,55 @@
import { VersioningType } from '@nestjs/common';
import { RouteInfo } from '@nestjs/common/interfaces';
import {
addLeadingSlash,
stripEndSlash,
} from '@nestjs/common/utils/shared.utils';
import { ApplicationConfig } from '../application-config';
import { isRouteExcluded } from '../router/utils';
import { RoutePathFactory } from './../router/route-path-factory';
export class RouteInfoPathExtractor {
private routePathFactory: RoutePathFactory;
constructor(private readonly applicationConfig: ApplicationConfig) {
this.routePathFactory = new RoutePathFactory(applicationConfig);
}
public extractPathsFrom({ path, method, version }: RouteInfo) {
const prefixPath = stripEndSlash(
addLeadingSlash(this.applicationConfig.getGlobalPrefix()),
);
const excludedRoutes =
this.applicationConfig.getGlobalPrefixOptions().exclude;
const applicationVersioningConfig = this.applicationConfig.getVersioning();
let versionPath = '';
if (version && applicationVersioningConfig?.type === VersioningType.URI) {
const versionPrefix = this.routePathFactory.getVersionPrefix(
applicationVersioningConfig,
);
versionPath = addLeadingSlash(versionPrefix + version.toString());
}
const isAWildcard = ['*', '/*', '/*/', '(.*)', '/(.*)'].includes(path);
if (isAWildcard) {
return Array.isArray(excludedRoutes)
? [
prefixPath + versionPath + addLeadingSlash(path),
...excludedRoutes.map(
route => versionPath + addLeadingSlash(route.path),
),
]
: [prefixPath + versionPath + addLeadingSlash(path)];
}
if (
Array.isArray(excludedRoutes) &&
isRouteExcluded(excludedRoutes, path, method)
) {
return [versionPath + addLeadingSlash(path)];
}
return [prefixPath + versionPath + addLeadingSlash(path)];
}
}

View File

@@ -1,6 +1,10 @@
import { RequestMethod } from '@nestjs/common';
import { HttpServer, RouteInfo, Type } from '@nestjs/common/interfaces';
import { isFunction } from '@nestjs/common/utils/shared.utils';
import {
addLeadingSlash,
isFunction,
isString,
} from '@nestjs/common/utils/shared.utils';
import { iterate } from 'iterare';
import * as pathToRegexp from 'path-to-regexp';
import { v4 as uuid } from 'uuid';
@@ -8,13 +12,22 @@ import { ExcludeRouteMetadata } from '../router/interfaces/exclude-route-metadat
import { isRouteExcluded } from '../router/utils';
export const mapToExcludeRoute = (
routes: RouteInfo[],
routes: (string | RouteInfo)[],
): ExcludeRouteMetadata[] => {
return routes.map(({ path, method }) => ({
pathRegex: pathToRegexp(path),
requestMethod: method,
path,
}));
return routes.map(route => {
if (isString(route)) {
return {
path: route,
requestMethod: RequestMethod.ALL,
pathRegex: pathToRegexp(addLeadingSlash(route)),
};
}
return {
path: route.path,
requestMethod: route.method,
pathRegex: pathToRegexp(addLeadingSlash(route.path)),
};
});
};
export const filterMiddleware = <T extends Function | Type<any> = any>(

View File

@@ -7,7 +7,6 @@ import {
NestHybridApplicationOptions,
NestInterceptor,
PipeTransform,
RequestMethod,
VersioningOptions,
VersioningType,
WebSocketAdapter,
@@ -15,7 +14,6 @@ import {
import {
GlobalPrefixOptions,
NestApplicationOptions,
RouteInfo,
} from '@nestjs/common/interfaces';
import {
CorsOptions,
@@ -31,7 +29,6 @@ import {
} from '@nestjs/common/utils/shared.utils';
import { iterate } from 'iterare';
import { platform } from 'os';
import * as pathToRegexp from 'path-to-regexp';
import { AbstractHttpAdapter } from './adapters';
import { ApplicationConfig } from './application-config';
import { MESSAGES } from './constants';
@@ -41,10 +38,9 @@ import { Injector } from './injector/injector';
import { GraphInspector } from './inspector/graph-inspector';
import { MiddlewareContainer } from './middleware/container';
import { MiddlewareModule } from './middleware/middleware-module';
import { mapToExcludeRoute } from './middleware/utils';
import { NestApplicationContext } from './nest-application-context';
import { ExcludeRouteMetadata } from './router/interfaces/exclude-route-metadata.interface';
import { Resolver } from './router/interfaces/resolver.interface';
import { RoutePathFactory } from './router/route-path-factory';
import { RoutesResolver } from './router/routes-resolver';
const { SocketModule } = optionalRequire(
@@ -89,9 +85,8 @@ export class NestApplication
this.selectContextModule();
this.registerHttpServer();
this.injector = new Injector({ preview: this.appOptions.preview });
this.middlewareModule = new MiddlewareModule(new RoutePathFactory(config));
this.middlewareModule = new MiddlewareModule();
this.routesResolver = new RoutesResolver(
this.container,
this.config,
@@ -372,22 +367,9 @@ export class NestApplication
public setGlobalPrefix(prefix: string, options?: GlobalPrefixOptions): this {
this.config.setGlobalPrefix(prefix);
if (options) {
const exclude = options?.exclude.map(
(route: string | RouteInfo): ExcludeRouteMetadata => {
if (isString(route)) {
return {
requestMethod: RequestMethod.ALL,
pathRegex: pathToRegexp(addLeadingSlash(route)),
path: route,
};
}
return {
requestMethod: route.method,
pathRegex: pathToRegexp(addLeadingSlash(route.path)),
path: route.path,
};
},
);
const exclude = options?.exclude
? mapToExcludeRoute(options.exclude)
: [];
this.config.setGlobalPrefixOptions({
...options,
exclude,

View File

@@ -1,7 +1,11 @@
import { RequestMethod } from '@nestjs/common';
export interface ExcludeRouteMetadata {
/**
* Route path.
*/
path: string;
/**
* Regular expression representing the route path.
*/

View File

@@ -1,4 +1,5 @@
import { RequestMethod } from '@nestjs/common';
import { addLeadingSlash } from '@nestjs/common/utils/shared.utils';
import { ExcludeRouteMetadata } from '../interfaces/exclude-route-metadata.interface';
export const isRequestMethodAll = (method: RequestMethod) => {
@@ -15,7 +16,7 @@ export function isRouteExcluded(
isRequestMethodAll(route.requestMethod) ||
route.requestMethod === requestMethod
) {
return route.pathRegex.exec(path);
return route.pathRegex.exec(addLeadingSlash(path));
}
return false;
});

View File

@@ -24,9 +24,9 @@ describe('ApplicationConfig', () => {
const options: GlobalPrefixOptions<ExcludeRouteMetadata> = {
exclude: [
{
path: '/health',
pathRegex: new RegExp(/health/),
requestMethod: RequestMethod.GET,
path: 'health',
},
],
};

View File

@@ -1,5 +1,5 @@
import { Injectable } from '@nestjs/common';
import { RoutePathFactory } from '@nestjs/core/router/route-path-factory';
import { RouteInfoPathExtractor } from '@nestjs/core/middleware/route-info-path-extractor';
import * as chai from 'chai';
import { expect } from 'chai';
import * as chaiAsPromised from 'chai-as-promised';
@@ -47,15 +47,21 @@ describe('MiddlewareModule', () => {
beforeEach(() => {
const container = new NestContainer();
const appConfig = new ApplicationConfig();
graphInspector = new GraphInspector(container);
middlewareModule = new MiddlewareModule(new RoutePathFactory(appConfig));
middlewareModule = new MiddlewareModule();
middlewareModule['routerExceptionFilter'] = new RouterExceptionFilters(
new NestContainer(),
appConfig,
new NoopHttpAdapter({}),
);
middlewareModule['routeInfoPathExtractor'] = new RouteInfoPathExtractor(
appConfig,
);
middlewareModule['routerExceptionFilter'] = new RouterExceptionFilters(
container,
appConfig,
new NoopHttpAdapter({}),
);
middlewareModule['config'] = appConfig;
middlewareModule['graphInspector'] = graphInspector;
});

View File

@@ -0,0 +1,93 @@
import { RequestMethod, VersioningType } from '@nestjs/common';
import { ApplicationConfig } from '@nestjs/core';
import { mapToExcludeRoute } from '@nestjs/core/middleware/utils';
import { expect } from 'chai';
import { RouteInfoPathExtractor } from './../../middleware/route-info-path-extractor';
describe('RouteInfoPathExtractor', () => {
describe('getPaths', () => {
let appConfig: ApplicationConfig;
let routeInfoPathExtractor: RouteInfoPathExtractor;
beforeEach(() => {
appConfig = new ApplicationConfig();
appConfig.enableVersioning({
type: VersioningType.URI,
});
routeInfoPathExtractor = new RouteInfoPathExtractor(appConfig);
});
it(`should return correct paths`, () => {
expect(
routeInfoPathExtractor.extractPathsFrom({
path: '*',
method: RequestMethod.ALL,
}),
).to.eql(['/*']);
expect(
routeInfoPathExtractor.extractPathsFrom({
path: '*',
method: RequestMethod.ALL,
version: '1',
}),
).to.eql(['/v1/*']);
});
it(`should return correct paths when set global prefix`, () => {
appConfig.setGlobalPrefix('api');
expect(
routeInfoPathExtractor.extractPathsFrom({
path: '*',
method: RequestMethod.ALL,
}),
).to.eql(['/api/*']);
expect(
routeInfoPathExtractor.extractPathsFrom({
path: '*',
method: RequestMethod.ALL,
version: '1',
}),
).to.eql(['/api/v1/*']);
});
it(`should return correct paths when set global prefix and global prefix options`, () => {
appConfig.setGlobalPrefix('api');
appConfig.setGlobalPrefixOptions({
exclude: mapToExcludeRoute(['foo']),
});
expect(
routeInfoPathExtractor.extractPathsFrom({
path: '*',
method: RequestMethod.ALL,
}),
).to.eql(['/api/*', '/foo']);
expect(
routeInfoPathExtractor.extractPathsFrom({
path: '*',
method: RequestMethod.ALL,
version: '1',
}),
).to.eql(['/api/v1/*', '/v1/foo']);
expect(
routeInfoPathExtractor.extractPathsFrom({
path: 'foo',
method: RequestMethod.ALL,
version: '1',
}),
).to.eql(['/v1/foo']);
expect(
routeInfoPathExtractor.extractPathsFrom({
path: 'bar',
method: RequestMethod.ALL,
version: '1',
}),
).to.eql(['/api/v1/bar']);
});
});
});

View File

@@ -1,4 +1,5 @@
import { RequestMethod, Type } from '@nestjs/common';
import { addLeadingSlash } from '@nestjs/common/utils/shared.utils';
import { expect } from 'chai';
import * as sinon from 'sinon';
import {
@@ -10,6 +11,7 @@ import {
mapToExcludeRoute,
} from '../../middleware/utils';
import { NoopHttpAdapter } from '../utils/noop-adapter.spec';
import * as pathToRegexp from 'path-to-regexp';
describe('middleware utils', () => {
const noopAdapter = new NoopHttpAdapter({});
@@ -17,6 +19,27 @@ describe('middleware utils', () => {
class Test {}
function fnMiddleware(req, res, next) {}
describe('mapToExcludeRoute', () => {
it('should return exclude route metadata', () => {
const stringRoute = 'foo';
const routeInfo = {
path: 'bar',
method: RequestMethod.GET,
};
expect(mapToExcludeRoute([stringRoute, routeInfo])).to.eql([
{
path: stringRoute,
requestMethod: RequestMethod.ALL,
pathRegex: pathToRegexp(addLeadingSlash(stringRoute)),
},
{
path: routeInfo.path,
requestMethod: routeInfo.method,
pathRegex: pathToRegexp(addLeadingSlash(routeInfo.path)),
},
]);
});
});
describe('filterMiddleware', () => {
let middleware: any[];
beforeEach(() => {

View File

@@ -1,8 +1,10 @@
import { RequestMethod } from '@nestjs/common';
import { expect } from 'chai';
import { ApplicationConfig } from '../application-config';
import { NestContainer } from '../injector/container';
import { GraphInspector } from '../inspector/graph-inspector';
import { NestApplication } from '../nest-application';
import { mapToExcludeRoute } from './../middleware/utils';
import { NoopHttpAdapter } from './utils/noop-adapter.spec';
describe('NestApplication', () => {
@@ -54,4 +56,24 @@ describe('NestApplication', () => {
).to.equal(1);
});
});
describe('Global Prefix', () => {
it('should get correct global prefix options', () => {
const applicationConfig = new ApplicationConfig();
const container = new NestContainer(applicationConfig);
const instance = new NestApplication(
container,
new NoopHttpAdapter({}),
applicationConfig,
new GraphInspector(container),
{},
);
const excludeRoute = ['foo', { path: 'bar', method: RequestMethod.GET }];
instance.setGlobalPrefix('api', {
exclude: excludeRoute,
});
expect(applicationConfig.getGlobalPrefixOptions()).to.eql({
exclude: mapToExcludeRoute(excludeRoute),
});
});
});
});

View File

@@ -247,9 +247,9 @@ describe('RoutePathFactory', () => {
sinon.stub(applicationConfig, 'getGlobalPrefixOptions').returns({
exclude: [
{
path: '/random',
pathRegex: pathToRegexp('/random'),
requestMethod: RequestMethod.ALL,
path: '/random',
},
],
});
@@ -266,9 +266,9 @@ describe('RoutePathFactory', () => {
sinon.stub(applicationConfig, 'getGlobalPrefixOptions').returns({
exclude: [
{
path: '/cats',
pathRegex: pathToRegexp('/cats'),
requestMethod: RequestMethod.ALL,
path: '/cats',
},
],
});
@@ -285,9 +285,9 @@ describe('RoutePathFactory', () => {
sinon.stub(applicationConfig, 'getGlobalPrefixOptions').returns({
exclude: [
{
path: '/cats',
pathRegex: pathToRegexp('/cats'),
requestMethod: RequestMethod.GET,
path: '/cats',
},
],
});