xref: /aosp_15_r20/external/tensorflow/tensorflow/python/distribute/experimental/rpc/rpc_ops_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for rpc_ops.py."""
16
17import threading
18import time
19import numpy as np
20import portpicker
21
22from tensorflow.python.distribute.experimental.rpc import rpc_ops
23from tensorflow.python.eager import context
24from tensorflow.python.eager import def_function as eager_def_function
25from tensorflow.python.framework import config
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import errors
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_spec
31from tensorflow.python.framework import test_util
32from tensorflow.python.ops import data_flow_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import resource_variable_ops
35from tensorflow.python.ops import variables
36from tensorflow.python.platform import test
37from tensorflow.python.util import nest
38
39
40@test_util.with_eager_op_as_function
41class RpcOpsTest(test.TestCase):
42
43  def setUp(self):
44    super(RpcOpsTest, self).setUp()
45    cpus = config.list_physical_devices("CPU")
46    # Set 2 virtual CPUs
47    config.set_logical_device_configuration(cpus[0], [
48        context.LogicalDeviceConfiguration(),
49        context.LogicalDeviceConfiguration()
50    ])
51
52  def test_generated_rpc_ops(self):
53    @eager_def_function.function(input_signature=[
54        tensor_spec.TensorSpec([], dtypes.int32),
55        tensor_spec.TensorSpec([], dtypes.int32)
56    ])
57    def remote_fn(a, b):
58      return math_ops.multiply(a, b)
59
60    concrete_remote_fn = remote_fn.get_concrete_function()
61
62    a = variables.Variable(2, dtype=dtypes.int32)
63    b = variables.Variable(3, dtype=dtypes.int32)
64
65    port = portpicker.pick_unused_port()
66    address = "localhost:{}".format(port)
67    server_resource = rpc_ops.gen_rpc_ops.rpc_server(server_address=address)
68
69    rpc_ops.gen_rpc_ops.rpc_server_register(
70        server_resource,
71        f=concrete_remote_fn,
72        captured_inputs=concrete_remote_fn.captured_inputs,
73        output_specs=rpc_ops.get_output_specs_from_function(concrete_remote_fn),
74        method_name="multiply")
75
76    rpc_ops.gen_rpc_ops.rpc_server_start(server_resource)
77    client_handle, _ = rpc_ops.gen_rpc_ops.rpc_client(
78        server_address=address, timeout_in_ms=5000)
79    future_resource, deleter = rpc_ops.gen_rpc_ops.rpc_call(
80        client_handle, args=[a, b], method_name="multiply", timeout_in_ms=0)
81
82    error_code, _ = rpc_ops.gen_rpc_ops.rpc_check_status(future_resource)
83    self.assertAllEqual(error_code, 0)
84    self.assertAllEqual(
85        rpc_ops.gen_rpc_ops.rpc_get_value(future_resource, Tout=[dtypes.int32]),
86        [6])
87
88    resource_variable_ops.EagerResourceDeleter(
89        handle=server_resource, handle_device=server_resource.device)
90
91    resource_variable_ops.EagerResourceDeleter(
92        handle=client_handle, handle_device=client_handle.device)
93
94    rpc_ops.gen_rpc_ops.delete_rpc_future_resource(future_resource, deleter)
95
96  def test_exported_rpc_api_static_factory(self):
97
98    @eager_def_function.function(input_signature=[
99        tensor_spec.TensorSpec([], dtypes.int32),
100        tensor_spec.TensorSpec([], dtypes.int32)
101    ])
102    def _remote_fn(a, b):
103      return math_ops.multiply(a, b)
104
105    port = portpicker.pick_unused_port()
106    address = "localhost:{}".format(port)
107    server_resource = rpc_ops.Server.create("grpc", address)
108    server_resource.register("multiply", _remote_fn)
109
110    server_resource.start()
111    client = rpc_ops.Client.create("grpc", address=address, name="test_client")
112
113    a = variables.Variable(2, dtype=dtypes.int32)
114    b = variables.Variable(3, dtype=dtypes.int32)
115
116    mul_or = client.call(
117        args=[a, b],
118        method_name="multiply",
119        output_specs=tensor_spec.TensorSpec((), dtypes.int32))
120
121    self.assertAllEqual(mul_or.is_ok(), True)
122    self.assertAllEqual(mul_or.get_value(), 6)
123
124    # Test empty client name
125    client1 = rpc_ops.Client.create("grpc", address)
126    mul_or = client1.call(
127        args=[a, b],
128        method_name="multiply",
129        output_specs=tensor_spec.TensorSpec((), dtypes.int32))
130    self.assertAllEqual(mul_or.is_ok(), True)
131    self.assertAllEqual(mul_or.get_value(), 6)
132
133    # Test without output_spec
134    mul_or = client1.multiply(a, b)
135    self.assertAllEqual(mul_or.is_ok(), True)
136    self.assertAllEqual(mul_or.get_value(), 6)
137
138    self.assertEqual(client1.multiply.__doc__,
139                     "RPC Call for multiply method to server " + address)
140
141  def test_rpc_ops_call_method(self):
142
143    @eager_def_function.function(input_signature=[
144        tensor_spec.TensorSpec([], dtypes.int32),
145        tensor_spec.TensorSpec([], dtypes.int32)
146    ])
147    def _remote_fn(a, b):
148      return math_ops.multiply(a, b)
149
150    port = portpicker.pick_unused_port()
151    address = "localhost:{}".format(port)
152    server_resource = rpc_ops.GrpcServer(address)
153
154    @eager_def_function.function(input_signature=[
155        tensor_spec.TensorSpec([], dtypes.int32),
156        tensor_spec.TensorSpec([], dtypes.int32)
157    ])
158    def add_fn(a, b):
159      return math_ops.add(a, b)
160
161    # Register TF function
162    server_resource.register("multiply", _remote_fn)
163
164    # Register concrete Function
165    server_resource.register("add", add_fn.get_concrete_function())
166
167    server_resource.start()
168    client = rpc_ops.GrpcClient(address=address, name="test_client")
169
170    a = variables.Variable(2, dtype=dtypes.int32)
171    b = variables.Variable(3, dtype=dtypes.int32)
172
173    mul_or = client.call(
174        args=[a, b],
175        method_name="multiply",
176        output_specs=tensor_spec.TensorSpec((), dtypes.int32))
177
178    self.assertAllEqual(mul_or.is_ok(), True)
179    self.assertAllEqual(mul_or.get_value(), 6)
180
181    add_or = client.call(
182        args=[a, b],
183        method_name="add",
184        output_specs=tensor_spec.TensorSpec((), dtypes.int32))
185
186    self.assertAllEqual(add_or.is_ok(), True)
187    self.assertAllEqual(add_or.get_value(), 5)
188
189    # Test empty client name
190    client1 = rpc_ops.GrpcClient(address, list_registered_methods=True)
191    mul_or = client1.call(
192        args=[a, b],
193        method_name="multiply",
194        output_specs=tensor_spec.TensorSpec((), dtypes.int32))
195    self.assertAllEqual(mul_or.is_ok(), True)
196    self.assertAllEqual(mul_or.get_value(), 6)
197
198  def test_rpc_ops_non_blocking_convenience_methods(self):
199    @eager_def_function.function(input_signature=[
200        tensor_spec.TensorSpec([], dtypes.int32),
201        tensor_spec.TensorSpec([], dtypes.int32)
202    ])
203    def _remote_fn(a, b):
204      return math_ops.multiply(a, b)
205
206    port = portpicker.pick_unused_port()
207    address = "localhost:{}".format(port)
208    server_resource = rpc_ops.GrpcServer(address)
209
210    # Register TF function
211    server_resource.register("multiply", _remote_fn)
212
213    server_resource.start()
214    a = variables.Variable(2, dtype=dtypes.int32)
215    b = variables.Variable(3, dtype=dtypes.int32)
216
217    client = rpc_ops.GrpcClient(address, list_registered_methods=True)
218
219    mul_or = client.multiply(a, b)
220    self.assertAllEqual(mul_or.is_ok(), True)
221    self.assertAllEqual(mul_or.get_value(), 6)
222
223    self.assertEqual(client.multiply.__doc__,
224                     "RPC Call for multiply method to server " + address)
225
226  def test_rpc_ops_blocking_convenience_methods(self):
227    @eager_def_function.function(input_signature=[
228        tensor_spec.TensorSpec([], dtypes.int32),
229        tensor_spec.TensorSpec([], dtypes.int32)
230    ])
231    def _remote_fn(a, b):
232      return math_ops.multiply(a, b)
233
234    port = portpicker.pick_unused_port()
235    address = "localhost:{}".format(port)
236    server_resource = rpc_ops.GrpcServer(address)
237
238    # Register TF function
239    server_resource.register("multiply", _remote_fn)
240
241    server_resource.start()
242
243    client = rpc_ops.GrpcClient(address, list_registered_methods=True)
244
245    a = variables.Variable(2, dtype=dtypes.int32)
246    b = variables.Variable(3, dtype=dtypes.int32)
247    self.assertAllEqual(client.multiply_blocking(a, b), 6)
248
249    self.assertEqual(
250        client.multiply_blocking.__doc__,
251        "RPC Call for multiply method to server " + address)
252
253  def test_output_specs(self):
254
255    @eager_def_function.function(
256        input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
257    def test_dict(val):
258      return {"key": val}
259
260    @eager_def_function.function(
261        input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
262    def is_positive(a):
263      if a > 0:
264        return True
265      return False
266
267    @eager_def_function.function(input_signature=[])
268    def do_nothing():
269      return []
270
271    @eager_def_function.function(
272        input_signature=[tensor_spec.TensorSpec([], dtypes.int32)])
273    def test_nested_structure(v):
274      return {"test": (v, [v, v]), "test1": (v,)}
275
276    port = portpicker.pick_unused_port()
277    address = "localhost:{}".format(port)
278    server_resource = rpc_ops.GrpcServer(address)
279
280    server_resource.register("test_dict", test_dict)
281    server_resource.register("is_positive", is_positive)
282    server_resource.register("test_nested_structure", test_nested_structure)
283    server_resource.register("do_nothing", do_nothing)
284
285    server_resource.start()
286
287    client = rpc_ops.GrpcClient(
288        address=address, name="test_client", list_registered_methods=True)
289
290    a = variables.Variable(2, dtype=dtypes.int32)
291
292    result_or = client.test_dict(a)
293    self.assertAllEqual(result_or.is_ok(), True)
294    nest.map_structure(self.assertAllEqual, result_or.get_value(), {"key": 2})
295
296    result_or = client.is_positive(a)
297    self.assertTrue(result_or.is_ok())
298    self.assertTrue(result_or.get_value())
299
300    result_or = client.test_nested_structure(a)
301    self.assertAllEqual(result_or.is_ok(), True)
302    nest.map_structure(self.assertAllEqual, result_or.get_value(), {
303        "test": (2, [2, 2]),
304        "test1": (2,)
305    })
306
307    result_or = client.do_nothing()
308    self.assertAllEqual(result_or.is_ok(), True)
309    self.assertAllEqual(result_or.get_value(), [])
310
311  def test_input_specs(self):
312
313    @eager_def_function.function(input_signature=[{
314        "a": tensor_spec.TensorSpec([], dtypes.int32),
315        "b": tensor_spec.TensorSpec([], dtypes.int32)
316    }])
317    def test_input_dict(value):
318      return math_ops.add(value["a"], value["b"])
319
320    port = portpicker.pick_unused_port()
321    address = "localhost:{}".format(port)
322    server_resource = rpc_ops.GrpcServer(address)
323
324    server_resource.register("test_input_dict", test_input_dict)
325
326    server_resource.start()
327
328    client = rpc_ops.GrpcClient(
329        address=address, name="test_client", list_registered_methods=True)
330    a = variables.Variable(2, dtype=dtypes.int32)
331    b = variables.Variable(3, dtype=dtypes.int32)
332    result_or = client.test_input_dict({"a": a, "b": b})
333    self.assertAllEqual(result_or.is_ok(), True)
334    self.assertAllEqual(result_or.get_value(), 5)
335
336    with self.assertRaises(TypeError):
337      client.test_input_dict([a, b])
338
339  def test_call_register_ordering(self):
340    port = portpicker.pick_unused_port()
341    address = "localhost:{}".format(port)
342
343    # Create client succeeds before server start and registration
344    client = rpc_ops.GrpcClient(address)
345
346    # Create client with list_registered_methods fails before server is started.
347    with self.assertRaises(errors.DeadlineExceededError):
348      rpc_ops.GrpcClient(
349          address,
350          name="client1",
351          list_registered_methods=True,
352          timeout_in_ms=1)
353
354    v = variables.Variable(initial_value=0, dtype=dtypes.int64)
355
356    @eager_def_function.function(
357        input_signature=[tensor_spec.TensorSpec([], dtypes.int64)])
358    def assign_add(a):
359      v.assign_add(a)
360
361    @eager_def_function.function(input_signature=[])
362    def read_var():
363      return v.value()
364
365    server = rpc_ops.GrpcServer(address)
366
367    def start_server():
368      # Delay server start to test whether client creation also waits
369      # till server is up.
370      time.sleep(1)
371      server.register("assign_add", assign_add)
372      server.start()
373
374    t = threading.Thread(target=start_server)
375    t.start()
376
377    # Create same "client1" again should succeed.
378    client1_with_listed_methods = rpc_ops.GrpcClient(
379        address, name="client1", list_registered_methods=True)
380
381    result_or = client1_with_listed_methods.assign_add(
382        variables.Variable(2, dtype=dtypes.int64))
383    self.assertAllEqual(result_or.is_ok(), True)
384
385    result_or = client.call("assign_add",
386                            [variables.Variable(2, dtype=dtypes.int64)])
387    self.assertAllEqual(result_or.is_ok(), True)
388
389    # Create client with registered methods
390    client2_with_listed_methods = rpc_ops.GrpcClient(
391        address=address, name="client2", list_registered_methods=True)
392
393    result_or = client2_with_listed_methods.assign_add(
394        variables.Variable(2, dtype=dtypes.int64))
395    self.assertAllEqual(result_or.is_ok(), True)
396
397    self.assertAllEqual(v, 6)
398
399    # Register new method after server started.
400    with self.assertRaisesRegex(
401        errors.FailedPreconditionError,
402        "All methods must be registered before starting the server"):
403      server.register("read_var", read_var)
404
405  def test_client_timeout(self):
406    port = portpicker.pick_unused_port()
407    address = "localhost:{}".format(port)
408
409    @eager_def_function.function(input_signature=[
410        tensor_spec.TensorSpec([], dtypes.int32),
411        tensor_spec.TensorSpec([], dtypes.int32)
412    ])
413    def add(a, b):
414      return math_ops.add(a, b)
415
416    server = rpc_ops.GrpcServer(address)
417
418    def start_server():
419      # Delay server start to simulate deadline exceeded for 1st RPC call
420      # response. Client waits till server is started, thus it can trigger
421      # deadline exceeded.
422      time.sleep(1)
423      server.register("add", add)
424      server.start()
425
426    t = threading.Thread(target=start_server)
427    t.start()
428
429    def ensure_server_is_ready(client):
430      server_ready = False
431      while not server_ready:
432        result_or = client.call(
433            "add", [constant_op.constant(20),
434                    constant_op.constant(30)])
435        if result_or.is_ok():
436          server_ready = True
437        else:
438          error_code, _ = result_or.get_error()
439          if error_code == errors.UNAVAILABLE:
440            server_ready = False
441          else:
442            server_ready = True
443      return
444
445    # Create client with list_registered_methods fails before server is started.
446    with self.assertRaises(errors.DeadlineExceededError):
447      rpc_ops.GrpcClient(
448          address,
449          name="client1",
450          list_registered_methods=True,
451          timeout_in_ms=1)
452
453    # Create same client again should succeed with
454    # list_registered_methods=False. Default timeout for client is 1 ms.
455    client = rpc_ops.GrpcClient(
456        address, name="client1", list_registered_methods=False, timeout_in_ms=1)
457
458    ensure_server_is_ready(client)
459    # Make explicit RPC call, the timeout of 1 ms should lead to
460    # deadline exceeded error.
461
462    result_or = client.call(
463        "add", [constant_op.constant(20),
464                constant_op.constant(30)],
465        timeout_in_ms=1)
466    self.assertAllEqual(result_or.is_ok(), False)
467    error_code, error_message = result_or.get_error()
468    self.assertAllEqual(error_code, errors.DEADLINE_EXCEEDED, error_message)
469
470    # Specifying reasonable timeout for call should succeed.
471    result_or = client.call(
472        "add", [constant_op.constant(20),
473                constant_op.constant(30)],
474        timeout_in_ms=5000)
475    self.assertAllEqual(result_or.is_ok(), True)
476    error_code, _ = result_or.get_error()
477
478    # Test timeouts for convenience methods
479
480    # Restart server again with delay to simulate deadline exceeded.
481    del server
482    server = rpc_ops.GrpcServer(address)
483    t = threading.Thread(target=start_server)
484    t.start()
485
486    # Client with no default timeout.
487    client = rpc_ops.GrpcClient(
488        address, name="client2", list_registered_methods=True)
489
490    # Succeeds with reasonable timeout.
491    result_or = client.add(
492        constant_op.constant(20), constant_op.constant(30), timeout_in_ms=5000)
493    self.assertAllEqual(result_or.is_ok(), True)
494
495  def test_async_call_op_wrapper(self):
496    v = variables.Variable(initial_value=0, dtype=dtypes.int64)
497
498    @eager_def_function.function(
499        input_signature=[tensor_spec.TensorSpec([], dtypes.int64)])
500    def assign_add(a):
501      v.assign_add(a)
502
503    @eager_def_function.function(input_signature=[])
504    def read_var():
505      return v.value()
506
507    port = portpicker.pick_unused_port()
508    address = "localhost:{}".format(port)
509    server = rpc_ops.GrpcServer(address)
510    server.register("assign_add", assign_add)
511    server.register("read_var", read_var)
512    server.start()
513
514    client = rpc_ops.GrpcClient(address)
515
516    futures = []
517    for _ in range(10):
518      futures.append(
519          client.call("assign_add",
520                      [variables.Variable(2, dtype=dtypes.int64)]))
521
522    for f in futures:
523      f.is_ok()
524
525    result_or = client.call(
526        "read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)])
527
528    self.assertAllEqual(result_or.is_ok(), True)
529    self.assertAllEqual(result_or.get_value(), [20])
530
531  def test_rpc_call_op_in_tf_function(self):
532
533    @eager_def_function.function(input_signature=[
534        tensor_spec.TensorSpec([], dtypes.int32),
535        tensor_spec.TensorSpec([], dtypes.int32)
536    ])
537    def _remote_fn(a, b):
538      return math_ops.multiply(a, b)
539
540    port = portpicker.pick_unused_port()
541    address = "localhost:{}".format(port)
542    server_resource = rpc_ops.GrpcServer(address)
543
544    server_resource.register("remote_fn", _remote_fn)
545
546    server_resource.start()
547    client = rpc_ops.GrpcClient(address=address, name="test_client")
548
549    a = variables.Variable(2, dtype=dtypes.int32)
550    b = variables.Variable(3, dtype=dtypes.int32)
551
552    @eager_def_function.function
553    def call_fn():
554      result_or = client.call(
555          args=[a, b],
556          method_name="remote_fn",
557          output_specs=[tensor_spec.TensorSpec([], dtypes.int32)])
558
559      self.assertAllEqual(True, result_or.is_ok())
560      result = result_or.get_value()
561      self.assertEqual(len(result), 1)  # Call returns a list(tensors)
562      # TODO(ishark): Shape for output tensor is unknown currently.
563      # Add attribute for capturing TensorSpec for output and enable
564      # check below:
565      # self.assertIsNotNone(result[0].shape.rank)
566      return result
567
568    self.assertAllEqual(call_fn(), [6])
569
570  def test_resource_deletion(self):
571    port = portpicker.pick_unused_port()
572    address = "localhost:{}".format(port)
573    server = rpc_ops.GrpcServer(address)
574    server_handle = server._server_handle
575
576    # Test Future resource deletion
577    v = variables.Variable(initial_value=0, dtype=dtypes.int64)
578
579    @eager_def_function.function(input_signature=[])
580    def read_var():
581      return v.value()
582
583    server.register("read_var", read_var)
584
585    server.start()
586    client = rpc_ops.GrpcClient(address)
587
588    client_handle = client._client_handle
589
590    # Check future resource deletion without calling get_value.
591    def _create_and_delete_rpc_future():
592      handle = client.call(
593          "read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)])
594      return handle._status_or
595
596    @eager_def_function.function
597    def _create_and_delete_rpc_future_fn():
598      handle = client.call(
599          "read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)])
600      return handle._status_or
601
602    for _ in range(2):
603      handle = _create_and_delete_rpc_future()
604      with self.assertRaises(errors.NotFoundError):
605        resource_variable_ops.destroy_resource_op(
606            handle, ignore_lookup_error=False)
607
608    for _ in range(2):
609      handle = _create_and_delete_rpc_future_fn()
610      with self.assertRaises(errors.NotFoundError):
611        resource_variable_ops.destroy_resource_op(
612            handle, ignore_lookup_error=False)
613
614    # Check future resource deletion with calling get_value.
615    def _create_and_delete_with_future():
616      handle = client.call(
617          "read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)])
618      status_or_handle = handle._status_or
619      handle.get_value()
620      return status_or_handle
621
622    # Check future resource deletion with calling get_value with tf.function.
623    @eager_def_function.function
624    def _create_and_delete_with_future_fn():
625      handle = client.call(
626          "read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)])
627      status_or_handle = handle._status_or
628      handle.get_value()
629      return status_or_handle
630
631    for _ in range(2):
632      resource_handle = _create_and_delete_with_future()
633      with self.assertRaises(errors.NotFoundError):
634        resource_variable_ops.destroy_resource_op(
635            resource_handle, ignore_lookup_error=False)
636
637    for _ in range(2):
638      resource_handle = _create_and_delete_with_future_fn()
639      with self.assertRaises(errors.NotFoundError):
640        resource_variable_ops.destroy_resource_op(
641            resource_handle, ignore_lookup_error=False)
642
643    # Test server client resource gets deleted.
644    del client
645    with self.assertRaises(errors.NotFoundError):
646      resource_variable_ops.destroy_resource_op(
647          client_handle, ignore_lookup_error=False)
648
649    # Test server server resource gets deleted.
650    del server
651    with self.assertRaises(errors.NotFoundError):
652      resource_variable_ops.destroy_resource_op(
653          server_handle, ignore_lookup_error=False)
654
655  def test_rpc_error(self):
656    v = variables.Variable(initial_value=0, dtype=dtypes.int64)
657
658    @eager_def_function.function(
659        input_signature=[tensor_spec.TensorSpec([], dtypes.int64)])
660    def assign_add(a):
661      v.assign_add(a)
662
663    @eager_def_function.function(input_signature=[])
664    def read_var():
665      return v.value()
666
667    port = portpicker.pick_unused_port()
668    address = "localhost:{}".format(port)
669    server = rpc_ops.GrpcServer(address)
670    server.register("assign_add", assign_add)
671    server.register("read_var", read_var)
672    server.start()
673
674    client = rpc_ops.GrpcClient(address, list_registered_methods=True)
675
676    # confirm it works as expected when arguments are passed.
677    result_or = client.call("assign_add",
678                            [variables.Variable(2, dtype=dtypes.int64)])
679    self.assertAllEqual(result_or.is_ok(), True)
680    result_or = client.call(
681        "read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)])
682    self.assertAllEqual(result_or.is_ok(), True)
683    self.assertAllEqual(result_or.get_value(), [2])
684    result_or = client.assign_add(variables.Variable(2, dtype=dtypes.int64))
685    self.assertAllEqual(True, result_or.is_ok())
686
687    result_or = client.read_var()
688    self.assertAllEqual(True, result_or.is_ok())
689    self.assertAllEqual(result_or.get_value(), 4)
690
691    # Fails with invalid argument error when no arguments are passed.
692    result_or = client.call("assign_add")
693    self.assertAllEqual(result_or.is_ok(), False)
694    error_code, _ = result_or.get_error()
695    self.assertAllEqual(error_code, errors.INVALID_ARGUMENT)
696
697    del server
698    with self.assertRaises(errors.DeadlineExceededError):
699      _ = client.assign_add_blocking(
700          variables.Variable(2, dtype=dtypes.int64), timeout_in_ms=1)
701
702  def test_captured_inputs(self):
703    v = variables.Variable(initial_value=0, dtype=dtypes.int64)
704
705    @eager_def_function.function(
706        input_signature=[tensor_spec.TensorSpec([], dtypes.int64)])
707    def assign_add(a):
708      v.assign_add(a)
709
710    @eager_def_function.function(input_signature=[])
711    def read_var():
712      return v.value()
713
714    port = portpicker.pick_unused_port()
715    address = "localhost:{}".format(port)
716    server = rpc_ops.GrpcServer(address)
717    server.register("assign_add", assign_add)
718    server.register("read_var", read_var)
719
720    server.start()
721
722    client = rpc_ops.GrpcClient(address)
723
724    result_or = client.call("assign_add",
725                            [variables.Variable(2, dtype=dtypes.int64)])
726    self.assertAllEqual(result_or.is_ok(), True)
727    result_or = client.call("assign_add",
728                            [variables.Variable(2, dtype=dtypes.int64)])
729    self.assertAllEqual(result_or.is_ok(), True)
730    result_or = client.call(
731        "read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)])
732
733    self.assertAllEqual(result_or.is_ok(), True)
734    self.assertAllEqual(result_or.get_value(), [4])
735
736  def test_register_method_twice(self):
737    v = variables.Variable(initial_value=0, dtype=dtypes.int64)
738
739    @eager_def_function.function(
740        input_signature=[tensor_spec.TensorSpec([], dtypes.int64)])
741    def assign_add(a):
742      v.assign_add(a)
743
744    @eager_def_function.function(
745        input_signature=[tensor_spec.TensorSpec([], dtypes.int64)])
746    def assign(a):
747      v.assign(a)
748
749    port = portpicker.pick_unused_port()
750    address = "localhost:{}".format(port)
751    server = rpc_ops.GrpcServer(address)
752    server.register("assign", assign_add)
753    with self.assertRaisesRegex(errors.InvalidArgumentError,
754                                "assign is already registered."):
755      # Reusing the same error name.
756      server.register("assign", assign)
757
758  def test_tf_function_register_without_input_signature(self):
759    v = variables.Variable(initial_value=0, dtype=dtypes.int64)
760
761    @eager_def_function.function
762    def assign(a):
763      v.assign(a)
764
765    port = portpicker.pick_unused_port()
766    address = "localhost:{}".format(port)
767    server = rpc_ops.GrpcServer(address)
768    with self.assertRaisesRegex(
769        ValueError, "Input signature not specified for the function."):
770      server.register("assign", assign)
771
772    # Register without input signature should work for functions without input
773    # args.
774    @eager_def_function.function
775    def read_var():
776      return v.value()
777
778    server.register("read_var", read_var)
779
780  def test_multi_device_resource(self):
781    elements = np.random.randint(100, size=[200])
782
783    with ops.device("/device:CPU:1"):
784      queue = data_flow_ops.FIFOQueue(200, dtypes.int64, shapes=[])
785
786    @eager_def_function.function()
787    def populate_queue():
788      queue.enqueue_many(elements)
789      queue.close()
790
791    with ops.device("/device:CPU:0"):
792      port = portpicker.pick_unused_port()
793      address = "localhost:{}".format(port)
794      server = rpc_ops.GrpcServer(address)
795      server.register("populate_queue", populate_queue)
796      server.start()
797
798      client = rpc_ops.GrpcClient(address, list_registered_methods=True)
799      client.populate_queue()
800
801    for e in elements:
802      self.assertAllEqual(e, queue.dequeue())
803
804  def test_queue_resource(self):
805    elements = np.random.randint(100, size=[200])
806    queue = data_flow_ops.FIFOQueue(200, dtypes.int64, shapes=[])
807
808    @eager_def_function.function()
809    def populate_queue():
810      queue.enqueue_many(elements)
811      queue.close()
812
813    port = portpicker.pick_unused_port()
814    address = "localhost:{}".format(port)
815    server = rpc_ops.GrpcServer(address)
816    server.register("populate_queue", populate_queue)
817    server.start()
818
819    client = rpc_ops.GrpcClient(address, list_registered_methods=True)
820    client.populate_queue()
821
822    for e in elements:
823      self.assertAllEqual(e, queue.dequeue())
824
825  def test_multi_device_resource_cpu(self):
826    with ops.device("/device:cpu:1"):
827      v = variables.Variable(initial_value=0, dtype=dtypes.int64)
828
829    @eager_def_function.function(
830        input_signature=[tensor_spec.TensorSpec([], dtypes.int64)])
831    def assign_add(a):
832      v.assign_add(a)
833
834    with ops.device("/device:CPU:0"):
835      port = portpicker.pick_unused_port()
836      address = "localhost:{}".format(port)
837      server = rpc_ops.GrpcServer(address)
838      server.register("assign_add", assign_add)
839      server.start()
840
841      client = rpc_ops.GrpcClient(address, list_registered_methods=True)
842      result_or = client.assign_add(variables.Variable(2, dtype=dtypes.int64))
843      self.assertAllEqual(result_or.is_ok(), True)
844
845    self.assertAllEqual(v, 2)
846
847
848if __name__ == "__main__":
849  ops.enable_eager_execution()
850  test.main()
851