1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for rpc_ops.py.""" 16 17import threading 18import time 19import numpy as np 20import portpicker 21 22from tensorflow.python.distribute.experimental.rpc import rpc_ops 23from tensorflow.python.eager import context 24from tensorflow.python.eager import def_function as eager_def_function 25from tensorflow.python.framework import config 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import errors 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_spec 31from tensorflow.python.framework import test_util 32from tensorflow.python.ops import data_flow_ops 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops import resource_variable_ops 35from tensorflow.python.ops import variables 36from tensorflow.python.platform import test 37from tensorflow.python.util import nest 38 39 40@test_util.with_eager_op_as_function 41class RpcOpsTest(test.TestCase): 42 43 def setUp(self): 44 super(RpcOpsTest, self).setUp() 45 cpus = config.list_physical_devices("CPU") 46 # Set 2 virtual CPUs 47 config.set_logical_device_configuration(cpus[0], [ 48 context.LogicalDeviceConfiguration(), 49 context.LogicalDeviceConfiguration() 50 ]) 51 52 def test_generated_rpc_ops(self): 53 @eager_def_function.function(input_signature=[ 54 tensor_spec.TensorSpec([], dtypes.int32), 55 tensor_spec.TensorSpec([], dtypes.int32) 56 ]) 57 def remote_fn(a, b): 58 return math_ops.multiply(a, b) 59 60 concrete_remote_fn = remote_fn.get_concrete_function() 61 62 a = variables.Variable(2, dtype=dtypes.int32) 63 b = variables.Variable(3, dtype=dtypes.int32) 64 65 port = portpicker.pick_unused_port() 66 address = "localhost:{}".format(port) 67 server_resource = rpc_ops.gen_rpc_ops.rpc_server(server_address=address) 68 69 rpc_ops.gen_rpc_ops.rpc_server_register( 70 server_resource, 71 f=concrete_remote_fn, 72 captured_inputs=concrete_remote_fn.captured_inputs, 73 output_specs=rpc_ops.get_output_specs_from_function(concrete_remote_fn), 74 method_name="multiply") 75 76 rpc_ops.gen_rpc_ops.rpc_server_start(server_resource) 77 client_handle, _ = rpc_ops.gen_rpc_ops.rpc_client( 78 server_address=address, timeout_in_ms=5000) 79 future_resource, deleter = rpc_ops.gen_rpc_ops.rpc_call( 80 client_handle, args=[a, b], method_name="multiply", timeout_in_ms=0) 81 82 error_code, _ = rpc_ops.gen_rpc_ops.rpc_check_status(future_resource) 83 self.assertAllEqual(error_code, 0) 84 self.assertAllEqual( 85 rpc_ops.gen_rpc_ops.rpc_get_value(future_resource, Tout=[dtypes.int32]), 86 [6]) 87 88 resource_variable_ops.EagerResourceDeleter( 89 handle=server_resource, handle_device=server_resource.device) 90 91 resource_variable_ops.EagerResourceDeleter( 92 handle=client_handle, handle_device=client_handle.device) 93 94 rpc_ops.gen_rpc_ops.delete_rpc_future_resource(future_resource, deleter) 95 96 def test_exported_rpc_api_static_factory(self): 97 98 @eager_def_function.function(input_signature=[ 99 tensor_spec.TensorSpec([], dtypes.int32), 100 tensor_spec.TensorSpec([], dtypes.int32) 101 ]) 102 def _remote_fn(a, b): 103 return math_ops.multiply(a, b) 104 105 port = portpicker.pick_unused_port() 106 address = "localhost:{}".format(port) 107 server_resource = rpc_ops.Server.create("grpc", address) 108 server_resource.register("multiply", _remote_fn) 109 110 server_resource.start() 111 client = rpc_ops.Client.create("grpc", address=address, name="test_client") 112 113 a = variables.Variable(2, dtype=dtypes.int32) 114 b = variables.Variable(3, dtype=dtypes.int32) 115 116 mul_or = client.call( 117 args=[a, b], 118 method_name="multiply", 119 output_specs=tensor_spec.TensorSpec((), dtypes.int32)) 120 121 self.assertAllEqual(mul_or.is_ok(), True) 122 self.assertAllEqual(mul_or.get_value(), 6) 123 124 # Test empty client name 125 client1 = rpc_ops.Client.create("grpc", address) 126 mul_or = client1.call( 127 args=[a, b], 128 method_name="multiply", 129 output_specs=tensor_spec.TensorSpec((), dtypes.int32)) 130 self.assertAllEqual(mul_or.is_ok(), True) 131 self.assertAllEqual(mul_or.get_value(), 6) 132 133 # Test without output_spec 134 mul_or = client1.multiply(a, b) 135 self.assertAllEqual(mul_or.is_ok(), True) 136 self.assertAllEqual(mul_or.get_value(), 6) 137 138 self.assertEqual(client1.multiply.__doc__, 139 "RPC Call for multiply method to server " + address) 140 141 def test_rpc_ops_call_method(self): 142 143 @eager_def_function.function(input_signature=[ 144 tensor_spec.TensorSpec([], dtypes.int32), 145 tensor_spec.TensorSpec([], dtypes.int32) 146 ]) 147 def _remote_fn(a, b): 148 return math_ops.multiply(a, b) 149 150 port = portpicker.pick_unused_port() 151 address = "localhost:{}".format(port) 152 server_resource = rpc_ops.GrpcServer(address) 153 154 @eager_def_function.function(input_signature=[ 155 tensor_spec.TensorSpec([], dtypes.int32), 156 tensor_spec.TensorSpec([], dtypes.int32) 157 ]) 158 def add_fn(a, b): 159 return math_ops.add(a, b) 160 161 # Register TF function 162 server_resource.register("multiply", _remote_fn) 163 164 # Register concrete Function 165 server_resource.register("add", add_fn.get_concrete_function()) 166 167 server_resource.start() 168 client = rpc_ops.GrpcClient(address=address, name="test_client") 169 170 a = variables.Variable(2, dtype=dtypes.int32) 171 b = variables.Variable(3, dtype=dtypes.int32) 172 173 mul_or = client.call( 174 args=[a, b], 175 method_name="multiply", 176 output_specs=tensor_spec.TensorSpec((), dtypes.int32)) 177 178 self.assertAllEqual(mul_or.is_ok(), True) 179 self.assertAllEqual(mul_or.get_value(), 6) 180 181 add_or = client.call( 182 args=[a, b], 183 method_name="add", 184 output_specs=tensor_spec.TensorSpec((), dtypes.int32)) 185 186 self.assertAllEqual(add_or.is_ok(), True) 187 self.assertAllEqual(add_or.get_value(), 5) 188 189 # Test empty client name 190 client1 = rpc_ops.GrpcClient(address, list_registered_methods=True) 191 mul_or = client1.call( 192 args=[a, b], 193 method_name="multiply", 194 output_specs=tensor_spec.TensorSpec((), dtypes.int32)) 195 self.assertAllEqual(mul_or.is_ok(), True) 196 self.assertAllEqual(mul_or.get_value(), 6) 197 198 def test_rpc_ops_non_blocking_convenience_methods(self): 199 @eager_def_function.function(input_signature=[ 200 tensor_spec.TensorSpec([], dtypes.int32), 201 tensor_spec.TensorSpec([], dtypes.int32) 202 ]) 203 def _remote_fn(a, b): 204 return math_ops.multiply(a, b) 205 206 port = portpicker.pick_unused_port() 207 address = "localhost:{}".format(port) 208 server_resource = rpc_ops.GrpcServer(address) 209 210 # Register TF function 211 server_resource.register("multiply", _remote_fn) 212 213 server_resource.start() 214 a = variables.Variable(2, dtype=dtypes.int32) 215 b = variables.Variable(3, dtype=dtypes.int32) 216 217 client = rpc_ops.GrpcClient(address, list_registered_methods=True) 218 219 mul_or = client.multiply(a, b) 220 self.assertAllEqual(mul_or.is_ok(), True) 221 self.assertAllEqual(mul_or.get_value(), 6) 222 223 self.assertEqual(client.multiply.__doc__, 224 "RPC Call for multiply method to server " + address) 225 226 def test_rpc_ops_blocking_convenience_methods(self): 227 @eager_def_function.function(input_signature=[ 228 tensor_spec.TensorSpec([], dtypes.int32), 229 tensor_spec.TensorSpec([], dtypes.int32) 230 ]) 231 def _remote_fn(a, b): 232 return math_ops.multiply(a, b) 233 234 port = portpicker.pick_unused_port() 235 address = "localhost:{}".format(port) 236 server_resource = rpc_ops.GrpcServer(address) 237 238 # Register TF function 239 server_resource.register("multiply", _remote_fn) 240 241 server_resource.start() 242 243 client = rpc_ops.GrpcClient(address, list_registered_methods=True) 244 245 a = variables.Variable(2, dtype=dtypes.int32) 246 b = variables.Variable(3, dtype=dtypes.int32) 247 self.assertAllEqual(client.multiply_blocking(a, b), 6) 248 249 self.assertEqual( 250 client.multiply_blocking.__doc__, 251 "RPC Call for multiply method to server " + address) 252 253 def test_output_specs(self): 254 255 @eager_def_function.function( 256 input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) 257 def test_dict(val): 258 return {"key": val} 259 260 @eager_def_function.function( 261 input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) 262 def is_positive(a): 263 if a > 0: 264 return True 265 return False 266 267 @eager_def_function.function(input_signature=[]) 268 def do_nothing(): 269 return [] 270 271 @eager_def_function.function( 272 input_signature=[tensor_spec.TensorSpec([], dtypes.int32)]) 273 def test_nested_structure(v): 274 return {"test": (v, [v, v]), "test1": (v,)} 275 276 port = portpicker.pick_unused_port() 277 address = "localhost:{}".format(port) 278 server_resource = rpc_ops.GrpcServer(address) 279 280 server_resource.register("test_dict", test_dict) 281 server_resource.register("is_positive", is_positive) 282 server_resource.register("test_nested_structure", test_nested_structure) 283 server_resource.register("do_nothing", do_nothing) 284 285 server_resource.start() 286 287 client = rpc_ops.GrpcClient( 288 address=address, name="test_client", list_registered_methods=True) 289 290 a = variables.Variable(2, dtype=dtypes.int32) 291 292 result_or = client.test_dict(a) 293 self.assertAllEqual(result_or.is_ok(), True) 294 nest.map_structure(self.assertAllEqual, result_or.get_value(), {"key": 2}) 295 296 result_or = client.is_positive(a) 297 self.assertTrue(result_or.is_ok()) 298 self.assertTrue(result_or.get_value()) 299 300 result_or = client.test_nested_structure(a) 301 self.assertAllEqual(result_or.is_ok(), True) 302 nest.map_structure(self.assertAllEqual, result_or.get_value(), { 303 "test": (2, [2, 2]), 304 "test1": (2,) 305 }) 306 307 result_or = client.do_nothing() 308 self.assertAllEqual(result_or.is_ok(), True) 309 self.assertAllEqual(result_or.get_value(), []) 310 311 def test_input_specs(self): 312 313 @eager_def_function.function(input_signature=[{ 314 "a": tensor_spec.TensorSpec([], dtypes.int32), 315 "b": tensor_spec.TensorSpec([], dtypes.int32) 316 }]) 317 def test_input_dict(value): 318 return math_ops.add(value["a"], value["b"]) 319 320 port = portpicker.pick_unused_port() 321 address = "localhost:{}".format(port) 322 server_resource = rpc_ops.GrpcServer(address) 323 324 server_resource.register("test_input_dict", test_input_dict) 325 326 server_resource.start() 327 328 client = rpc_ops.GrpcClient( 329 address=address, name="test_client", list_registered_methods=True) 330 a = variables.Variable(2, dtype=dtypes.int32) 331 b = variables.Variable(3, dtype=dtypes.int32) 332 result_or = client.test_input_dict({"a": a, "b": b}) 333 self.assertAllEqual(result_or.is_ok(), True) 334 self.assertAllEqual(result_or.get_value(), 5) 335 336 with self.assertRaises(TypeError): 337 client.test_input_dict([a, b]) 338 339 def test_call_register_ordering(self): 340 port = portpicker.pick_unused_port() 341 address = "localhost:{}".format(port) 342 343 # Create client succeeds before server start and registration 344 client = rpc_ops.GrpcClient(address) 345 346 # Create client with list_registered_methods fails before server is started. 347 with self.assertRaises(errors.DeadlineExceededError): 348 rpc_ops.GrpcClient( 349 address, 350 name="client1", 351 list_registered_methods=True, 352 timeout_in_ms=1) 353 354 v = variables.Variable(initial_value=0, dtype=dtypes.int64) 355 356 @eager_def_function.function( 357 input_signature=[tensor_spec.TensorSpec([], dtypes.int64)]) 358 def assign_add(a): 359 v.assign_add(a) 360 361 @eager_def_function.function(input_signature=[]) 362 def read_var(): 363 return v.value() 364 365 server = rpc_ops.GrpcServer(address) 366 367 def start_server(): 368 # Delay server start to test whether client creation also waits 369 # till server is up. 370 time.sleep(1) 371 server.register("assign_add", assign_add) 372 server.start() 373 374 t = threading.Thread(target=start_server) 375 t.start() 376 377 # Create same "client1" again should succeed. 378 client1_with_listed_methods = rpc_ops.GrpcClient( 379 address, name="client1", list_registered_methods=True) 380 381 result_or = client1_with_listed_methods.assign_add( 382 variables.Variable(2, dtype=dtypes.int64)) 383 self.assertAllEqual(result_or.is_ok(), True) 384 385 result_or = client.call("assign_add", 386 [variables.Variable(2, dtype=dtypes.int64)]) 387 self.assertAllEqual(result_or.is_ok(), True) 388 389 # Create client with registered methods 390 client2_with_listed_methods = rpc_ops.GrpcClient( 391 address=address, name="client2", list_registered_methods=True) 392 393 result_or = client2_with_listed_methods.assign_add( 394 variables.Variable(2, dtype=dtypes.int64)) 395 self.assertAllEqual(result_or.is_ok(), True) 396 397 self.assertAllEqual(v, 6) 398 399 # Register new method after server started. 400 with self.assertRaisesRegex( 401 errors.FailedPreconditionError, 402 "All methods must be registered before starting the server"): 403 server.register("read_var", read_var) 404 405 def test_client_timeout(self): 406 port = portpicker.pick_unused_port() 407 address = "localhost:{}".format(port) 408 409 @eager_def_function.function(input_signature=[ 410 tensor_spec.TensorSpec([], dtypes.int32), 411 tensor_spec.TensorSpec([], dtypes.int32) 412 ]) 413 def add(a, b): 414 return math_ops.add(a, b) 415 416 server = rpc_ops.GrpcServer(address) 417 418 def start_server(): 419 # Delay server start to simulate deadline exceeded for 1st RPC call 420 # response. Client waits till server is started, thus it can trigger 421 # deadline exceeded. 422 time.sleep(1) 423 server.register("add", add) 424 server.start() 425 426 t = threading.Thread(target=start_server) 427 t.start() 428 429 def ensure_server_is_ready(client): 430 server_ready = False 431 while not server_ready: 432 result_or = client.call( 433 "add", [constant_op.constant(20), 434 constant_op.constant(30)]) 435 if result_or.is_ok(): 436 server_ready = True 437 else: 438 error_code, _ = result_or.get_error() 439 if error_code == errors.UNAVAILABLE: 440 server_ready = False 441 else: 442 server_ready = True 443 return 444 445 # Create client with list_registered_methods fails before server is started. 446 with self.assertRaises(errors.DeadlineExceededError): 447 rpc_ops.GrpcClient( 448 address, 449 name="client1", 450 list_registered_methods=True, 451 timeout_in_ms=1) 452 453 # Create same client again should succeed with 454 # list_registered_methods=False. Default timeout for client is 1 ms. 455 client = rpc_ops.GrpcClient( 456 address, name="client1", list_registered_methods=False, timeout_in_ms=1) 457 458 ensure_server_is_ready(client) 459 # Make explicit RPC call, the timeout of 1 ms should lead to 460 # deadline exceeded error. 461 462 result_or = client.call( 463 "add", [constant_op.constant(20), 464 constant_op.constant(30)], 465 timeout_in_ms=1) 466 self.assertAllEqual(result_or.is_ok(), False) 467 error_code, error_message = result_or.get_error() 468 self.assertAllEqual(error_code, errors.DEADLINE_EXCEEDED, error_message) 469 470 # Specifying reasonable timeout for call should succeed. 471 result_or = client.call( 472 "add", [constant_op.constant(20), 473 constant_op.constant(30)], 474 timeout_in_ms=5000) 475 self.assertAllEqual(result_or.is_ok(), True) 476 error_code, _ = result_or.get_error() 477 478 # Test timeouts for convenience methods 479 480 # Restart server again with delay to simulate deadline exceeded. 481 del server 482 server = rpc_ops.GrpcServer(address) 483 t = threading.Thread(target=start_server) 484 t.start() 485 486 # Client with no default timeout. 487 client = rpc_ops.GrpcClient( 488 address, name="client2", list_registered_methods=True) 489 490 # Succeeds with reasonable timeout. 491 result_or = client.add( 492 constant_op.constant(20), constant_op.constant(30), timeout_in_ms=5000) 493 self.assertAllEqual(result_or.is_ok(), True) 494 495 def test_async_call_op_wrapper(self): 496 v = variables.Variable(initial_value=0, dtype=dtypes.int64) 497 498 @eager_def_function.function( 499 input_signature=[tensor_spec.TensorSpec([], dtypes.int64)]) 500 def assign_add(a): 501 v.assign_add(a) 502 503 @eager_def_function.function(input_signature=[]) 504 def read_var(): 505 return v.value() 506 507 port = portpicker.pick_unused_port() 508 address = "localhost:{}".format(port) 509 server = rpc_ops.GrpcServer(address) 510 server.register("assign_add", assign_add) 511 server.register("read_var", read_var) 512 server.start() 513 514 client = rpc_ops.GrpcClient(address) 515 516 futures = [] 517 for _ in range(10): 518 futures.append( 519 client.call("assign_add", 520 [variables.Variable(2, dtype=dtypes.int64)])) 521 522 for f in futures: 523 f.is_ok() 524 525 result_or = client.call( 526 "read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)]) 527 528 self.assertAllEqual(result_or.is_ok(), True) 529 self.assertAllEqual(result_or.get_value(), [20]) 530 531 def test_rpc_call_op_in_tf_function(self): 532 533 @eager_def_function.function(input_signature=[ 534 tensor_spec.TensorSpec([], dtypes.int32), 535 tensor_spec.TensorSpec([], dtypes.int32) 536 ]) 537 def _remote_fn(a, b): 538 return math_ops.multiply(a, b) 539 540 port = portpicker.pick_unused_port() 541 address = "localhost:{}".format(port) 542 server_resource = rpc_ops.GrpcServer(address) 543 544 server_resource.register("remote_fn", _remote_fn) 545 546 server_resource.start() 547 client = rpc_ops.GrpcClient(address=address, name="test_client") 548 549 a = variables.Variable(2, dtype=dtypes.int32) 550 b = variables.Variable(3, dtype=dtypes.int32) 551 552 @eager_def_function.function 553 def call_fn(): 554 result_or = client.call( 555 args=[a, b], 556 method_name="remote_fn", 557 output_specs=[tensor_spec.TensorSpec([], dtypes.int32)]) 558 559 self.assertAllEqual(True, result_or.is_ok()) 560 result = result_or.get_value() 561 self.assertEqual(len(result), 1) # Call returns a list(tensors) 562 # TODO(ishark): Shape for output tensor is unknown currently. 563 # Add attribute for capturing TensorSpec for output and enable 564 # check below: 565 # self.assertIsNotNone(result[0].shape.rank) 566 return result 567 568 self.assertAllEqual(call_fn(), [6]) 569 570 def test_resource_deletion(self): 571 port = portpicker.pick_unused_port() 572 address = "localhost:{}".format(port) 573 server = rpc_ops.GrpcServer(address) 574 server_handle = server._server_handle 575 576 # Test Future resource deletion 577 v = variables.Variable(initial_value=0, dtype=dtypes.int64) 578 579 @eager_def_function.function(input_signature=[]) 580 def read_var(): 581 return v.value() 582 583 server.register("read_var", read_var) 584 585 server.start() 586 client = rpc_ops.GrpcClient(address) 587 588 client_handle = client._client_handle 589 590 # Check future resource deletion without calling get_value. 591 def _create_and_delete_rpc_future(): 592 handle = client.call( 593 "read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)]) 594 return handle._status_or 595 596 @eager_def_function.function 597 def _create_and_delete_rpc_future_fn(): 598 handle = client.call( 599 "read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)]) 600 return handle._status_or 601 602 for _ in range(2): 603 handle = _create_and_delete_rpc_future() 604 with self.assertRaises(errors.NotFoundError): 605 resource_variable_ops.destroy_resource_op( 606 handle, ignore_lookup_error=False) 607 608 for _ in range(2): 609 handle = _create_and_delete_rpc_future_fn() 610 with self.assertRaises(errors.NotFoundError): 611 resource_variable_ops.destroy_resource_op( 612 handle, ignore_lookup_error=False) 613 614 # Check future resource deletion with calling get_value. 615 def _create_and_delete_with_future(): 616 handle = client.call( 617 "read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)]) 618 status_or_handle = handle._status_or 619 handle.get_value() 620 return status_or_handle 621 622 # Check future resource deletion with calling get_value with tf.function. 623 @eager_def_function.function 624 def _create_and_delete_with_future_fn(): 625 handle = client.call( 626 "read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)]) 627 status_or_handle = handle._status_or 628 handle.get_value() 629 return status_or_handle 630 631 for _ in range(2): 632 resource_handle = _create_and_delete_with_future() 633 with self.assertRaises(errors.NotFoundError): 634 resource_variable_ops.destroy_resource_op( 635 resource_handle, ignore_lookup_error=False) 636 637 for _ in range(2): 638 resource_handle = _create_and_delete_with_future_fn() 639 with self.assertRaises(errors.NotFoundError): 640 resource_variable_ops.destroy_resource_op( 641 resource_handle, ignore_lookup_error=False) 642 643 # Test server client resource gets deleted. 644 del client 645 with self.assertRaises(errors.NotFoundError): 646 resource_variable_ops.destroy_resource_op( 647 client_handle, ignore_lookup_error=False) 648 649 # Test server server resource gets deleted. 650 del server 651 with self.assertRaises(errors.NotFoundError): 652 resource_variable_ops.destroy_resource_op( 653 server_handle, ignore_lookup_error=False) 654 655 def test_rpc_error(self): 656 v = variables.Variable(initial_value=0, dtype=dtypes.int64) 657 658 @eager_def_function.function( 659 input_signature=[tensor_spec.TensorSpec([], dtypes.int64)]) 660 def assign_add(a): 661 v.assign_add(a) 662 663 @eager_def_function.function(input_signature=[]) 664 def read_var(): 665 return v.value() 666 667 port = portpicker.pick_unused_port() 668 address = "localhost:{}".format(port) 669 server = rpc_ops.GrpcServer(address) 670 server.register("assign_add", assign_add) 671 server.register("read_var", read_var) 672 server.start() 673 674 client = rpc_ops.GrpcClient(address, list_registered_methods=True) 675 676 # confirm it works as expected when arguments are passed. 677 result_or = client.call("assign_add", 678 [variables.Variable(2, dtype=dtypes.int64)]) 679 self.assertAllEqual(result_or.is_ok(), True) 680 result_or = client.call( 681 "read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)]) 682 self.assertAllEqual(result_or.is_ok(), True) 683 self.assertAllEqual(result_or.get_value(), [2]) 684 result_or = client.assign_add(variables.Variable(2, dtype=dtypes.int64)) 685 self.assertAllEqual(True, result_or.is_ok()) 686 687 result_or = client.read_var() 688 self.assertAllEqual(True, result_or.is_ok()) 689 self.assertAllEqual(result_or.get_value(), 4) 690 691 # Fails with invalid argument error when no arguments are passed. 692 result_or = client.call("assign_add") 693 self.assertAllEqual(result_or.is_ok(), False) 694 error_code, _ = result_or.get_error() 695 self.assertAllEqual(error_code, errors.INVALID_ARGUMENT) 696 697 del server 698 with self.assertRaises(errors.DeadlineExceededError): 699 _ = client.assign_add_blocking( 700 variables.Variable(2, dtype=dtypes.int64), timeout_in_ms=1) 701 702 def test_captured_inputs(self): 703 v = variables.Variable(initial_value=0, dtype=dtypes.int64) 704 705 @eager_def_function.function( 706 input_signature=[tensor_spec.TensorSpec([], dtypes.int64)]) 707 def assign_add(a): 708 v.assign_add(a) 709 710 @eager_def_function.function(input_signature=[]) 711 def read_var(): 712 return v.value() 713 714 port = portpicker.pick_unused_port() 715 address = "localhost:{}".format(port) 716 server = rpc_ops.GrpcServer(address) 717 server.register("assign_add", assign_add) 718 server.register("read_var", read_var) 719 720 server.start() 721 722 client = rpc_ops.GrpcClient(address) 723 724 result_or = client.call("assign_add", 725 [variables.Variable(2, dtype=dtypes.int64)]) 726 self.assertAllEqual(result_or.is_ok(), True) 727 result_or = client.call("assign_add", 728 [variables.Variable(2, dtype=dtypes.int64)]) 729 self.assertAllEqual(result_or.is_ok(), True) 730 result_or = client.call( 731 "read_var", output_specs=[tensor_spec.TensorSpec([], dtypes.int64)]) 732 733 self.assertAllEqual(result_or.is_ok(), True) 734 self.assertAllEqual(result_or.get_value(), [4]) 735 736 def test_register_method_twice(self): 737 v = variables.Variable(initial_value=0, dtype=dtypes.int64) 738 739 @eager_def_function.function( 740 input_signature=[tensor_spec.TensorSpec([], dtypes.int64)]) 741 def assign_add(a): 742 v.assign_add(a) 743 744 @eager_def_function.function( 745 input_signature=[tensor_spec.TensorSpec([], dtypes.int64)]) 746 def assign(a): 747 v.assign(a) 748 749 port = portpicker.pick_unused_port() 750 address = "localhost:{}".format(port) 751 server = rpc_ops.GrpcServer(address) 752 server.register("assign", assign_add) 753 with self.assertRaisesRegex(errors.InvalidArgumentError, 754 "assign is already registered."): 755 # Reusing the same error name. 756 server.register("assign", assign) 757 758 def test_tf_function_register_without_input_signature(self): 759 v = variables.Variable(initial_value=0, dtype=dtypes.int64) 760 761 @eager_def_function.function 762 def assign(a): 763 v.assign(a) 764 765 port = portpicker.pick_unused_port() 766 address = "localhost:{}".format(port) 767 server = rpc_ops.GrpcServer(address) 768 with self.assertRaisesRegex( 769 ValueError, "Input signature not specified for the function."): 770 server.register("assign", assign) 771 772 # Register without input signature should work for functions without input 773 # args. 774 @eager_def_function.function 775 def read_var(): 776 return v.value() 777 778 server.register("read_var", read_var) 779 780 def test_multi_device_resource(self): 781 elements = np.random.randint(100, size=[200]) 782 783 with ops.device("/device:CPU:1"): 784 queue = data_flow_ops.FIFOQueue(200, dtypes.int64, shapes=[]) 785 786 @eager_def_function.function() 787 def populate_queue(): 788 queue.enqueue_many(elements) 789 queue.close() 790 791 with ops.device("/device:CPU:0"): 792 port = portpicker.pick_unused_port() 793 address = "localhost:{}".format(port) 794 server = rpc_ops.GrpcServer(address) 795 server.register("populate_queue", populate_queue) 796 server.start() 797 798 client = rpc_ops.GrpcClient(address, list_registered_methods=True) 799 client.populate_queue() 800 801 for e in elements: 802 self.assertAllEqual(e, queue.dequeue()) 803 804 def test_queue_resource(self): 805 elements = np.random.randint(100, size=[200]) 806 queue = data_flow_ops.FIFOQueue(200, dtypes.int64, shapes=[]) 807 808 @eager_def_function.function() 809 def populate_queue(): 810 queue.enqueue_many(elements) 811 queue.close() 812 813 port = portpicker.pick_unused_port() 814 address = "localhost:{}".format(port) 815 server = rpc_ops.GrpcServer(address) 816 server.register("populate_queue", populate_queue) 817 server.start() 818 819 client = rpc_ops.GrpcClient(address, list_registered_methods=True) 820 client.populate_queue() 821 822 for e in elements: 823 self.assertAllEqual(e, queue.dequeue()) 824 825 def test_multi_device_resource_cpu(self): 826 with ops.device("/device:cpu:1"): 827 v = variables.Variable(initial_value=0, dtype=dtypes.int64) 828 829 @eager_def_function.function( 830 input_signature=[tensor_spec.TensorSpec([], dtypes.int64)]) 831 def assign_add(a): 832 v.assign_add(a) 833 834 with ops.device("/device:CPU:0"): 835 port = portpicker.pick_unused_port() 836 address = "localhost:{}".format(port) 837 server = rpc_ops.GrpcServer(address) 838 server.register("assign_add", assign_add) 839 server.start() 840 841 client = rpc_ops.GrpcClient(address, list_registered_methods=True) 842 result_or = client.assign_add(variables.Variable(2, dtype=dtypes.int64)) 843 self.assertAllEqual(result_or.is_ok(), True) 844 845 self.assertAllEqual(v, 2) 846 847 848if __name__ == "__main__": 849 ops.enable_eager_execution() 850 test.main() 851