1# Copyright 2017 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Helpers for :mod:`grpc`.""" 16 17import collections 18import functools 19 20import grpc 21import pkg_resources 22 23from google.api_core import exceptions 24import google.auth 25import google.auth.credentials 26import google.auth.transport.grpc 27import google.auth.transport.requests 28 29try: 30 import grpc_gcp 31 32 HAS_GRPC_GCP = True 33except ImportError: 34 HAS_GRPC_GCP = False 35 36try: 37 # google.auth.__version__ was added in 1.26.0 38 _GOOGLE_AUTH_VERSION = google.auth.__version__ 39except AttributeError: 40 try: # try pkg_resources if it is available 41 _GOOGLE_AUTH_VERSION = pkg_resources.get_distribution("google-auth").version 42 except pkg_resources.DistributionNotFound: # pragma: NO COVER 43 _GOOGLE_AUTH_VERSION = None 44 45# The list of gRPC Callable interfaces that return iterators. 46_STREAM_WRAP_CLASSES = (grpc.UnaryStreamMultiCallable, grpc.StreamStreamMultiCallable) 47 48 49def _patch_callable_name(callable_): 50 """Fix-up gRPC callable attributes. 51 52 gRPC callable lack the ``__name__`` attribute which causes 53 :func:`functools.wraps` to error. This adds the attribute if needed. 54 """ 55 if not hasattr(callable_, "__name__"): 56 callable_.__name__ = callable_.__class__.__name__ 57 58 59def _wrap_unary_errors(callable_): 60 """Map errors for Unary-Unary and Stream-Unary gRPC callables.""" 61 _patch_callable_name(callable_) 62 63 @functools.wraps(callable_) 64 def error_remapped_callable(*args, **kwargs): 65 try: 66 return callable_(*args, **kwargs) 67 except grpc.RpcError as exc: 68 raise exceptions.from_grpc_error(exc) from exc 69 70 return error_remapped_callable 71 72 73class _StreamingResponseIterator(grpc.Call): 74 def __init__(self, wrapped, prefetch_first_result=True): 75 self._wrapped = wrapped 76 77 # This iterator is used in a retry context, and returned outside after init. 78 # gRPC will not throw an exception until the stream is consumed, so we need 79 # to retrieve the first result, in order to fail, in order to trigger a retry. 80 try: 81 if prefetch_first_result: 82 self._stored_first_result = next(self._wrapped) 83 except TypeError: 84 # It is possible the wrapped method isn't an iterable (a grpc.Call 85 # for instance). If this happens don't store the first result. 86 pass 87 except StopIteration: 88 # ignore stop iteration at this time. This should be handled outside of retry. 89 pass 90 91 def __iter__(self): 92 """This iterator is also an iterable that returns itself.""" 93 return self 94 95 def __next__(self): 96 """Get the next response from the stream. 97 98 Returns: 99 protobuf.Message: A single response from the stream. 100 """ 101 try: 102 if hasattr(self, "_stored_first_result"): 103 result = self._stored_first_result 104 del self._stored_first_result 105 return result 106 return next(self._wrapped) 107 except grpc.RpcError as exc: 108 # If the stream has already returned data, we cannot recover here. 109 raise exceptions.from_grpc_error(exc) from exc 110 111 # grpc.Call & grpc.RpcContext interface 112 113 def add_callback(self, callback): 114 return self._wrapped.add_callback(callback) 115 116 def cancel(self): 117 return self._wrapped.cancel() 118 119 def code(self): 120 return self._wrapped.code() 121 122 def details(self): 123 return self._wrapped.details() 124 125 def initial_metadata(self): 126 return self._wrapped.initial_metadata() 127 128 def is_active(self): 129 return self._wrapped.is_active() 130 131 def time_remaining(self): 132 return self._wrapped.time_remaining() 133 134 def trailing_metadata(self): 135 return self._wrapped.trailing_metadata() 136 137 138def _wrap_stream_errors(callable_): 139 """Wrap errors for Unary-Stream and Stream-Stream gRPC callables. 140 141 The callables that return iterators require a bit more logic to re-map 142 errors when iterating. This wraps both the initial invocation and the 143 iterator of the return value to re-map errors. 144 """ 145 _patch_callable_name(callable_) 146 147 @functools.wraps(callable_) 148 def error_remapped_callable(*args, **kwargs): 149 try: 150 result = callable_(*args, **kwargs) 151 # Auto-fetching the first result causes PubSub client's streaming pull 152 # to hang when re-opening the stream, thus we need examine the hacky 153 # hidden flag to see if pre-fetching is disabled. 154 # https://github.com/googleapis/python-pubsub/issues/93#issuecomment-630762257 155 prefetch_first = getattr(callable_, "_prefetch_first_result_", True) 156 return _StreamingResponseIterator( 157 result, prefetch_first_result=prefetch_first 158 ) 159 except grpc.RpcError as exc: 160 raise exceptions.from_grpc_error(exc) from exc 161 162 return error_remapped_callable 163 164 165def wrap_errors(callable_): 166 """Wrap a gRPC callable and map :class:`grpc.RpcErrors` to friendly error 167 classes. 168 169 Errors raised by the gRPC callable are mapped to the appropriate 170 :class:`google.api_core.exceptions.GoogleAPICallError` subclasses. 171 The original `grpc.RpcError` (which is usually also a `grpc.Call`) is 172 available from the ``response`` property on the mapped exception. This 173 is useful for extracting metadata from the original error. 174 175 Args: 176 callable_ (Callable): A gRPC callable. 177 178 Returns: 179 Callable: The wrapped gRPC callable. 180 """ 181 if isinstance(callable_, _STREAM_WRAP_CLASSES): 182 return _wrap_stream_errors(callable_) 183 else: 184 return _wrap_unary_errors(callable_) 185 186 187def _create_composite_credentials( 188 credentials=None, 189 credentials_file=None, 190 default_scopes=None, 191 scopes=None, 192 ssl_credentials=None, 193 quota_project_id=None, 194 default_host=None, 195): 196 """Create the composite credentials for secure channels. 197 198 Args: 199 credentials (google.auth.credentials.Credentials): The credentials. If 200 not specified, then this function will attempt to ascertain the 201 credentials from the environment using :func:`google.auth.default`. 202 credentials_file (str): A file with credentials that can be loaded with 203 :func:`google.auth.load_credentials_from_file`. This argument is 204 mutually exclusive with credentials. 205 default_scopes (Sequence[str]): A optional list of scopes needed for this 206 service. These are only used when credentials are not specified and 207 are passed to :func:`google.auth.default`. 208 scopes (Sequence[str]): A optional list of scopes needed for this 209 service. These are only used when credentials are not specified and 210 are passed to :func:`google.auth.default`. 211 ssl_credentials (grpc.ChannelCredentials): Optional SSL channel 212 credentials. This can be used to specify different certificates. 213 quota_project_id (str): An optional project to use for billing and quota. 214 default_host (str): The default endpoint. e.g., "pubsub.googleapis.com". 215 216 Returns: 217 grpc.ChannelCredentials: The composed channel credentials object. 218 219 Raises: 220 google.api_core.DuplicateCredentialArgs: If both a credentials object and credentials_file are passed. 221 """ 222 if credentials and credentials_file: 223 raise exceptions.DuplicateCredentialArgs( 224 "'credentials' and 'credentials_file' are mutually exclusive." 225 ) 226 227 if credentials_file: 228 credentials, _ = google.auth.load_credentials_from_file( 229 credentials_file, scopes=scopes, default_scopes=default_scopes 230 ) 231 elif credentials: 232 credentials = google.auth.credentials.with_scopes_if_required( 233 credentials, scopes=scopes, default_scopes=default_scopes 234 ) 235 else: 236 credentials, _ = google.auth.default( 237 scopes=scopes, default_scopes=default_scopes 238 ) 239 240 if quota_project_id and isinstance( 241 credentials, google.auth.credentials.CredentialsWithQuotaProject 242 ): 243 credentials = credentials.with_quota_project(quota_project_id) 244 245 request = google.auth.transport.requests.Request() 246 247 # Create the metadata plugin for inserting the authorization header. 248 metadata_plugin = google.auth.transport.grpc.AuthMetadataPlugin( 249 credentials, request, default_host=default_host, 250 ) 251 252 # Create a set of grpc.CallCredentials using the metadata plugin. 253 google_auth_credentials = grpc.metadata_call_credentials(metadata_plugin) 254 255 if ssl_credentials is None: 256 ssl_credentials = grpc.ssl_channel_credentials() 257 258 # Combine the ssl credentials and the authorization credentials. 259 return grpc.composite_channel_credentials(ssl_credentials, google_auth_credentials) 260 261 262def create_channel( 263 target, 264 credentials=None, 265 scopes=None, 266 ssl_credentials=None, 267 credentials_file=None, 268 quota_project_id=None, 269 default_scopes=None, 270 default_host=None, 271 **kwargs 272): 273 """Create a secure channel with credentials. 274 275 Args: 276 target (str): The target service address in the format 'hostname:port'. 277 credentials (google.auth.credentials.Credentials): The credentials. If 278 not specified, then this function will attempt to ascertain the 279 credentials from the environment using :func:`google.auth.default`. 280 scopes (Sequence[str]): A optional list of scopes needed for this 281 service. These are only used when credentials are not specified and 282 are passed to :func:`google.auth.default`. 283 ssl_credentials (grpc.ChannelCredentials): Optional SSL channel 284 credentials. This can be used to specify different certificates. 285 credentials_file (str): A file with credentials that can be loaded with 286 :func:`google.auth.load_credentials_from_file`. This argument is 287 mutually exclusive with credentials. 288 quota_project_id (str): An optional project to use for billing and quota. 289 default_scopes (Sequence[str]): Default scopes passed by a Google client 290 library. Use 'scopes' for user-defined scopes. 291 default_host (str): The default endpoint. e.g., "pubsub.googleapis.com". 292 kwargs: Additional key-word args passed to 293 :func:`grpc_gcp.secure_channel` or :func:`grpc.secure_channel`. 294 295 Returns: 296 grpc.Channel: The created channel. 297 298 Raises: 299 google.api_core.DuplicateCredentialArgs: If both a credentials object and credentials_file are passed. 300 """ 301 302 composite_credentials = _create_composite_credentials( 303 credentials=credentials, 304 credentials_file=credentials_file, 305 default_scopes=default_scopes, 306 scopes=scopes, 307 ssl_credentials=ssl_credentials, 308 quota_project_id=quota_project_id, 309 default_host=default_host, 310 ) 311 312 if HAS_GRPC_GCP: 313 # If grpc_gcp module is available use grpc_gcp.secure_channel, 314 # otherwise, use grpc.secure_channel to create grpc channel. 315 return grpc_gcp.secure_channel(target, composite_credentials, **kwargs) 316 else: 317 return grpc.secure_channel(target, composite_credentials, **kwargs) 318 319 320_MethodCall = collections.namedtuple( 321 "_MethodCall", ("request", "timeout", "metadata", "credentials") 322) 323 324_ChannelRequest = collections.namedtuple("_ChannelRequest", ("method", "request")) 325 326 327class _CallableStub(object): 328 """Stub for the grpc.*MultiCallable interfaces.""" 329 330 def __init__(self, method, channel): 331 self._method = method 332 self._channel = channel 333 self.response = None 334 """Union[protobuf.Message, Callable[protobuf.Message], exception]: 335 The response to give when invoking this callable. If this is a 336 callable, it will be invoked with the request protobuf. If it's an 337 exception, the exception will be raised when this is invoked. 338 """ 339 self.responses = None 340 """Iterator[ 341 Union[protobuf.Message, Callable[protobuf.Message], exception]]: 342 An iterator of responses. If specified, self.response will be populated 343 on each invocation by calling ``next(self.responses)``.""" 344 self.requests = [] 345 """List[protobuf.Message]: All requests sent to this callable.""" 346 self.calls = [] 347 """List[Tuple]: All invocations of this callable. Each tuple is the 348 request, timeout, metadata, and credentials.""" 349 350 def __call__(self, request, timeout=None, metadata=None, credentials=None): 351 self._channel.requests.append(_ChannelRequest(self._method, request)) 352 self.calls.append(_MethodCall(request, timeout, metadata, credentials)) 353 self.requests.append(request) 354 355 response = self.response 356 if self.responses is not None: 357 if response is None: 358 response = next(self.responses) 359 else: 360 raise ValueError( 361 "{method}.response and {method}.responses are mutually " 362 "exclusive.".format(method=self._method) 363 ) 364 365 if callable(response): 366 return response(request) 367 368 if isinstance(response, Exception): 369 raise response 370 371 if response is not None: 372 return response 373 374 raise ValueError('Method stub for "{}" has no response.'.format(self._method)) 375 376 377def _simplify_method_name(method): 378 """Simplifies a gRPC method name. 379 380 When gRPC invokes the channel to create a callable, it gives a full 381 method name like "/google.pubsub.v1.Publisher/CreateTopic". This 382 returns just the name of the method, in this case "CreateTopic". 383 384 Args: 385 method (str): The name of the method. 386 387 Returns: 388 str: The simplified name of the method. 389 """ 390 return method.rsplit("/", 1).pop() 391 392 393class ChannelStub(grpc.Channel): 394 """A testing stub for the grpc.Channel interface. 395 396 This can be used to test any client that eventually uses a gRPC channel 397 to communicate. By passing in a channel stub, you can configure which 398 responses are returned and track which requests are made. 399 400 For example: 401 402 .. code-block:: python 403 404 channel_stub = grpc_helpers.ChannelStub() 405 client = FooClient(channel=channel_stub) 406 407 channel_stub.GetFoo.response = foo_pb2.Foo(name='bar') 408 409 foo = client.get_foo(labels=['baz']) 410 411 assert foo.name == 'bar' 412 assert channel_stub.GetFoo.requests[0].labels = ['baz'] 413 414 Each method on the stub can be accessed and configured on the channel. 415 Here's some examples of various configurations: 416 417 .. code-block:: python 418 419 # Return a basic response: 420 421 channel_stub.GetFoo.response = foo_pb2.Foo(name='bar') 422 assert client.get_foo().name == 'bar' 423 424 # Raise an exception: 425 channel_stub.GetFoo.response = NotFound('...') 426 427 with pytest.raises(NotFound): 428 client.get_foo() 429 430 # Use a sequence of responses: 431 channel_stub.GetFoo.responses = iter([ 432 foo_pb2.Foo(name='bar'), 433 foo_pb2.Foo(name='baz'), 434 ]) 435 436 assert client.get_foo().name == 'bar' 437 assert client.get_foo().name == 'baz' 438 439 # Use a callable 440 441 def on_get_foo(request): 442 return foo_pb2.Foo(name='bar' + request.id) 443 444 channel_stub.GetFoo.response = on_get_foo 445 446 assert client.get_foo(id='123').name == 'bar123' 447 """ 448 449 def __init__(self, responses=[]): 450 self.requests = [] 451 """Sequence[Tuple[str, protobuf.Message]]: A list of all requests made 452 on this channel in order. The tuple is of method name, request 453 message.""" 454 self._method_stubs = {} 455 456 def _stub_for_method(self, method): 457 method = _simplify_method_name(method) 458 self._method_stubs[method] = _CallableStub(method, self) 459 return self._method_stubs[method] 460 461 def __getattr__(self, key): 462 try: 463 return self._method_stubs[key] 464 except KeyError: 465 raise AttributeError 466 467 def unary_unary(self, method, request_serializer=None, response_deserializer=None): 468 """grpc.Channel.unary_unary implementation.""" 469 return self._stub_for_method(method) 470 471 def unary_stream(self, method, request_serializer=None, response_deserializer=None): 472 """grpc.Channel.unary_stream implementation.""" 473 return self._stub_for_method(method) 474 475 def stream_unary(self, method, request_serializer=None, response_deserializer=None): 476 """grpc.Channel.stream_unary implementation.""" 477 return self._stub_for_method(method) 478 479 def stream_stream( 480 self, method, request_serializer=None, response_deserializer=None 481 ): 482 """grpc.Channel.stream_stream implementation.""" 483 return self._stub_for_method(method) 484 485 def subscribe(self, callback, try_to_connect=False): 486 """grpc.Channel.subscribe implementation.""" 487 pass 488 489 def unsubscribe(self, callback): 490 """grpc.Channel.unsubscribe implementation.""" 491 pass 492 493 def close(self): 494 """grpc.Channel.close implementation.""" 495 pass 496