fix(core): fix race condition in class dependency resolution

Fix race condition in class dependency resolution, which could
otherwise lead to undefined or null injection. Split the resolution
process in 2 parts with barrier synchronization in between to ensure all
dependencies are present in dependant's instance wrapper and the
staticity of its dependency tree is evaluated correctly.

Closes #4873
This commit is contained in:
Jiri Hajek
2025-07-14 11:50:05 +02:00
parent 138577e50a
commit a453b6375e
8 changed files with 428 additions and 107 deletions

View File

@@ -0,0 +1,130 @@
import { Test } from '@nestjs/testing';
import { expect } from 'chai';
import * as sinon from 'sinon';
import { Global, Inject, Injectable, Module, Scope } from '@nestjs/common';
@Global()
@Module({})
export class GlobalModule1 {}
@Global()
@Module({})
export class GlobalModule2 {}
@Global()
@Module({})
export class GlobalModule3 {}
@Global()
@Module({})
export class GlobalModule4 {}
@Global()
@Module({})
export class GlobalModule5 {}
@Global()
@Module({})
export class GlobalModule6 {}
@Global()
@Module({})
export class GlobalModule7 {}
@Global()
@Module({})
export class GlobalModule8 {}
@Global()
@Module({})
export class GlobalModule9 {}
@Global()
@Module({})
export class GlobalModule10 {}
@Injectable()
class TransientProvider {}
@Injectable()
class RequestProvider {}
@Injectable()
export class Dependant {
constructor(
private readonly transientProvider: TransientProvider,
@Inject(RequestProvider)
private readonly requestProvider: RequestProvider,
) {}
public checkDependencies() {
expect(this.transientProvider).to.be.instanceOf(TransientProvider);
expect(this.requestProvider).to.be.instanceOf(RequestProvider);
}
}
@Global()
@Module({
providers: [
{
provide: TransientProvider,
scope: Scope.TRANSIENT,
useClass: TransientProvider,
},
{
provide: Dependant,
scope: Scope.DEFAULT,
useClass: Dependant,
},
],
})
export class GlobalModuleWithTransientProviderAndDependant {}
@Global()
@Module({
providers: [
{
provide: RequestProvider,
scope: Scope.REQUEST,
useFactory: () => {
return new RequestProvider();
},
},
],
exports: [RequestProvider],
})
export class GlobalModuleWithRequestProvider {}
@Module({
imports: [
GlobalModule1,
GlobalModule2,
GlobalModule3,
GlobalModule4,
GlobalModule5,
GlobalModule6,
GlobalModule7,
GlobalModule8,
GlobalModule9,
GlobalModule10,
GlobalModuleWithTransientProviderAndDependant,
GlobalModuleWithRequestProvider,
],
})
export class AppModule {}
describe('Many global modules', () => {
it('should inject request-scoped useFactory provider and transient-scoped useClass provider from different modules', async () => {
const moduleBuilder = Test.createTestingModule({
imports: [AppModule],
});
const moduleRef = await moduleBuilder.compile();
const dependant = await moduleRef.resolve(Dependant);
const checkDependenciesSpy = sinon.spy(dependant, 'checkDependencies');
dependant.checkDependencies();
expect(checkDependenciesSpy.called).to.be.true;
});
});

View File

@@ -2197,22 +2197,6 @@
}, },
"id": "1976848738" "id": "1976848738"
}, },
"-2105726668": {
"source": "-1803759743",
"target": "1010833816",
"metadata": {
"type": "class-to-class",
"sourceModuleName": "PropertiesModule",
"sourceClassName": "PropertiesService",
"targetClassName": "token",
"sourceClassToken": "PropertiesService",
"targetClassToken": "token",
"targetModuleName": "PropertiesModule",
"keyOrIndex": "token",
"injectionType": "property"
},
"id": "-2105726668"
},
"-21463590": { "-21463590": {
"source": "-1378706112", "source": "-1378706112",
"target": "1004276345", "target": "1004276345",
@@ -2229,6 +2213,22 @@
}, },
"id": "-21463590" "id": "-21463590"
}, },
"-2105726668": {
"source": "-1803759743",
"target": "1010833816",
"metadata": {
"type": "class-to-class",
"sourceModuleName": "PropertiesModule",
"sourceClassName": "PropertiesService",
"targetClassName": "token",
"sourceClassToken": "PropertiesService",
"targetClassToken": "token",
"targetModuleName": "PropertiesModule",
"keyOrIndex": "token",
"injectionType": "property"
},
"id": "-2105726668"
},
"-1657371464": { "-1657371464": {
"source": "-1673986099", "source": "-1673986099",
"target": "1919157847", "target": "1919157847",

View File

@@ -0,0 +1,51 @@
/**
* A simple barrier to synchronize flow of multiple async operations.
*/
export class Barrier {
private currentCount: number;
private targetCount: number;
private promise: Promise<void>;
private resolve: () => void;
constructor(targetCount: number) {
this.currentCount = 0;
this.targetCount = targetCount;
this.promise = new Promise<void>(resolve => {
this.resolve = resolve;
});
}
/**
* Signal that a participant has reached the barrier.
*
* The barrier will be resolved once `targetCount` participants have reached it.
*/
public signal(): void {
this.currentCount += 1;
if (this.currentCount === this.targetCount) {
this.resolve();
}
}
/**
* Wait for the barrier to be resolved.
*
* @returns A promise that resolves when the barrier is resolved.
*/
public async wait(): Promise<void> {
return this.promise;
}
/**
* Signal that a participant has reached the barrier and wait for the barrier to be resolved.
*
* The barrier will be resolved once `targetCount` participants have reached it.
*
* @returns A promise that resolves when the barrier is resolved.
*/
public async signalAndWait(): Promise<void> {
this.signal();
return this.wait();
}
}

View File

@@ -42,6 +42,7 @@ import {
} from './instance-wrapper'; } from './instance-wrapper';
import { Module } from './module'; import { Module } from './module';
import { SettlementSignal } from './settlement-signal'; import { SettlementSignal } from './settlement-signal';
import { Barrier } from '../helpers/barrier';
/** /**
* The type of an injectable dependency * The type of an injectable dependency
@@ -295,10 +296,16 @@ export class Injector {
? this.getFactoryProviderDependencies(wrapper) ? this.getFactoryProviderDependencies(wrapper)
: this.getClassDependencies(wrapper); : this.getClassDependencies(wrapper);
const paramBarrier = new Barrier(dependencies.length);
let isResolved = true; let isResolved = true;
const resolveParam = async (param: unknown, index: number) => { const resolveParam = async (param: unknown, index: number) => {
try { try {
if (this.isInquirer(param, parentInquirer)) { if (this.isInquirer(param, parentInquirer)) {
/*
* Signal the barrier to make sure other dependencies do not get stuck waiting forever.
*/
paramBarrier.signal();
return parentInquirer && parentInquirer.instance; return parentInquirer && parentInquirer.instance;
} }
if (inquirer?.isTransient && parentInquirer) { if (inquirer?.isTransient && parentInquirer) {
@@ -314,15 +321,36 @@ export class Injector {
inquirer, inquirer,
index, index,
); );
const instanceHost = paramWrapper.getInstanceByContextId(
this.getContextId(contextId, paramWrapper), /*
* Ensure that all instance wrappers are resolved at this point before we continue.
* Otherwise the staticity of `wrapper`'s dependency tree may be evaluated incorrectly
* and result in undefined / null injection.
*/
await paramBarrier.signalAndWait();
const paramWrapperWithInstance = await this.resolveComponentHost(
moduleRef,
paramWrapper,
contextId,
inquirer,
);
const instanceHost = paramWrapperWithInstance.getInstanceByContextId(
this.getContextId(contextId, paramWrapperWithInstance),
inquirerId, inquirerId,
); );
if (!instanceHost.isResolved && !paramWrapper.forwardRef) { if (!instanceHost.isResolved && !paramWrapperWithInstance.forwardRef) {
isResolved = false; isResolved = false;
} }
return instanceHost?.instance; return instanceHost?.instance;
} catch (err) { } catch (err) {
/*
* Signal the barrier to make sure other dependencies do not get stuck waiting forever. We
* do not care if this occurs after `Barrier.signalAndWait()` is called in the `try` block
* because the barrier will always have been resolved by then.
*/
paramBarrier.signal();
const isOptional = optionalDependenciesIds.includes(index); const isOptional = optionalDependenciesIds.includes(index);
if (!isOptional) { if (!isOptional) {
throw err; throw err;
@@ -422,7 +450,7 @@ export class Injector {
); );
} }
const token = this.resolveParamToken(wrapper, param); const token = this.resolveParamToken(wrapper, param);
return this.resolveComponentInstance<T>( return this.resolveComponentWrapper(
moduleRef, moduleRef,
token, token,
dependencyContext, dependencyContext,
@@ -444,7 +472,7 @@ export class Injector {
return param; return param;
} }
public async resolveComponentInstance<T>( public async resolveComponentWrapper<T>(
moduleRef: Module, moduleRef: Module,
token: InjectionToken, token: InjectionToken,
dependencyContext: InjectorDependencyContext, dependencyContext: InjectorDependencyContext,
@@ -456,7 +484,7 @@ export class Injector {
this.printResolvingDependenciesLog(token, inquirer); this.printResolvingDependenciesLog(token, inquirer);
this.printLookingForProviderLog(token, moduleRef); this.printLookingForProviderLog(token, moduleRef);
const providers = moduleRef.providers; const providers = moduleRef.providers;
const instanceWrapper = await this.lookupComponent( return this.lookupComponent(
providers, providers,
moduleRef, moduleRef,
{ ...dependencyContext, name: token }, { ...dependencyContext, name: token },
@@ -465,13 +493,6 @@ export class Injector {
inquirer, inquirer,
keyOrIndex, keyOrIndex,
); );
return this.resolveComponentHost(
moduleRef,
instanceWrapper,
contextId,
inquirer,
);
} }
public async resolveComponentHost<T>( public async resolveComponentHost<T>(
@@ -671,6 +692,7 @@ export class Injector {
return this.loadPropertiesMetadata(metadata, contextId, inquirer); return this.loadPropertiesMetadata(metadata, contextId, inquirer);
} }
const properties = this.reflectProperties(wrapper.metatype as Type<any>); const properties = this.reflectProperties(wrapper.metatype as Type<any>);
const propertyBarrier = new Barrier(properties.length);
const instances = await Promise.all( const instances = await Promise.all(
properties.map(async (item: PropertyDependency) => { properties.map(async (item: PropertyDependency) => {
try { try {
@@ -679,6 +701,11 @@ export class Injector {
name: item.name as Function | string | symbol, name: item.name as Function | string | symbol,
}; };
if (this.isInquirer(item.name, parentInquirer)) { if (this.isInquirer(item.name, parentInquirer)) {
/*
* Signal the barrier to make sure other dependencies do not get stuck waiting forever.
*/
propertyBarrier.signal();
return parentInquirer && parentInquirer.instance; return parentInquirer && parentInquirer.instance;
} }
const paramWrapper = await this.resolveSingleParam<T>( const paramWrapper = await this.resolveSingleParam<T>(
@@ -690,16 +717,37 @@ export class Injector {
inquirer, inquirer,
item.key, item.key,
); );
if (!paramWrapper) {
/*
* Ensure that all instance wrappers are resolved at this point before we continue.
* Otherwise the staticity of `wrapper`'s dependency tree may be evaluated incorrectly
* and result in undefined / null injection.
*/
await propertyBarrier.signalAndWait();
const paramWrapperWithInstance = await this.resolveComponentHost(
moduleRef,
paramWrapper,
contextId,
inquirer,
);
if (!paramWrapperWithInstance) {
return undefined; return undefined;
} }
const inquirerId = this.getInquirerId(inquirer); const inquirerId = this.getInquirerId(inquirer);
const instanceHost = paramWrapper.getInstanceByContextId( const instanceHost = paramWrapperWithInstance.getInstanceByContextId(
this.getContextId(contextId, paramWrapper), this.getContextId(contextId, paramWrapperWithInstance),
inquirerId, inquirerId,
); );
return instanceHost.instance; return instanceHost.instance;
} catch (err) { } catch (err) {
/*
* Signal the barrier to make sure other dependencies do not get stuck waiting forever. We
* do not care if this occurs after `Barrier.signalAndWait()` is called in the `try` block
* because the barrier will always have been resolved by then.
*/
propertyBarrier.signal();
if (!item.isOptional) { if (!item.isOptional) {
throw err; throw err;
} }

View File

@@ -0,0 +1,93 @@
import { expect } from 'chai';
import { Barrier } from '../../../core/helpers/barrier';
import * as sinon from 'sinon';
import * as chai from 'chai';
import * as chaiAsPromised from 'chai-as-promised';
import { setTimeout } from 'timers/promises';
chai.use(chaiAsPromised);
describe('Barrier', () => {
const targetCount = 3;
let barrier: Barrier;
let barrierResolveSpy: sinon.SinonSpy;
beforeEach(() => {
barrier = new Barrier(targetCount);
barrierResolveSpy = sinon.spy(<any>barrier, 'resolve');
});
afterEach(() => {
// resolve any promises that may still be waiting in the background
(<any>barrier).resolve();
});
describe('signal', () => {
it('should resolve the barrier when target count is reached', async () => {
for (let i = 0; i < targetCount; i++) {
barrier.signal();
}
expect(barrierResolveSpy.called).to.be.true;
});
it('should not resolve the barrier when target count is not reached', async () => {
for (let i = 0; i < targetCount - 1; i++) {
barrier.signal();
}
expect(barrierResolveSpy.called).to.be.false;
expect((<any>barrier).currentCount).to.be.equal(targetCount - 1);
});
});
describe('wait', () => {
it('should resolve when target count is reached', async () => {
const waitPromise = barrier.wait();
for (let i = 0; i < targetCount; i++) {
barrier.signal();
}
expect(waitPromise).to.be.fulfilled;
});
it('should not resolve when target count is not reached', async () => {
const waitPromise = barrier.wait();
for (let i = 0; i < targetCount - 1; i++) {
barrier.signal();
}
expect(waitPromise).not.to.be.fulfilled;
});
});
describe('signalAndWait', () => {
it('should resolve when target count is reached', async () => {
const promise = Promise.all(
Array.from({ length: targetCount }, () => barrier.signalAndWait()),
);
// wait for the promise to be resolved
await promise;
expect(promise).to.be.fulfilled;
expect(barrierResolveSpy.called).to.be.true;
});
it('should not resolve when target count is not reached', async () => {
const promise = Promise.all(
Array.from({ length: targetCount - 1 }, () => barrier.signalAndWait()),
);
/*
* Give the promise some time to work. We cannot await the promise because the test case would
* get stuck.
*/
await setTimeout(5);
expect(promise).not.to.be.fulfilled;
expect(barrierResolveSpy.called).to.be.false;
});
});
});

View File

@@ -545,7 +545,7 @@ describe('Injector', () => {
}); });
}); });
describe('resolveComponentInstance', () => { describe('resolveComponentHost', () => {
let module: any; let module: any;
beforeEach(() => { beforeEach(() => {
module = { module = {
@@ -560,16 +560,8 @@ describe('Injector', () => {
const loadStub = sinon const loadStub = sinon
.stub(injector, 'loadProvider') .stub(injector, 'loadProvider')
.callsFake(() => null!); .callsFake(() => null!);
sinon
.stub(injector, 'lookupComponent')
.returns(Promise.resolve(wrapper));
await injector.resolveComponentInstance( await injector.resolveComponentHost(module, wrapper);
module,
'',
{ index: 0, dependencies: [] },
wrapper,
);
expect(loadStub.called).to.be.true; expect(loadStub.called).to.be.true;
}); });
it('should not call loadProvider (isResolved)', async () => { it('should not call loadProvider (isResolved)', async () => {
@@ -578,16 +570,7 @@ describe('Injector', () => {
.stub(injector, 'loadProvider') .stub(injector, 'loadProvider')
.callsFake(() => null!); .callsFake(() => null!);
sinon await injector.resolveComponentHost(module, wrapper);
.stub(injector, 'lookupComponent')
.returns(Promise.resolve(wrapper));
await injector.resolveComponentInstance(
module,
'',
{ index: 0, dependencies: [] },
wrapper,
);
expect(loadStub.called).to.be.false; expect(loadStub.called).to.be.false;
}); });
it('should not call loadProvider (forwardRef)', async () => { it('should not call loadProvider (forwardRef)', async () => {
@@ -599,16 +582,7 @@ describe('Injector', () => {
.stub(injector, 'loadProvider') .stub(injector, 'loadProvider')
.callsFake(() => null!); .callsFake(() => null!);
sinon await injector.resolveComponentHost(module, wrapper);
.stub(injector, 'lookupComponent')
.returns(Promise.resolve(wrapper));
await injector.resolveComponentInstance(
module,
'',
{ index: 0, dependencies: [] },
wrapper,
);
expect(loadStub.called).to.be.false; expect(loadStub.called).to.be.false;
}); });
}); });
@@ -624,16 +598,8 @@ describe('Injector', () => {
async: true, async: true,
instance, instance,
}); });
sinon
.stub(injector, 'lookupComponent')
.returns(Promise.resolve(wrapper));
const result = await injector.resolveComponentInstance( const result = await injector.resolveComponentHost(module, wrapper);
module,
'',
{ index: 0, dependencies: [] },
wrapper,
);
expect(result.instance).to.be.true; expect(result.instance).to.be.true;
}); });
}); });

View File

@@ -15,7 +15,7 @@ describe('PipesConsumer', () => {
beforeEach(() => { beforeEach(() => {
value = 0; value = 0;
data = null; data = null;
(metatype = {}), (type = RouteParamtypes.QUERY); ((metatype = {}), (type = RouteParamtypes.QUERY));
stringifiedType = 'query'; stringifiedType = 'query';
transforms = [ transforms = [
createPipe(sinon.stub().callsFake(val => val + 1)), createPipe(sinon.stub().callsFake(val => val + 1)),

View File

@@ -23,7 +23,7 @@ export class TestingInjector extends Injector {
this.container = container; this.container = container;
} }
public async resolveComponentInstance<T>( public async resolveComponentWrapper<T>(
moduleRef: Module, moduleRef: Module,
name: any, name: any,
dependencyContext: InjectorDependencyContext, dependencyContext: InjectorDependencyContext,
@@ -33,7 +33,7 @@ export class TestingInjector extends Injector {
keyOrIndex?: string | number, keyOrIndex?: string | number,
): Promise<InstanceWrapper> { ): Promise<InstanceWrapper> {
try { try {
const existingProviderWrapper = await super.resolveComponentInstance( const existingProviderWrapper = await super.resolveComponentWrapper(
moduleRef, moduleRef,
name, name,
dependencyContext, dependencyContext,
@@ -44,39 +44,72 @@ export class TestingInjector extends Injector {
); );
return existingProviderWrapper; return existingProviderWrapper;
} catch (err) { } catch (err) {
if (this.mocker) { return this.mockWrapper(err, moduleRef, name, wrapper);
const mockedInstance = this.mocker(name);
if (!mockedInstance) {
throw err;
}
const newWrapper = new InstanceWrapper({
name,
isAlias: false,
scope: wrapper.scope,
instance: mockedInstance,
isResolved: true,
host: moduleRef,
metatype: wrapper.metatype,
});
const internalCoreModule = this.container.getInternalCoreModuleRef();
if (!internalCoreModule) {
throw new Error(
'Expected to have internal core module reference at this point.',
);
}
internalCoreModule.addCustomProvider(
{
provide: name,
useValue: mockedInstance,
},
internalCoreModule.providers,
);
internalCoreModule.addExportedProviderOrModule(name);
return newWrapper;
} else {
throw err;
}
} }
} }
public async resolveComponentHost<T>(
moduleRef: Module,
instanceWrapper: InstanceWrapper<T>,
contextId = STATIC_CONTEXT,
inquirer?: InstanceWrapper,
): Promise<InstanceWrapper> {
try {
const existingProviderWrapper = await super.resolveComponentHost(
moduleRef,
instanceWrapper,
contextId,
inquirer,
);
return existingProviderWrapper;
} catch (err) {
return this.mockWrapper(
err,
moduleRef,
instanceWrapper.name,
instanceWrapper,
);
}
}
private async mockWrapper<T>(
err: Error,
moduleRef: Module,
name: any,
wrapper: InstanceWrapper<T>,
): Promise<InstanceWrapper> {
if (!this.mocker) {
throw err;
}
const mockedInstance = this.mocker(name);
if (!mockedInstance) {
throw err;
}
const newWrapper = new InstanceWrapper({
name,
isAlias: false,
scope: wrapper.scope,
instance: mockedInstance,
isResolved: true,
host: moduleRef,
metatype: wrapper.metatype,
});
const internalCoreModule = this.container.getInternalCoreModuleRef();
if (!internalCoreModule) {
throw new Error(
'Expected to have internal core module reference at this point.',
);
}
internalCoreModule.addCustomProvider(
{
provide: name,
useValue: mockedInstance,
},
internalCoreModule.providers,
);
internalCoreModule.addExportedProviderOrModule(name);
return newWrapper;
}
} }