1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Module to expose RPC APIs in tensorflow.""" 16 17from typing import Optional, Sequence, Union 18 19import tensorflow.distribute.experimental.rpc.kernels.gen_rpc_ops as gen_rpc_ops 20from tensorflow.distribute.experimental.rpc.proto import tf_rpc_service_pb2 as rpc_pb2 21from tensorflow.python.data.util import structure 22from tensorflow.python.eager import context 23from tensorflow.python.eager import def_function 24from tensorflow.python.eager import function as tf_function 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import errors 28from tensorflow.python.framework import type_spec 29from tensorflow.python.ops import math_ops 30from tensorflow.python.ops import resource_variable_ops 31from tensorflow.python.saved_model import nested_structure_coder 32from tensorflow.python.types import core as core_tf_types 33from tensorflow.python.util import nest 34from tensorflow.python.util.tf_export import tf_export 35 36 37def get_output_specs_from_function(func: tf_function.ConcreteFunction): 38 output_specs = nest.map_structure(type_spec.type_spec_from_value, 39 func.structured_outputs) 40 output_specs_proto = nested_structure_coder.encode_structure(output_specs) 41 return output_specs_proto.SerializeToString() 42 43 44def get_input_specs_from_function(func: tf_function.ConcreteFunction): 45 arg_specs, _ = func.structured_input_signature 46 arg_specs_proto = nested_structure_coder.encode_structure(arg_specs) 47 return arg_specs_proto.SerializeToString() 48 49 50@tf_export("distribute.experimental.rpc.Server", v1=[]) 51class Server(object): 52 """A Server base class for accepting RPCs for registered tf.functions. 53 54 Functions can be registered on the server and are exposed via RPCs. 55 """ 56 57 @staticmethod 58 def create(rpc_layer, address): 59 """Create TF RPC server at given address. 60 61 Args: 62 rpc_layer: Communication layer between client and server. Only "grpc" rpc 63 layer is supported at the moment. 64 address: Address where RPC server is hosted. 65 66 Returns: 67 An instance of `tf.distribute.experimental.rpc.Server` class. 68 69 Raises: 70 A ValueError if rpc_layer other than "grpc" is used. Only GRPC 71 is supported at the moment. 72 73 Example usage: 74 75 >>> import portpicker 76 >>> @tf.function(input_signature=[ 77 ... tf.TensorSpec([], tf.int32), 78 ... tf.TensorSpec([], tf.int32)]) 79 ... def remote_fn(a, b): 80 ... return tf.add(a, b) 81 82 >>> port = portpicker.pick_unused_port() 83 >>> address = "localhost:{}".format(port) 84 >>> server = tf.distribute.experimental.rpc.Server.create("grpc", address) 85 >>> server.register("addition", remote_fn) 86 >>> server.start() 87 88 """ 89 if rpc_layer != "grpc": 90 raise ValueError("Only GRPC backend is supported at the moment.") 91 return GrpcServer(address=address) 92 93 def register(self, method_name: str, 94 func: Union[def_function.Function, 95 tf_function.ConcreteFunction]): 96 """Method for registering tf.function on server. 97 98 Registered methods can be invoked remotely from clients. 99 100 Args: 101 method_name: Name of the tf.function. Clients use this method_name to make 102 RPCs. 103 func: A `tf.function` or ConcreteFunction to register. 104 """ 105 raise NotImplementedError("Please use create_server method to create a" 106 "concrete subclass of Server.") 107 108 def start(self): 109 """Starts the RPC server on provided address. 110 111 Server listens for new requests from client, once it is started. 112 """ 113 raise NotImplementedError("Please use create_server method to create a" 114 "concrete subclass of Server.") 115 116 117@tf_export("distribute.experimental.rpc.Client", v1=[]) 118class Client(object): 119 """Client class for invoking RPCs to the server.""" 120 121 @staticmethod 122 def create(rpc_layer, address, name="", timeout_in_ms=0): 123 """Create TF RPC client to connect to the given address. 124 125 Args: 126 rpc_layer: Communication layer between client and server. Only "grpc" rpc 127 layer is supported at the moment. 128 address: Address of the server to connect the RPC client to. 129 name: Name of the RPC Client. You can create multiple clients connecting 130 to same server and distinguish them using different names. 131 timeout_in_ms: The default timeout to use for outgoing RPCs from client. 0 132 indicates no timeout. Exceeding timeout during RPC will raise 133 DeadlineExceeded error. 134 135 Returns: 136 An instance of `tf.distribute.experimental.rpc.Client` with the following 137 dynamically added methods for eagerly created clients: 138 * `Registered methods` e.g. multiply(**args): 139 If Client is created when executing eagerly, client will request the 140 list of registered methods from server during client creation. 141 The convenience methods for RPCs will be dynamically added to the 142 created Client instance. 143 144 For example, when a server has method "multiply" registered, the 145 client object created in eager mode will have 'multiply' method 146 available. Users can use client.multiply(..) to make RPC, instead of 147 client.call("multiply", ...) 148 149 Both "call" and "multiply" methods are non-blocking i.e. they return 150 a StatusOrResult object which should be used to wait for getting 151 value or error. 152 153 Along with the above, blocking versions of the registered 154 methods are also dynamically added to client instance. 155 e.g. multiply_blocking(**args). These methods block till the RPC is 156 finished and return response for successful RPC. Otherwise raise 157 exception. 158 159 These methods are not available when Client is created inside a 160 tf.function. 161 162 Raises: 163 A ValueError if rpc_layer other than "grpc" is used. Only GRPC 164 is supported at the moment. 165 A DeadlineExceeded exception in eager mode if timeout exceeds while 166 creating and listing client methods. 167 168 Example usage: 169 >>> # Have server already started. 170 >>> import portpicker 171 >>> @tf.function(input_signature=[ 172 ... tf.TensorSpec([], tf.int32), 173 ... tf.TensorSpec([], tf.int32)]) 174 ... def remote_fn(a, b): 175 ... return tf.add(a, b) 176 177 >>> port = portpicker.pick_unused_port() 178 >>> address = "localhost:{}".format(port) 179 >>> server = tf.distribute.experimental.rpc.Server.create("grpc", address) 180 >>> server.register("addition", remote_fn) 181 >>> server.start() 182 183 >>> # Start client 184 >>> client = tf.distribute.experimental.rpc.Client.create("grpc", 185 ... address=address, name="test_client") 186 187 >>> a = tf.constant(2, dtype=tf.int32) 188 >>> b = tf.constant(3, dtype=tf.int32) 189 190 >>> result = client.call( 191 ... args=[a, b], 192 ... method_name="addition", 193 ... output_specs=tf.TensorSpec((), tf.int32)) 194 195 >>> if result.is_ok(): 196 ... result.get_value() 197 198 >>> result = client.addition(a, b) 199 200 >>> if result.is_ok(): 201 ... result.get_value() 202 203 >>> value = client.addition_blocking(a, b) 204 """ 205 if rpc_layer != "grpc": 206 raise ValueError("Only GRPC backend is supported at the moment.") 207 if context.executing_eagerly(): 208 list_registered_methods = True 209 else: 210 list_registered_methods = False 211 return GrpcClient( 212 address=address, 213 name=name, 214 list_registered_methods=list_registered_methods, 215 timeout_in_ms=timeout_in_ms) 216 217 def call(self, 218 method_name: str, 219 args: Optional[Sequence[core_tf_types.Tensor]] = None, 220 output_specs=None, 221 timeout_in_ms=0): 222 """Method for making RPC calls to remote server. 223 224 This invokes RPC to the server, executing the registered method_name 225 remotely. 226 Args: 227 method_name: Remote registered method to invoke 228 args: List of arguments for the registered method. 229 output_specs: Output specs for the output from method. 230 For example, if tf.function is: @tf.function(input_signature=[ 231 tf.TensorSpec([], tf.int32), tf.TensorSpec([], tf.int32) ]) 232 def multiply_fn(a, b): return tf.math.multiply(a, b) 233 output_spec is: tf.TensorSpec((), tf.int32) If you have access to TF 234 Function, the output specs can be generated 235 from tf.function by calling: output_specs = 236 tf.nest.map_structure(tf.type_spec_from_value, 237 tf_function.get_concrete_function().structured_outputs If output_specs 238 are not provided, flattened list of tensors will be returned in 239 response. 240 timeout_in_ms: Timeout for this call. If 0, default client timeout will be 241 used. 242 243 Returns: 244 An instance of `StatusOrResult` class with the following available 245 methods. 246 * `is_ok()`: 247 Returns True of RPC was successful. 248 * `get_error()`: 249 Returns TF error_code and error message for the RPC. 250 * `get_value()`: 251 Returns the returned value from remote TF function execution 252 when RPC is successful. 253 254 Calling any of the above methods will block till RPC is completed and 255 result is available. 256 """ 257 raise NotImplementedError("Must be implemented in inherited classes.") 258 259 260class GrpcServer(Server): 261 """GrpcServer object encapsulates a resource with GRPC server. 262 263 Functions can be registered locally and are exposed via RPCs. 264 Example: 265 ``` 266 server = rpc_ops.GrpcServer("host:port") 267 @tf.function 268 def add(a, b): 269 return a + b 270 271 server.register("add", add) 272 server.start() 273 ``` 274 """ 275 276 def __init__(self, address: str): 277 self._server_handle = gen_rpc_ops.rpc_server(address) 278 if context.executing_eagerly(): 279 self._handle_deleter = resource_variable_ops.EagerResourceDeleter( 280 handle=self._server_handle, handle_device=self._server_handle.device) 281 else: 282 raise NotImplementedError("Please create the server outside tf.function.") 283 284 def register(self, method_name: str, 285 func: Union[def_function.Function, 286 tf_function.ConcreteFunction]): 287 """Method for registering functions.""" 288 289 if isinstance(func, def_function.Function): 290 if func._function_spec.arg_names: # pylint: disable=protected-access 291 if func.input_signature is None: 292 raise ValueError("Input signature not specified for the function.") 293 concrete_fn = func.get_concrete_function() 294 gen_rpc_ops.rpc_server_register( 295 self._server_handle, 296 method_name=method_name, 297 captured_inputs=concrete_fn.captured_inputs, 298 input_specs=get_input_specs_from_function(concrete_fn), 299 output_specs=get_output_specs_from_function(concrete_fn), 300 f=concrete_fn) 301 elif isinstance(func, tf_function.ConcreteFunction): 302 gen_rpc_ops.rpc_server_register( 303 self._server_handle, 304 method_name=method_name, 305 captured_inputs=func.captured_inputs, 306 input_specs=get_input_specs_from_function(func), 307 output_specs=get_output_specs_from_function(func), 308 f=func) 309 else: 310 # Python functions 311 # TODO(b/186762191): Add an implementation to support python functions. 312 raise ValueError("Only TF functions are supported with Register method") 313 314 def start(self): 315 """Starts GRPC server.""" 316 gen_rpc_ops.rpc_server_start(self._server_handle) 317 318 319class GrpcClient(Client): 320 """Client wrapper to connect to remote RPC server using GRPC. 321 322 If Client is created with (list_registered_methods=True): 323 1. Input and output specs for the methods till this point will be fetched from 324 Server. 325 2. convenience methods are added to invoke registered methods directly from 326 client. 327 For example: 328 For call a server method `add` 329 client.add(a, b) or client.add_async(a, b) can be used instead of 330 client.call(args=[a,b], output_specs=[..]) 331 332 Prerequiste for using list_registered_methods=True: 333 1. Server should be already started with the registered methods. 334 2. Client must be created in Eager mode. 335 """ 336 337 def __init__(self, 338 address: str, 339 name: str = "", 340 list_registered_methods=False, 341 timeout_in_ms=0): 342 self._client_handle, methods = gen_rpc_ops.rpc_client( 343 shared_name=name, 344 server_address=address, 345 list_registered_methods=list_registered_methods, 346 timeout_in_ms=timeout_in_ms) 347 if context.executing_eagerly(): 348 self._handle_deleter = resource_variable_ops.EagerResourceDeleter( 349 handle=self._client_handle, handle_device=self._client_handle.device) 350 else: 351 raise NotImplementedError( 352 "Client creation is supported only in eager mode.") 353 self._server_address = address 354 self._method_registry = {} 355 for method in methods.numpy(): 356 m = rpc_pb2.RegisteredMethod() 357 m.ParseFromString(method) 358 output_specs = nested_structure_coder.decode_proto(m.output_specs) 359 input_specs = nested_structure_coder.decode_proto(m.input_specs) 360 self._method_registry[m.method] = output_specs 361 # TODO(ishark): Perhaps doc string can also be taken as input during 362 # function registration. 363 doc_string = "RPC Call for " + m.method + " method to server " + address 364 self._add_method(m.method, output_specs, input_specs, self._client_handle, 365 doc_string) 366 367 def _add_method(self, method_name, output_specs, input_specs, client_handle, 368 doc_string): 369 """Method to add RPC methods to the client object.""" 370 371 def validate_and_get_flat_inputs(*args): 372 if args is None: 373 args = [] 374 if input_specs: 375 nest.assert_same_structure(args, input_specs) 376 flat_inputs = nest.flatten(args) 377 return flat_inputs 378 379 def call_wrapper(*args, timeout_in_ms=0): 380 status_or, deleter = gen_rpc_ops.rpc_call( 381 client_handle, 382 args=validate_and_get_flat_inputs(*args), 383 method_name=method_name, 384 timeout_in_ms=timeout_in_ms) 385 return StatusOrResult(status_or, deleter, output_specs) 386 387 def call_blocking_wrapper(*args, timeout_in_ms=0): 388 status_or, deleter = gen_rpc_ops.rpc_call( 389 client_handle, 390 args=validate_and_get_flat_inputs(*args), 391 method_name=method_name, 392 timeout_in_ms=timeout_in_ms) 393 status_or = StatusOrResult(status_or, deleter, output_specs) 394 if status_or.is_ok(): 395 return status_or.get_value() 396 else: 397 error_code, error_msg = status_or.get_error() 398 raise errors.exception_type_from_error_code(error_code.numpy())( 399 None, None, error_msg.numpy()) 400 401 setattr(self, method_name, call_wrapper) 402 call_wrapper.__doc__ = doc_string 403 404 blocking_method_name = method_name + "_blocking" 405 setattr(self, blocking_method_name, call_blocking_wrapper) 406 call_blocking_wrapper.__doc__ = doc_string 407 408 def call(self, 409 method_name: str, 410 args: Optional[Sequence[core_tf_types.Tensor]] = None, 411 output_specs=None, 412 timeout_in_ms=0): 413 """Method to invoke remote registered functions on the connected server. 414 415 Server should be started before making an RPC Call. 416 417 Args: 418 method_name: Registered method to invoke on Server. 419 args: Input arguments for the method. 420 output_specs: Output specs for the output from method. 421 timeout_in_ms: Timeout for this call. If 0, default client timeout will be 422 used. 423 424 Returns: 425 StatusOrResult object. This function issues the RPC call to server, it 426 does not block for the duration of RPC. Please call is_ok, get_error or 427 get_value methods on the returned object to blocked till RPC finishes. 428 """ 429 if args is None: 430 args = [] 431 status_or, deleter = gen_rpc_ops.rpc_call( 432 self._client_handle, 433 args=nest.flatten(args), 434 method_name=method_name, 435 timeout_in_ms=timeout_in_ms) 436 return StatusOrResult(status_or, deleter, output_specs) 437 438 439class StatusOrResult(object): 440 """Class representing result and status from RPC Call.""" 441 442 def __init__(self, status_or, deleter, output_specs=None): 443 self._status_or = status_or 444 self._output_specs = output_specs 445 self._deleter = deleter 446 self._error_code, self._error_message = None, None 447 448 def _check_status(self): 449 if self._error_code is None: 450 self._error_code, self._error_message = gen_rpc_ops.rpc_check_status( 451 self._status_or) 452 453 def __del__(self): 454 # Make sure the resource is deleted in the same mode as it was created in. 455 if context.executing_eagerly(): 456 with context.eager_mode(): 457 gen_rpc_ops.delete_rpc_future_resource( 458 handle=self._status_or, deleter=self._deleter) 459 else: 460 with context.graph_mode(): 461 gen_rpc_ops.delete_rpc_future_resource( 462 handle=self._status_or, deleter=self._deleter) 463 464 def is_ok(self): 465 """Returns True if RPC is successful, otherwise returns False. 466 467 This call will block for RPC result. 468 """ 469 self._check_status() 470 return math_ops.equal(self._error_code, 471 constant_op.constant(0, dtype=dtypes.int64)) 472 473 def get_error(self): 474 """Returns (TF Error Code, Error Message) from RPC Response. 475 476 This call will block for RPC result. 477 """ 478 self._check_status() 479 return self._error_code, self._error_message 480 481 def get_value(self): 482 """Returns the returned response value from RPC Call when RPC is successful. 483 484 The returned value is tensors in the output_specs format as returned from 485 the RPC call 486 487 488 This call will block for RPC result. 489 """ 490 491 self._check_status() 492 if self._output_specs is None or isinstance(self._output_specs, 493 structure.NoneTensorSpec): 494 flat_output_dtypes = [] 495 return_none = True 496 else: 497 return_none = False 498 flat_output_dtypes = [s.dtype for s in nest.flatten(self._output_specs)] 499 500 result = gen_rpc_ops.rpc_get_value(self._status_or, Tout=flat_output_dtypes) 501 if return_none: 502 return None 503 else: 504 return nest.pack_sequence_as(self._output_specs, result) 505