xref: /aosp_15_r20/external/tensorflow/tensorflow/python/eager/def_function_xla_jit_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16import gc
17import re
18
19from tensorflow.compiler.tests import xla_test
20from tensorflow.python.eager import backprop
21from tensorflow.python.eager import context
22from tensorflow.python.eager import def_function
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import errors
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import tensor_spec
28from tensorflow.python.framework import test_util
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import collective_ops
31from tensorflow.python.ops import control_flow_ops
32from tensorflow.python.ops import control_flow_util
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import random_ops
35from tensorflow.python.ops import resource_variable_ops
36from tensorflow.python.ops import string_ops
37from tensorflow.python.ops import summary_ops_v2
38from tensorflow.python.ops import tensor_array_ops
39from tensorflow.python.ops import variables
40from tensorflow.python.platform import test
41
42
43@test_util.with_eager_op_as_function
44class DefFunctionTest(xla_test.XLATestCase):
45
46  def testAutoclusteringWithTfFunction(self):
47    if 'tpu' in self.device.lower():
48      self.skipTest('Autoclustering does not run on TPU')
49
50    with ops.device('device:{}:0'.format(self.device)):
51
52      @def_function.function(jit_compile=False)
53      def outer(a, b, c):
54        return a * inner(b, c) + c
55
56      @def_function.function(jit_compile=True)
57      def inner(b, c):
58        return b + c * b
59
60      i1 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0])
61      i2 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0])
62      i3 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0])
63
64      with context.collect_graphs(optimized=True) as graphs:
65        outer(i1, i2, i3)
66
67      if test_util.is_xla_enabled():
68        self.assertIn('_XlaRun', [n.op for n in graphs[0].node])
69      else:
70        self.assertNotIn('_XlaRun', [n.op for n in graphs[0].node])
71
72  def testBasic(self):
73    with ops.device('device:{}:0'.format(self.device)):
74
75      def fn(x, a):
76        return x + a
77
78      func = def_function.function(fn, jit_compile=False)
79      xla_func = def_function.function(fn, jit_compile=True)
80
81      inputs = constant_op.constant([1, 2, 2, 3, 3])
82      self.assertAllClose([2, 3, 3, 4, 4], func(inputs, 1))
83      self.assertAllClose([2, 3, 3, 4, 4], xla_func(inputs, 1))
84
85  def testBasicInt32(self):
86    with ops.device('device:{}:0'.format(self.device)):
87
88      @def_function.function(jit_compile=True)
89      def fn(x, a):
90        return x + a
91
92      inputs = constant_op.constant([1, 2, 2, 3, 3], dtype=dtypes.int32)
93      self.assertAllClose([2, 3, 3, 4, 4], fn(inputs, 1))
94
95  def testDerivative(self):
96    with ops.device('device:{}:0'.format(self.device)):
97
98      def fn(x, a):
99        return 2 * x + a
100
101      xla_func = def_function.function(fn, jit_compile=True)
102
103      with backprop.GradientTape() as tape:
104        inputs = constant_op.constant([1., 2., 2., 3., 3.])
105        tape.watch(inputs)
106        outputs = xla_func(inputs, 1)
107
108      self.assertAllClose([2, 2, 2, 2, 2], tape.gradient(outputs, inputs))
109
110      # pylint: disable=protected-access
111      (forward, backward) = xla_func.get_concrete_function(
112          inputs, 1)._delayed_rewrite_functions.forward_backward()
113
114      # Check that the must-compile attribute gets correctly propagated to the
115      # created derivatives.
116      self.assertTrue(backward.function_def.attr['_XlaMustCompile'])
117      self.assertTrue(forward.definition.attr['_XlaMustCompile'])
118
119  # Calling function with jit_compile=True from
120  # jit_compile=False should compile the inner func.
121  def testNestedCall(self):
122    if 'tpu' in self.device.lower():
123      self.skipTest('b/162800687: Inner function runs on host')
124
125    with ops.device('device:{}:0'.format(self.device)):
126
127      @def_function.function(jit_compile=True)
128      def fn(x, a):
129        return x + a
130
131      @def_function.function(jit_compile=False)
132      def fn2(x, a):
133        return fn(x, a)
134
135      inputs = constant_op.constant([1, 2, 2, 3, 3])
136      self.assertAllClose([2, 3, 3, 4, 4], fn2(inputs, 1))
137
138  def testNestedCallUnsupportedOps(self):
139    if 'tpu' in self.device.lower():
140      self.skipTest('Outside compilation will extract string_length to CPU')
141
142    with ops.device('device:{}:0'.format(self.device)):
143
144      def fn(x):
145        return string_ops.string_length(
146            string_ops.string_format('{}', x))
147
148      xla_func = def_function.function(fn, jit_compile=True)
149
150      def fn2(x):
151        return xla_func(x)
152
153      func = def_function.function(fn2, jit_compile=False)
154      inputs = constant_op.constant([1, 2, 2, 3, 3])
155      with self.assertRaisesRegex(
156          errors.InvalidArgumentError, 'legalization failed'
157          if test_util.is_mlir_bridge_enabled() else 'unsupported operations'):
158        func(inputs)
159
160  def testUnsupportedOps(self):
161    with ops.device('device:{}:0'.format(self.device)):
162
163      def fn(x):
164        return string_ops.string_length(
165            string_ops.string_format('{}', x))
166
167      xla_func = def_function.function(fn, jit_compile=True)
168
169      with self.assertRaisesRegex(
170          errors.InvalidArgumentError, 'legalization failed'
171          if test_util.is_mlir_bridge_enabled() else 'unsupported operations'):
172        xla_func(constant_op.constant([3.1, 3.2]))
173
174  def testCollectiveReduceChannelId(self):
175    with ops.device('device:{}:0'.format(self.device)):
176
177      @def_function.function(jit_compile=True)
178      def fn(x, y):
179        t0 = collective_ops.all_reduce_v2(
180            t=x, group_size=2, group_key=1, instance_key=1)
181        t1 = collective_ops.all_reduce_v2(
182            t=y, group_size=2, group_key=1, instance_key=1)
183        return t0 + t1
184
185      inputs = constant_op.constant([1.0, 2.0, 3.0])
186      # Make sure 2 different channel ids are assigned to the 2 all-reduce
187      # instructions generated by XLA.
188      hlo_str = fn.experimental_get_compiler_ir(inputs, inputs)()
189      matches = re.findall('channel_id=([0-9]*),', hlo_str)
190      self.assertLen(matches, 2)
191      self.assertNotEqual(matches[0], matches[1])
192
193  def testCollectiveReduceGroupAssignment(self):
194    if not test_util.is_mlir_bridge_enabled():
195      self.skipTest('AssignGroup is only supported in the MLIR bridge.')
196
197    with ops.device('device:{}:0'.format(self.device)):
198
199      @def_function.function(jit_compile=True)
200      def fn(x):
201        group_size, group_key = collective_ops.assign_group_v2(
202            group_assignment=[[0]], device_index=0, base_key=1000)
203        t0 = collective_ops.all_reduce_v2(
204            t=x, group_size=group_size, group_key=group_key, instance_key=1)
205        return t0
206
207      inputs = constant_op.constant([1.0, 2.0, 3.0])
208      # Make sure 2 different channel ids are assigned to the 2 all-reduce
209      # instructions generated by XLA.
210      hlo_str = fn.experimental_get_compiler_ir(inputs)()
211      self.assertIn('replica_groups={{0}}', hlo_str)
212
213  @test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
214                                 'support stack traces')
215  def testPythonLocationInMetadata(self):
216    with ops.device('device:{}:0'.format(self.device)):
217
218      @def_function.function(jit_compile=True)
219      def fn(x, y):
220        return x + y
221
222      inputs = constant_op.constant([1, 2, 2, 3, 3])
223      self.assertIn('def_function_xla_jit_test',
224                    fn.experimental_get_compiler_ir(inputs, inputs)())
225
226  @test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
227                                 'support stack traces')
228  def testPythonLocationNestedInMetadata(self):
229    with ops.device('device:{}:0'.format(self.device)):
230
231      @def_function.function(jit_compile=True)
232      def f(x, y):
233        return x + y
234
235      @def_function.function(jit_compile=True)
236      def g(x, y):
237        return f(x, y)
238
239      inputs = constant_op.constant([1, 2, 2, 3, 3])
240      self.assertIn('def_function_xla_jit_test',
241                    g.experimental_get_compiler_ir(inputs, inputs)())
242
243  def testPythonStackTrace(self):
244    with ops.device('device:{}:0'.format(self.device)):
245
246      @def_function.function(jit_compile=True)
247      def fn(x):
248        return string_ops.string_length(
249            string_ops.string_format('{}', x))  # COMMENT2
250
251      inputs = constant_op.constant([1, 2, 2, 3, 3])
252      with self.assertRaisesRegex(errors.InvalidArgumentError, 'COMMENT2'):
253        fn(inputs)
254
255  def testPythonStackTraceUncompiledWithinCompiled(self):
256    with ops.device('device:{}:0'.format(self.device)):
257
258      @def_function.function
259      def fn(x):
260        return string_ops.string_length(
261            string_ops.string_format('{}', x))  # COMMENT3
262
263      @def_function.function(jit_compile=True)
264      def outer(x):
265        return fn(x)
266
267      inputs = constant_op.constant([1, 2, 2, 3, 3])
268      with self.assertRaisesRegex(errors.InvalidArgumentError, 'COMMENT3'):
269        outer(inputs)
270
271  @test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
272                                 'support stack traces')
273  def testPythonStackTraceCompiledWithinUncompiled(self):
274    with ops.device('device:{}:0'.format(self.device)):
275
276      @def_function.function(jit_compile=True)
277      def fn(x):
278        return string_ops.string_length(
279            string_ops.string_format('{}', x))  # COMMENT1
280
281      @def_function.function
282      def outer(x):
283        return fn(x)
284
285      inputs = constant_op.constant([1, 2, 2, 3, 3])
286      with self.assertRaisesRegex(errors.InvalidArgumentError, 'COMMENT1'):
287        outer(inputs)
288
289  @test_util.disable_mlir_bridge('TODO(b/155782411): MLIR bridge does not'
290                                 'support stack traces')
291  def testPythonStackTraceCompiledWithinCompiled(self):
292    with ops.device('device:{}:0'.format(self.device)):
293
294      @def_function.function(jit_compile=True)
295      def fn(x):
296        return string_ops.string_length(
297            string_ops.string_format('{}', x))  # COMMENT4
298
299      @def_function.function
300      def outer(x):
301        return fn(x)
302
303      inputs = constant_op.constant([1, 2, 2, 3, 3])
304      with self.assertRaisesRegex(errors.InvalidArgumentError, 'COMMENT4'):
305        outer(inputs)
306
307  def testFunctionGradient(self):
308    with ops.device('device:{}:0'.format(self.device)):
309      v = resource_variable_ops.ResourceVariable(2.0)
310
311      def fn(x):
312        return v * x
313
314      func = def_function.function(fn, jit_compile=False)
315      xla_func = def_function.function(fn, jit_compile=True)
316
317      def run_and_check(test_func):
318        x = constant_op.constant(3.0)
319        with backprop.GradientTape() as tape:
320          y = test_func(x)
321        dy = tape.gradient(y, v)
322
323        self.assertAllClose(6.0, y)
324        self.assertAllClose(3.0, dy)
325
326      run_and_check(func)
327      run_and_check(xla_func)
328
329  @test_util.disable_mlir_bridge('TODO(b/162521846): MLIR bridge fails'
330                                 ' msan, function library not found')
331  def testControlFlow(self):
332
333    with ops.device('device:{}:0'.format(self.device)):
334
335      @def_function.function(jit_compile=True)
336      def f(x):
337        assert control_flow_util.GraphOrParentsInXlaContext(
338            ops.get_default_graph())
339        x = ops.convert_to_tensor(x)
340
341        def body(i, a):
342          return i + 1, control_flow_ops.cond(i > 2, lambda: a + (x**2),
343                                              lambda: a + 3)
344
345        return control_flow_ops.while_loop(
346            lambda i, *_: i < 10,
347            body, (constant_op.constant(0), constant_op.constant(3.)),
348            maximum_iterations=10)[1]
349
350      @def_function.function(jit_compile=True)
351      def g(x):
352        x = ops.convert_to_tensor(x)
353        with backprop.GradientTape() as tape:
354          tape.watch(x)
355          y = f(x)
356        return y, tape.gradient(y, x)
357
358      # Test that XLA context gets correctly propagated.
359      g._get_concrete_function_garbage_collected(2.0)(2.0)
360
361      self.assertAllClose(40.0, f(2.0))
362      self.assertAllClose([40.0, 28.0], g(2.0))
363      self.assertAllClose(40.0, f.get_concrete_function(2.0)(2.0))
364      self.assertAllClose([40.0, 28.0], g.get_concrete_function(2.0)(2.0))
365
366  def testWhileLoopWithUnmodifiedCarriedShape(self):
367    with ops.device('device:{}:0'.format(self.device)):
368      signature = [tensor_spec.TensorSpec(shape=[None], dtype=dtypes.float32)]
369
370      # We define a signature that specifies unknown vector shape, then test
371      # that tf.shape constness gets properly propagated into the while_loop
372      # even when carried as part of the loop state.
373      @def_function.function(input_signature=signature, jit_compile=True)
374      def g(x):
375        return control_flow_ops.while_loop_v2(
376            lambda *_: True,
377            lambda y, shp: (y + random_ops.random_normal(shp)**2, shp),
378            (x, array_ops.shape(x)),
379            maximum_iterations=3)[0]
380
381      self.assertAllGreater(g(array_ops.zeros([7])), 0.)
382
383  def testNestedWhileLoopWithUnmodifiedCarriedShape(self):
384    with ops.device('device:{}:0'.format(self.device)):
385      signature = [tensor_spec.TensorSpec(shape=[None], dtype=dtypes.float32)]
386
387      @def_function.function(input_signature=signature, jit_compile=True)
388      def g(x):
389
390        def inner(z, shp):
391          return z + random_ops.random_normal(shp)**2, shp
392
393        def outer(y, shp):
394          y, shp = control_flow_ops.while_loop_v2(
395              lambda *_: True, inner, (y, shp), maximum_iterations=3)
396          y, shp = array_ops.identity_n([y, shp])
397          return control_flow_ops.while_loop_v2(
398              lambda *_: True, inner, (y, shp), maximum_iterations=5)
399
400        shp = array_ops.shape(x, name='x_shp')
401        return control_flow_ops.while_loop_v2(
402            lambda *_: True, outer, (x, shp), maximum_iterations=4)[0]
403
404      self.assertAllGreater(g(array_ops.zeros([7])), 0.)
405
406  def testNestedWhileLoopWithUnmodifiedCarriedShapeSlice(self):
407    with ops.device('device:{}:0'.format(self.device)):
408      signature = [
409          tensor_spec.TensorSpec(shape=[None, None], dtype=dtypes.float32)
410      ]
411
412      @def_function.function(input_signature=signature, jit_compile=True)
413      def g(x):
414
415        def inner(z, shp):
416          return z + random_ops.random_normal(shp)**2, shp
417
418        def outer(y, shp):
419          y, shp = control_flow_ops.while_loop_v2(
420              lambda *_: True, inner, (y, shp), maximum_iterations=3)
421          return control_flow_ops.while_loop_v2(
422              lambda *_: True, inner, (y, shp), maximum_iterations=4)
423
424        shp = array_ops.shape(x, name='x_shp')
425        x = control_flow_ops.while_loop_v2(
426            lambda *_: True, outer, (x, shp), maximum_iterations=5)[0]
427
428        shp2 = array_ops.shape(x, name='x_shp_after')[1:]
429        w = control_flow_ops.while_loop_v2(
430            lambda *_: True,
431            outer, (array_ops.zeros_like(x[0]), shp2),
432            maximum_iterations=6)[0]
433        return x + w
434
435      self.assertAllGreater(g(array_ops.zeros([7, 13])), 0.)
436
437  def testMethodCompilation(self):
438
439    with ops.device('device:{}:0'.format(self.device)):
440
441      class C(object):
442
443        @def_function.function(jit_compile=True)
444        def f1(self, x, a):
445          return x + a
446
447      inputs = constant_op.constant([1, 2, 2, 3, 3])
448      c = C()
449      self.assertAllClose([2, 3, 3, 4, 4], c.f1(inputs, 1))
450
451  def testMethodCompilationUnsupportedFunc(self):
452    with ops.device('device:{}:0'.format(self.device)):
453
454      class C(object):
455
456        @def_function.function(jit_compile=True)
457        def f1(self, x):
458          return string_ops.string_length(
459              string_ops.string_format('{}', x))
460
461      inputs = constant_op.constant([1, 2, 2, 3, 3])
462      c = C()
463      with self.assertRaisesRegex(
464          errors.InvalidArgumentError, 'legalization failed'
465          if test_util.is_mlir_bridge_enabled() else 'unsupported operations'):
466        c.f1(inputs)
467
468  def testMustBeConstantPropagation(self):
469    if 'tpu' in self.device.lower():
470      self.skipTest('b/162799319: Cannot resolve constant on TPU')
471
472    with ops.device('device:{}:0'.format(self.device)):
473
474      @def_function.function(jit_compile=True)
475      def f():
476        return constant_op.constant([0, 2, 1], dtype=dtypes.int32)
477
478      @def_function.function(jit_compile=True)
479      def g(a, b):
480        return array_ops.transpose(a, b)
481
482      @def_function.function
483      def z():
484        return g(array_ops.ones([3, 4, 3], dtype=dtypes.float32), f())
485
486      z()
487
488  def testArgMinMax(self):
489    with ops.device('device:{}:0'.format(self.device)):
490
491      @def_function.function(jit_compile=True)
492      def argmax(x):
493        return math_ops.argmax(x)
494
495      @def_function.function(jit_compile=True)
496      def argmin(x):
497        return math_ops.argmin(x)
498
499      self.assertAllClose(0, argmax(array_ops.ones([10], dtype=dtypes.float32)))
500      self.assertAllClose(0, argmax(array_ops.ones([10])))
501      self.assertAllClose(0, argmin(array_ops.ones([10], dtype=dtypes.float32)))
502      self.assertAllClose(0, argmin(array_ops.ones([10])))
503
504  @test_util.disable_mlir_bridge('TensorArray support not implemented')
505  def testErrorMessagePassingTensorArray(self):
506    with ops.device('device:{}:0'.format(self.device)):
507
508      @def_function.function(jit_compile=True)
509      def f(x):
510        ta = tensor_array_ops.TensorArray(
511            dtype=dtypes.float32, size=1, element_shape=[])
512        ta = ta.write(0, 2 * x)
513        y = ta.read(0)
514        return y
515
516      x = constant_op.constant(3.14)
517      with backprop.GradientTape() as tape:
518        tape.watch(x)
519        with self.assertRaisesRegex(errors.UnimplementedError,
520                                    'TensorList crossing the XLA/TF boundary'):
521          y = f(x)
522          tape.gradient(y, x)
523
524  @test_util.disable_mlir_bridge('TODO(b/162281863): MLIR bridge errors out'
525                                 ' lowering TensorListConcatV2')
526  def testTensorListConcatV2(self):
527    with ops.device('device:{}:0'.format(self.device)):
528
529      def f(x):
530        ta = tensor_array_ops.TensorArray(
531            dtype=dtypes.float32, size=2, element_shape=[3])
532        ta = ta.write(0, 2 * x)
533        ta = ta.write(1, 3 * x)
534        return ta.concat()
535
536      compiled_f = def_function.function(jit_compile=True)(f)
537
538      inputs = constant_op.constant([3.14, 2.68, 7.69])
539
540      self.assertAllClose([6.28, 5.36, 15.38, 9.42, 8.04, 23.07], f(inputs))
541
542      self.assertAllClose(compiled_f(inputs), f(inputs))
543
544  @test_util.disable_mlir_bridge('TODO(b/162281863): MLIR bridge errors out'
545                                 ' lowering TensorListConcatV2')
546  def testTensorListConcatV2Multidim(self):
547    with ops.device('device:{}:0'.format(self.device)):
548
549      def f(x):
550        ta = tensor_array_ops.TensorArray(
551            dtype=dtypes.float32, size=2, element_shape=[3, 2])
552        ta = ta.write(0, 2 * x)
553        ta = ta.write(1, 3 * x)
554        return ta.concat()
555
556      compiled_f = def_function.function(jit_compile=True)(f)
557
558      inputs = constant_op.constant([[3.14, 21.1], [2.68, 22.2], [7.69, 23.3]])
559      self.assertAllClose(f(inputs), compiled_f(inputs))
560
561  @test_util.disable_mlir_bridge('TODO(b/162281863): MLIR bridge errors out'
562                                 ' lowering TensorListConcatV2')
563  def testTensorListConcatV2Scalars(self):
564    with ops.device('device:{}:0'.format(self.device)):
565
566      def f(x):
567        ta = tensor_array_ops.TensorArray(
568            dtype=dtypes.float32, size=2, element_shape=[1])
569        ta = ta.write(0, 2 * x)
570        ta = ta.write(1, 3 * x)
571        return ta.concat()
572
573      compiled_f = def_function.function(jit_compile=True)(f)
574      inputs = constant_op.constant([3.14])
575      self.assertAllClose(f(inputs), compiled_f(inputs))
576
577  @test_util.disable_mlir_bridge('TODO(b/162281863): MLIR bridge errors out'
578                                 ' lowering TensorListConcatV2')
579  def testTensorListConcatGrad(self):
580    with ops.device('device:{}:0'.format(self.device)):
581
582      def f(x):
583        ta = tensor_array_ops.TensorArray(
584            dtype=dtypes.float32, size=2, element_shape=[3])
585        ta = ta.write(0, 2 * x)
586        ta = ta.write(1, 3 * x)
587        return ta.concat()
588
589      def g():
590        x = constant_op.constant([3.14, 2.68, 7.69])
591        with backprop.GradientTape() as tape:
592          tape.watch(x)
593          y = f(x)
594          return tape.gradient(y, x)
595
596      compiled_g = def_function.function(jit_compile=True)(g)
597
598      self.assertAllClose([5.0, 5.0, 5.0], g())
599      self.assertAllClose(compiled_g(), g())
600
601  @test_util.disable_mlir_bridge('TODO(b/162281863): MLIR bridge errors out'
602                                 ' lowering TensorListConcatV2')
603  def testTensorListConcatGradNestedCompile(self):
604    with ops.device('device:{}:0'.format(self.device)):
605
606      @def_function.function(jit_compile=True)
607      def f(x):
608        ta = tensor_array_ops.TensorArray(
609            dtype=dtypes.float32, size=2, element_shape=[3])
610        ta = ta.write(0, 2 * x)
611        ta = ta.write(1, 3 * x)
612        return ta.concat()
613
614      @def_function.function(jit_compile=True)
615      def g():
616        x = constant_op.constant([3.14, 2.68, 7.69])
617        with backprop.GradientTape() as tape:
618          tape.watch(x)
619          y = f(x)
620          out = tape.gradient(y, x)
621        return out
622
623      self.assertAllClose([5.0, 5.0, 5.0], g())
624
625  def testCumsum(self):
626    if 'tpu' in self.device.lower():
627      self.skipTest('b/162771302: 64bit rewrite of cumsum not supported')
628
629    with ops.device('device:{}:0'.format(self.device)):
630
631      @def_function.function(jit_compile=True)
632      def f(x):
633        return math_ops.cumsum(x)
634
635      f64_input = constant_op.constant([1.1, 2.2, 3.3], dtype=dtypes.float64)
636      self.assertAllClose([1.1, 3.3, 6.6], f(f64_input))
637
638  def testNoExcessiveRetracing(self):
639    with ops.device('device:{}:0'.format(self.device)):
640      inner_retracings = 0
641
642      @def_function.function(jit_compile=True)
643      def inner(a, b):
644        nonlocal inner_retracings
645        inner_retracings += 1
646        return a * b + a
647
648      def outer(a, b):
649        return inner(a, b)
650
651      func_input = random_ops.random_normal([10, 10])
652      for _ in range(2):
653        def_function.function(outer)(func_input, func_input)
654
655      self.assertEqual(inner_retracings, 1)
656
657  def testUpdateVariable(self):
658    with ops.device('device:{}:0'.format(self.device)):
659      v = variables.Variable([0.0, 0.0])
660
661      @def_function.function(jit_compile=True)
662      def f():
663        v.assign([3.1, 2.3])
664
665      f()
666      self.assertAllClose(v, [3.1, 2.3])
667
668  @test_util.disable_mlir_bridge('MLIR does not support resource update for'
669                                 ' signature with compile-time constant.')
670  def testUniqueDifferentSizes(self):
671    if not 'gpu' in self.device.lower():
672      self.skipTest('Currently works only on GPU')
673
674    with ops.device('device:{}:0'.format(self.device)):
675
676      @def_function.function(jit_compile=True)
677      def f(x, y):
678        return array_ops.unique(x).y + array_ops.unique(y).y
679
680      f(constant_op.constant([3.1, 3.2]), constant_op.constant([3.3, 3.2]))
681
682      with self.assertRaisesRegex(errors.InternalError, 'different size'):
683        f(
684            constant_op.constant([3.1, 3.2]),
685            constant_op.constant([3.1, 3.2, 3.3]))
686
687  def testUniqueCompilability(self):
688    with ops.device('device:{}:0'.format(self.device)):
689
690      @def_function.function(jit_compile=True)
691      def f(x):
692        return array_ops.unique(x).y
693
694      self.assertAllClose(f(constant_op.constant([3.1, 3.2, 3.2])), [3.1, 3.2])
695
696  def testUpdateVariableMemoryUsage(self):
697    with ops.device('device:{}:0'.format(self.device)):
698
699      on_gpu = 'gpu' in self.device.lower()
700      v = variables.Variable([3.1, 3.2])
701
702      @def_function.function(jit_compile=True)
703      def update_var(a, b):
704        v.assign_add(a * b)
705
706      arg1 = random_ops.random_normal([2])
707      arg2 = random_ops.random_normal([2])
708
709      gc.collect()
710      initial_usage = context.context().get_memory_info(
711          v.device)['current'] if on_gpu else 0
712      update_var(arg1, arg2)
713      gc.collect()
714      final_usage = context.context().get_memory_info(
715          v.device)['current'] if on_gpu else 0
716      self.assertEqual(initial_usage, final_usage)
717
718  @test_util.disable_mlir_bridge('MLIR does not support resource update for'
719                                 ' signature with compile-time constant.')
720  def testUpdateVariableWithCompileTimeConstMemoryUsage(self):
721    with ops.device('device:{}:0'.format(self.device)):
722
723      on_gpu = 'gpu' in self.device.lower()
724      v = variables.Variable(random_ops.random_normal([1024, 1024]))
725
726      # test a signature of (compile-time const, arg, res_var). The compile-time
727      # const will be optimized away so that the kernel signature will become
728      # (arg, res_var).
729      @def_function.function(jit_compile=True)
730      def update_var(shape, arg):
731        v.assign_add(array_ops.broadcast_to(arg, shape))
732
733      arg = random_ops.random_normal([1])
734
735      gc.collect()
736      initial_usage = context.context().get_memory_info(
737          v.device)['current'] if on_gpu else 0
738      update_var(constant_op.constant([1024, 1024]), arg)
739      # Need to do update_var for a second time so that BFC Allocator could
740      # defragment the GPU memory.
741      update_var(constant_op.constant([1024, 1024]), arg)
742      gc.collect()
743      final_usage = context.context().get_memory_info(
744          v.device)['current'] if on_gpu else 0
745      self.assertEqual(initial_usage, final_usage)
746
747  @test_util.disable_mlir_bridge('TODO(b/162381930): MLIR bridge renames '
748                                 ' functions')
749  def testUpdateVariableInClass(self):
750    with ops.device('device:{}:0'.format(self.device)):
751
752      class C(object):
753
754        @def_function.function(jit_compile=True)
755        def update_var(self, a, b):
756          if not hasattr(self, 'v'):
757            self.v = variables.Variable(3.1)
758          self.v.assign_add(a * b)
759
760      c = C()
761
762      @def_function.function
763      def outer():
764        c.update_var(constant_op.constant(0.7), constant_op.constant(0.6))
765
766      outer()
767      self.assertAllClose(c.v, 3.52)
768
769  def testUpdateVariableMultipleOutputs(self):
770    with ops.device('device:{}:0'.format(self.device)):
771      v = variables.Variable(3.1)
772
773      @def_function.function(jit_compile=True)
774      def update_var(a, b):
775        v.assign_add(a * b)
776        return a * b + v
777
778      out = update_var(constant_op.constant(0.7), constant_op.constant(0.6))
779      self.assertAllClose(v, 3.52)
780      self.assertAllClose(out, 3.94)
781
782  def testReturnIdentity(self):
783    with ops.device('device:{}:0'.format(self.device)):
784
785      @def_function.function(jit_compile=True)
786      def f(a, b):
787        return (a, b)
788
789      a = random_ops.random_normal([10, 10])
790      b = random_ops.random_normal([10, 10])
791
792      on_gpu = 'gpu' in self.device.lower()
793      gc.collect()
794      initial_usage = context.context().get_memory_info(
795          b.backing_device)['current'] if on_gpu else 0
796
797      f(a, b)
798
799      gc.collect()
800      final_usage = context.context().get_memory_info(
801          b.backing_device)['current'] if on_gpu else 0
802      self.assertEqual(initial_usage, final_usage)
803
804  def testGetCompilerIrConstants(self):
805    if 'tpu' in self.device.lower():
806      self.skipTest('TPU generates different HLO')
807
808    with ops.device('device:{}:0'.format(self.device)):
809
810      @def_function.function(jit_compile=True)
811      def f(a, b):
812        return array_ops.transpose(a, b)
813
814      a = array_ops.ones([3, 4, 3], dtype=dtypes.float32)
815      b = constant_op.constant([0, 2, 1], dtype=dtypes.int32)
816
817      self.assertIn('{1,2,0}',
818                    f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo'))
819
820  @test_util.disable_mlir_bridge('TODO(b/168732524): MLIR bridge does not '
821                                 ' optimize single-element tuples to scalars')
822  def testGetCompilerIrResourceVars(self):
823    with ops.device('device:{}:0'.format(self.device)):
824
825      v = variables.Variable([3.1, 3.2])
826
827      @def_function.function(jit_compile=True)
828      def f(a, b):
829        v.assign_add(a * b)
830
831      a = random_ops.random_normal([2])
832      b = random_ops.random_normal([2])
833
834      self.assertIn('input_output_alias={ {}: (2, {}, may-alias) }',
835                    f.experimental_get_compiler_ir(a, b)('optimized_hlo'))
836
837  def testGetCompilerIrNotCompiled(self):
838    with ops.device('device:{}:0'.format(self.device)):
839
840      @def_function.function
841      def f(x):
842        return x + 1
843
844      a = random_ops.random_normal([10, 10])
845      with self.assertRaisesRegex(ValueError,
846                                  'marked with \'jit_compile'):
847        f.experimental_get_compiler_ir(a)()
848
849  def testGetCompilerIrNested(self):
850    with ops.device('device:{}:0'.format(self.device)):
851
852      @def_function.function(jit_compile=True)
853      def fn(x, a):
854        return x + a
855
856      @def_function.function(jit_compile=False)
857      def fn2(x, a):
858        fn.experimental_get_compiler_ir(x, a)()
859        return fn(x, a)
860
861      inputs = constant_op.constant([1, 2, 2, 3, 3])
862      with self.assertRaises(TypeError):
863        fn2(inputs, 1)
864
865  def testGetCompilerIrKwargs(self):
866    with ops.device('device:{}:0'.format(self.device)):
867
868      v = variables.Variable([0.1, 0.1])
869
870      @def_function.function(jit_compile=True)
871      def f(a, b):
872        return (a + b) * v
873
874      a = constant_op.constant([1.1, 1.1])
875      b = constant_op.constant([2.2, 2.2])
876
877      self.assertIn('multiply',
878                    f.experimental_get_compiler_ir(b=a, a=b)(stage='hlo'))
879
880  def testGetCompilerIrDot(self):
881    with ops.device('device:{}:0'.format(self.device)):
882
883      @def_function.function(jit_compile=True)
884      def f(a, b):
885        return a + b
886
887      a = constant_op.constant([1.1, 1.1])
888      b = constant_op.constant([2.2, 2.2])
889
890      self.assertIn(
891          'label',
892          f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo_dot'))
893
894  def testGetCompilerIrNoDevicePlacement(self):
895    if 'gpu' not in self.device.lower():
896      self.skipTest('Testing get_compiler_ir on GPUs without placement')
897
898    @def_function.function(jit_compile=True)
899    def f(a, b):
900      return a + b
901
902    a = constant_op.constant([1.1, 1.1])
903    b = constant_op.constant([2.2, 2.2])
904
905    self.assertIn(
906        'label',
907        f.experimental_get_compiler_ir(a, b)(stage='optimized_hlo_dot'))
908
909  def testGetCompilerIrNonTensors(self):
910    with ops.device('device:{}:0'.format(self.device)):
911
912      @def_function.function(jit_compile=True)
913      def f(l):
914        return l[0] + l[1]
915
916      l = [constant_op.constant(1.1), constant_op.constant(2.2)]
917
918      self.assertIn('tuple',
919                    f.experimental_get_compiler_ir(l)())
920
921  def testGetCompilerIrSerialized(self):
922    with ops.device('device:{}:0'.format(self.device)):
923
924      @def_function.function(jit_compile=True)
925      def fn(x):
926        return x - x
927
928      inputs = constant_op.constant([1, 2, 2, 3, 3])
929      for stage in ('hlo_serialized', 'optimized_hlo_serialized'):
930        hlo = fn.experimental_get_compiler_ir(inputs)(
931            stage=stage, device_name=f'/device:{self.device}:0')
932        self.assertIsInstance(hlo, bytes)
933
934  def testDotOptimizedHlo(self):
935    with ops.device('device:{}:0'.format(self.device)):
936
937      a = random_ops.random_normal([100, 100])
938      b = random_ops.random_normal([100, 100])
939
940      @def_function.function(jit_compile=True)
941      def f(a, b):
942        return math_ops.matmul(a, b)
943
944      self.assertRegex(f.experimental_get_compiler_ir(a, b)('optimized_hlo'),
945                       '(dot)|(convolution)')
946
947  def testConstantOnWrongDevice(self):
948    with ops.device('device:{}:0'.format(self.device)):
949
950      s = random_ops.random_uniform([2], 1, 10, dtypes.int32)
951      l = random_ops.random_normal([s[0] * s[1]])
952
953      @def_function.function(jit_compile=True)
954      def f(l):
955        return array_ops.reshape(l, s)
956
957      self.assertIn('tuple',
958                    f.experimental_get_compiler_ir(l)())
959
960  @test_util.disable_mlir_bridge('TODO(b/172845417): MLIR bridge does not '
961                                 'support getting constants out of resources')
962  def testGetConstantOutOfResourceVariable(self):
963    with ops.device('device:{}:0'.format(self.device)):
964
965      # Use floats to force device placement.
966      a = variables.Variable(50.0)
967      b = variables.Variable(2.0)
968
969      @def_function.function(jit_compile=True)
970      def f(x):
971        return array_ops.reshape(
972            x, [math_ops.cast(a, dtypes.int32),
973                math_ops.cast(b, dtypes.int32)])
974
975      # OK since the value is known at compile time.
976      out = f(random_ops.random_normal([10, 10]))
977      self.assertEqual(out.shape[0], 50)
978      self.assertEqual(out.shape[1], 2)
979
980  @test_util.disable_mlir_bridge('TODO(b/172845417): MLIR bridge does not '
981                                 'support getting constants out of resources')
982  def testGetConstantOutOfResourceVariableAfterWrite(self):
983    with ops.device('device:{}:0'.format(self.device)):
984
985      # Use floats to force device placement.
986      a = variables.Variable(50.0)
987      b = variables.Variable(2.0)
988
989      @def_function.function(jit_compile=True)
990      def f(x, val1, val2):
991        a.assign(math_ops.cast(val1, dtypes.float32))
992        b.assign(math_ops.cast(val2, dtypes.float32))
993        return array_ops.reshape(
994            x, [math_ops.cast(a, dtypes.int32),
995                math_ops.cast(b, dtypes.int32)])
996
997      val1 = constant_op.constant(2)
998      val2 = constant_op.constant(50)
999
1000      # Returns an error, since the value known at compile time was overriden.
1001      with self.assertRaisesRegex(errors.InvalidArgumentError,
1002                                  'concrete values at compile time'):
1003        f(random_ops.random_normal([10, 10]), val1, val2)
1004
1005  @test_util.disable_mlir_bridge('TODO(b/172845417): MLIR bridge does not '
1006                                 'support getting constants out of resources')
1007  def testGetConstantOutOfResourceVariableBeforeWrite(self):
1008    with ops.device('device:{}:0'.format(self.device)):
1009
1010      # Use floats to force device placement.
1011      a = variables.Variable(50.0)
1012      b = variables.Variable(2.0)
1013
1014      @def_function.function(jit_compile=True)
1015      def f(x, val1, val2):
1016        out = array_ops.reshape(
1017            x, [math_ops.cast(a, dtypes.int32),
1018                math_ops.cast(b, dtypes.int32)])
1019        a.assign(math_ops.cast(val1, dtypes.float32))
1020        b.assign(math_ops.cast(val2, dtypes.float32))
1021        return out
1022
1023      val1 = constant_op.constant(2)
1024      val2 = constant_op.constant(50)
1025
1026      # OK since the write happens after the reshape.
1027      out = f(random_ops.random_normal([10, 10]), val1, val2)
1028      self.assertEqual(out.shape[0], 50)
1029      self.assertEqual(out.shape[1], 2)
1030
1031  def testTfAssert(self):
1032    with ops.device('device:{}:0'.format(self.device)):
1033
1034      @def_function.function(jit_compile=True)
1035      def f(x):
1036        control_flow_ops.Assert(x == 1, ['Wrong value'])
1037
1038      f(constant_op.constant(1))
1039
1040  def testTensorArrayErrorMessage(self):
1041    with ops.device('device:{}:0'.format(self.device)):
1042
1043      @def_function.function(jit_compile=True)
1044      def f():
1045        # The error message as old and new bridge differ in which op they flag.
1046        # The one points to the creation of the unitialized tensor array, the
1047        # other is the use of the unitialized tensor array.
1048        ta = tensor_array_ops.TensorArray(  # EXPECTED_MESSAGE_NEW
1049            dtype=dtypes.float32,
1050            size=2,
1051            dynamic_size=True,
1052            element_shape=(None,))
1053        return ta.concat()  # EXPECTED_MESSAGE_OLD
1054
1055      if test_util.is_mlir_bridge_enabled():
1056        with self.assertRaisesRegex(errors.InvalidArgumentError,
1057                                    'EXPECTED_MESSAGE_NEW'):
1058          f()
1059      else:
1060        with self.assertRaisesRegex(errors.InvalidArgumentError,
1061                                    'EXPECTED_MESSAGE_OLD'):
1062          f()
1063
1064  def testCounter(self):
1065    cell_nojit = def_function._tf_function_counter.get_cell('0')
1066    cell_jit = def_function._tf_function_counter.get_cell('1')
1067    orig_nojit = cell_nojit.value()
1068    orig_jit = cell_jit.value()
1069
1070    with ops.device('device:{}:0'.format(self.device)):
1071      @def_function.function
1072      def f(a):
1073        return a + a
1074      f(constant_op.constant(1))
1075      self.assertEqual(cell_nojit.value(), orig_nojit + 1)
1076      self.assertEqual(cell_jit.value(), orig_jit)
1077      f(constant_op.constant(1.))  # Calling again does not increment
1078      self.assertEqual(cell_nojit.value(), orig_nojit + 1)
1079
1080      @def_function.function(jit_compile=True)
1081      def f1(a):
1082        return a + a
1083      f1(constant_op.constant(1))
1084      self.assertEqual(cell_nojit.value(), orig_nojit + 1)
1085      self.assertEqual(cell_jit.value(), orig_jit + 1)
1086
1087      @def_function.function
1088      def f2(a):
1089        @def_function.function
1090        def g(a):
1091          return a + a
1092        @def_function.function(jit_compile=True)
1093        def h(a):
1094          return a + a
1095        return g(a) + h(a)
1096      f2(constant_op.constant(1))
1097      self.assertEqual(cell_nojit.value(), orig_nojit + 2)
1098      self.assertEqual(cell_jit.value(), orig_jit + 2)
1099
1100      @def_function.function(jit_compile=True)
1101      def f3(a):
1102        @def_function.function
1103        def g(a):
1104          return a + a
1105        @def_function.function(jit_compile=True)
1106        def h(a):
1107          return a + a
1108        return g(a) + h(a)
1109      f3(constant_op.constant(1))
1110      self.assertEqual(cell_nojit.value(), orig_nojit + 2)
1111      self.assertEqual(cell_jit.value(), orig_jit + 3)
1112
1113  @test_util.disable_mlir_bridge('TODO(b/162272821): MLIR bridge returns '
1114                                 ' wrong status type')
1115  def testResourceWrongDevice(self):
1116    if 'gpu' not in self.device.lower():
1117      self.skipTest('Need a GPU to have non-trivial device placement')
1118
1119    with ops.device('device:CPU:0'):
1120      v = variables.Variable([3.1, 3.2])
1121
1122    with ops.device('device:{}:0'.format(self.device)):
1123
1124      @def_function.function(experimental_compile=True)
1125      def update_var(a):
1126        v.assign_add(a)
1127
1128      arg = random_ops.random_normal([2])
1129      with self.assertRaisesRegex(errors.InvalidArgumentError,
1130                                  'Trying to access resource .*'):
1131        update_var(arg)
1132
1133  def testMustBeConstantInsideCondition(self):
1134    with ops.device('device:{}:0'.format(self.device)):
1135
1136      @def_function.function(jit_compile=True)
1137      def f(x, d):
1138        if math_ops.reduce_all(
1139            math_ops.greater(x, random_ops.random_normal([10, 10]))):
1140          return array_ops.reshape(x * 2, constant_op.constant([100]))
1141        else:
1142          return array_ops.reshape(x * 3, d)
1143
1144      f(random_ops.random_normal([10, 10]), constant_op.constant([100]))
1145
1146  def testConditionalGradientTapeMathRegression(self):
1147    with ops.device('device:{}:0'.format(self.device)):
1148      with backprop.GradientTape():
1149
1150        @def_function.function(jit_compile=True, autograph=False)
1151        def f(x):
1152          return control_flow_ops.cond(
1153              math_ops.reduce_all(x > 1), lambda: 1. / x, lambda: x)
1154
1155        v = variables.Variable([[2.]])
1156        self.assertAllClose(f(v), constant_op.constant([[0.5]]))
1157
1158  @test_util.disable_mlir_bridge('TODO(b/190444466): MLIR bridge seems to '
1159                                 'ignore resource assignments')
1160  def testErrMsgAssignWrongShape(self):
1161    with ops.device('device:{}:0'.format(self.device)):
1162
1163      v = variables.Variable([3.1, 3.2])
1164
1165      @def_function.function(jit_compile=True)
1166      def f(samples):
1167        v.assign(array_ops.zeros(samples))  # assignment
1168
1169      with self.assertRaisesRegex(
1170          errors.InvalidArgumentError,
1171          'Shape .* cannot be changed after initialization'):
1172        f(constant_op.constant(6))
1173
1174      with self.assertRaisesRegex(errors.InvalidArgumentError, 'assignment'):
1175        f(constant_op.constant(6))
1176
1177  def testTfSummaryErrMsg(self):
1178    if 'gpu' not in self.device.lower():
1179      self.skipTest('Only runs on GPU')
1180
1181    with ops.device('device:{}:0'.format(self.device)):
1182      writer = summary_ops_v2.create_file_writer(self.get_temp_dir())
1183
1184      @def_function.function(jit_compile=True)
1185      def my_func_temp():
1186        with writer.as_default():
1187          summary_ops_v2.scalar('my_metric', 0.5, step=10)
1188
1189      with self.assertRaisesRegex(errors.InvalidArgumentError,
1190                                  'Trying to access resource .*'):
1191        my_func_temp()
1192
1193  def testSinglePassArgmax(self):
1194    with ops.device('device:{}:0'.format(self.device)):
1195
1196      @def_function.function(jit_compile=True)
1197      def f(x):
1198        return math_ops.argmax(x)
1199
1200      hlo = f.experimental_get_compiler_ir(
1201          array_ops.ones([10], dtype=dtypes.float32))(
1202              stage='hlo')
1203
1204      # Test that reduction occurs only once.
1205      self.assertGreater(hlo.count('reduce'), 1)
1206
1207
1208if __name__ == '__main__':
1209  ops.enable_eager_execution()
1210  test.main()
1211