xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/experimental/rpc/rpc_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Module to expose RPC APIs in tensorflow."""
16
17from typing import Optional, Sequence, Union
18
19import tensorflow.distribute.experimental.rpc.kernels.gen_rpc_ops as gen_rpc_ops
20from tensorflow.distribute.experimental.rpc.proto import tf_rpc_service_pb2 as rpc_pb2
21from tensorflow.python.data.util import structure
22from tensorflow.python.eager import context
23from tensorflow.python.eager import def_function
24from tensorflow.python.eager import function as tf_function
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import errors
28from tensorflow.python.framework import type_spec
29from tensorflow.python.ops import math_ops
30from tensorflow.python.ops import resource_variable_ops
31from tensorflow.python.saved_model import nested_structure_coder
32from tensorflow.python.types import core as core_tf_types
33from tensorflow.python.util import nest
34from tensorflow.python.util.tf_export import tf_export
35
36
37def get_output_specs_from_function(func: tf_function.ConcreteFunction):
38  output_specs = nest.map_structure(type_spec.type_spec_from_value,
39                                    func.structured_outputs)
40  output_specs_proto = nested_structure_coder.encode_structure(output_specs)
41  return output_specs_proto.SerializeToString()
42
43
44def get_input_specs_from_function(func: tf_function.ConcreteFunction):
45  arg_specs, _ = func.structured_input_signature
46  arg_specs_proto = nested_structure_coder.encode_structure(arg_specs)
47  return arg_specs_proto.SerializeToString()
48
49
50@tf_export("distribute.experimental.rpc.Server", v1=[])
51class Server(object):
52  """A Server base class for accepting RPCs for registered tf.functions.
53
54    Functions can be registered on the server and are exposed via RPCs.
55  """
56
57  @staticmethod
58  def create(rpc_layer, address):
59    """Create TF RPC server at given address.
60
61    Args:
62      rpc_layer: Communication layer between client and server. Only "grpc" rpc
63        layer is supported at the moment.
64      address: Address where RPC server is hosted.
65
66    Returns:
67      An instance of `tf.distribute.experimental.rpc.Server` class.
68
69    Raises:
70        A ValueError if rpc_layer other than "grpc" is used. Only GRPC
71        is supported at the moment.
72
73    Example usage:
74
75      >>> import portpicker
76      >>> @tf.function(input_signature=[
77      ...      tf.TensorSpec([], tf.int32),
78      ...      tf.TensorSpec([], tf.int32)])
79      ... def remote_fn(a, b):
80      ...   return tf.add(a, b)
81
82      >>> port = portpicker.pick_unused_port()
83      >>> address = "localhost:{}".format(port)
84      >>> server = tf.distribute.experimental.rpc.Server.create("grpc", address)
85      >>> server.register("addition", remote_fn)
86      >>> server.start()
87
88    """
89    if rpc_layer != "grpc":
90      raise ValueError("Only GRPC backend is supported at the moment.")
91    return GrpcServer(address=address)
92
93  def register(self, method_name: str,
94               func: Union[def_function.Function,
95                           tf_function.ConcreteFunction]):
96    """Method for registering tf.function on server.
97
98    Registered methods can be invoked remotely from clients.
99
100    Args:
101      method_name: Name of the tf.function. Clients use this method_name to make
102        RPCs.
103      func: A `tf.function` or ConcreteFunction to register.
104    """
105    raise NotImplementedError("Please use create_server method to create a"
106                              "concrete subclass of Server.")
107
108  def start(self):
109    """Starts the RPC server on provided address.
110
111     Server listens for new requests from client, once it is started.
112    """
113    raise NotImplementedError("Please use create_server method to create a"
114                              "concrete subclass of Server.")
115
116
117@tf_export("distribute.experimental.rpc.Client", v1=[])
118class Client(object):
119  """Client class for invoking RPCs to the server."""
120
121  @staticmethod
122  def create(rpc_layer, address, name="", timeout_in_ms=0):
123    """Create TF RPC client to connect to the given address.
124
125    Args:
126      rpc_layer: Communication layer between client and server. Only "grpc" rpc
127        layer is supported at the moment.
128      address: Address of the server to connect the RPC client to.
129      name: Name of the RPC Client. You can create multiple clients connecting
130        to same server and distinguish them using different names.
131      timeout_in_ms: The default timeout to use for outgoing RPCs from client. 0
132        indicates no timeout. Exceeding timeout during RPC will raise
133        DeadlineExceeded error.
134
135    Returns:
136      An instance of `tf.distribute.experimental.rpc.Client` with the following
137      dynamically added methods for eagerly created clients:
138        * `Registered methods` e.g. multiply(**args):
139            If Client is created when executing eagerly, client will request the
140            list of registered methods from server during client creation.
141            The convenience methods for RPCs will be dynamically added to the
142            created Client instance.
143
144            For example, when a server has method "multiply" registered, the
145            client object created in eager mode will have 'multiply' method
146            available. Users can use client.multiply(..) to make RPC, instead of
147            client.call("multiply", ...)
148
149            Both "call" and "multiply" methods are non-blocking i.e. they return
150            a StatusOrResult object which should be used to wait for getting
151            value or error.
152
153            Along with the above, blocking versions of the registered
154            methods are also dynamically added to client instance.
155            e.g. multiply_blocking(**args). These methods block till the RPC is
156            finished and return response for successful RPC. Otherwise raise
157            exception.
158
159            These methods are not available when Client is created inside a
160            tf.function.
161
162    Raises:
163        A ValueError if rpc_layer other than "grpc" is used. Only GRPC
164          is supported at the moment.
165        A DeadlineExceeded exception in eager mode if timeout exceeds while
166          creating and listing client methods.
167
168    Example usage:
169      >>> # Have server already started.
170      >>> import portpicker
171      >>> @tf.function(input_signature=[
172      ...      tf.TensorSpec([], tf.int32),
173      ...      tf.TensorSpec([], tf.int32)])
174      ... def remote_fn(a, b):
175      ...   return tf.add(a, b)
176
177      >>> port = portpicker.pick_unused_port()
178      >>> address = "localhost:{}".format(port)
179      >>> server = tf.distribute.experimental.rpc.Server.create("grpc", address)
180      >>> server.register("addition", remote_fn)
181      >>> server.start()
182
183      >>> # Start client
184      >>> client = tf.distribute.experimental.rpc.Client.create("grpc",
185      ...      address=address, name="test_client")
186
187      >>> a = tf.constant(2, dtype=tf.int32)
188      >>> b = tf.constant(3, dtype=tf.int32)
189
190      >>> result = client.call(
191      ...    args=[a, b],
192      ...    method_name="addition",
193      ...    output_specs=tf.TensorSpec((), tf.int32))
194
195      >>> if result.is_ok():
196      ...   result.get_value()
197
198      >>> result = client.addition(a, b)
199
200      >>> if result.is_ok():
201      ...   result.get_value()
202
203      >>> value = client.addition_blocking(a, b)
204    """
205    if rpc_layer != "grpc":
206      raise ValueError("Only GRPC backend is supported at the moment.")
207    if context.executing_eagerly():
208      list_registered_methods = True
209    else:
210      list_registered_methods = False
211    return GrpcClient(
212        address=address,
213        name=name,
214        list_registered_methods=list_registered_methods,
215        timeout_in_ms=timeout_in_ms)
216
217  def call(self,
218           method_name: str,
219           args: Optional[Sequence[core_tf_types.Tensor]] = None,
220           output_specs=None,
221           timeout_in_ms=0):
222    """Method for making RPC calls to remote server.
223
224    This invokes RPC to the server, executing the registered method_name
225    remotely.
226    Args:
227      method_name: Remote registered method to invoke
228      args: List of arguments for the registered method.
229      output_specs: Output specs for the output from method.
230         For example, if tf.function is: @tf.function(input_signature=[
231           tf.TensorSpec([], tf.int32), tf.TensorSpec([], tf.int32) ])
232          def multiply_fn(a, b): return tf.math.multiply(a, b)
233        output_spec is: tf.TensorSpec((), tf.int32)  If you have access to TF
234          Function, the output specs can be generated
235       from tf.function by calling: output_specs =
236         tf.nest.map_structure(tf.type_spec_from_value,
237         tf_function.get_concrete_function().structured_outputs  If output_specs
238         are not provided, flattened list of tensors will be returned in
239         response.
240      timeout_in_ms: Timeout for this call. If 0, default client timeout will be
241        used.
242
243    Returns:
244      An instance of `StatusOrResult` class with the following available
245      methods.
246        * `is_ok()`:
247            Returns True of RPC was successful.
248        * `get_error()`:
249            Returns TF error_code and error message for the RPC.
250        * `get_value()`:
251            Returns the returned value from remote TF function execution
252            when RPC is successful.
253
254      Calling any of the above methods will block till RPC is completed and
255      result is available.
256    """
257    raise NotImplementedError("Must be implemented in inherited classes.")
258
259
260class GrpcServer(Server):
261  """GrpcServer object encapsulates a resource with GRPC server.
262
263    Functions can be registered locally and are exposed via RPCs.
264    Example:
265    ```
266    server = rpc_ops.GrpcServer("host:port")
267    @tf.function
268    def add(a, b):
269      return a + b
270
271    server.register("add", add)
272    server.start()
273    ```
274  """
275
276  def __init__(self, address: str):
277    self._server_handle = gen_rpc_ops.rpc_server(address)
278    if context.executing_eagerly():
279      self._handle_deleter = resource_variable_ops.EagerResourceDeleter(
280          handle=self._server_handle, handle_device=self._server_handle.device)
281    else:
282      raise NotImplementedError("Please create the server outside tf.function.")
283
284  def register(self, method_name: str,
285               func: Union[def_function.Function,
286                           tf_function.ConcreteFunction]):
287    """Method for registering functions."""
288
289    if isinstance(func, def_function.Function):
290      if func._function_spec.arg_names:  # pylint: disable=protected-access
291        if func.input_signature is None:
292          raise ValueError("Input signature not specified for the function.")
293      concrete_fn = func.get_concrete_function()
294      gen_rpc_ops.rpc_server_register(
295          self._server_handle,
296          method_name=method_name,
297          captured_inputs=concrete_fn.captured_inputs,
298          input_specs=get_input_specs_from_function(concrete_fn),
299          output_specs=get_output_specs_from_function(concrete_fn),
300          f=concrete_fn)
301    elif isinstance(func, tf_function.ConcreteFunction):
302      gen_rpc_ops.rpc_server_register(
303          self._server_handle,
304          method_name=method_name,
305          captured_inputs=func.captured_inputs,
306          input_specs=get_input_specs_from_function(func),
307          output_specs=get_output_specs_from_function(func),
308          f=func)
309    else:
310      # Python functions
311      # TODO(b/186762191): Add an implementation to support python functions.
312      raise ValueError("Only TF functions are supported with Register method")
313
314  def start(self):
315    """Starts GRPC server."""
316    gen_rpc_ops.rpc_server_start(self._server_handle)
317
318
319class GrpcClient(Client):
320  """Client wrapper to connect to remote RPC server using GRPC.
321
322  If Client is created with (list_registered_methods=True):
323  1. Input and output specs for the methods till this point will be fetched from
324  Server.
325  2. convenience methods are added to invoke registered methods directly from
326  client.
327  For example:
328    For call a server method `add`
329    client.add(a, b) or client.add_async(a, b) can be used instead of
330    client.call(args=[a,b], output_specs=[..])
331
332  Prerequiste for using list_registered_methods=True:
333   1. Server should be already started with the registered methods.
334   2. Client must be created in Eager mode.
335  """
336
337  def __init__(self,
338               address: str,
339               name: str = "",
340               list_registered_methods=False,
341               timeout_in_ms=0):
342    self._client_handle, methods = gen_rpc_ops.rpc_client(
343        shared_name=name,
344        server_address=address,
345        list_registered_methods=list_registered_methods,
346        timeout_in_ms=timeout_in_ms)
347    if context.executing_eagerly():
348      self._handle_deleter = resource_variable_ops.EagerResourceDeleter(
349          handle=self._client_handle, handle_device=self._client_handle.device)
350    else:
351      raise NotImplementedError(
352          "Client creation is supported only in eager mode.")
353    self._server_address = address
354    self._method_registry = {}
355    for method in methods.numpy():
356      m = rpc_pb2.RegisteredMethod()
357      m.ParseFromString(method)
358      output_specs = nested_structure_coder.decode_proto(m.output_specs)
359      input_specs = nested_structure_coder.decode_proto(m.input_specs)
360      self._method_registry[m.method] = output_specs
361      # TODO(ishark): Perhaps doc string can also be taken as input during
362      # function registration.
363      doc_string = "RPC Call for " + m.method + " method to server " + address
364      self._add_method(m.method, output_specs, input_specs, self._client_handle,
365                       doc_string)
366
367  def _add_method(self, method_name, output_specs, input_specs, client_handle,
368                  doc_string):
369    """Method to add RPC methods to the client object."""
370
371    def validate_and_get_flat_inputs(*args):
372      if args is None:
373        args = []
374      if input_specs:
375        nest.assert_same_structure(args, input_specs)
376      flat_inputs = nest.flatten(args)
377      return flat_inputs
378
379    def call_wrapper(*args, timeout_in_ms=0):
380      status_or, deleter = gen_rpc_ops.rpc_call(
381          client_handle,
382          args=validate_and_get_flat_inputs(*args),
383          method_name=method_name,
384          timeout_in_ms=timeout_in_ms)
385      return StatusOrResult(status_or, deleter, output_specs)
386
387    def call_blocking_wrapper(*args, timeout_in_ms=0):
388      status_or, deleter = gen_rpc_ops.rpc_call(
389          client_handle,
390          args=validate_and_get_flat_inputs(*args),
391          method_name=method_name,
392          timeout_in_ms=timeout_in_ms)
393      status_or = StatusOrResult(status_or, deleter, output_specs)
394      if status_or.is_ok():
395        return status_or.get_value()
396      else:
397        error_code, error_msg = status_or.get_error()
398        raise errors.exception_type_from_error_code(error_code.numpy())(
399            None, None, error_msg.numpy())
400
401    setattr(self, method_name, call_wrapper)
402    call_wrapper.__doc__ = doc_string
403
404    blocking_method_name = method_name + "_blocking"
405    setattr(self, blocking_method_name, call_blocking_wrapper)
406    call_blocking_wrapper.__doc__ = doc_string
407
408  def call(self,
409           method_name: str,
410           args: Optional[Sequence[core_tf_types.Tensor]] = None,
411           output_specs=None,
412           timeout_in_ms=0):
413    """Method to invoke remote registered functions on the connected server.
414
415    Server should be started before making an RPC Call.
416
417    Args:
418      method_name: Registered method to invoke on Server.
419      args: Input arguments for the method.
420      output_specs: Output specs for the output from method.
421      timeout_in_ms: Timeout for this call. If 0, default client timeout will be
422       used.
423
424    Returns:
425      StatusOrResult object. This function issues the RPC call to server, it
426      does not block for the duration of RPC. Please call is_ok, get_error or
427      get_value methods on the returned object to blocked till RPC finishes.
428    """
429    if args is None:
430      args = []
431    status_or, deleter = gen_rpc_ops.rpc_call(
432        self._client_handle,
433        args=nest.flatten(args),
434        method_name=method_name,
435        timeout_in_ms=timeout_in_ms)
436    return StatusOrResult(status_or, deleter, output_specs)
437
438
439class StatusOrResult(object):
440  """Class representing result and status from RPC Call."""
441
442  def __init__(self, status_or, deleter, output_specs=None):
443    self._status_or = status_or
444    self._output_specs = output_specs
445    self._deleter = deleter
446    self._error_code, self._error_message = None, None
447
448  def _check_status(self):
449    if self._error_code is None:
450      self._error_code, self._error_message = gen_rpc_ops.rpc_check_status(
451          self._status_or)
452
453  def __del__(self):
454    # Make sure the resource is deleted in the same mode as it was created in.
455    if context.executing_eagerly():
456      with context.eager_mode():
457        gen_rpc_ops.delete_rpc_future_resource(
458            handle=self._status_or, deleter=self._deleter)
459    else:
460      with context.graph_mode():
461        gen_rpc_ops.delete_rpc_future_resource(
462            handle=self._status_or, deleter=self._deleter)
463
464  def is_ok(self):
465    """Returns True if RPC is successful, otherwise returns False.
466
467    This call will block for RPC result.
468    """
469    self._check_status()
470    return math_ops.equal(self._error_code,
471                          constant_op.constant(0, dtype=dtypes.int64))
472
473  def get_error(self):
474    """Returns (TF Error Code, Error Message) from RPC Response.
475
476    This call will block for RPC result.
477    """
478    self._check_status()
479    return self._error_code, self._error_message
480
481  def get_value(self):
482    """Returns the returned response value from RPC Call when RPC is successful.
483
484      The returned value is tensors in the output_specs format as returned from
485      the RPC call
486
487
488    This call will block for RPC result.
489    """
490
491    self._check_status()
492    if self._output_specs is None or isinstance(self._output_specs,
493                                                structure.NoneTensorSpec):
494      flat_output_dtypes = []
495      return_none = True
496    else:
497      return_none = False
498      flat_output_dtypes = [s.dtype for s in nest.flatten(self._output_specs)]
499
500    result = gen_rpc_ops.rpc_get_value(self._status_or, Tout=flat_output_dtypes)
501    if return_none:
502      return None
503    else:
504      return nest.pack_sequence_as(self._output_specs, result)
505