// Copyright 2022 The Pigweed Authors // // Licensed under the Apache License, Version 2.0 (the "License"); you may not // use this file except in compliance with the License. You may obtain a copy of // the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, WITHOUT // WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the // License for the specific language governing permissions and limitations under // the License. /* eslint-env browser */ import { Status } from 'pigweedjs/pw_status'; import { MessageCreator } from 'pigweedjs/pw_protobuf_compiler'; import { Message } from 'google-protobuf'; import { PacketType, RpcPacket, } from 'pigweedjs/protos/pw_rpc/internal/packet_pb'; import { ProtoCollection } from 'pigweedjs/protos/collection'; import { Request, Response } from 'pigweedjs/protos/pw_rpc/ts/test_pb'; import { Client } from './client'; import { Channel, Method, Service } from './descriptors'; import { BidirectionalStreamingMethodStub, ClientStreamingMethodStub, ServerStreamingMethodStub, UnaryMethodStub, } from './method'; import * as packets from './packets'; const LEGACY_OPEN_CALL_ID = 0; const OPEN_CALL_ID = 2 ** 32 - 1; describe('Client', () => { let protoCollection: ProtoCollection; let client: Client; let lastPacketSent: RpcPacket; beforeEach(() => { protoCollection = new ProtoCollection(); const channels = [new Channel(1, savePacket), new Channel(5)]; client = Client.fromProtoSet(channels, protoCollection); }); function savePacket(packetBytes: Uint8Array): void { lastPacketSent = RpcPacket.deserializeBinary(packetBytes); } it('channel returns undefined for empty list', () => { const channels = Array(); const emptyChannelClient = Client.fromProtoSet(channels, protoCollection); expect(emptyChannelClient.channel()).toBeUndefined(); }); it('fetches channel or returns undefined', () => { expect(client.channel(1)!.channel.id).toEqual(1); expect(client.channel(5)!.channel.id).toEqual(5); expect(client.channel()!.channel.id).toEqual(1); expect(client.channel(2)).toBeUndefined(); }); it('ChannelClient fetches method by name', () => { const channel = client.channel()!; const stub = channel.methodStub('pw.rpc.test1.TheTestService.SomeUnary')!; expect(stub.method.name).toEqual('SomeUnary'); }); it('ChannelClient for unknown name returns undefined', () => { const channel = client.channel()!; expect(channel.methodStub('')).toBeUndefined(); expect( channel.methodStub('pw.rpc.test1.Garbage.SomeUnary'), ).toBeUndefined(); expect( channel.methodStub('pw.rpc.test1.TheTestService.Garbage'), ).toBeUndefined(); }); it('processPacket with invalid proto data', () => { const textEncoder = new TextEncoder(); const data = textEncoder.encode('NOT a packet!'); expect(client.processPacket(data)).toEqual(Status.DATA_LOSS); }); it('processPacket not for client', () => { const packet = new RpcPacket(); packet.setType(PacketType.REQUEST); const processStatus = client.processPacket(packet.serializeBinary()); expect(processStatus).toEqual(Status.INVALID_ARGUMENT); }); it('processPacket for unrecognized channel', () => { const packet = packets.encodeResponse([123, 456, 789, 456], new Request()); expect(client.processPacket(packet)).toEqual(Status.NOT_FOUND); }); it('processPacket for unrecognized service', () => { const packet = packets.encodeResponse([1, 456, 789, 456], new Request()); const status = client.processPacket(packet); expect(client.processPacket(packet)).toEqual(Status.OK); expect(lastPacketSent.getChannelId()).toEqual(1); expect(lastPacketSent.getServiceId()).toEqual(456); expect(lastPacketSent.getMethodId()).toEqual(789); expect(lastPacketSent.getCallId()).toEqual(456); expect(lastPacketSent.getType()).toEqual(PacketType.CLIENT_ERROR); expect(lastPacketSent.getStatus()).toEqual(Status.NOT_FOUND); }); it('processPacket for unrecognized method', () => { const service = client.services.values().next().value; const packet = packets.encodeResponse( [1, service.id, 789, 456], new Request(), ); const status = client.processPacket(packet); expect(client.processPacket(packet)).toEqual(Status.OK); expect(lastPacketSent.getChannelId()).toEqual(1); expect(lastPacketSent.getServiceId()).toEqual(service.id); expect(lastPacketSent.getMethodId()).toEqual(789); expect(lastPacketSent.getCallId()).toEqual(456); expect(lastPacketSent.getType()).toEqual(PacketType.CLIENT_ERROR); expect(lastPacketSent.getStatus()).toEqual(Status.NOT_FOUND); }); it('processPacket for non-pending method', () => { const service = client.services.values().next().value; const method = service.methods.values().next().value; const packet = packets.encodeResponse( [1, service.id, method.id, 456], new Request(), ); const status = client.processPacket(packet); expect(client.processPacket(packet)).toEqual(Status.OK); expect(lastPacketSent.getChannelId()).toEqual(1); expect(lastPacketSent.getServiceId()).toEqual(service.id); expect(lastPacketSent.getMethodId()).toEqual(method.id); expect(lastPacketSent.getType()).toEqual(PacketType.CLIENT_ERROR); expect(lastPacketSent.getStatus()).toEqual(Status.FAILED_PRECONDITION); }); }); describe('RPC', () => { let protoCollection: ProtoCollection; let client: Client; let lastPacketSent: RpcPacket | undefined; let requests: RpcPacket[] = []; let nextPackets: [Uint8Array, Status][] = []; let responseLock = false; let sendResponsesAfterPackets = 0; let outputException: Error | undefined; beforeEach(async () => { protoCollection = new ProtoCollection(); const channels = [ new Channel(1, handlePacket), new Channel(2, () => { // Do nothing. }), ]; client = Client.fromProtoSet(channels, protoCollection); lastPacketSent = undefined; requests = []; nextPackets = []; responseLock = false; sendResponsesAfterPackets = 0; outputException = undefined; }); function newRequest(magicNumber = 123): Message { const request = new Request(); request.setMagicNumber(magicNumber); return request; } function newResponse(payload = '._.'): Message { const response = new Response(); response.setPayload(payload); return response; } function enqueueResponse( channelId: number, method: Method, status: Status, callId: number, response?: Message, ) { const packet = new RpcPacket(); packet.setType(PacketType.RESPONSE); packet.setChannelId(channelId); packet.setServiceId(method.service.id); packet.setMethodId(method.id); packet.setCallId(callId); packet.setStatus(status); if (response === undefined) { packet.setPayload(new Uint8Array(0)); } else { packet.setPayload(response.serializeBinary()); } nextPackets.push([packet.serializeBinary(), Status.OK]); } function enqueueServerStream( channelId: number, method: Method, response: Message, callId: number, status: Status = Status.OK, ) { const packet = new RpcPacket(); packet.setType(PacketType.SERVER_STREAM); packet.setChannelId(channelId); packet.setServiceId(method.service.id); packet.setMethodId(method.id); packet.setCallId(callId); packet.setPayload(response.serializeBinary()); packet.setStatus(status); nextPackets.push([packet.serializeBinary(), status]); } function enqueueError( channelId: number, method: Method, status: Status, processStatus: Status, callId: number, ) { const packet = new RpcPacket(); packet.setType(PacketType.SERVER_ERROR); packet.setChannelId(channelId); packet.setServiceId(method.service.id); packet.setMethodId(method.id); packet.setCallId(callId); packet.setStatus(status); nextPackets.push([packet.serializeBinary(), processStatus]); } function lastRequest(): RpcPacket { if (requests.length == 0) { throw Error('Tried to fetch request from empty list'); } return requests[requests.length - 1]; } function sentPayload(messageType: typeof Message): any { return messageType.deserializeBinary(lastRequest().getPayload_asU8()); } function handlePacket(data: Uint8Array): void { if (outputException !== undefined) { throw outputException; } requests.push(packets.decode(data)); if (sendResponsesAfterPackets > 1) { sendResponsesAfterPackets -= 1; return; } processEnqueuedPackets(); } function processEnqueuedPackets(): void { // Avoid infinite recursion when processing a packet causes another packet // to send. if (responseLock) return; responseLock = true; for (const [packet, status] of nextPackets) { expect(client.processPacket(packet)).toEqual(status); } nextPackets = []; responseLock = false; } describe('Unary', () => { let unaryStub: UnaryMethodStub; beforeEach(async () => { unaryStub = client .channel() ?.methodStub( 'pw.rpc.test1.TheTestService.SomeUnary', ) as UnaryMethodStub; }); const openCallIds = [ ['OPEN_CALL_ID', OPEN_CALL_ID], ['LEGACY_OPEN_CALL_ID', LEGACY_OPEN_CALL_ID], ]; openCallIds.forEach(([idName, callId]) => { it(`matches responses with ${idName} to requests with arbitrary IDs`, async () => { const promisedResponse = unaryStub.call(newRequest(6)); enqueueResponse( 1, unaryStub.method, Status.ABORTED, OPEN_CALL_ID, newResponse('is unrequested'), ); processEnqueuedPackets(); const [status, response] = await promisedResponse; expect(sentPayload(Request).getMagicNumber()).toEqual(6); expect(status).toEqual(Status.ABORTED); expect(response).toEqual(newResponse('is unrequested')); }); }); it('blocking call', async () => { for (let i = 0; i < 3; i++) { enqueueResponse( 1, unaryStub.method, Status.ABORTED, unaryStub.rpcs.nextCallId, newResponse('0_o'), ); const [status, response] = await unaryStub.call(newRequest(6)); expect(sentPayload(Request).getMagicNumber()).toEqual(6); expect(status).toEqual(Status.ABORTED); expect(response).toEqual(newResponse('0_o')); } }); it('nonblocking call', () => { for (let i = 0; i < 3; i++) { const response = newResponse('hello world'); enqueueResponse( 1, unaryStub.method, Status.ABORTED, unaryStub.rpcs.nextCallId, response, ); const onNext = jest.fn(); const onCompleted = jest.fn(); const onError = jest.fn(); const call = unaryStub.invoke( newRequest(5), onNext, onCompleted, onError, ); expect(sentPayload(Request).getMagicNumber()).toEqual(5); expect(onNext).toHaveBeenCalledWith(response); expect(onError).not.toHaveBeenCalled(); expect(onCompleted).toHaveBeenCalledWith(Status.ABORTED); } }); it('open', () => { outputException = Error('Error should be ignored'); for (let i = 0; i < 3; i++) { const response = newResponse('hello world'); enqueueResponse( 1, unaryStub.method, Status.ABORTED, unaryStub.rpcs.nextCallId, response, ); const onNext = jest.fn(); const onCompleted = jest.fn(); const onError = jest.fn(); unaryStub.open(newRequest(5), onNext, onCompleted, onError); expect(requests).toHaveLength(0); processEnqueuedPackets(); expect(onNext).toHaveBeenCalledWith(response); expect(onError).not.toHaveBeenCalled(); expect(onCompleted).toHaveBeenCalledWith(Status.ABORTED); } }); it('nonblocking concurrent call', () => { // Start several calls to the same method const callsAndCallbacks = []; for (let i = 0; i < 3; i++) { const onNext = jest.fn(); const onCompleted = jest.fn(); const onError = jest.fn(); const call = unaryStub.invoke( newRequest(5), onNext, onCompleted, onError, ); callsAndCallbacks.push([call, onNext, onCompleted, onError]); expect(sentPayload(Request).getMagicNumber()).toEqual(5); } // Respond only to the last call const [lastCall, lastCallback] = callsAndCallbacks.pop(); const lastResponse = newResponse('last payload'); enqueueResponse( 1, unaryStub.method, Status.OK, lastCall.callId, lastResponse, ); processEnqueuedPackets(); expect(lastCallback).toHaveBeenCalledWith(lastResponse); for (const i in callsAndCallbacks) { const [_call, onNext, onCompleted, onError] = callsAndCallbacks[i]; expect(onNext).toBeCalledTimes(0); expect(onCompleted).toBeCalledTimes(0); expect(onError).toBeCalledTimes(0); } }); it('blocking server error', async () => { for (let i = 0; i < 3; i++) { enqueueError( 1, unaryStub.method, Status.NOT_FOUND, Status.OK, unaryStub.rpcs.nextCallId, ); try { await unaryStub.call(newRequest()); fail('call expected to fail'); } catch (e: any) { expect(e.status).toBe(Status.NOT_FOUND); } } }); it('nonblocking call cancel', () => { for (let i = 0; i < 3; i++) { const onNext = jest.fn(); const call = unaryStub.invoke(newRequest(), onNext); expect(requests.length).toBeGreaterThan(0); requests = []; expect(call.cancel()).toBe(true); expect(lastRequest().getType()).toEqual(PacketType.CLIENT_ERROR); expect(lastRequest().getStatus()).toEqual(Status.CANCELLED); expect(call.cancel()).toBe(false); expect(onNext).not.toHaveBeenCalled(); } }); it('blocking call with timeout', async () => { try { await unaryStub.call(newRequest(), 10); fail('Promise should not be resolve'); } catch (err: any) { expect(err.timeoutMs).toEqual(10); } }); it('nonblocking exception in callback', () => { const errorCallback = () => { throw Error('Something went wrong!'); }; enqueueResponse( 1, unaryStub.method, Status.OK, unaryStub.rpcs.nextCallId, ); const call = unaryStub.invoke(newRequest(), errorCallback); expect(call.callbackException!.name).toEqual('Error'); expect(call.callbackException!.message).toEqual('Something went wrong!'); }); }); describe('ServerStreaming', () => { let serverStreaming: ServerStreamingMethodStub; beforeEach(async () => { serverStreaming = client .channel() ?.methodStub( 'pw.rpc.test1.TheTestService.SomeServerStreaming', ) as ServerStreamingMethodStub; }); it('non-blocking call', () => { const response1 = newResponse('!!!'); const response2 = newResponse('?'); for (let i = 0; i < 3; i++) { enqueueServerStream( 1, serverStreaming.method, response1, serverStreaming.rpcs.nextCallId, ); enqueueServerStream( 1, serverStreaming.method, response2, serverStreaming.rpcs.nextCallId, ); enqueueResponse( 1, serverStreaming.method, Status.ABORTED, serverStreaming.rpcs.nextCallId, ); const onNext = jest.fn(); const onCompleted = jest.fn(); const onError = jest.fn(); serverStreaming.invoke(newRequest(4), onNext, onCompleted, onError); expect(onNext).toHaveBeenCalledWith(response1); expect(onNext).toHaveBeenCalledWith(response2); expect(onError).not.toHaveBeenCalled(); expect(onCompleted).toHaveBeenCalledWith(Status.ABORTED); expect( sentPayload(serverStreaming.method.requestType).getMagicNumber(), ).toEqual(4); } }); it('open', () => { outputException = Error('Error should be ignored'); const response1 = newResponse('!!!'); const response2 = newResponse('?'); for (let i = 0; i < 3; i++) { enqueueServerStream( 1, serverStreaming.method, response1, serverStreaming.rpcs.nextCallId, ); enqueueServerStream( 1, serverStreaming.method, response2, serverStreaming.rpcs.nextCallId, ); enqueueResponse( 1, serverStreaming.method, Status.ABORTED, serverStreaming.rpcs.nextCallId, ); const onNext = jest.fn(); const onCompleted = jest.fn(); const onError = jest.fn(); const call = serverStreaming.open( newRequest(3), onNext, onCompleted, onError, ); expect(requests).toHaveLength(0); processEnqueuedPackets(); expect(onNext).toHaveBeenCalledWith(response1); expect(onNext).toHaveBeenCalledWith(response2); expect(onError).not.toHaveBeenCalled(); expect(onCompleted).toHaveBeenCalledWith(Status.ABORTED); } }); it('blocking timeout', async () => { try { await serverStreaming.call(newRequest(), 10); fail('Promise should not be resolve'); } catch (err: any) { expect(err.timeoutMs).toEqual(10); } }); it('non-blocking cancel', () => { const testResponse = newResponse('!!!'); enqueueServerStream( 1, serverStreaming.method, testResponse, serverStreaming.rpcs.nextCallId, ); const onNext = jest.fn(); const onCompleted = jest.fn(); const onError = jest.fn(); let call = serverStreaming.invoke(newRequest(3), onNext); expect(onNext).toHaveBeenNthCalledWith(1, testResponse); // onNext.calls.reset(); call.cancel(); expect(lastRequest().getType()).toEqual(PacketType.CLIENT_ERROR); expect(lastRequest().getStatus()).toEqual(Status.CANCELLED); // Ensure the RPC can be called after being cancelled. enqueueServerStream( 1, serverStreaming.method, testResponse, serverStreaming.rpcs.nextCallId, ); enqueueResponse( 1, serverStreaming.method, Status.OK, serverStreaming.rpcs.nextCallId, ); call = serverStreaming.invoke(newRequest(), onNext, onCompleted, onError); expect(onNext).toHaveBeenNthCalledWith(2, testResponse); expect(onError).not.toHaveBeenCalled(); expect(onCompleted).toHaveBeenCalledWith(Status.OK); }); }); describe('ClientStreaming', () => { let clientStreaming: ClientStreamingMethodStub; beforeEach(async () => { clientStreaming = client .channel() ?.methodStub( 'pw.rpc.test1.TheTestService.SomeClientStreaming', ) as ClientStreamingMethodStub; }); it('non-blocking call', () => { const testResponse = newResponse('-.-'); for (let i = 0; i < 3; i++) { const onNext = jest.fn(); const stream = clientStreaming.invoke(onNext); expect(stream.completed).toBe(false); stream.send(newRequest(31)); expect(lastRequest().getType()).toEqual(PacketType.CLIENT_STREAM); expect(sentPayload(Request).getMagicNumber()).toEqual(31); expect(stream.completed).toBe(false); // Enqueue the server response to be sent after the next message. enqueueResponse( 1, clientStreaming.method, Status.OK, stream.callId, testResponse, ); stream.send(newRequest(32)); expect(lastRequest().getType()).toEqual(PacketType.CLIENT_STREAM); expect(sentPayload(Request).getMagicNumber()).toEqual(32); expect(onNext).toHaveBeenCalledWith(testResponse); expect(stream.completed).toBe(true); expect(stream.status).toEqual(Status.OK); expect(stream.error).toBeUndefined(); } }); it('open', () => { outputException = Error('Error should be ignored'); const response = newResponse('!!!'); for (let i = 0; i < 3; i++) { enqueueResponse( 1, clientStreaming.method, Status.OK, clientStreaming.rpcs.nextCallId, response, ); const onNext = jest.fn(); const onCompleted = jest.fn(); const onError = jest.fn(); const call = clientStreaming.open(onNext, onCompleted, onError); expect(requests).toHaveLength(0); processEnqueuedPackets(); expect(onNext).toHaveBeenCalledWith(response); expect(onError).not.toHaveBeenCalled(); expect(onCompleted).toHaveBeenCalledWith(Status.OK); } }); it('blocking timeout', async () => { try { await clientStreaming.call([newRequest()], 10); fail('Promise should not be resolve'); } catch (err: any) { expect(err.timeoutMs).toEqual(10); } }); it('non-blocking call ended by client', () => { const testResponse = newResponse('0.o'); for (let i = 0; i < 3; i++) { const onNext = jest.fn(); const stream = clientStreaming.invoke(onNext); expect(stream.completed).toBe(false); stream.send(newRequest(31)); expect(lastRequest().getType()).toEqual(PacketType.CLIENT_STREAM); expect(sentPayload(Request).getMagicNumber()).toEqual(31); expect(stream.completed).toBe(false); // Enqueue the server response to be sent after the next message. enqueueResponse( 1, clientStreaming.method, Status.OK, stream.callId, testResponse, ); stream.finishAndWait(); expect(lastRequest().getType()).toEqual( PacketType.CLIENT_REQUEST_COMPLETION, ); expect(onNext).toHaveBeenCalledWith(testResponse); expect(stream.completed).toBe(true); expect(stream.status).toEqual(Status.OK); expect(stream.error).toBeUndefined(); } }); it('non-blocking call cancelled', () => { for (let i = 0; i < 3; i++) { const stream = clientStreaming.invoke(); stream.send(newRequest()); expect(stream.cancel()).toBe(true); expect(lastRequest().getType()).toEqual(PacketType.CLIENT_ERROR); expect(lastRequest().getStatus()).toEqual(Status.CANCELLED); expect(stream.cancel()).toBe(false); expect(stream.completed).toBe(true); expect(stream.error).toEqual(Status.CANCELLED); } }); it('non-blocking call server error', async () => { for (let i = 0; i < 3; i++) { const stream = clientStreaming.invoke(); enqueueError( 1, clientStreaming.method, Status.INVALID_ARGUMENT, Status.OK, stream.callId, ); stream.send(newRequest()); await stream .finishAndWait() .then(() => { fail('Promise should not be resolved'); }) .catch((reason) => { expect(reason.status).toEqual(Status.INVALID_ARGUMENT); }); } }); it('non-blocking call server error after stream end', async () => { for (let i = 0; i < 3; i++) { const stream = clientStreaming.invoke(); // Error will be sent in response to the CLIENT_REQUEST_COMPLETION packet. enqueueError( 1, clientStreaming.method, Status.INVALID_ARGUMENT, Status.OK, stream.callId, ); await stream .finishAndWait() .then(() => { fail('Promise should not be resolved'); }) .catch((reason) => { expect(reason.status).toEqual(Status.INVALID_ARGUMENT); }); } }); it('non-blocking call send after cancelled', () => { expect.assertions(2); const stream = clientStreaming.invoke(); expect(stream.cancel()).toBe(true); try { stream.send(newRequest()); } catch (e) { console.log(e); expect(e.status).toEqual(Status.CANCELLED); } // expect(() => stream.send(newRequest())).toThrowError( // error => error.status === Status.CANCELLED // ); }); it('non-blocking finish after completed', async () => { const enqueuedResponse = newResponse('?!'); enqueueResponse( 1, clientStreaming.method, Status.UNAVAILABLE, clientStreaming.rpcs.nextCallId, enqueuedResponse, ); const stream = clientStreaming.invoke(); const result = await stream.finishAndWait(); expect(result[1]).toEqual(enqueuedResponse); expect(await stream.finishAndWait()).toEqual(result); expect(await stream.finishAndWait()).toEqual(result); }); it('non-blocking finish after error', async () => { enqueueError( 1, clientStreaming.method, Status.UNAVAILABLE, Status.OK, clientStreaming.rpcs.nextCallId, ); const stream = clientStreaming.invoke(); for (let i = 0; i < 3; i++) { await stream .finishAndWait() .then(() => { fail('Promise should not be resolved'); }) .catch((reason) => { expect(reason.status).toEqual(Status.UNAVAILABLE); expect(stream.error).toEqual(Status.UNAVAILABLE); expect(stream.response).toBeUndefined(); }); } }); }); describe('BidirectionalStreaming', () => { let bidiStreaming: BidirectionalStreamingMethodStub; beforeEach(async () => { bidiStreaming = client .channel() ?.methodStub( 'pw.rpc.test1.TheTestService.SomeBidiStreaming', ) as BidirectionalStreamingMethodStub; }); it('blocking call', async () => { const testRequests = [newRequest(123), newRequest(456)]; sendResponsesAfterPackets = 3; enqueueResponse( 1, bidiStreaming.method, Status.NOT_FOUND, bidiStreaming.rpcs.nextCallId, ); const results = await bidiStreaming.call(testRequests); expect(results[0]).toEqual(Status.NOT_FOUND); expect(results[1]).toEqual([]); }); it('blocking server error', async () => { const testRequests = [newRequest(123)]; enqueueError( 1, bidiStreaming.method, Status.NOT_FOUND, Status.OK, bidiStreaming.rpcs.nextCallId, ); await bidiStreaming .call(testRequests) .then(() => { fail('Promise should not be resolved'); }) .catch((reason) => { expect(reason.status).toEqual(Status.NOT_FOUND); }); }); it('non-blocking call', () => { const rep1 = newResponse('!!!'); const rep2 = newResponse('?'); for (let i = 0; i < 3; i++) { const testResponses: Array = []; const stream = bidiStreaming.invoke((response) => { testResponses.push(response); }); expect(stream.completed).toBe(false); stream.send(newRequest(55)); expect(lastRequest().getType()).toEqual(PacketType.CLIENT_STREAM); expect(sentPayload(Request).getMagicNumber()).toEqual(55); expect(stream.completed).toBe(false); expect(testResponses).toEqual([]); enqueueServerStream(1, bidiStreaming.method, rep1, stream.callId); enqueueServerStream(1, bidiStreaming.method, rep2, stream.callId); stream.send(newRequest(66)); expect(lastRequest().getType()).toEqual(PacketType.CLIENT_STREAM); expect(sentPayload(Request).getMagicNumber()).toEqual(66); expect(stream.completed).toBe(false); expect(testResponses).toEqual([rep1, rep2]); enqueueResponse(1, bidiStreaming.method, Status.OK, stream.callId); stream.send(newRequest(77)); expect(stream.completed).toBe(true); expect(testResponses).toEqual([rep1, rep2]); expect(stream.status).toEqual(Status.OK); expect(stream.error).toBeUndefined(); } }); it('open', () => { outputException = Error('Error should be ignored'); const response1 = newResponse('!!!'); const response2 = newResponse('?'); for (let i = 0; i < 3; i++) { enqueueServerStream( 1, bidiStreaming.method, response1, bidiStreaming.rpcs.nextCallId, ); enqueueServerStream( 1, bidiStreaming.method, response2, bidiStreaming.rpcs.nextCallId, ); enqueueResponse( 1, bidiStreaming.method, Status.OK, bidiStreaming.rpcs.nextCallId, ); const onNext = jest.fn(); const onCompleted = jest.fn(); const onError = jest.fn(); const call = bidiStreaming.open(onNext, onCompleted, onError); expect(requests).toHaveLength(0); processEnqueuedPackets(); expect(onNext).toHaveBeenCalledWith(response1); expect(onNext).toHaveBeenCalledWith(response2); expect(onError).not.toHaveBeenCalled(); expect(onCompleted).toHaveBeenCalledWith(Status.OK); } }); it('blocking timeout', async () => { try { await bidiStreaming.call([newRequest()], 10); fail('Promise should not be resolve'); } catch (err: any) { expect(err.timeoutMs).toEqual(10); } }); it('non-blocking server error', async () => { const response = newResponse('!!!'); for (let i = 0; i < 3; i++) { const testResponses: Array = []; const stream = bidiStreaming.invoke((response) => { testResponses.push(response); }); expect(stream.completed).toBe(false); enqueueServerStream(1, bidiStreaming.method, response, stream.callId); stream.send(newRequest(55)); expect(stream.completed).toBe(false); expect(testResponses).toEqual([response]); enqueueError( 1, bidiStreaming.method, Status.OUT_OF_RANGE, Status.OK, stream.callId, ); stream.send(newRequest(999)); expect(stream.completed).toBe(true); expect(testResponses).toEqual([response]); expect(stream.status).toBeUndefined(); expect(stream.error).toEqual(Status.OUT_OF_RANGE); await stream .finishAndWait() .then(() => { fail('Promise should not be resolved'); }) .catch((reason) => { expect(reason.status).toEqual(Status.OUT_OF_RANGE); }); } }); it('non-blocking server error after stream end', async () => { for (let i = 0; i < 3; i++) { const stream = bidiStreaming.invoke(); // Error is sent in response to CLIENT_REQUEST_COMPLETION packet. enqueueError( 1, bidiStreaming.method, Status.INVALID_ARGUMENT, Status.OK, stream.callId, ); await stream .finishAndWait() .then(() => { fail('Promise should not be resolved'); }) .catch((reason) => { expect(reason.status).toEqual(Status.INVALID_ARGUMENT); }); } }); it('non-blocking send after cancelled', async () => { const stream = bidiStreaming.invoke(); expect(stream.cancel()).toBe(true); try { stream.send(newRequest()); fail('send should have failed'); } catch (e: any) { expect(e.status).toBe(Status.CANCELLED); } }); it('non-blocking finish after completed', async () => { const response = newResponse('!?'); enqueueServerStream( 1, bidiStreaming.method, response, bidiStreaming.rpcs.nextCallId, ); enqueueResponse( 1, bidiStreaming.method, Status.UNAVAILABLE, bidiStreaming.rpcs.nextCallId, ); const stream = bidiStreaming.invoke(); const result = await stream.finishAndWait(); expect(result[1]).toEqual([response]); expect(await stream.finishAndWait()).toEqual(result); expect(await stream.finishAndWait()).toEqual(result); }); it('non-blocking finish after error', async () => { const response = newResponse('!?'); enqueueServerStream( 1, bidiStreaming.method, response, bidiStreaming.rpcs.nextCallId, ); enqueueError( 1, bidiStreaming.method, Status.UNAVAILABLE, Status.OK, bidiStreaming.rpcs.nextCallId, ); const stream = bidiStreaming.invoke(); for (let i = 0; i < 3; i++) { await stream .finishAndWait() .then(() => { fail('Promise should not be resolved'); }) .catch((reason) => { expect(reason.status).toEqual(Status.UNAVAILABLE); expect(stream.error).toEqual(Status.UNAVAILABLE); }); } }); }); }); describe('RPC with custom serializers', () => { let client: Client; let lastPacketSent: RpcPacket | undefined; let requests: RpcPacket[] = []; let nextPackets: [Uint8Array, Status][] = []; let responseLock = false; let sendResponsesAfterPackets = 0; let outputException: Error | undefined; const requestSerializeFn = jest.fn(); const responseDeserializeFn = jest.fn(); beforeEach(async () => { const channels = [ new Channel(1, handlePacket), new Channel(2, () => { // Do nothing. }), ]; const services = [ new Service('pw.rpc.test1.TheTestService', [ { name: 'SomeUnary', requestType: Request, responseType: Response, customRequestSerializer: { serialize: (msg) => { requestSerializeFn(msg); return msg.serializeBinary(); }, deserialize: Request.deserializeBinary, }, customResponseSerializer: { serialize: (msg) => { return msg.serializeBinary(); }, deserialize: (bytes) => { responseDeserializeFn(bytes); return Response.deserializeBinary(bytes); }, }, }, ]), ]; client = new Client(channels, services); lastPacketSent = undefined; requests = []; nextPackets = []; responseLock = false; sendResponsesAfterPackets = 0; outputException = undefined; }); function lastRequest(): RpcPacket { if (requests.length == 0) { throw Error('Tried to fetch request from empty list'); } return requests[requests.length - 1]; } function newRequest(magicNumber = 123): Message { const request = new Request(); request.setMagicNumber(magicNumber); return request; } function newResponse(payload = '._.'): Message { const response = new Response(); response.setPayload(payload); return response; } function handlePacket(data: Uint8Array): void { if (outputException !== undefined) { throw outputException; } requests.push(packets.decode(data)); if (sendResponsesAfterPackets > 1) { sendResponsesAfterPackets -= 1; return; } processEnqueuedPackets(); } function processEnqueuedPackets(): void { // Avoid infinite recursion when processing a packet causes another packet // to send. if (responseLock) return; responseLock = true; for (const [packet, status] of nextPackets) { expect(client.processPacket(packet)).toEqual(status); } nextPackets = []; responseLock = false; } function enqueueResponse( channelId: number, method: Method, status: Status, callId: number, response?: Message, ) { const packet = new RpcPacket(); packet.setType(PacketType.RESPONSE); packet.setChannelId(channelId); packet.setServiceId(method.service.id); packet.setMethodId(method.id); packet.setCallId(callId); packet.setStatus(status); if (response === undefined) { packet.setPayload(new Uint8Array(0)); } else { packet.setPayload(response.serializeBinary()); } nextPackets.push([packet.serializeBinary(), Status.OK]); } function sentPayload(messageType: typeof Message): any { return messageType.deserializeBinary(lastRequest().getPayload_asU8()); } describe('Unary', () => { let unaryStub: UnaryMethodStub; beforeEach(async () => { unaryStub = client .channel() ?.methodStub( 'pw.rpc.test1.TheTestService.SomeUnary', ) as UnaryMethodStub; }); it('blocking call', async () => { for (let i = 0; i < 3; i++) { enqueueResponse( 1, unaryStub.method, Status.ABORTED, unaryStub.rpcs.nextCallId, newResponse('0_o'), ); const [status, response] = await unaryStub.call(newRequest(6)); expect(sentPayload(Request).getMagicNumber()).toEqual(6); expect(status).toEqual(Status.ABORTED); expect(response).toEqual(newResponse('0_o')); expect(requestSerializeFn).toBeCalledWith(newRequest(6)); expect(responseDeserializeFn).toBeCalledWith( newResponse('0_o').serializeBinary(), ); } }); }); });