diff --git a/packages/microservices/client/client-grpc.ts b/packages/microservices/client/client-grpc.ts index d8e8797b3..1ef58acf7 100644 --- a/packages/microservices/client/client-grpc.ts +++ b/packages/microservices/client/client-grpc.ts @@ -1,8 +1,7 @@ import { Logger } from '@nestjs/common/services/logger.service'; import { loadPackage } from '@nestjs/common/utils/load-package.util'; import { isFunction, isObject } from '@nestjs/common/utils/shared.utils'; -import { Observable } from 'rxjs'; -import { takeUntil } from 'rxjs/operators'; +import { Observable, Subscription } from 'rxjs'; import { GRPC_DEFAULT_MAX_RECEIVE_MESSAGE_LENGTH, GRPC_DEFAULT_MAX_SEND_MESSAGE_LENGTH, @@ -115,6 +114,7 @@ export class ClientGrpcProxy extends ClientProxy implements ClientGrpc { const isRequestStream = client[methodName].requestStream; const stream = new Observable(observer => { let isClientCanceled = false; + let upstreamSubscription: Subscription; const upstreamSubjectOrData = args[0]; const isUpstreamSubject = @@ -126,7 +126,7 @@ export class ClientGrpcProxy extends ClientProxy implements ClientGrpc { : client[methodName](...args); if (isRequestStream && isUpstreamSubject) { - upstreamSubjectOrData.pipe(takeUntil(stream)).subscribe( + upstreamSubscription = upstreamSubjectOrData.subscribe( (val: unknown) => call.write(val), (err: unknown) => call.emit('error', err), () => call.end(), @@ -143,10 +143,19 @@ export class ClientGrpcProxy extends ClientProxy implements ClientGrpc { observer.error(error); }); call.on('end', () => { + if (upstreamSubscription) { + upstreamSubscription.unsubscribe(); + upstreamSubscription = null; + } call.removeAllListeners(); observer.complete(); }); - return (): any => { + return () => { + if (upstreamSubscription) { + upstreamSubscription.unsubscribe(); + upstreamSubscription = null; + } + if (call.finished) { return undefined; } diff --git a/packages/microservices/test/client/client-grpc.spec.ts b/packages/microservices/test/client/client-grpc.spec.ts index 1259f4d8d..f384fe62f 100644 --- a/packages/microservices/test/client/client-grpc.spec.ts +++ b/packages/microservices/test/client/client-grpc.spec.ts @@ -1,7 +1,7 @@ import { Logger } from '@nestjs/common'; import { expect } from 'chai'; import { join } from 'path'; -import { Observable } from 'rxjs'; +import { Observable, Subject } from 'rxjs'; import * as sinon from 'sinon'; import { ClientGrpcProxy } from '../../client/client-grpc'; import { InvalidGrpcPackageException } from '../../errors/invalid-grpc-package.exception'; @@ -138,6 +138,35 @@ describe('ClientGrpcProxy', () => { }); }); + describe('when stream request', () => { + const methodName = 'm'; + const writeSpy = sinon.spy(); + const obj = { + [methodName]: () => ({ on: (type, fn) => fn(), write: writeSpy }), + }; + + let stream$: Observable; + let upstream: Subject; + + beforeEach(() => { + upstream = new Subject(); + (obj[methodName] as any).requestStream = true; + stream$ = client.createStreamServiceMethod(obj, methodName)(upstream); + }); + + it('should subscribe to request upstream', () => { + const upstreamSubscribe = sinon.spy(upstream, 'subscribe'); + stream$.subscribe( + () => ({}), + () => ({}), + ); + upstream.next({ test: true }); + + expect(writeSpy.called).to.be.true; + expect(upstreamSubscribe.called).to.be.true; + }); + }); + describe('flow-control', () => { const methodName = 'm'; type EvtCallback = (...args: any[]) => void; @@ -237,6 +266,39 @@ describe('ClientGrpcProxy', () => { expect(spy.called).to.be.true; }); }); + describe('when stream request', () => { + const writeSpy = sinon.spy(); + const methodName = 'm'; + const obj = { + [methodName]: callback => { + callback(null, {}); + return { + write: writeSpy, + }; + }, + }; + + let stream$: Observable; + let upstream: Subject; + + beforeEach(() => { + upstream = new Subject(); + (obj[methodName] as any).requestStream = true; + stream$ = client.createUnaryServiceMethod(obj, methodName)(upstream); + }); + + it('should subscribe to request upstream', () => { + const upstreamSubscribe = sinon.spy(upstream, 'subscribe'); + stream$.subscribe( + () => ({}), + () => ({}), + ); + upstream.next({ test: true }); + + expect(writeSpy.called).to.be.true; + expect(upstreamSubscribe.called).to.be.true; + }); + }); }); describe('createClients', () => { diff --git a/packages/microservices/test/server/server-grpc.spec.ts b/packages/microservices/test/server/server-grpc.spec.ts index 4ca913c01..986a47ece 100644 --- a/packages/microservices/test/server/server-grpc.spec.ts +++ b/packages/microservices/test/server/server-grpc.spec.ts @@ -6,6 +6,7 @@ import { of } from 'rxjs'; import * as sinon from 'sinon'; import { InvalidGrpcPackageException } from '../../errors/invalid-grpc-package.exception'; import { ServerGrpc } from '../../server/server-grpc'; +import { CANCEL_EVENT } from '../../constants'; class NoopLogger extends Logger { log(message: any, context?: string): void {} @@ -441,6 +442,48 @@ describe('ServerGrpc', () => { expect(handler.called).to.be.true; }); + describe('when response is not a stream', () => { + it('should call callback', async () => { + const handler = async () => ({ test: true }); + const fn = server.createRequestStreamMethod(handler, false); + const call = { + on: (event, callback) => { + if (event !== CANCEL_EVENT) { + callback(); + } + }, + off: sinon.spy(), + end: sinon.spy(), + write: sinon.spy(), + }; + + const responseCallback = sinon.spy(); + await fn(call as any, responseCallback); + + expect(responseCallback.called).to.be.true; + }); + describe('when response is a stream', () => { + it('should call write() and end()', async () => { + const handler = async () => ({ test: true }); + const fn = server.createRequestStreamMethod(handler, true); + const call = { + on: (event, callback) => { + if (event !== CANCEL_EVENT) { + callback(); + } + }, + off: sinon.spy(), + end: sinon.spy(), + write: sinon.spy(), + }; + + await fn(call as any, null); + + expect(call.write.called).to.be.true; + expect(call.end.called).to.be.true; + }); + }); + }); }); describe('createStreamCallMethod', () => {