1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16import gc 17import re 18 19from tensorflow.compiler.tests import xla_test 20from tensorflow.python.eager import backprop 21from tensorflow.python.eager import context 22from tensorflow.python.eager import def_function 23from tensorflow.python.framework import constant_op 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import errors 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import tensor_spec 28from tensorflow.python.framework import test_util 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import collective_ops 31from tensorflow.python.ops import control_flow_ops 32from tensorflow.python.ops import control_flow_util 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops import random_ops 35from tensorflow.python.ops import resource_variable_ops 36from tensorflow.python.ops import string_ops 37from tensorflow.python.ops import summary_ops_v2 38from tensorflow.python.ops import tensor_array_ops 39from tensorflow.python.ops import variables 40from tensorflow.python.platform import test 41 42 43@test_util.with_eager_op_as_function 44class DefFunctionTest(xla_test.XLATestCase): 45 46 def testAutoclusteringWithTfFunction(self): 47 if 'tpu' in self.device.lower(): 48 self.skipTest('Autoclustering does not run on TPU') 49 50 with ops.device('device:{}:0'.format(self.device)): 51 52 @def_function.function(jit_compile=False) 53 def outer(a, b, c): 54 return a * inner(b, c) + c 55 56 @def_function.function(jit_compile=True) 57 def inner(b, c): 58 return b + c * b 59 60 i1 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0]) 61 i2 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0]) 62 i3 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0]) 63 64 with context.collect_graphs(optimized=True) as graphs: 65 outer(i1, i2, i3) 66 67 if test_util.is_xla_enabled(): 68 self.assertIn('_XlaRun', [n.op for n in graphs[0].node]) 69 else: 70 self.assertNotIn('_XlaRun', [n.op for n in graphs[0].node]) 71 72 def testBasic(self): 73 with ops.device('device:{}:0'.format(self.device)): 74 75 def fn(x, a): 76 return x + a 77 78 func = def_function.function(fn, jit_compile=False) 79 xla_func = def_function.function(fn, jit_compile=True) 80 81 inputs = constant_op.constant([1, 2, 2, 3, 3]) 82 self.assertAllClose([2, 3, 3, 4, 4], func(inputs, 1)) 83 self.assertAllClose([2, 3, 3, 4, 4], xla_func(inputs, 1)) 84 85 def testBasicInt32(self): 86 with ops.device('device:{}:0'.format(self.device)): 87 88 @def_function.function(jit_compile=True) 89 def fn(x, a): 90 return x + a 91 92 inputs = constant_op.constant([1, 2, 2, 3, 3], dtype=dtypes.int32) 93 self.assertAllClose([2, 3, 3, 4, 4], fn(inputs, 1)) 94 95 def testDerivative(self): 96 with ops.device('device:{}:0'.format(self.device)): 97 98 def fn(x, a): 99 return 2 * x + a 100 101 xla_func = def_function.function(fn, jit_compile=True) 102 103 with backprop.GradientTape() as tape: 104 inputs = constant_op.constant([1., 2., 2., 3., 3.]) 105 tape.watch(inputs) 106 outputs = xla_func(inputs, 1) 107 108 self.assertAllClose([2, 2, 2, 2, 2], tape.gradient(outputs, inputs)) 109 110 # pylint: disable=protected-access 111 (forward, backward) = xla_func.get_concrete_function( 112 inputs, 1)._delayed_rewrite_functions.forward_backward() 113 114 # Check that the must-compile attribute gets correctly propagated to the 115 # created derivatives. 116 self.assertTrue(backward.function_def.attr['_XlaMustCompile']) 117 self.assertTrue(forward.definition.attr['_XlaMustCompile']) 118 119 # Calling function with jit_compile=True from 120 # jit_compile=False should compile the inner func. 121 def testNestedCall(self): 122 if 'tpu' in self.device.lower(): 123 self.skipTest('b/162800687: Inner function runs on host') 124 125 with ops.device('device:{}:0'.format(self.device)): 126 127 @def_function.function(jit_compile=True) 128 def fn(x, a): 129 return x + a 130 131 @def_function.function(jit_compile=False) 132 def fn2(x, a): 133 return fn(x, a) 134 135 inputs = constant_op.constant([1, 2, 2, 3, 3]) 136 self.assertAllClose([2, 3, 3, 4, 4], fn2(inputs, 1)) 137 138 def testNestedCallUnsupportedOps(self): 139 if 'tpu' in self.device.lower(): 140 self.skipTest('Outside compilation will extract string_length to CPU') 141 142 with ops.device('device:{}:0'.format(self.device)): 143 144 def fn(x): 145 return string_ops.string_length( 146 string_ops.string_format('{}', x)) 147 148 xla_func = def_function.function(fn, jit_compile=True) 149 150 def fn2(x): 151 return xla_func(x) 152 153 func = def_function.function(fn2, jit_compile=False) 154 inputs = constant_op.constant([1, 2, 2, 3, 3]) 155 with self.assertRaisesRegex( 156 errors.InvalidArgumentError, 'legalization failed' 157 if test_util.is_mlir_bridge_enabled() else 'unsupported operations'): 158 func(inputs) 159 160 def testUnsupportedOps(self): 161 with ops.device('device:{}:0'.format(self.device)): 162 163 def fn(x): 164 return string_ops.string_length( 165 string_ops.string_format('{}', x)) 166 167 xla_func = def_function.function(fn, jit_compile=True) 168 169 with self.assertRaisesRegex( 170 errors.InvalidArgumentError, 'legalization failed' 171 if test_util.is_mlir_bridge_enabled() else 'unsupported operations'): 172 xla_func(constant_op.constant([3.1, 3.2])) 173 174 def testCollectiveReduceChannelId(self): 175 with ops.device('device:{}:0'.format(self.device)): 176 177 @def_function.function(jit_compile=True) 178 def fn(x, y): 179 t0 = collective_ops.all_reduce_v2( 180 t=x, group_size=2, group_key=1, instance_key=1) 181 t1 = collective_ops.all_reduce_v2( 182 t=y, group_size=2, group_key=1, instance_key=1) 183 return t0 + t1 184 185 inputs = constant_op.constant([1.0, 2.0, 3.0]) 186 # Make sure 2 different channel ids are assigned to the 2 all-reduce 187 # instructions generated by XLA. 188 hlo_str = fn.experimental_get_compiler_ir(inputs, inputs)() 189 matches = re.findall('channel_id=([0-9]*),', hlo_str) 190 self.assertLen(matches, 2) 191 self.assertNotEqual(matches[0], matches[1]) 192 193 def testCollectiveReduceGroupAssignment(self): 194 if not test_util.is_mlir_bridge_enabled(): 195 self.skipTest('AssignGroup is only supported in the MLIR bridge.') 196 197 with ops.device('device:{}:0'.format(self.device)): 198 199 @def_function.function(jit_compile=True) 200 def fn(x): 201 group_size, group_key = collective_ops.assign_group_v2( 202 group_assignment=[[0]], device_index=0, base_key=1000) 203 t0 = collective_ops.all_reduce_v2( 204 t=x, group_size=group_size, group_key=group_key, instance_key=1) 205 return t0 206 207 inputs = constant_op.constant([1.0, 2.0, 3.0]) 208 # Make sure 2 different channel ids are assigned to the 2 all-reduce 209 # instructions generated by XLA. 210 hlo_str = fn.experimental_get_compiler_ir(inputs)() 211 self.assertIn('replica_groups={{0}}', hlo_str) 212 213 @test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not' 214 'support stack traces') 215 def testPythonLocationInMetadata(self): 216 with ops.device('device:{}:0'.format(self.device)): 217 218 @def_function.function(jit_compile=True) 219 def fn(x, y): 220 return x + y 221 222 inputs = constant_op.constant([1, 2, 2, 3, 3]) 223 self.assertIn('def_function_xla_jit_test', 224 fn.experimental_get_compiler_ir(inputs, inputs)()) 225 226 @test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not' 227 'support stack traces') 228 def testPythonLocationNestedInMetadata(self): 229 with ops.device('device:{}:0'.format(self.device)): 230 231 @def_function.function(jit_compile=True) 232 def f(x, y): 233 return x + y 234 235 @def_function.function(jit_compile=True) 236 def g(x, y): 237 return f(x, y) 238 239 inputs = constant_op.constant([1, 2, 2, 3, 3]) 240 self.assertIn('def_function_xla_jit_test', 241 g.experimental_get_compiler_ir(inputs, inputs)()) 242 243 def testPythonStackTrace(self): 244 with ops.device('device:{}:0'.format(self.device)): 245 246 @def_function.function(jit_compile=True) 247 def fn(x): 248 return string_ops.string_length( 249 string_ops.string_format('{}', x)) # COMMENT2 250 251 inputs = constant_op.constant([1, 2, 2, 3, 3]) 252 with self.assertRaisesRegex(errors.InvalidArgumentError, 'COMMENT2'): 253 fn(inputs) 254 255 def testPythonStackTraceUncompiledWithinCompiled(self): 256 with ops.device('device:{}:0'.format(self.device)): 257 258 @def_function.function 259 def fn(x): 260 return string_ops.string_length( 261 string_ops.string_format('{}', x)) # COMMENT3 262 263 @def_function.function(jit_compile=True) 264 def outer(x): 265 return fn(x) 266 267 inputs = constant_op.constant([1, 2, 2, 3, 3]) 268 with self.assertRaisesRegex(errors.InvalidArgumentError, 'COMMENT3'): 269 outer(inputs) 270 271 @test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not' 272 'support stack traces') 273 def testPythonStackTraceCompiledWithinUncompiled(self): 274 with ops.device('device:{}:0'.format(self.device)): 275 276 @def_function.function(jit_compile=True) 277 def fn(x): 278 return string_ops.string_length( 279 string_ops.string_format('{}', x)) # COMMENT1 280 281 @def_function.function 282 def outer(x): 283 return fn(x) 284 285 inputs = constant_op.constant([1, 2, 2, 3, 3]) 286 with self.assertRaisesRegex(errors.InvalidArgumentError, 'COMMENT1'): 287 outer(inputs) 288 289 @test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not' 290 'support stack traces') 291 def testPythonStackTraceCompiledWithinCompiled(self): 292 with ops.device('device:{}:0'.format(self.device)): 293 294 @def_function.function(jit_compile=True) 295 def fn(x): 296 return string_ops.string_length( 297 string_ops.string_format('{}', x)) # COMMENT4 298 299 @def_function.function 300 def outer(x): 301 return fn(x) 302 303 inputs = constant_op.constant([1, 2, 2, 3, 3]) 304 with self.assertRaisesRegex(errors.InvalidArgumentError, 'COMMENT4'): 305 outer(inputs) 306 307 def testFunctionGradient(self): 308 with ops.device('device:{}:0'.format(self.device)): 309 v = resource_variable_ops.ResourceVariable(2.0) 310 311 def fn(x): 312 return v * x 313 314 func = def_function.function(fn, jit_compile=False) 315 xla_func = def_function.function(fn, jit_compile=True) 316 317 def run_and_check(test_func): 318 x = constant_op.constant(3.0) 319 with backprop.GradientTape() as tape: 320 y = test_func(x) 321 dy = tape.gradient(y, v) 322 323 self.assertAllClose(6.0, y) 324 self.assertAllClose(3.0, dy) 325 326 run_and_check(func) 327 run_and_check(xla_func) 328 329 @test_util.disable_mlir_bridge('TODO(b/162521846): MLIR bridge fails' 330 ' msan, function library not found') 331 def testControlFlow(self): 332 333 with ops.device('device:{}:0'.format(self.device)): 334 335 @def_function.function(jit_compile=True) 336 def f(x): 337 assert control_flow_util.GraphOrParentsInXlaContext( 338 ops.get_default_graph()) 339 x = ops.convert_to_tensor(x) 340 341 def body(i, a): 342 return i + 1, control_flow_ops.cond(i > 2, lambda: a + (x**2), 343 lambda: a + 3) 344 345 return control_flow_ops.while_loop( 346 lambda i, *_: i < 10, 347 body, (constant_op.constant(0), constant_op.constant(3.)), 348 maximum_iterations=10)[1] 349 350 @def_function.function(jit_compile=True) 351 def g(x): 352 x = ops.convert_to_tensor(x) 353 with backprop.GradientTape() as tape: 354 tape.watch(x) 355 y = f(x) 356 return y, tape.gradient(y, x) 357 358 # Test that XLA context gets correctly propagated. 359 g._get_concrete_function_garbage_collected(2.0)(2.0) 360 361 self.assertAllClose(40.0, f(2.0)) 362 self.assertAllClose([40.0, 28.0], g(2.0)) 363 self.assertAllClose(40.0, f.get_concrete_function(2.0)(2.0)) 364 self.assertAllClose([40.0, 28.0], g.get_concrete_function(2.0)(2.0)) 365 366 def testWhileLoopWithUnmodifiedCarriedShape(self): 367 with ops.device('device:{}:0'.format(self.device)): 368 signature = [tensor_spec.TensorSpec(shape=[None], dtype=dtypes.float32)] 369 370 # We define a signature that specifies unknown vector shape, then test 371 # that tf.shape constness gets properly propagated into the while_loop 372 # even when carried as part of the loop state. 373 @def_function.function(input_signature=signature, jit_compile=True) 374 def g(x): 375 return control_flow_ops.while_loop_v2( 376 lambda *_: True, 377 lambda y, shp: (y + random_ops.random_normal(shp)**2, shp), 378 (x, array_ops.shape(x)), 379 maximum_iterations=3)[0] 380 381 self.assertAllGreater(g(array_ops.zeros([7])), 0.) 382 383 def testNestedWhileLoopWithUnmodifiedCarriedShape(self): 384 with ops.device('device:{}:0'.format(self.device)): 385 signature = [tensor_spec.TensorSpec(shape=[None], dtype=dtypes.float32)] 386 387 @def_function.function(input_signature=signature, jit_compile=True) 388 def g(x): 389 390 def inner(z, shp): 391 return z + random_ops.random_normal(shp)**2, shp 392 393 def outer(y, shp): 394 y, shp = control_flow_ops.while_loop_v2( 395 lambda *_: True, inner, (y, shp), maximum_iterations=3) 396 y, shp = array_ops.identity_n([y, shp]) 397 return control_flow_ops.while_loop_v2( 398 lambda *_: True, inner, (y, shp), maximum_iterations=5) 399 400 shp = array_ops.shape(x, name='x_shp') 401 return control_flow_ops.while_loop_v2( 402 lambda *_: True, outer, (x, shp), maximum_iterations=4)[0] 403 404 self.assertAllGreater(g(array_ops.zeros([7])), 0.) 405 406 def testNestedWhileLoopWithUnmodifiedCarriedShapeSlice(self): 407 with ops.device('device:{}:0'.format(self.device)): 408 signature = [ 409 tensor_spec.TensorSpec(shape=[None, None], dtype=dtypes.float32) 410 ] 411 412 @def_function.function(input_signature=signature, jit_compile=True) 413 def g(x): 414 415 def inner(z, shp): 416 return z + random_ops.random_normal(shp)**2, shp 417 418 def outer(y, shp): 419 y, shp = control_flow_ops.while_loop_v2( 420 lambda *_: True, inner, (y, shp), maximum_iterations=3) 421 return control_flow_ops.while_loop_v2( 422 lambda *_: True, inner, (y, shp), maximum_iterations=4) 423 424 shp = array_ops.shape(x, name='x_shp') 425 x = control_flow_ops.while_loop_v2( 426 lambda *_: True, outer, (x, shp), maximum_iterations=5)[0] 427 428 shp2 = array_ops.shape(x, name='x_shp_after')[1:] 429 w = control_flow_ops.while_loop_v2( 430 lambda *_: True, 431 outer, (array_ops.zeros_like(x[0]), shp2), 432 maximum_iterations=6)[0] 433 return x + w 434 435 self.assertAllGreater(g(array_ops.zeros([7, 13])), 0.) 436 437 def testMethodCompilation(self): 438 439 with ops.device('device:{}:0'.format(self.device)): 440 441 class C(object): 442 443 @def_function.function(jit_compile=True) 444 def f1(self, x, a): 445 return x + a 446 447 inputs = constant_op.constant([1, 2, 2, 3, 3]) 448 c = C() 449 self.assertAllClose([2, 3, 3, 4, 4], c.f1(inputs, 1)) 450 451 def testMethodCompilationUnsupportedFunc(self): 452 with ops.device('device:{}:0'.format(self.device)): 453 454 class C(object): 455 456 @def_function.function(jit_compile=True) 457 def f1(self, x): 458 return string_ops.string_length( 459 string_ops.string_format('{}', x)) 460 461 inputs = constant_op.constant([1, 2, 2, 3, 3]) 462 c = C() 463 with self.assertRaisesRegex( 464 errors.InvalidArgumentError, 'legalization failed' 465 if test_util.is_mlir_bridge_enabled() else 'unsupported operations'): 466 c.f1(inputs) 467 468 def testMustBeConstantPropagation(self): 469 if 'tpu' in self.device.lower(): 470 self.skipTest('b/162799319: Cannot resolve constant on TPU') 471 472 with ops.device('device:{}:0'.format(self.device)): 473 474 @def_function.function(jit_compile=True) 475 def f(): 476 return constant_op.constant([0, 2, 1], dtype=dtypes.int32) 477 478 @def_function.function(jit_compile=True) 479 def g(a, b): 480 return array_ops.transpose(a, b) 481 482 @def_function.function 483 def z(): 484 return g(array_ops.ones([3, 4, 3], dtype=dtypes.float32), f()) 485 486 z() 487 488 def testArgMinMax(self): 489 with ops.device('device:{}:0'.format(self.device)): 490 491 @def_function.function(jit_compile=True) 492 def argmax(x): 493 return math_ops.argmax(x) 494 495 @def_function.function(jit_compile=True) 496 def argmin(x): 497 return math_ops.argmin(x) 498 499 self.assertAllClose(0, argmax(array_ops.ones([10], dtype=dtypes.float32))) 500 self.assertAllClose(0, argmax(array_ops.ones([10]))) 501 self.assertAllClose(0, argmin(array_ops.ones([10], dtype=dtypes.float32))) 502 self.assertAllClose(0, argmin(array_ops.ones([10]))) 503 504 @test_util.disable_mlir_bridge('TensorArray support not implemented') 505 def testErrorMessagePassingTensorArray(self): 506 with ops.device('device:{}:0'.format(self.device)): 507 508 @def_function.function(jit_compile=True) 509 def f(x): 510 ta = tensor_array_ops.TensorArray( 511 dtype=dtypes.float32, size=1, element_shape=[]) 512 ta = ta.write(0, 2 * x) 513 y = ta.read(0) 514 return y 515 516 x = constant_op.constant(3.14) 517 with backprop.GradientTape() as tape: 518 tape.watch(x) 519 with self.assertRaisesRegex(errors.UnimplementedError, 520 'TensorList crossing the XLA/TF boundary'): 521 y = f(x) 522 tape.gradient(y, x) 523 524 @test_util.disable_mlir_bridge('TODO(b/162281863): MLIR bridge errors out' 525 ' lowering TensorListConcatV2') 526 def testTensorListConcatV2(self): 527 with ops.device('device:{}:0'.format(self.device)): 528 529 def f(x): 530 ta = tensor_array_ops.TensorArray( 531 dtype=dtypes.float32, size=2, element_shape=[3]) 532 ta = ta.write(0, 2 * x) 533 ta = ta.write(1, 3 * x) 534 return ta.concat() 535 536 compiled_f = def_function.function(jit_compile=True)(f) 537 538 inputs = constant_op.constant([3.14, 2.68, 7.69]) 539 540 self.assertAllClose([6.28, 5.36, 15.38, 9.42, 8.04, 23.07], f(inputs)) 541 542 self.assertAllClose(compiled_f(inputs), f(inputs)) 543 544 @test_util.disable_mlir_bridge('TODO(b/162281863): MLIR bridge errors out' 545 ' lowering TensorListConcatV2') 546 def testTensorListConcatV2Multidim(self): 547 with ops.device('device:{}:0'.format(self.device)): 548 549 def f(x): 550 ta = tensor_array_ops.TensorArray( 551 dtype=dtypes.float32, size=2, element_shape=[3, 2]) 552 ta = ta.write(0, 2 * x) 553 ta = ta.write(1, 3 * x) 554 return ta.concat() 555 556 compiled_f = def_function.function(jit_compile=True)(f) 557 558 inputs = constant_op.constant([[3.14, 21.1], [2.68, 22.2], [7.69, 23.3]]) 559 self.assertAllClose(f(inputs), compiled_f(inputs)) 560 561 @test_util.disable_mlir_bridge('TODO(b/162281863): MLIR bridge errors out' 562 ' lowering TensorListConcatV2') 563 def testTensorListConcatV2Scalars(self): 564 with ops.device('device:{}:0'.format(self.device)): 565 566 def f(x): 567 ta = tensor_array_ops.TensorArray( 568 dtype=dtypes.float32, size=2, element_shape=[1]) 569 ta = ta.write(0, 2 * x) 570 ta = ta.write(1, 3 * x) 571 return ta.concat() 572 573 compiled_f = def_function.function(jit_compile=True)(f) 574 inputs = constant_op.constant([3.14]) 575 self.assertAllClose(f(inputs), compiled_f(inputs)) 576 577 @test_util.disable_mlir_bridge('TODO(b/162281863): MLIR bridge errors out' 578 ' lowering TensorListConcatV2') 579 def testTensorListConcatGrad(self): 580 with ops.device('device:{}:0'.format(self.device)): 581 582 def f(x): 583 ta = tensor_array_ops.TensorArray( 584 dtype=dtypes.float32, size=2, element_shape=[3]) 585 ta = ta.write(0, 2 * x) 586 ta = ta.write(1, 3 * x) 587 return ta.concat() 588 589 def g(): 590 x = constant_op.constant([3.14, 2.68, 7.69]) 591 with backprop.GradientTape() as tape: 592 tape.watch(x) 593 y = f(x) 594 return tape.gradient(y, x) 595 596 compiled_g = def_function.function(jit_compile=True)(g) 597 598 self.assertAllClose([5.0, 5.0, 5.0], g()) 599 self.assertAllClose(compiled_g(), g()) 600 601 @test_util.disable_mlir_bridge('TODO(b/162281863): MLIR bridge errors out' 602 ' lowering TensorListConcatV2') 603 def testTensorListConcatGradNestedCompile(self): 604 with ops.device('device:{}:0'.format(self.device)): 605 606 @def_function.function(jit_compile=True) 607 def f(x): 608 ta = tensor_array_ops.TensorArray( 609 dtype=dtypes.float32, size=2, element_shape=[3]) 610 ta = ta.write(0, 2 * x) 611 ta = ta.write(1, 3 * x) 612 return ta.concat() 613 614 @def_function.function(jit_compile=True) 615 def g(): 616 x = constant_op.constant([3.14, 2.68, 7.69]) 617 with backprop.GradientTape() as tape: 618 tape.watch(x) 619 y = f(x) 620 out = tape.gradient(y, x) 621 return out 622 623 self.assertAllClose([5.0, 5.0, 5.0], g()) 624 625 def testCumsum(self): 626 if 'tpu' in self.device.lower(): 627 self.skipTest('b/162771302: 64bit rewrite of cumsum not supported') 628 629 with ops.device('device:{}:0'.format(self.device)): 630 631 @def_function.function(jit_compile=True) 632 def f(x): 633 return math_ops.cumsum(x) 634 635 f64_input = constant_op.constant([1.1, 2.2, 3.3], dtype=dtypes.float64) 636 self.assertAllClose([1.1, 3.3, 6.6], f(f64_input)) 637 638 def testNoExcessiveRetracing(self): 639 with ops.device('device:{}:0'.format(self.device)): 640 inner_retracings = 0 641 642 @def_function.function(jit_compile=True) 643 def inner(a, b): 644 nonlocal inner_retracings 645 inner_retracings += 1 646 return a * b + a 647 648 def outer(a, b): 649 return inner(a, b) 650 651 func_input = random_ops.random_normal([10, 10]) 652 for _ in range(2): 653 def_function.function(outer)(func_input, func_input) 654 655 self.assertEqual(inner_retracings, 1) 656 657 def testUpdateVariable(self): 658 with ops.device('device:{}:0'.format(self.device)): 659 v = variables.Variable([0.0, 0.0]) 660 661 @def_function.function(jit_compile=True) 662 def f(): 663 v.assign([3.1, 2.3]) 664 665 f() 666 self.assertAllClose(v, [3.1, 2.3]) 667 668 @test_util.disable_mlir_bridge('MLIR does not support resource update for' 669 ' signature with compile-time constant.') 670 def testUniqueDifferentSizes(self): 671 if not 'gpu' in self.device.lower(): 672 self.skipTest('Currently works only on GPU') 673 674 with ops.device('device:{}:0'.format(self.device)): 675 676 @def_function.function(jit_compile=True) 677 def f(x, y): 678 return array_ops.unique(x).y + array_ops.unique(y).y 679 680 f(constant_op.constant([3.1, 3.2]), constant_op.constant([3.3, 3.2])) 681 682 with self.assertRaisesRegex(errors.InternalError, 'different size'): 683 f( 684 constant_op.constant([3.1, 3.2]), 685 constant_op.constant([3.1, 3.2, 3.3])) 686 687 def testUniqueCompilability(self): 688 with ops.device('device:{}:0'.format(self.device)): 689 690 @def_function.function(jit_compile=True) 691 def f(x): 692 return array_ops.unique(x).y 693 694 self.assertAllClose(f(constant_op.constant([3.1, 3.2, 3.2])), [3.1, 3.2]) 695 696 def testUpdateVariableMemoryUsage(self): 697 with ops.device('device:{}:0'.format(self.device)): 698 699 on_gpu = 'gpu' in self.device.lower() 700 v = variables.Variable([3.1, 3.2]) 701 702 @def_function.function(jit_compile=True) 703 def update_var(a, b): 704 v.assign_add(a * b) 705 706 arg1 = random_ops.random_normal([2]) 707 arg2 = random_ops.random_normal([2]) 708 709 gc.collect() 710 initial_usage = context.context().get_memory_info( 711 v.device)['current'] if on_gpu else 0 712 update_var(arg1, arg2) 713 gc.collect() 714 final_usage = context.context().get_memory_info( 715 v.device)['current'] if on_gpu else 0 716 self.assertEqual(initial_usage, final_usage) 717 718 @test_util.disable_mlir_bridge('MLIR does not support resource update for' 719 ' signature with compile-time constant.') 720 def testUpdateVariableWithCompileTimeConstMemoryUsage(self): 721 with ops.device('device:{}:0'.format(self.device)): 722 723 on_gpu = 'gpu' in self.device.lower() 724 v = variables.Variable(random_ops.random_normal([1024, 1024])) 725 726 # test a signature of (compile-time const, arg, res_var). The compile-time 727 # const will be optimized away so that the kernel signature will become 728 # (arg, res_var). 729 @def_function.function(jit_compile=True) 730 def update_var(shape, arg): 731 v.assign_add(array_ops.broadcast_to(arg, shape)) 732 733 arg = random_ops.random_normal([1]) 734 735 gc.collect() 736 initial_usage = context.context().get_memory_info( 737 v.device)['current'] if on_gpu else 0 738 update_var(constant_op.constant([1024, 1024]), arg) 739 # Need to do update_var for a second time so that BFC Allocator could 740 # defragment the GPU memory. 741 update_var(constant_op.constant([1024, 1024]), arg) 742 gc.collect() 743 final_usage = context.context().get_memory_info( 744 v.device)['current'] if on_gpu else 0 745 self.assertEqual(initial_usage, final_usage) 746 747 @test_util.disable_mlir_bridge('TODO(b/162381930): MLIR bridge renames ' 748 ' functions') 749 def testUpdateVariableInClass(self): 750 with ops.device('device:{}:0'.format(self.device)): 751 752 class C(object): 753 754 @def_function.function(jit_compile=True) 755 def update_var(self, a, b): 756 if not hasattr(self, 'v'): 757 self.v = variables.Variable(3.1) 758 self.v.assign_add(a * b) 759 760 c = C() 761 762 @def_function.function 763 def outer(): 764 c.update_var(constant_op.constant(0.7), constant_op.constant(0.6)) 765 766 outer() 767 self.assertAllClose(c.v, 3.52) 768 769 def testUpdateVariableMultipleOutputs(self): 770 with ops.device('device:{}:0'.format(self.device)): 771 v = variables.Variable(3.1) 772 773 @def_function.function(jit_compile=True) 774 def update_var(a, b): 775 v.assign_add(a * b) 776 return a * b + v 777 778 out = update_var(constant_op.constant(0.7), constant_op.constant(0.6)) 779 self.assertAllClose(v, 3.52) 780 self.assertAllClose(out, 3.94) 781 782 def testReturnIdentity(self): 783 with ops.device('device:{}:0'.format(self.device)): 784 785 @def_function.function(jit_compile=True) 786 def f(a, b): 787 return (a, b) 788 789 a = random_ops.random_normal([10, 10]) 790 b = random_ops.random_normal([10, 10]) 791 792 on_gpu = 'gpu' in self.device.lower() 793 gc.collect() 794 initial_usage = context.context().get_memory_info( 795 b.backing_device)['current'] if on_gpu else 0 796 797 f(a, b) 798 799 gc.collect() 800 final_usage = context.context().get_memory_info( 801 b.backing_device)['current'] if on_gpu else 0 802 self.assertEqual(initial_usage, final_usage) 803 804 def testGetCompilerIrConstants(self): 805 if 'tpu' in self.device.lower(): 806 self.skipTest('TPU generates different HLO') 807 808 with ops.device('device:{}:0'.format(self.device)): 809 810 @def_function.function(jit_compile=True) 811 def f(a, b): 812 return array_ops.transpose(a, b) 813 814 a = array_ops.ones([3, 4, 3], dtype=dtypes.float32) 815 b = constant_op.constant([0, 2, 1], dtype=dtypes.int32) 816 817 self.assertIn('{1,2,0}', 818 f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo')) 819 820 @test_util.disable_mlir_bridge('TODO(b/168732524): MLIR bridge does not ' 821 ' optimize single-element tuples to scalars') 822 def testGetCompilerIrResourceVars(self): 823 with ops.device('device:{}:0'.format(self.device)): 824 825 v = variables.Variable([3.1, 3.2]) 826 827 @def_function.function(jit_compile=True) 828 def f(a, b): 829 v.assign_add(a * b) 830 831 a = random_ops.random_normal([2]) 832 b = random_ops.random_normal([2]) 833 834 self.assertIn('input_output_alias={ {}: (2, {}, may-alias) }', 835 f.experimental_get_compiler_ir(a, b)('optimized_hlo')) 836 837 def testGetCompilerIrNotCompiled(self): 838 with ops.device('device:{}:0'.format(self.device)): 839 840 @def_function.function 841 def f(x): 842 return x + 1 843 844 a = random_ops.random_normal([10, 10]) 845 with self.assertRaisesRegex(ValueError, 846 'marked with \'jit_compile'): 847 f.experimental_get_compiler_ir(a)() 848 849 def testGetCompilerIrNested(self): 850 with ops.device('device:{}:0'.format(self.device)): 851 852 @def_function.function(jit_compile=True) 853 def fn(x, a): 854 return x + a 855 856 @def_function.function(jit_compile=False) 857 def fn2(x, a): 858 fn.experimental_get_compiler_ir(x, a)() 859 return fn(x, a) 860 861 inputs = constant_op.constant([1, 2, 2, 3, 3]) 862 with self.assertRaises(TypeError): 863 fn2(inputs, 1) 864 865 def testGetCompilerIrKwargs(self): 866 with ops.device('device:{}:0'.format(self.device)): 867 868 v = variables.Variable([0.1, 0.1]) 869 870 @def_function.function(jit_compile=True) 871 def f(a, b): 872 return (a + b) * v 873 874 a = constant_op.constant([1.1, 1.1]) 875 b = constant_op.constant([2.2, 2.2]) 876 877 self.assertIn('multiply', 878 f.experimental_get_compiler_ir(b=a, a=b)(stage='hlo')) 879 880 def testGetCompilerIrDot(self): 881 with ops.device('device:{}:0'.format(self.device)): 882 883 @def_function.function(jit_compile=True) 884 def f(a, b): 885 return a + b 886 887 a = constant_op.constant([1.1, 1.1]) 888 b = constant_op.constant([2.2, 2.2]) 889 890 self.assertIn( 891 'label', 892 f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo_dot')) 893 894 def testGetCompilerIrNoDevicePlacement(self): 895 if 'gpu' not in self.device.lower(): 896 self.skipTest('Testing get_compiler_ir on GPUs without placement') 897 898 @def_function.function(jit_compile=True) 899 def f(a, b): 900 return a + b 901 902 a = constant_op.constant([1.1, 1.1]) 903 b = constant_op.constant([2.2, 2.2]) 904 905 self.assertIn( 906 'label', 907 f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo_dot')) 908 909 def testGetCompilerIrNonTensors(self): 910 with ops.device('device:{}:0'.format(self.device)): 911 912 @def_function.function(jit_compile=True) 913 def f(l): 914 return l[0] + l[1] 915 916 l = [constant_op.constant(1.1), constant_op.constant(2.2)] 917 918 self.assertIn('tuple', 919 f.experimental_get_compiler_ir(l)()) 920 921 def testGetCompilerIrSerialized(self): 922 with ops.device('device:{}:0'.format(self.device)): 923 924 @def_function.function(jit_compile=True) 925 def fn(x): 926 return x - x 927 928 inputs = constant_op.constant([1, 2, 2, 3, 3]) 929 for stage in ('hlo_serialized', 'optimized_hlo_serialized'): 930 hlo = fn.experimental_get_compiler_ir(inputs)( 931 stage=stage, device_name=f'/device:{self.device}:0') 932 self.assertIsInstance(hlo, bytes) 933 934 def testDotOptimizedHlo(self): 935 with ops.device('device:{}:0'.format(self.device)): 936 937 a = random_ops.random_normal([100, 100]) 938 b = random_ops.random_normal([100, 100]) 939 940 @def_function.function(jit_compile=True) 941 def f(a, b): 942 return math_ops.matmul(a, b) 943 944 self.assertRegex(f.experimental_get_compiler_ir(a, b)('optimized_hlo'), 945 '(dot)|(convolution)') 946 947 def testConstantOnWrongDevice(self): 948 with ops.device('device:{}:0'.format(self.device)): 949 950 s = random_ops.random_uniform([2], 1, 10, dtypes.int32) 951 l = random_ops.random_normal([s[0] * s[1]]) 952 953 @def_function.function(jit_compile=True) 954 def f(l): 955 return array_ops.reshape(l, s) 956 957 self.assertIn('tuple', 958 f.experimental_get_compiler_ir(l)()) 959 960 @test_util.disable_mlir_bridge('TODO(b/172845417): MLIR bridge does not ' 961 'support getting constants out of resources') 962 def testGetConstantOutOfResourceVariable(self): 963 with ops.device('device:{}:0'.format(self.device)): 964 965 # Use floats to force device placement. 966 a = variables.Variable(50.0) 967 b = variables.Variable(2.0) 968 969 @def_function.function(jit_compile=True) 970 def f(x): 971 return array_ops.reshape( 972 x, [math_ops.cast(a, dtypes.int32), 973 math_ops.cast(b, dtypes.int32)]) 974 975 # OK since the value is known at compile time. 976 out = f(random_ops.random_normal([10, 10])) 977 self.assertEqual(out.shape[0], 50) 978 self.assertEqual(out.shape[1], 2) 979 980 @test_util.disable_mlir_bridge('TODO(b/172845417): MLIR bridge does not ' 981 'support getting constants out of resources') 982 def testGetConstantOutOfResourceVariableAfterWrite(self): 983 with ops.device('device:{}:0'.format(self.device)): 984 985 # Use floats to force device placement. 986 a = variables.Variable(50.0) 987 b = variables.Variable(2.0) 988 989 @def_function.function(jit_compile=True) 990 def f(x, val1, val2): 991 a.assign(math_ops.cast(val1, dtypes.float32)) 992 b.assign(math_ops.cast(val2, dtypes.float32)) 993 return array_ops.reshape( 994 x, [math_ops.cast(a, dtypes.int32), 995 math_ops.cast(b, dtypes.int32)]) 996 997 val1 = constant_op.constant(2) 998 val2 = constant_op.constant(50) 999 1000 # Returns an error, since the value known at compile time was overriden. 1001 with self.assertRaisesRegex(errors.InvalidArgumentError, 1002 'concrete values at compile time'): 1003 f(random_ops.random_normal([10, 10]), val1, val2) 1004 1005 @test_util.disable_mlir_bridge('TODO(b/172845417): MLIR bridge does not ' 1006 'support getting constants out of resources') 1007 def testGetConstantOutOfResourceVariableBeforeWrite(self): 1008 with ops.device('device:{}:0'.format(self.device)): 1009 1010 # Use floats to force device placement. 1011 a = variables.Variable(50.0) 1012 b = variables.Variable(2.0) 1013 1014 @def_function.function(jit_compile=True) 1015 def f(x, val1, val2): 1016 out = array_ops.reshape( 1017 x, [math_ops.cast(a, dtypes.int32), 1018 math_ops.cast(b, dtypes.int32)]) 1019 a.assign(math_ops.cast(val1, dtypes.float32)) 1020 b.assign(math_ops.cast(val2, dtypes.float32)) 1021 return out 1022 1023 val1 = constant_op.constant(2) 1024 val2 = constant_op.constant(50) 1025 1026 # OK since the write happens after the reshape. 1027 out = f(random_ops.random_normal([10, 10]), val1, val2) 1028 self.assertEqual(out.shape[0], 50) 1029 self.assertEqual(out.shape[1], 2) 1030 1031 def testTfAssert(self): 1032 with ops.device('device:{}:0'.format(self.device)): 1033 1034 @def_function.function(jit_compile=True) 1035 def f(x): 1036 control_flow_ops.Assert(x == 1, ['Wrong value']) 1037 1038 f(constant_op.constant(1)) 1039 1040 def testTensorArrayErrorMessage(self): 1041 with ops.device('device:{}:0'.format(self.device)): 1042 1043 @def_function.function(jit_compile=True) 1044 def f(): 1045 # The error message as old and new bridge differ in which op they flag. 1046 # The one points to the creation of the unitialized tensor array, the 1047 # other is the use of the unitialized tensor array. 1048 ta = tensor_array_ops.TensorArray( # EXPECTED_MESSAGE_NEW 1049 dtype=dtypes.float32, 1050 size=2, 1051 dynamic_size=True, 1052 element_shape=(None,)) 1053 return ta.concat() # EXPECTED_MESSAGE_OLD 1054 1055 if test_util.is_mlir_bridge_enabled(): 1056 with self.assertRaisesRegex(errors.InvalidArgumentError, 1057 'EXPECTED_MESSAGE_NEW'): 1058 f() 1059 else: 1060 with self.assertRaisesRegex(errors.InvalidArgumentError, 1061 'EXPECTED_MESSAGE_OLD'): 1062 f() 1063 1064 def testCounter(self): 1065 cell_nojit = def_function._tf_function_counter.get_cell('0') 1066 cell_jit = def_function._tf_function_counter.get_cell('1') 1067 orig_nojit = cell_nojit.value() 1068 orig_jit = cell_jit.value() 1069 1070 with ops.device('device:{}:0'.format(self.device)): 1071 @def_function.function 1072 def f(a): 1073 return a + a 1074 f(constant_op.constant(1)) 1075 self.assertEqual(cell_nojit.value(), orig_nojit + 1) 1076 self.assertEqual(cell_jit.value(), orig_jit) 1077 f(constant_op.constant(1.)) # Calling again does not increment 1078 self.assertEqual(cell_nojit.value(), orig_nojit + 1) 1079 1080 @def_function.function(jit_compile=True) 1081 def f1(a): 1082 return a + a 1083 f1(constant_op.constant(1)) 1084 self.assertEqual(cell_nojit.value(), orig_nojit + 1) 1085 self.assertEqual(cell_jit.value(), orig_jit + 1) 1086 1087 @def_function.function 1088 def f2(a): 1089 @def_function.function 1090 def g(a): 1091 return a + a 1092 @def_function.function(jit_compile=True) 1093 def h(a): 1094 return a + a 1095 return g(a) + h(a) 1096 f2(constant_op.constant(1)) 1097 self.assertEqual(cell_nojit.value(), orig_nojit + 2) 1098 self.assertEqual(cell_jit.value(), orig_jit + 2) 1099 1100 @def_function.function(jit_compile=True) 1101 def f3(a): 1102 @def_function.function 1103 def g(a): 1104 return a + a 1105 @def_function.function(jit_compile=True) 1106 def h(a): 1107 return a + a 1108 return g(a) + h(a) 1109 f3(constant_op.constant(1)) 1110 self.assertEqual(cell_nojit.value(), orig_nojit + 2) 1111 self.assertEqual(cell_jit.value(), orig_jit + 3) 1112 1113 @test_util.disable_mlir_bridge('TODO(b/162272821): MLIR bridge returns ' 1114 ' wrong status type') 1115 def testResourceWrongDevice(self): 1116 if 'gpu' not in self.device.lower(): 1117 self.skipTest('Need a GPU to have non-trivial device placement') 1118 1119 with ops.device('device:CPU:0'): 1120 v = variables.Variable([3.1, 3.2]) 1121 1122 with ops.device('device:{}:0'.format(self.device)): 1123 1124 @def_function.function(experimental_compile=True) 1125 def update_var(a): 1126 v.assign_add(a) 1127 1128 arg = random_ops.random_normal([2]) 1129 with self.assertRaisesRegex(errors.InvalidArgumentError, 1130 'Trying to access resource .*'): 1131 update_var(arg) 1132 1133 def testMustBeConstantInsideCondition(self): 1134 with ops.device('device:{}:0'.format(self.device)): 1135 1136 @def_function.function(jit_compile=True) 1137 def f(x, d): 1138 if math_ops.reduce_all( 1139 math_ops.greater(x, random_ops.random_normal([10, 10]))): 1140 return array_ops.reshape(x * 2, constant_op.constant([100])) 1141 else: 1142 return array_ops.reshape(x * 3, d) 1143 1144 f(random_ops.random_normal([10, 10]), constant_op.constant([100])) 1145 1146 def testConditionalGradientTapeMathRegression(self): 1147 with ops.device('device:{}:0'.format(self.device)): 1148 with backprop.GradientTape(): 1149 1150 @def_function.function(jit_compile=True, autograph=False) 1151 def f(x): 1152 return control_flow_ops.cond( 1153 math_ops.reduce_all(x > 1), lambda: 1. / x, lambda: x) 1154 1155 v = variables.Variable([[2.]]) 1156 self.assertAllClose(f(v), constant_op.constant([[0.5]])) 1157 1158 @test_util.disable_mlir_bridge('TODO(b/190444466): MLIR bridge seems to ' 1159 'ignore resource assignments') 1160 def testErrMsgAssignWrongShape(self): 1161 with ops.device('device:{}:0'.format(self.device)): 1162 1163 v = variables.Variable([3.1, 3.2]) 1164 1165 @def_function.function(jit_compile=True) 1166 def f(samples): 1167 v.assign(array_ops.zeros(samples)) # assignment 1168 1169 with self.assertRaisesRegex( 1170 errors.InvalidArgumentError, 1171 'Shape .* cannot be changed after initialization'): 1172 f(constant_op.constant(6)) 1173 1174 with self.assertRaisesRegex(errors.InvalidArgumentError, 'assignment'): 1175 f(constant_op.constant(6)) 1176 1177 def testTfSummaryErrMsg(self): 1178 if 'gpu' not in self.device.lower(): 1179 self.skipTest('Only runs on GPU') 1180 1181 with ops.device('device:{}:0'.format(self.device)): 1182 writer = summary_ops_v2.create_file_writer(self.get_temp_dir()) 1183 1184 @def_function.function(jit_compile=True) 1185 def my_func_temp(): 1186 with writer.as_default(): 1187 summary_ops_v2.scalar('my_metric', 0.5, step=10) 1188 1189 with self.assertRaisesRegex(errors.InvalidArgumentError, 1190 'Trying to access resource .*'): 1191 my_func_temp() 1192 1193 def testSinglePassArgmax(self): 1194 with ops.device('device:{}:0'.format(self.device)): 1195 1196 @def_function.function(jit_compile=True) 1197 def f(x): 1198 return math_ops.argmax(x) 1199 1200 hlo = f.experimental_get_compiler_ir( 1201 array_ops.ones([10], dtype=dtypes.float32))( 1202 stage='hlo') 1203 1204 # Test that reduction occurs only once. 1205 self.assertGreater(hlo.count('reduce'), 1) 1206 1207 1208if __name__ == '__main__': 1209 ops.enable_eager_execution() 1210 test.main() 1211