xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tests/jit_test.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for JIT compilation on the CPU and GPU devices."""
16
17import os
18
19import numpy as np
20
21from tensorflow.compiler.tests import test_utils
22from tensorflow.core.protobuf import config_pb2
23from tensorflow.core.protobuf import rewriter_config_pb2
24from tensorflow.python.client import session as session_lib
25from tensorflow.python.compiler.xla import jit
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import function
29from tensorflow.python.framework import ops
30from tensorflow.python.ops import array_ops
31from tensorflow.python.ops import control_flow_ops
32from tensorflow.python.ops import gradients_impl
33from tensorflow.python.ops import math_ops
34from tensorflow.python.ops import nn_ops
35from tensorflow.python.platform import test
36
37
38jit_scope = jit.experimental_jit_scope
39
40# Disable rewrites to make sure we don't end up having to update this test
41# whenever we implement new ones.
42def NoRewriteSessionConfig():
43  rewriter_config = rewriter_config_pb2.RewriterConfig(
44      disable_model_pruning=True,
45      arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF,
46      dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF,
47      function_optimization=rewriter_config_pb2.RewriterConfig.OFF)
48  graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config)
49  return config_pb2.ConfigProto(graph_options=graph_options)
50
51
52def CompiledKernel(fn, *inputs, **kwargs):
53  """Execute 'fn' as a compiled XLA kernel, with 'inputs'."""
54  name = kwargs.pop("name", None)
55  noinline = kwargs.pop("noinline", None)
56
57  @function.Defun(func_name=name, noinline=noinline, compiled=True)
58  def Compiled(*args):
59    return fn(*args)
60
61  return Compiled(*inputs)
62
63
64def RunMetadataLabels(run_metadata):
65  """Returns all labels in run_metadata."""
66  labels = []
67  for dev_stats in run_metadata.step_stats.dev_stats:
68    for node_stats in dev_stats.node_stats:
69      labels.append(node_stats.timeline_label)
70  return labels
71
72
73def InLabels(labels, substr):
74  """Returns true iff one of the labels contains substr."""
75  return any(substr in x for x in labels)
76
77
78def MetadataHasXlaRunOp(run_metadata):
79  """Returns true if there are XlaRun kernels in run_metadata's timeline."""
80
81  # TODO(phawkins): find a less hacky way to test whether a kernel ran.
82  return InLabels(RunMetadataLabels(run_metadata), "_XlaRun")
83
84
85class JitLaunchTest(test.TestCase):
86
87  # Evaluates 'fn' on 'args' both directly and as a compiled XLA kernel.
88  # Verifies that the outputs match and that XLA was invoked. 'fn' must take
89  # the same number of tensors as arguments that are in 'args', and must return
90  # a tuple of output tensors.
91  #
92  # If 'require_kernel_launch' is True, then we verify that an XlaCompile/XlaRun
93  # node actually ran. However, it is sometimes possible for XlaCompile/XlaRun
94  # ops to be constant-folded away, so the check is optional.
95  def _compare(self,
96               fn,
97               args,
98               require_kernel_launch=True,
99               name=None,
100               noinline=None):
101    with session_lib.Session(config=NoRewriteSessionConfig()) as sess:
102      placeholders = []
103      feeds = {}
104      for arg in args:
105        placeholder = array_ops.placeholder(
106            dtypes.as_dtype(arg.dtype), list(arg.shape))
107        placeholders.append(placeholder)
108        feeds[placeholder] = arg
109
110      compiled_op = CompiledKernel(
111          fn, *placeholders, name=name, noinline=noinline)
112      direct_op = fn(*placeholders)
113
114      run_metadata = config_pb2.RunMetadata()
115      compiled = test_utils.RunWithWarmup(
116          sess, compiled_op, feeds,
117          config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE),
118          run_metadata)
119      print("Compiled Result {}".format(compiled))
120
121      if require_kernel_launch:
122        self.assert_(MetadataHasXlaRunOp(run_metadata))
123
124        direct = sess.run(direct_op, feeds)
125        print("Direct Result {}".format(direct))
126
127        if (isinstance(compiled, (tuple, list)) and
128            (isinstance(direct, (tuple, list)))):
129          for (x, y) in zip(compiled, direct):
130            self.assertAllClose(x, y, rtol=1e-1)
131        else:
132          self.assertAllClose(compiled, direct, rtol=1e-2)
133
134  def testNoOutputs(self):
135    with session_lib.Session() as sess:
136
137      # Check that calling the result as a compiled kernel doesn't crash.
138      @function.Defun(compiled=True)
139      def KernelWithNoOutputs():
140        a = constant_op.constant(100)  # pylint: disable=unused-variable
141
142      call = KernelWithNoOutputs()  # pylint: disable=assignment-from-no-return
143      test_utils.RunWithWarmup(sess, call, {})
144
145  def testAliasing(self):
146    """Regression test for compiled functions that return an aliased buffer.
147
148       XLA returns aliased buffers if outputs are identical. Tests that
149       we handle that case.
150    """
151
152    def AddOnceReturnTwice(x):
153      y = math_ops.add(x, x)
154      return y, y
155
156    # Exercises compiling a function (say, Foo) which calls another function
157    # (say, Bar) which is not inlined. When the compiler compiles Foo, it needs
158    # to symbolically execute Bar correctly regardless of whether Bar is inlined
159    # or not.
160
161    # Tests compiled=True and noinline=True.
162    self._compare(
163        AddOnceReturnTwice, [np.array([[[0.5, -1.0]]], dtype=np.float32)],
164        name="AddOnceReturnTwice_inline",
165        noinline=True)
166
167    # Tests compiled=True and noinline=False.
168    self._compare(
169        AddOnceReturnTwice, [np.array([[[0.5, -1.0]]], dtype=np.float32)],
170        name="AddOnceReturnTwice_noinline",
171        noinline=False)
172
173  def testOneConstOutput(self):
174    """Test consisting of a single constant return value."""
175
176    def OneConstOutput():
177      return constant_op.constant([-3, 44, 99])
178
179    self._compare(OneConstOutput, [], require_kernel_launch=False)
180
181  def testConstZeroElementOutput(self):
182    """Test consisting of a constant zero element return value."""
183
184    def ConstZeroElementOutput():
185      return array_ops.fill([7, 0], 3.0)
186
187    self._compare(ConstZeroElementOutput, [], require_kernel_launch=False)
188
189  def testSomeConstOutputs(self):
190    """Test kernels that return a mixture of const and non-const outputs."""
191
192    def SomeConstOutputs(x):
193      return constant_op.constant(
194          [-2, 7]), array_ops.identity(x), constant_op.constant(3.5)
195
196    self._compare(
197        SomeConstOutputs, [np.array(
198            [[1, 2, 3], [4, 5, 6]], dtype=np.float32)])
199
200  def testInt32Input(self):
201    """Test an int32-typed input.
202
203       On a GPU, int32 tensors will be placed in host memory.
204    """
205
206    def AddToSelf(x):
207      return math_ops.add(x, x)
208
209    self._compare(AddToSelf, [np.array([7, 1, 3], dtype=np.int32)])
210
211  def testMandatoryConstantInput(self):
212    """Tests an operator that has a mandatory-constant shape input."""
213
214    def FillWithFloat(x):
215      return array_ops.fill(x, 9.5)
216
217    self._compare(FillWithFloat, [np.array([3, 2], dtype=np.int32)])
218
219  def testMnistForwardFunc(self):
220    """Compute inference function from MNIST beginners tutorial."""
221    batch_size = 16
222    image_size = 28 * 28
223    num_classes = 10
224
225    # Define a TensorFlow function to compute the forward pass.
226    def MnistForward(w, b, x):
227      return nn_ops.softmax(math_ops.matmul(x, w) + b)
228
229    w = np.random.random_sample((image_size, num_classes)).astype(np.float32)
230    b = np.random.random_sample((num_classes)).astype(np.float32)
231    x = np.random.random_sample((batch_size, image_size)).astype(np.float32)
232    self._compare(MnistForward, [w, b, x])
233
234  def testExplicitMarking(self):
235    """Test explicit marking of operators to compile."""
236    batch_size = 16
237    image_size = 28 * 28
238    num_classes = 10
239
240    with ops.Graph().as_default():
241      x = array_ops.placeholder(dtypes.float32)
242      w = array_ops.placeholder(dtypes.float32)
243      b = array_ops.placeholder(dtypes.float32)
244      with jit_scope():
245        y1 = math_ops.matmul(x, w)
246      y2 = math_ops.add(y1, b)
247      with jit_scope():
248        y = math_ops.square(y2)
249
250      dw = np.random.random_sample((image_size, num_classes)).astype(np.float32)
251      db = np.random.random_sample((num_classes)).astype(np.float32)
252      dx = np.random.random_sample((batch_size, image_size)).astype(np.float32)
253      with session_lib.Session() as sess:
254        run_metadata = config_pb2.RunMetadata()
255        output = test_utils.RunWithWarmup(
256            sess,
257            y, {
258                x: dx,
259                w: dw,
260                b: db
261            },
262            run_metadata=run_metadata,
263            options=config_pb2.RunOptions(
264                trace_level=config_pb2.RunOptions.FULL_TRACE))
265
266        # TODO(phawkins): really we would like to test that there were exactly
267        # two kernel launches. However, we have no reliable way to determine
268        # that.
269        self.assert_(MetadataHasXlaRunOp(run_metadata))
270
271        expected = np.square(np.dot(dx, dw) + db)
272        self.assertAllClose(expected, output, rtol=1e-1)
273
274
275class XlaCompilationTest(test.TestCase):
276  """Tests for auto-compilation on CPU/GPU devices."""
277
278  def testReshape(self):
279    """Tests an operator with compile-time constant and non-constant inputs."""
280
281    with self.session(config=NoRewriteSessionConfig()) as sess:
282      x = array_ops.placeholder(dtypes.float32)
283      y = array_ops.placeholder(dtypes.int32)
284      with jit_scope():
285        # Reshape's first argument is non-constant in the JIT, but its second
286        # (shape) argument will be treated as a compile-time constant for
287        # each JIT compilation.
288        # We do not use a tf.const() argument since we want to ensure the
289        # shape is still a run-time argument to the JIT, and not
290        # statically known as part of the JIT compilation's input graph.
291        z = array_ops.reshape(x, y)
292      run_metadata = config_pb2.RunMetadata()
293      out = test_utils.RunWithWarmup(
294          sess,
295          z, {
296              x: np.array([1, 2, 3, 4, 5, 6], np.float32),
297              y: [-1, 3]
298          },
299          run_metadata=run_metadata,
300          options=config_pb2.RunOptions(
301              trace_level=config_pb2.RunOptions.FULL_TRACE))
302      self.assert_(MetadataHasXlaRunOp(run_metadata))
303      self.assertAllClose(np.array([[1, 2, 3], [4, 5, 6]], np.float32), out)
304
305  def testIgnoredArguments(self):
306    """Tests that JIT computations can ignore formal parameters."""
307
308    with self.session(config=NoRewriteSessionConfig()) as sess:
309      x = array_ops.placeholder(dtypes.int32)
310      y = array_ops.placeholder(dtypes.int32)
311      with jit_scope():
312        z = math_ops.add(x, x)
313        w = math_ops.add(y, y)
314        # Pulls 'w' into the same compilation via control dependencies.
315        with ops.control_dependencies([w]):
316          n = control_flow_ops.no_op()
317        with ops.control_dependencies([n]):
318          t = math_ops.add(z, z)
319
320      run_metadata = config_pb2.RunMetadata()
321      out = test_utils.RunWithWarmup(
322          sess,
323          t, {
324              x: np.int32(7),
325              y: np.int32(404)
326          },
327          run_metadata=run_metadata,
328          options=config_pb2.RunOptions(
329              trace_level=config_pb2.RunOptions.FULL_TRACE))
330      self.assert_(MetadataHasXlaRunOp(run_metadata))
331      self.assertAllClose(28, out)
332
333  def testLoops(self):
334    """Tests that compilation accepts computations containing loops."""
335
336    with self.session(config=NoRewriteSessionConfig()) as session:
337      x = array_ops.placeholder(dtypes.float32)
338      with jit_scope():
339        c = lambda i, _: math_ops.less(i, 5)
340        b = lambda i, x: (i + 1, x * 2.0 + 1.0)
341        _, y = control_flow_ops.while_loop(c, b, (constant_op.constant(0), x))
342
343      run_metadata = config_pb2.RunMetadata()
344      result = session.run(y, {x: np.float32(2)},
345                           run_metadata=run_metadata,
346                           options=config_pb2.RunOptions(
347                               trace_level=config_pb2.RunOptions.FULL_TRACE))
348      self.assert_(MetadataHasXlaRunOp(run_metadata))
349      self.assertAllClose(result, np.float32(95), rtol=1e-1)
350
351  def testCond(self):
352    """Tests that compilation handles switch operators."""
353
354    with self.session(config=NoRewriteSessionConfig()) as session:
355      x = array_ops.placeholder(dtypes.float32)
356      y = array_ops.placeholder(dtypes.float32)
357      c = array_ops.placeholder(dtypes.bool)
358      with jit_scope():
359        z = x + 1.0
360        w = control_flow_ops.cond(c, lambda: z, lambda: y)
361        t = math_ops.add(z, w)
362
363      # If JIT compilation chooses to cluster z and t, then execution will
364      # deadlock.
365
366      run_metadata = config_pb2.RunMetadata()
367      result = test_utils.RunWithWarmup(
368          session,
369          t, {
370              x: np.float32(2),
371              y: np.float32(4),
372              c: True
373          },
374          run_metadata=run_metadata,
375          options=config_pb2.RunOptions(
376              trace_level=config_pb2.RunOptions.FULL_TRACE))
377      self.assert_(MetadataHasXlaRunOp(run_metadata))
378      self.assertAllClose(result, np.float32(6), rtol=1e-1)
379
380  def testNestedFunction(self):
381    g = ops.Graph()
382    with g.as_default():
383
384      @function.Defun(compiled=True)
385      def Bar(x, y):
386        return x + 2 * y
387
388      @function.Defun(compiled=True)
389      def Foo(x):
390        return Bar(x * x, x * x * x)
391
392      @function.Defun()
393      def Entry(x):
394        return Foo(x)
395
396      inp = array_ops.placeholder(dtypes.float32)
397      out = Entry(inp)
398
399    with self.session(
400        config=NoRewriteSessionConfig(), graph=g, use_gpu=True) as sess:
401      run_metadata = config_pb2.RunMetadata()
402      val = sess.run(out,
403                     feed_dict={inp: [2., 10.]},
404                     run_metadata=run_metadata,
405                     options=config_pb2.RunOptions(
406                         trace_level=config_pb2.RunOptions.FULL_TRACE))
407      self.assertAllClose(val, [20., 2100.])
408
409  def testLoopDeadlock(self):
410    """Regression test for bug that caused deadlocks in graphs with loops."""
411
412    with self.session(config=NoRewriteSessionConfig()) as session:
413      x = array_ops.placeholder(dtypes.float32)
414      with jit_scope():
415        y = x + 1.0
416        c = lambda i, _x, _y: math_ops.less(i, 5)
417        b = lambda i, x, _y: (i + 1, x * 2.0 + 1.0, x - 3.0)
418        _, _, w = control_flow_ops.while_loop(c, b,
419                                              (constant_op.constant(0), y, x))
420        u = w + y
421      result = session.run(u, {x: np.float32(2)})
422      self.assertAllClose(result, np.float32(63), rtol=1e-1)
423
424  def testGradient(self):
425    """Tests that the backprop function is properly compiled."""
426
427    def _Run(compiled):
428
429      @function.Defun(compiled=compiled)
430      def Forward(x):
431        return math_ops.log(x)
432
433      g = ops.Graph()
434      with g.as_default():
435        x = array_ops.placeholder(dtypes.float32)
436        y = Forward(x)
437        dx, = gradients_impl.gradients(y, [x], 1.0)
438
439      cfg = NoRewriteSessionConfig()
440      cfg.graph_options.optimizer_options.opt_level = (
441          config_pb2.OptimizerOptions.L1)
442      cfg.graph_options.optimizer_options.do_function_inlining = True
443      with session_lib.Session(graph=g, config=cfg) as sess:
444        run_metadata = config_pb2.RunMetadata()
445        dx_val = test_utils.RunWithWarmup(
446            sess,
447            dx,
448            feed_dict={x: 100.},
449            run_metadata=run_metadata,
450            options=config_pb2.RunOptions(
451                trace_level=config_pb2.RunOptions.FULL_TRACE))
452      self.assertAllClose(dx_val, 0.01)
453      return RunMetadataLabels(run_metadata)
454
455    # SymGrad[f=log(x)](x, dy) = 1/x * dy
456    #
457    # Note: we don't need to compute log(x) for dx due to graph pruning.
458
459    # Do not compile the backprop. We should see one Reciprocal and one Mul.
460    labels = _Run(compiled=False)
461    self.assertFalse(InLabels(labels, "Log"))
462    self.assertTrue(InLabels(labels, "Reciprocal"))
463    self.assertTrue(InLabels(labels, "Mul"))
464    self.assertFalse(InLabels(labels, "XlaCompile"))
465    self.assertFalse(InLabels(labels, "XlaRun"))
466
467    # Compile the backprop. One XlaCompile/XlaRun pair.
468    labels = _Run(compiled=True)
469    self.assertFalse(InLabels(labels, "Log"))
470    self.assertFalse(InLabels(labels, "Reciprocal"))
471    self.assertFalse(InLabels(labels, "Mul"))
472    self.assertTrue(InLabels(labels, "XlaCompile"))
473    self.assertTrue(InLabels(labels, "XlaRun"))
474
475
476class ElementWiseFusionTest(test.TestCase):
477
478  # Runs a simple test with the input jit_level and fusion_only flag.
479  def simpleTest(self, arg0, arg1, global_jit_level):
480    config = config_pb2.ConfigProto()
481    config.graph_options.optimizer_options.global_jit_level = global_jit_level
482
483    with session_lib.Session(config=config) as sess:
484      a1 = array_ops.placeholder(dtypes.float32, [2, 2], name="a1")
485      a2 = array_ops.placeholder(dtypes.float32, [2, 2], name="a2")
486      # Two element-wise ops. We need at least two ops since single
487      # element clusters are not passed to XLA in fusion_only mode.
488      a3 = a1 * a2
489      a4 = a3 + a1
490      # A matmul to break XLA clustering.
491      a5 = math_ops.matmul(a4, a1)
492      # Two more element-wise ops.
493      a6 = a5 - a4
494      a7 = a6 + a2
495
496      run_metadata = config_pb2.RunMetadata()
497      output = test_utils.RunWithWarmup(
498          sess,
499          a7, {
500              a1: arg0,
501              a2: arg1
502          },
503          run_metadata=run_metadata,
504          options=config_pb2.RunOptions(
505              trace_level=config_pb2.RunOptions.FULL_TRACE))
506
507      labels = RunMetadataLabels(run_metadata)
508
509      xla_compile_count = sum("XlaCompile(" in x for x in labels)
510      xla_run_count = sum("XlaRun(" in x for x in labels)
511      self.assertEqual(xla_compile_count, xla_run_count)
512
513      return output, xla_run_count
514
515
516class LazyCompilationTest(test.TestCase):
517
518  def testLazyCompilation(self):
519
520    @function.Defun(compiled=True)
521    def CompiledFunction(x):
522      return math_ops.log(x)
523
524    with session_lib.Session(config=NoRewriteSessionConfig()) as sess:
525      x = array_ops.placeholder(dtypes.float32)
526      y = CompiledFunction(x)
527
528      # The very first run of the cluster is always compiled (non-lazily).
529      run_metadata_for_first_run = config_pb2.RunMetadata()
530      sess.run(
531          y,
532          feed_dict={x: [2., 10., 19., 77., 100.]},
533          run_metadata=run_metadata_for_first_run,
534          options=config_pb2.RunOptions(
535              trace_level=config_pb2.RunOptions.FULL_TRACE))
536      self.assertTrue(
537          InLabels(
538              RunMetadataLabels(run_metadata_for_first_run), "_XlaCompile"))
539      self.assertTrue(
540          InLabels(RunMetadataLabels(run_metadata_for_first_run), "_XlaRun"))
541
542      run_metadata_before_warmup = config_pb2.RunMetadata()
543      sess.run(
544          y,
545          feed_dict={x: [2., 10.]},
546          run_metadata=run_metadata_before_warmup,
547          options=config_pb2.RunOptions(
548              trace_level=config_pb2.RunOptions.FULL_TRACE))
549      self.assertTrue(
550          InLabels(
551              RunMetadataLabels(run_metadata_before_warmup), "_XlaCompile"))
552      self.assertFalse(
553          InLabels(RunMetadataLabels(run_metadata_before_warmup), "_XlaRun"))
554
555      # We compile when we see the same shape a second time.
556
557      run_metadata_after_warmup = config_pb2.RunMetadata()
558      sess.run(
559          y,
560          feed_dict={x: [2., 10.]},
561          run_metadata=run_metadata_after_warmup,
562          options=config_pb2.RunOptions(
563              trace_level=config_pb2.RunOptions.FULL_TRACE))
564      self.assertTrue(
565          InLabels(RunMetadataLabels(run_metadata_after_warmup), "_XlaCompile"))
566      self.assertTrue(
567          InLabels(RunMetadataLabels(run_metadata_after_warmup), "_XlaRun"))
568
569      run_metadata_for_new_shape = config_pb2.RunMetadata()
570      sess.run(
571          y,
572          feed_dict={x: [2., 10., 12.]},
573          run_metadata=run_metadata_for_new_shape,
574          options=config_pb2.RunOptions(
575              trace_level=config_pb2.RunOptions.FULL_TRACE))
576      self.assertTrue(
577          InLabels(
578              RunMetadataLabels(run_metadata_for_new_shape), "_XlaCompile"))
579      self.assertFalse(
580          InLabels(RunMetadataLabels(run_metadata_for_new_shape), "_XlaRun"))
581
582  def testIsNotMegamorphic(self):
583
584    @function.Defun(compiled=True)
585    def CompiledFunction(x):
586      return math_ops.log(x)
587
588    with session_lib.Session(config=NoRewriteSessionConfig()) as sess:
589      x = array_ops.placeholder(dtypes.float32)
590      y = CompiledFunction(x)
591
592      # Run the cluster with lots of shape signatures, but in a way that it
593      # isn't megamorphic (i.e. each shape signature sees a lot of executions).
594      # Then check that the cluster has not been marked as megamorphic.
595
596      for shape in range(10, 50):
597        for _ in range(0, 1000):
598          sess.run(y, feed_dict={x: [0.] * shape})
599
600      for _ in range(0, 10):
601        sess.run(y, feed_dict={x: [0.] * 60})
602
603      run_metadata = config_pb2.RunMetadata()
604      sess.run(
605          y,
606          feed_dict={x: [0.] * 60},
607          run_metadata=run_metadata,
608          options=config_pb2.RunOptions(
609              trace_level=config_pb2.RunOptions.FULL_TRACE))
610      self.assertTrue(InLabels(RunMetadataLabels(run_metadata), "_XlaCompile"))
611      self.assertTrue(InLabels(RunMetadataLabels(run_metadata), "_XlaRun"))
612
613
614if __name__ == "__main__":
615  os.environ["TF_XLA_FLAGS"] = ("--tf_xla_enable_lazy_compilation=true " +
616                                os.environ.get("TF_XLA_FLAGS", ""))
617  # This test is using Tensorflow sessions which are not compatible with eager
618  # mode.
619  ops.disable_eager_execution()
620  test.main()
621