xref: /aosp_15_r20/external/grpc-grpc/src/python/grpcio_tests/tests/unit/_metadata_code_details_test.py (revision cc02d7e222339f7a4f6ba5f422e6413f4bd931f2)
1# Copyright 2016 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests application-provided metadata, status code, and details."""
15
16import logging
17import threading
18import unittest
19
20import grpc
21
22from tests.unit import test_common
23from tests.unit.framework.common import test_constants
24from tests.unit.framework.common import test_control
25
26_SERIALIZED_REQUEST = b"\x46\x47\x48"
27_SERIALIZED_RESPONSE = b"\x49\x50\x51"
28
29_REQUEST_SERIALIZER = lambda unused_request: _SERIALIZED_REQUEST
30_REQUEST_DESERIALIZER = lambda unused_serialized_request: object()
31_RESPONSE_SERIALIZER = lambda unused_response: _SERIALIZED_RESPONSE
32_RESPONSE_DESERIALIZER = lambda unused_serialized_response: object()
33
34_SERVICE = "test.TestService"
35_UNARY_UNARY = "UnaryUnary"
36_UNARY_STREAM = "UnaryStream"
37_STREAM_UNARY = "StreamUnary"
38_STREAM_STREAM = "StreamStream"
39
40_CLIENT_METADATA = (
41    ("client-md-key", "client-md-key"),
42    ("client-md-key-bin", b"\x00\x01"),
43)
44
45_SERVER_INITIAL_METADATA = (
46    ("server-initial-md-key", "server-initial-md-value"),
47    ("server-initial-md-key-bin", b"\x00\x02"),
48)
49
50_SERVER_TRAILING_METADATA = (
51    ("server-trailing-md-key", "server-trailing-md-value"),
52    ("server-trailing-md-key-bin", b"\x00\x03"),
53)
54
55_NON_OK_CODE = grpc.StatusCode.NOT_FOUND
56_DETAILS = "Test details!"
57
58# calling abort should always fail an RPC, even for "invalid" codes
59_ABORT_CODES = (_NON_OK_CODE, 3, grpc.StatusCode.OK)
60_EXPECTED_CLIENT_CODES = (
61    _NON_OK_CODE,
62    grpc.StatusCode.UNKNOWN,
63    grpc.StatusCode.UNKNOWN,
64)
65_EXPECTED_DETAILS = (_DETAILS, _DETAILS, "")
66
67
68class _Servicer(object):
69    def __init__(self):
70        self._lock = threading.Lock()
71        self._abort_call = False
72        self._code = None
73        self._details = None
74        self._exception = False
75        self._return_none = False
76        self._received_client_metadata = None
77
78    def unary_unary(self, request, context):
79        with self._lock:
80            self._received_client_metadata = context.invocation_metadata()
81            context.send_initial_metadata(_SERVER_INITIAL_METADATA)
82            context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
83            if self._abort_call:
84                context.abort(self._code, self._details)
85            else:
86                if self._code is not None:
87                    context.set_code(self._code)
88                if self._details is not None:
89                    context.set_details(self._details)
90            if self._exception:
91                raise test_control.Defect()
92            else:
93                return None if self._return_none else object()
94
95    def unary_stream(self, request, context):
96        with self._lock:
97            self._received_client_metadata = context.invocation_metadata()
98            context.send_initial_metadata(_SERVER_INITIAL_METADATA)
99            context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
100            if self._abort_call:
101                context.abort(self._code, self._details)
102            else:
103                if self._code is not None:
104                    context.set_code(self._code)
105                if self._details is not None:
106                    context.set_details(self._details)
107            for _ in range(test_constants.STREAM_LENGTH // 2):
108                yield _SERIALIZED_RESPONSE
109            if self._exception:
110                raise test_control.Defect()
111
112    def stream_unary(self, request_iterator, context):
113        with self._lock:
114            self._received_client_metadata = context.invocation_metadata()
115            context.send_initial_metadata(_SERVER_INITIAL_METADATA)
116            context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
117            # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the
118            # request iterator.
119            list(request_iterator)
120            if self._abort_call:
121                context.abort(self._code, self._details)
122            else:
123                if self._code is not None:
124                    context.set_code(self._code)
125                if self._details is not None:
126                    context.set_details(self._details)
127            if self._exception:
128                raise test_control.Defect()
129            else:
130                return None if self._return_none else _SERIALIZED_RESPONSE
131
132    def stream_stream(self, request_iterator, context):
133        with self._lock:
134            self._received_client_metadata = context.invocation_metadata()
135            context.send_initial_metadata(_SERVER_INITIAL_METADATA)
136            context.set_trailing_metadata(_SERVER_TRAILING_METADATA)
137            # TODO(https://github.com/grpc/grpc/issues/6891): just ignore the
138            # request iterator.
139            list(request_iterator)
140            if self._abort_call:
141                context.abort(self._code, self._details)
142            else:
143                if self._code is not None:
144                    context.set_code(self._code)
145                if self._details is not None:
146                    context.set_details(self._details)
147            for _ in range(test_constants.STREAM_LENGTH // 3):
148                yield object()
149            if self._exception:
150                raise test_control.Defect()
151
152    def set_abort_call(self):
153        with self._lock:
154            self._abort_call = True
155
156    def set_code(self, code):
157        with self._lock:
158            self._code = code
159
160    def set_details(self, details):
161        with self._lock:
162            self._details = details
163
164    def set_exception(self):
165        with self._lock:
166            self._exception = True
167
168    def set_return_none(self):
169        with self._lock:
170            self._return_none = True
171
172    def received_client_metadata(self):
173        with self._lock:
174            return self._received_client_metadata
175
176
177def _generic_handler(servicer):
178    method_handlers = {
179        _UNARY_UNARY: grpc.unary_unary_rpc_method_handler(
180            servicer.unary_unary,
181            request_deserializer=_REQUEST_DESERIALIZER,
182            response_serializer=_RESPONSE_SERIALIZER,
183        ),
184        _UNARY_STREAM: grpc.unary_stream_rpc_method_handler(
185            servicer.unary_stream
186        ),
187        _STREAM_UNARY: grpc.stream_unary_rpc_method_handler(
188            servicer.stream_unary
189        ),
190        _STREAM_STREAM: grpc.stream_stream_rpc_method_handler(
191            servicer.stream_stream,
192            request_deserializer=_REQUEST_DESERIALIZER,
193            response_serializer=_RESPONSE_SERIALIZER,
194        ),
195    }
196    return grpc.method_handlers_generic_handler(_SERVICE, method_handlers)
197
198
199class MetadataCodeDetailsTest(unittest.TestCase):
200    def setUp(self):
201        self._servicer = _Servicer()
202        self._server = test_common.test_server()
203        self._server.add_generic_rpc_handlers(
204            (_generic_handler(self._servicer),)
205        )
206        port = self._server.add_insecure_port("[::]:0")
207        self._server.start()
208
209        self._channel = grpc.insecure_channel("localhost:{}".format(port))
210        unary_unary_method_name = "/".join(
211            (
212                "",
213                _SERVICE,
214                _UNARY_UNARY,
215            )
216        )
217        self._unary_unary = self._channel.unary_unary(
218            unary_unary_method_name,
219            request_serializer=_REQUEST_SERIALIZER,
220            response_deserializer=_RESPONSE_DESERIALIZER,
221            _registered_method=True,
222        )
223        unary_stream_method_name = "/".join(
224            (
225                "",
226                _SERVICE,
227                _UNARY_STREAM,
228            )
229        )
230        self._unary_stream = self._channel.unary_stream(
231            unary_stream_method_name,
232            _registered_method=True,
233        )
234        stream_unary_method_name = "/".join(
235            (
236                "",
237                _SERVICE,
238                _STREAM_UNARY,
239            )
240        )
241        self._stream_unary = self._channel.stream_unary(
242            stream_unary_method_name,
243            _registered_method=True,
244        )
245        stream_stream_method_name = "/".join(
246            (
247                "",
248                _SERVICE,
249                _STREAM_STREAM,
250            )
251        )
252        self._stream_stream = self._channel.stream_stream(
253            stream_stream_method_name,
254            request_serializer=_REQUEST_SERIALIZER,
255            response_deserializer=_RESPONSE_DESERIALIZER,
256            _registered_method=True,
257        )
258
259    def tearDown(self):
260        self._server.stop(None)
261        self._channel.close()
262
263    def testSuccessfulUnaryUnary(self):
264        self._servicer.set_details(_DETAILS)
265
266        unused_response, call = self._unary_unary.with_call(
267            object(), metadata=_CLIENT_METADATA
268        )
269
270        self.assertTrue(
271            test_common.metadata_transmitted(
272                _CLIENT_METADATA, self._servicer.received_client_metadata()
273            )
274        )
275        self.assertTrue(
276            test_common.metadata_transmitted(
277                _SERVER_INITIAL_METADATA, call.initial_metadata()
278            )
279        )
280        self.assertTrue(
281            test_common.metadata_transmitted(
282                _SERVER_TRAILING_METADATA, call.trailing_metadata()
283            )
284        )
285        self.assertIs(grpc.StatusCode.OK, call.code())
286
287    def testSuccessfulUnaryStream(self):
288        self._servicer.set_details(_DETAILS)
289
290        response_iterator_call = self._unary_stream(
291            _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA
292        )
293        received_initial_metadata = response_iterator_call.initial_metadata()
294        list(response_iterator_call)
295
296        self.assertTrue(
297            test_common.metadata_transmitted(
298                _CLIENT_METADATA, self._servicer.received_client_metadata()
299            )
300        )
301        self.assertTrue(
302            test_common.metadata_transmitted(
303                _SERVER_INITIAL_METADATA, received_initial_metadata
304            )
305        )
306        self.assertTrue(
307            test_common.metadata_transmitted(
308                _SERVER_TRAILING_METADATA,
309                response_iterator_call.trailing_metadata(),
310            )
311        )
312        self.assertIs(grpc.StatusCode.OK, response_iterator_call.code())
313
314    def testSuccessfulStreamUnary(self):
315        self._servicer.set_details(_DETAILS)
316
317        unused_response, call = self._stream_unary.with_call(
318            iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
319            metadata=_CLIENT_METADATA,
320        )
321
322        self.assertTrue(
323            test_common.metadata_transmitted(
324                _CLIENT_METADATA, self._servicer.received_client_metadata()
325            )
326        )
327        self.assertTrue(
328            test_common.metadata_transmitted(
329                _SERVER_INITIAL_METADATA, call.initial_metadata()
330            )
331        )
332        self.assertTrue(
333            test_common.metadata_transmitted(
334                _SERVER_TRAILING_METADATA, call.trailing_metadata()
335            )
336        )
337        self.assertIs(grpc.StatusCode.OK, call.code())
338
339    def testSuccessfulStreamStream(self):
340        self._servicer.set_details(_DETAILS)
341
342        response_iterator_call = self._stream_stream(
343            iter([object()] * test_constants.STREAM_LENGTH),
344            metadata=_CLIENT_METADATA,
345        )
346        received_initial_metadata = response_iterator_call.initial_metadata()
347        list(response_iterator_call)
348
349        self.assertTrue(
350            test_common.metadata_transmitted(
351                _CLIENT_METADATA, self._servicer.received_client_metadata()
352            )
353        )
354        self.assertTrue(
355            test_common.metadata_transmitted(
356                _SERVER_INITIAL_METADATA, received_initial_metadata
357            )
358        )
359        self.assertTrue(
360            test_common.metadata_transmitted(
361                _SERVER_TRAILING_METADATA,
362                response_iterator_call.trailing_metadata(),
363            )
364        )
365        self.assertIs(grpc.StatusCode.OK, response_iterator_call.code())
366
367    def testAbortedUnaryUnary(self):
368        test_cases = zip(
369            _ABORT_CODES, _EXPECTED_CLIENT_CODES, _EXPECTED_DETAILS
370        )
371        for abort_code, expected_code, expected_details in test_cases:
372            self._servicer.set_code(abort_code)
373            self._servicer.set_details(_DETAILS)
374            self._servicer.set_abort_call()
375
376            with self.assertRaises(grpc.RpcError) as exception_context:
377                self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
378
379            self.assertTrue(
380                test_common.metadata_transmitted(
381                    _CLIENT_METADATA, self._servicer.received_client_metadata()
382                )
383            )
384            self.assertTrue(
385                test_common.metadata_transmitted(
386                    _SERVER_INITIAL_METADATA,
387                    exception_context.exception.initial_metadata(),
388                )
389            )
390            self.assertTrue(
391                test_common.metadata_transmitted(
392                    _SERVER_TRAILING_METADATA,
393                    exception_context.exception.trailing_metadata(),
394                )
395            )
396            self.assertIs(expected_code, exception_context.exception.code())
397            self.assertEqual(
398                expected_details, exception_context.exception.details()
399            )
400
401    def testAbortedUnaryStream(self):
402        test_cases = zip(
403            _ABORT_CODES, _EXPECTED_CLIENT_CODES, _EXPECTED_DETAILS
404        )
405        for abort_code, expected_code, expected_details in test_cases:
406            self._servicer.set_code(abort_code)
407            self._servicer.set_details(_DETAILS)
408            self._servicer.set_abort_call()
409
410            response_iterator_call = self._unary_stream(
411                _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA
412            )
413            received_initial_metadata = (
414                response_iterator_call.initial_metadata()
415            )
416            with self.assertRaises(grpc.RpcError):
417                self.assertEqual(len(list(response_iterator_call)), 0)
418
419            self.assertTrue(
420                test_common.metadata_transmitted(
421                    _CLIENT_METADATA, self._servicer.received_client_metadata()
422                )
423            )
424            self.assertTrue(
425                test_common.metadata_transmitted(
426                    _SERVER_INITIAL_METADATA, received_initial_metadata
427                )
428            )
429            self.assertTrue(
430                test_common.metadata_transmitted(
431                    _SERVER_TRAILING_METADATA,
432                    response_iterator_call.trailing_metadata(),
433                )
434            )
435            self.assertIs(expected_code, response_iterator_call.code())
436            self.assertEqual(expected_details, response_iterator_call.details())
437
438    def testAbortedStreamUnary(self):
439        test_cases = zip(
440            _ABORT_CODES, _EXPECTED_CLIENT_CODES, _EXPECTED_DETAILS
441        )
442        for abort_code, expected_code, expected_details in test_cases:
443            self._servicer.set_code(abort_code)
444            self._servicer.set_details(_DETAILS)
445            self._servicer.set_abort_call()
446
447            with self.assertRaises(grpc.RpcError) as exception_context:
448                self._stream_unary.with_call(
449                    iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
450                    metadata=_CLIENT_METADATA,
451                )
452
453            self.assertTrue(
454                test_common.metadata_transmitted(
455                    _CLIENT_METADATA, self._servicer.received_client_metadata()
456                )
457            )
458            self.assertTrue(
459                test_common.metadata_transmitted(
460                    _SERVER_INITIAL_METADATA,
461                    exception_context.exception.initial_metadata(),
462                )
463            )
464            self.assertTrue(
465                test_common.metadata_transmitted(
466                    _SERVER_TRAILING_METADATA,
467                    exception_context.exception.trailing_metadata(),
468                )
469            )
470            self.assertIs(expected_code, exception_context.exception.code())
471            self.assertEqual(
472                expected_details, exception_context.exception.details()
473            )
474
475    def testAbortedStreamStream(self):
476        test_cases = zip(
477            _ABORT_CODES, _EXPECTED_CLIENT_CODES, _EXPECTED_DETAILS
478        )
479        for abort_code, expected_code, expected_details in test_cases:
480            self._servicer.set_code(abort_code)
481            self._servicer.set_details(_DETAILS)
482            self._servicer.set_abort_call()
483
484            response_iterator_call = self._stream_stream(
485                iter([object()] * test_constants.STREAM_LENGTH),
486                metadata=_CLIENT_METADATA,
487            )
488            received_initial_metadata = (
489                response_iterator_call.initial_metadata()
490            )
491            with self.assertRaises(grpc.RpcError):
492                self.assertEqual(len(list(response_iterator_call)), 0)
493
494            self.assertTrue(
495                test_common.metadata_transmitted(
496                    _CLIENT_METADATA, self._servicer.received_client_metadata()
497                )
498            )
499            self.assertTrue(
500                test_common.metadata_transmitted(
501                    _SERVER_INITIAL_METADATA, received_initial_metadata
502                )
503            )
504            self.assertTrue(
505                test_common.metadata_transmitted(
506                    _SERVER_TRAILING_METADATA,
507                    response_iterator_call.trailing_metadata(),
508                )
509            )
510            self.assertIs(expected_code, response_iterator_call.code())
511            self.assertEqual(expected_details, response_iterator_call.details())
512
513    def testCustomCodeUnaryUnary(self):
514        self._servicer.set_code(_NON_OK_CODE)
515        self._servicer.set_details(_DETAILS)
516
517        with self.assertRaises(grpc.RpcError) as exception_context:
518            self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
519
520        self.assertTrue(
521            test_common.metadata_transmitted(
522                _CLIENT_METADATA, self._servicer.received_client_metadata()
523            )
524        )
525        self.assertTrue(
526            test_common.metadata_transmitted(
527                _SERVER_INITIAL_METADATA,
528                exception_context.exception.initial_metadata(),
529            )
530        )
531        self.assertTrue(
532            test_common.metadata_transmitted(
533                _SERVER_TRAILING_METADATA,
534                exception_context.exception.trailing_metadata(),
535            )
536        )
537        self.assertIs(_NON_OK_CODE, exception_context.exception.code())
538        self.assertEqual(_DETAILS, exception_context.exception.details())
539
540    def testCustomCodeUnaryStream(self):
541        self._servicer.set_code(_NON_OK_CODE)
542        self._servicer.set_details(_DETAILS)
543
544        response_iterator_call = self._unary_stream(
545            _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA
546        )
547        received_initial_metadata = response_iterator_call.initial_metadata()
548        with self.assertRaises(grpc.RpcError):
549            list(response_iterator_call)
550
551        self.assertTrue(
552            test_common.metadata_transmitted(
553                _CLIENT_METADATA, self._servicer.received_client_metadata()
554            )
555        )
556        self.assertTrue(
557            test_common.metadata_transmitted(
558                _SERVER_INITIAL_METADATA, received_initial_metadata
559            )
560        )
561        self.assertTrue(
562            test_common.metadata_transmitted(
563                _SERVER_TRAILING_METADATA,
564                response_iterator_call.trailing_metadata(),
565            )
566        )
567        self.assertIs(_NON_OK_CODE, response_iterator_call.code())
568        self.assertEqual(_DETAILS, response_iterator_call.details())
569
570    def testCustomCodeStreamUnary(self):
571        self._servicer.set_code(_NON_OK_CODE)
572        self._servicer.set_details(_DETAILS)
573
574        with self.assertRaises(grpc.RpcError) as exception_context:
575            self._stream_unary.with_call(
576                iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
577                metadata=_CLIENT_METADATA,
578            )
579
580        self.assertTrue(
581            test_common.metadata_transmitted(
582                _CLIENT_METADATA, self._servicer.received_client_metadata()
583            )
584        )
585        self.assertTrue(
586            test_common.metadata_transmitted(
587                _SERVER_INITIAL_METADATA,
588                exception_context.exception.initial_metadata(),
589            )
590        )
591        self.assertTrue(
592            test_common.metadata_transmitted(
593                _SERVER_TRAILING_METADATA,
594                exception_context.exception.trailing_metadata(),
595            )
596        )
597        self.assertIs(_NON_OK_CODE, exception_context.exception.code())
598        self.assertEqual(_DETAILS, exception_context.exception.details())
599
600    def testCustomCodeStreamStream(self):
601        self._servicer.set_code(_NON_OK_CODE)
602        self._servicer.set_details(_DETAILS)
603
604        response_iterator_call = self._stream_stream(
605            iter([object()] * test_constants.STREAM_LENGTH),
606            metadata=_CLIENT_METADATA,
607        )
608        received_initial_metadata = response_iterator_call.initial_metadata()
609        with self.assertRaises(grpc.RpcError) as exception_context:
610            list(response_iterator_call)
611
612        self.assertTrue(
613            test_common.metadata_transmitted(
614                _CLIENT_METADATA, self._servicer.received_client_metadata()
615            )
616        )
617        self.assertTrue(
618            test_common.metadata_transmitted(
619                _SERVER_INITIAL_METADATA, received_initial_metadata
620            )
621        )
622        self.assertTrue(
623            test_common.metadata_transmitted(
624                _SERVER_TRAILING_METADATA,
625                exception_context.exception.trailing_metadata(),
626            )
627        )
628        self.assertIs(_NON_OK_CODE, exception_context.exception.code())
629        self.assertEqual(_DETAILS, exception_context.exception.details())
630
631    def testCustomCodeExceptionUnaryUnary(self):
632        self._servicer.set_code(_NON_OK_CODE)
633        self._servicer.set_details(_DETAILS)
634        self._servicer.set_exception()
635
636        with self.assertRaises(grpc.RpcError) as exception_context:
637            self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
638
639        self.assertTrue(
640            test_common.metadata_transmitted(
641                _CLIENT_METADATA, self._servicer.received_client_metadata()
642            )
643        )
644        self.assertTrue(
645            test_common.metadata_transmitted(
646                _SERVER_INITIAL_METADATA,
647                exception_context.exception.initial_metadata(),
648            )
649        )
650        self.assertTrue(
651            test_common.metadata_transmitted(
652                _SERVER_TRAILING_METADATA,
653                exception_context.exception.trailing_metadata(),
654            )
655        )
656        self.assertIs(_NON_OK_CODE, exception_context.exception.code())
657        self.assertEqual(_DETAILS, exception_context.exception.details())
658
659    def testCustomCodeExceptionUnaryStream(self):
660        self._servicer.set_code(_NON_OK_CODE)
661        self._servicer.set_details(_DETAILS)
662        self._servicer.set_exception()
663
664        response_iterator_call = self._unary_stream(
665            _SERIALIZED_REQUEST, metadata=_CLIENT_METADATA
666        )
667        received_initial_metadata = response_iterator_call.initial_metadata()
668        with self.assertRaises(grpc.RpcError):
669            list(response_iterator_call)
670
671        self.assertTrue(
672            test_common.metadata_transmitted(
673                _CLIENT_METADATA, self._servicer.received_client_metadata()
674            )
675        )
676        self.assertTrue(
677            test_common.metadata_transmitted(
678                _SERVER_INITIAL_METADATA, received_initial_metadata
679            )
680        )
681        self.assertTrue(
682            test_common.metadata_transmitted(
683                _SERVER_TRAILING_METADATA,
684                response_iterator_call.trailing_metadata(),
685            )
686        )
687        self.assertIs(_NON_OK_CODE, response_iterator_call.code())
688        self.assertEqual(_DETAILS, response_iterator_call.details())
689
690    def testCustomCodeExceptionStreamUnary(self):
691        self._servicer.set_code(_NON_OK_CODE)
692        self._servicer.set_details(_DETAILS)
693        self._servicer.set_exception()
694
695        with self.assertRaises(grpc.RpcError) as exception_context:
696            self._stream_unary.with_call(
697                iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
698                metadata=_CLIENT_METADATA,
699            )
700
701        self.assertTrue(
702            test_common.metadata_transmitted(
703                _CLIENT_METADATA, self._servicer.received_client_metadata()
704            )
705        )
706        self.assertTrue(
707            test_common.metadata_transmitted(
708                _SERVER_INITIAL_METADATA,
709                exception_context.exception.initial_metadata(),
710            )
711        )
712        self.assertTrue(
713            test_common.metadata_transmitted(
714                _SERVER_TRAILING_METADATA,
715                exception_context.exception.trailing_metadata(),
716            )
717        )
718        self.assertIs(_NON_OK_CODE, exception_context.exception.code())
719        self.assertEqual(_DETAILS, exception_context.exception.details())
720
721    def testCustomCodeExceptionStreamStream(self):
722        self._servicer.set_code(_NON_OK_CODE)
723        self._servicer.set_details(_DETAILS)
724        self._servicer.set_exception()
725
726        response_iterator_call = self._stream_stream(
727            iter([object()] * test_constants.STREAM_LENGTH),
728            metadata=_CLIENT_METADATA,
729        )
730        received_initial_metadata = response_iterator_call.initial_metadata()
731        with self.assertRaises(grpc.RpcError):
732            list(response_iterator_call)
733
734        self.assertTrue(
735            test_common.metadata_transmitted(
736                _CLIENT_METADATA, self._servicer.received_client_metadata()
737            )
738        )
739        self.assertTrue(
740            test_common.metadata_transmitted(
741                _SERVER_INITIAL_METADATA, received_initial_metadata
742            )
743        )
744        self.assertTrue(
745            test_common.metadata_transmitted(
746                _SERVER_TRAILING_METADATA,
747                response_iterator_call.trailing_metadata(),
748            )
749        )
750        self.assertIs(_NON_OK_CODE, response_iterator_call.code())
751        self.assertEqual(_DETAILS, response_iterator_call.details())
752
753    def testCustomCodeReturnNoneUnaryUnary(self):
754        self._servicer.set_code(_NON_OK_CODE)
755        self._servicer.set_details(_DETAILS)
756        self._servicer.set_return_none()
757
758        with self.assertRaises(grpc.RpcError) as exception_context:
759            self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
760
761        self.assertTrue(
762            test_common.metadata_transmitted(
763                _CLIENT_METADATA, self._servicer.received_client_metadata()
764            )
765        )
766        self.assertTrue(
767            test_common.metadata_transmitted(
768                _SERVER_INITIAL_METADATA,
769                exception_context.exception.initial_metadata(),
770            )
771        )
772        self.assertTrue(
773            test_common.metadata_transmitted(
774                _SERVER_TRAILING_METADATA,
775                exception_context.exception.trailing_metadata(),
776            )
777        )
778        self.assertIs(_NON_OK_CODE, exception_context.exception.code())
779        self.assertEqual(_DETAILS, exception_context.exception.details())
780
781    def testCustomCodeReturnNoneStreamUnary(self):
782        self._servicer.set_code(_NON_OK_CODE)
783        self._servicer.set_details(_DETAILS)
784        self._servicer.set_return_none()
785
786        with self.assertRaises(grpc.RpcError) as exception_context:
787            self._stream_unary.with_call(
788                iter([_SERIALIZED_REQUEST] * test_constants.STREAM_LENGTH),
789                metadata=_CLIENT_METADATA,
790            )
791
792        self.assertTrue(
793            test_common.metadata_transmitted(
794                _CLIENT_METADATA, self._servicer.received_client_metadata()
795            )
796        )
797        self.assertTrue(
798            test_common.metadata_transmitted(
799                _SERVER_INITIAL_METADATA,
800                exception_context.exception.initial_metadata(),
801            )
802        )
803        self.assertTrue(
804            test_common.metadata_transmitted(
805                _SERVER_TRAILING_METADATA,
806                exception_context.exception.trailing_metadata(),
807            )
808        )
809        self.assertIs(_NON_OK_CODE, exception_context.exception.code())
810        self.assertEqual(_DETAILS, exception_context.exception.details())
811
812
813class _InspectServicer(_Servicer):
814    def __init__(self):
815        super(_InspectServicer, self).__init__()
816        self.actual_code = None
817        self.actual_details = None
818        self.actual_trailing_metadata = None
819
820    def unary_unary(self, request, context):
821        super(_InspectServicer, self).unary_unary(request, context)
822
823        self.actual_code = context.code()
824        self.actual_details = context.details()
825        self.actual_trailing_metadata = context.trailing_metadata()
826
827
828class InspectContextTest(unittest.TestCase):
829    def setUp(self):
830        self._servicer = _InspectServicer()
831        self._server = test_common.test_server()
832        self._server.add_generic_rpc_handlers(
833            (_generic_handler(self._servicer),)
834        )
835        port = self._server.add_insecure_port("[::]:0")
836        self._server.start()
837
838        self._channel = grpc.insecure_channel("localhost:{}".format(port))
839        unary_unary_method_name = "/".join(
840            (
841                "",
842                _SERVICE,
843                _UNARY_UNARY,
844            )
845        )
846        self._unary_unary = self._channel.unary_unary(
847            unary_unary_method_name,
848            request_serializer=_REQUEST_SERIALIZER,
849            response_deserializer=_RESPONSE_DESERIALIZER,
850            _registered_method=True,
851        )
852
853    def tearDown(self):
854        self._server.stop(None)
855        self._channel.close()
856
857    def testCodeDetailsInContext(self):
858        self._servicer.set_code(_NON_OK_CODE)
859        self._servicer.set_details(_DETAILS)
860
861        with self.assertRaises(grpc.RpcError) as exc_info:
862            self._unary_unary.with_call(object(), metadata=_CLIENT_METADATA)
863
864        err = exc_info.exception
865        self.assertEqual(_NON_OK_CODE, err.code())
866
867        self.assertEqual(self._servicer.actual_code, _NON_OK_CODE)
868        self.assertEqual(
869            self._servicer.actual_details.decode("utf-8"), _DETAILS
870        )
871        self.assertEqual(
872            self._servicer.actual_trailing_metadata, _SERVER_TRAILING_METADATA
873        )
874
875
876if __name__ == "__main__":
877    logging.basicConfig()
878    unittest.main(verbosity=2)
879