1# Copyright 2017 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Helpers for :mod:`grpc`."""
16
17import collections
18import functools
19
20import grpc
21import pkg_resources
22
23from google.api_core import exceptions
24import google.auth
25import google.auth.credentials
26import google.auth.transport.grpc
27import google.auth.transport.requests
28
29try:
30    import grpc_gcp
31
32    HAS_GRPC_GCP = True
33except ImportError:
34    HAS_GRPC_GCP = False
35
36try:
37    # google.auth.__version__ was added in 1.26.0
38    _GOOGLE_AUTH_VERSION = google.auth.__version__
39except AttributeError:
40    try:  # try pkg_resources if it is available
41        _GOOGLE_AUTH_VERSION = pkg_resources.get_distribution("google-auth").version
42    except pkg_resources.DistributionNotFound:  # pragma: NO COVER
43        _GOOGLE_AUTH_VERSION = None
44
45# The list of gRPC Callable interfaces that return iterators.
46_STREAM_WRAP_CLASSES = (grpc.UnaryStreamMultiCallable, grpc.StreamStreamMultiCallable)
47
48
49def _patch_callable_name(callable_):
50    """Fix-up gRPC callable attributes.
51
52    gRPC callable lack the ``__name__`` attribute which causes
53    :func:`functools.wraps` to error. This adds the attribute if needed.
54    """
55    if not hasattr(callable_, "__name__"):
56        callable_.__name__ = callable_.__class__.__name__
57
58
59def _wrap_unary_errors(callable_):
60    """Map errors for Unary-Unary and Stream-Unary gRPC callables."""
61    _patch_callable_name(callable_)
62
63    @functools.wraps(callable_)
64    def error_remapped_callable(*args, **kwargs):
65        try:
66            return callable_(*args, **kwargs)
67        except grpc.RpcError as exc:
68            raise exceptions.from_grpc_error(exc) from exc
69
70    return error_remapped_callable
71
72
73class _StreamingResponseIterator(grpc.Call):
74    def __init__(self, wrapped, prefetch_first_result=True):
75        self._wrapped = wrapped
76
77        # This iterator is used in a retry context, and returned outside after init.
78        # gRPC will not throw an exception until the stream is consumed, so we need
79        # to retrieve the first result, in order to fail, in order to trigger a retry.
80        try:
81            if prefetch_first_result:
82                self._stored_first_result = next(self._wrapped)
83        except TypeError:
84            # It is possible the wrapped method isn't an iterable (a grpc.Call
85            # for instance). If this happens don't store the first result.
86            pass
87        except StopIteration:
88            # ignore stop iteration at this time. This should be handled outside of retry.
89            pass
90
91    def __iter__(self):
92        """This iterator is also an iterable that returns itself."""
93        return self
94
95    def __next__(self):
96        """Get the next response from the stream.
97
98        Returns:
99            protobuf.Message: A single response from the stream.
100        """
101        try:
102            if hasattr(self, "_stored_first_result"):
103                result = self._stored_first_result
104                del self._stored_first_result
105                return result
106            return next(self._wrapped)
107        except grpc.RpcError as exc:
108            # If the stream has already returned data, we cannot recover here.
109            raise exceptions.from_grpc_error(exc) from exc
110
111    # grpc.Call & grpc.RpcContext interface
112
113    def add_callback(self, callback):
114        return self._wrapped.add_callback(callback)
115
116    def cancel(self):
117        return self._wrapped.cancel()
118
119    def code(self):
120        return self._wrapped.code()
121
122    def details(self):
123        return self._wrapped.details()
124
125    def initial_metadata(self):
126        return self._wrapped.initial_metadata()
127
128    def is_active(self):
129        return self._wrapped.is_active()
130
131    def time_remaining(self):
132        return self._wrapped.time_remaining()
133
134    def trailing_metadata(self):
135        return self._wrapped.trailing_metadata()
136
137
138def _wrap_stream_errors(callable_):
139    """Wrap errors for Unary-Stream and Stream-Stream gRPC callables.
140
141    The callables that return iterators require a bit more logic to re-map
142    errors when iterating. This wraps both the initial invocation and the
143    iterator of the return value to re-map errors.
144    """
145    _patch_callable_name(callable_)
146
147    @functools.wraps(callable_)
148    def error_remapped_callable(*args, **kwargs):
149        try:
150            result = callable_(*args, **kwargs)
151            # Auto-fetching the first result causes PubSub client's streaming pull
152            # to hang when re-opening the stream, thus we need examine the hacky
153            # hidden flag to see if pre-fetching is disabled.
154            # https://github.com/googleapis/python-pubsub/issues/93#issuecomment-630762257
155            prefetch_first = getattr(callable_, "_prefetch_first_result_", True)
156            return _StreamingResponseIterator(
157                result, prefetch_first_result=prefetch_first
158            )
159        except grpc.RpcError as exc:
160            raise exceptions.from_grpc_error(exc) from exc
161
162    return error_remapped_callable
163
164
165def wrap_errors(callable_):
166    """Wrap a gRPC callable and map :class:`grpc.RpcErrors` to friendly error
167    classes.
168
169    Errors raised by the gRPC callable are mapped to the appropriate
170    :class:`google.api_core.exceptions.GoogleAPICallError` subclasses.
171    The original `grpc.RpcError` (which is usually also a `grpc.Call`) is
172    available from the ``response`` property on the mapped exception. This
173    is useful for extracting metadata from the original error.
174
175    Args:
176        callable_ (Callable): A gRPC callable.
177
178    Returns:
179        Callable: The wrapped gRPC callable.
180    """
181    if isinstance(callable_, _STREAM_WRAP_CLASSES):
182        return _wrap_stream_errors(callable_)
183    else:
184        return _wrap_unary_errors(callable_)
185
186
187def _create_composite_credentials(
188    credentials=None,
189    credentials_file=None,
190    default_scopes=None,
191    scopes=None,
192    ssl_credentials=None,
193    quota_project_id=None,
194    default_host=None,
195):
196    """Create the composite credentials for secure channels.
197
198    Args:
199        credentials (google.auth.credentials.Credentials): The credentials. If
200            not specified, then this function will attempt to ascertain the
201            credentials from the environment using :func:`google.auth.default`.
202        credentials_file (str): A file with credentials that can be loaded with
203            :func:`google.auth.load_credentials_from_file`. This argument is
204            mutually exclusive with credentials.
205        default_scopes (Sequence[str]): A optional list of scopes needed for this
206            service. These are only used when credentials are not specified and
207            are passed to :func:`google.auth.default`.
208        scopes (Sequence[str]): A optional list of scopes needed for this
209            service. These are only used when credentials are not specified and
210            are passed to :func:`google.auth.default`.
211        ssl_credentials (grpc.ChannelCredentials): Optional SSL channel
212            credentials. This can be used to specify different certificates.
213        quota_project_id (str): An optional project to use for billing and quota.
214        default_host (str): The default endpoint. e.g., "pubsub.googleapis.com".
215
216    Returns:
217        grpc.ChannelCredentials: The composed channel credentials object.
218
219    Raises:
220        google.api_core.DuplicateCredentialArgs: If both a credentials object and credentials_file are passed.
221    """
222    if credentials and credentials_file:
223        raise exceptions.DuplicateCredentialArgs(
224            "'credentials' and 'credentials_file' are mutually exclusive."
225        )
226
227    if credentials_file:
228        credentials, _ = google.auth.load_credentials_from_file(
229            credentials_file, scopes=scopes, default_scopes=default_scopes
230        )
231    elif credentials:
232        credentials = google.auth.credentials.with_scopes_if_required(
233            credentials, scopes=scopes, default_scopes=default_scopes
234        )
235    else:
236        credentials, _ = google.auth.default(
237            scopes=scopes, default_scopes=default_scopes
238        )
239
240    if quota_project_id and isinstance(
241        credentials, google.auth.credentials.CredentialsWithQuotaProject
242    ):
243        credentials = credentials.with_quota_project(quota_project_id)
244
245    request = google.auth.transport.requests.Request()
246
247    # Create the metadata plugin for inserting the authorization header.
248    metadata_plugin = google.auth.transport.grpc.AuthMetadataPlugin(
249        credentials, request, default_host=default_host,
250    )
251
252    # Create a set of grpc.CallCredentials using the metadata plugin.
253    google_auth_credentials = grpc.metadata_call_credentials(metadata_plugin)
254
255    if ssl_credentials is None:
256        ssl_credentials = grpc.ssl_channel_credentials()
257
258    # Combine the ssl credentials and the authorization credentials.
259    return grpc.composite_channel_credentials(ssl_credentials, google_auth_credentials)
260
261
262def create_channel(
263    target,
264    credentials=None,
265    scopes=None,
266    ssl_credentials=None,
267    credentials_file=None,
268    quota_project_id=None,
269    default_scopes=None,
270    default_host=None,
271    **kwargs
272):
273    """Create a secure channel with credentials.
274
275    Args:
276        target (str): The target service address in the format 'hostname:port'.
277        credentials (google.auth.credentials.Credentials): The credentials. If
278            not specified, then this function will attempt to ascertain the
279            credentials from the environment using :func:`google.auth.default`.
280        scopes (Sequence[str]): A optional list of scopes needed for this
281            service. These are only used when credentials are not specified and
282            are passed to :func:`google.auth.default`.
283        ssl_credentials (grpc.ChannelCredentials): Optional SSL channel
284            credentials. This can be used to specify different certificates.
285        credentials_file (str): A file with credentials that can be loaded with
286            :func:`google.auth.load_credentials_from_file`. This argument is
287            mutually exclusive with credentials.
288        quota_project_id (str): An optional project to use for billing and quota.
289        default_scopes (Sequence[str]): Default scopes passed by a Google client
290            library. Use 'scopes' for user-defined scopes.
291        default_host (str): The default endpoint. e.g., "pubsub.googleapis.com".
292        kwargs: Additional key-word args passed to
293            :func:`grpc_gcp.secure_channel` or :func:`grpc.secure_channel`.
294
295    Returns:
296        grpc.Channel: The created channel.
297
298    Raises:
299        google.api_core.DuplicateCredentialArgs: If both a credentials object and credentials_file are passed.
300    """
301
302    composite_credentials = _create_composite_credentials(
303        credentials=credentials,
304        credentials_file=credentials_file,
305        default_scopes=default_scopes,
306        scopes=scopes,
307        ssl_credentials=ssl_credentials,
308        quota_project_id=quota_project_id,
309        default_host=default_host,
310    )
311
312    if HAS_GRPC_GCP:
313        # If grpc_gcp module is available use grpc_gcp.secure_channel,
314        # otherwise, use grpc.secure_channel to create grpc channel.
315        return grpc_gcp.secure_channel(target, composite_credentials, **kwargs)
316    else:
317        return grpc.secure_channel(target, composite_credentials, **kwargs)
318
319
320_MethodCall = collections.namedtuple(
321    "_MethodCall", ("request", "timeout", "metadata", "credentials")
322)
323
324_ChannelRequest = collections.namedtuple("_ChannelRequest", ("method", "request"))
325
326
327class _CallableStub(object):
328    """Stub for the grpc.*MultiCallable interfaces."""
329
330    def __init__(self, method, channel):
331        self._method = method
332        self._channel = channel
333        self.response = None
334        """Union[protobuf.Message, Callable[protobuf.Message], exception]:
335        The response to give when invoking this callable. If this is a
336        callable, it will be invoked with the request protobuf. If it's an
337        exception, the exception will be raised when this is invoked.
338        """
339        self.responses = None
340        """Iterator[
341            Union[protobuf.Message, Callable[protobuf.Message], exception]]:
342        An iterator of responses. If specified, self.response will be populated
343        on each invocation by calling ``next(self.responses)``."""
344        self.requests = []
345        """List[protobuf.Message]: All requests sent to this callable."""
346        self.calls = []
347        """List[Tuple]: All invocations of this callable. Each tuple is the
348        request, timeout, metadata, and credentials."""
349
350    def __call__(self, request, timeout=None, metadata=None, credentials=None):
351        self._channel.requests.append(_ChannelRequest(self._method, request))
352        self.calls.append(_MethodCall(request, timeout, metadata, credentials))
353        self.requests.append(request)
354
355        response = self.response
356        if self.responses is not None:
357            if response is None:
358                response = next(self.responses)
359            else:
360                raise ValueError(
361                    "{method}.response and {method}.responses are mutually "
362                    "exclusive.".format(method=self._method)
363                )
364
365        if callable(response):
366            return response(request)
367
368        if isinstance(response, Exception):
369            raise response
370
371        if response is not None:
372            return response
373
374        raise ValueError('Method stub for "{}" has no response.'.format(self._method))
375
376
377def _simplify_method_name(method):
378    """Simplifies a gRPC method name.
379
380    When gRPC invokes the channel to create a callable, it gives a full
381    method name like "/google.pubsub.v1.Publisher/CreateTopic". This
382    returns just the name of the method, in this case "CreateTopic".
383
384    Args:
385        method (str): The name of the method.
386
387    Returns:
388        str: The simplified name of the method.
389    """
390    return method.rsplit("/", 1).pop()
391
392
393class ChannelStub(grpc.Channel):
394    """A testing stub for the grpc.Channel interface.
395
396    This can be used to test any client that eventually uses a gRPC channel
397    to communicate. By passing in a channel stub, you can configure which
398    responses are returned and track which requests are made.
399
400    For example:
401
402    .. code-block:: python
403
404        channel_stub = grpc_helpers.ChannelStub()
405        client = FooClient(channel=channel_stub)
406
407        channel_stub.GetFoo.response = foo_pb2.Foo(name='bar')
408
409        foo = client.get_foo(labels=['baz'])
410
411        assert foo.name == 'bar'
412        assert channel_stub.GetFoo.requests[0].labels = ['baz']
413
414    Each method on the stub can be accessed and configured on the channel.
415    Here's some examples of various configurations:
416
417    .. code-block:: python
418
419        # Return a basic response:
420
421        channel_stub.GetFoo.response = foo_pb2.Foo(name='bar')
422        assert client.get_foo().name == 'bar'
423
424        # Raise an exception:
425        channel_stub.GetFoo.response = NotFound('...')
426
427        with pytest.raises(NotFound):
428            client.get_foo()
429
430        # Use a sequence of responses:
431        channel_stub.GetFoo.responses = iter([
432            foo_pb2.Foo(name='bar'),
433            foo_pb2.Foo(name='baz'),
434        ])
435
436        assert client.get_foo().name == 'bar'
437        assert client.get_foo().name == 'baz'
438
439        # Use a callable
440
441        def on_get_foo(request):
442            return foo_pb2.Foo(name='bar' + request.id)
443
444        channel_stub.GetFoo.response = on_get_foo
445
446        assert client.get_foo(id='123').name == 'bar123'
447    """
448
449    def __init__(self, responses=[]):
450        self.requests = []
451        """Sequence[Tuple[str, protobuf.Message]]: A list of all requests made
452        on this channel in order. The tuple is of method name, request
453        message."""
454        self._method_stubs = {}
455
456    def _stub_for_method(self, method):
457        method = _simplify_method_name(method)
458        self._method_stubs[method] = _CallableStub(method, self)
459        return self._method_stubs[method]
460
461    def __getattr__(self, key):
462        try:
463            return self._method_stubs[key]
464        except KeyError:
465            raise AttributeError
466
467    def unary_unary(self, method, request_serializer=None, response_deserializer=None):
468        """grpc.Channel.unary_unary implementation."""
469        return self._stub_for_method(method)
470
471    def unary_stream(self, method, request_serializer=None, response_deserializer=None):
472        """grpc.Channel.unary_stream implementation."""
473        return self._stub_for_method(method)
474
475    def stream_unary(self, method, request_serializer=None, response_deserializer=None):
476        """grpc.Channel.stream_unary implementation."""
477        return self._stub_for_method(method)
478
479    def stream_stream(
480        self, method, request_serializer=None, response_deserializer=None
481    ):
482        """grpc.Channel.stream_stream implementation."""
483        return self._stub_for_method(method)
484
485    def subscribe(self, callback, try_to_connect=False):
486        """grpc.Channel.subscribe implementation."""
487        pass
488
489    def unsubscribe(self, callback):
490        """grpc.Channel.unsubscribe implementation."""
491        pass
492
493    def close(self):
494        """grpc.Channel.close implementation."""
495        pass
496