1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for JIT compilation on the CPU and GPU devices.""" 16 17import os 18 19import numpy as np 20 21from tensorflow.compiler.tests import test_utils 22from tensorflow.core.protobuf import config_pb2 23from tensorflow.core.protobuf import rewriter_config_pb2 24from tensorflow.python.client import session as session_lib 25from tensorflow.python.compiler.xla import jit 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import function 29from tensorflow.python.framework import ops 30from tensorflow.python.ops import array_ops 31from tensorflow.python.ops import control_flow_ops 32from tensorflow.python.ops import gradients_impl 33from tensorflow.python.ops import math_ops 34from tensorflow.python.ops import nn_ops 35from tensorflow.python.platform import test 36 37 38jit_scope = jit.experimental_jit_scope 39 40# Disable rewrites to make sure we don't end up having to update this test 41# whenever we implement new ones. 42def NoRewriteSessionConfig(): 43 rewriter_config = rewriter_config_pb2.RewriterConfig( 44 disable_model_pruning=True, 45 arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, 46 dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, 47 function_optimization=rewriter_config_pb2.RewriterConfig.OFF) 48 graph_options = config_pb2.GraphOptions(rewrite_options=rewriter_config) 49 return config_pb2.ConfigProto(graph_options=graph_options) 50 51 52def CompiledKernel(fn, *inputs, **kwargs): 53 """Execute 'fn' as a compiled XLA kernel, with 'inputs'.""" 54 name = kwargs.pop("name", None) 55 noinline = kwargs.pop("noinline", None) 56 57 @function.Defun(func_name=name, noinline=noinline, compiled=True) 58 def Compiled(*args): 59 return fn(*args) 60 61 return Compiled(*inputs) 62 63 64def RunMetadataLabels(run_metadata): 65 """Returns all labels in run_metadata.""" 66 labels = [] 67 for dev_stats in run_metadata.step_stats.dev_stats: 68 for node_stats in dev_stats.node_stats: 69 labels.append(node_stats.timeline_label) 70 return labels 71 72 73def InLabels(labels, substr): 74 """Returns true iff one of the labels contains substr.""" 75 return any(substr in x for x in labels) 76 77 78def MetadataHasXlaRunOp(run_metadata): 79 """Returns true if there are XlaRun kernels in run_metadata's timeline.""" 80 81 # TODO(phawkins): find a less hacky way to test whether a kernel ran. 82 return InLabels(RunMetadataLabels(run_metadata), "_XlaRun") 83 84 85class JitLaunchTest(test.TestCase): 86 87 # Evaluates 'fn' on 'args' both directly and as a compiled XLA kernel. 88 # Verifies that the outputs match and that XLA was invoked. 'fn' must take 89 # the same number of tensors as arguments that are in 'args', and must return 90 # a tuple of output tensors. 91 # 92 # If 'require_kernel_launch' is True, then we verify that an XlaCompile/XlaRun 93 # node actually ran. However, it is sometimes possible for XlaCompile/XlaRun 94 # ops to be constant-folded away, so the check is optional. 95 def _compare(self, 96 fn, 97 args, 98 require_kernel_launch=True, 99 name=None, 100 noinline=None): 101 with session_lib.Session(config=NoRewriteSessionConfig()) as sess: 102 placeholders = [] 103 feeds = {} 104 for arg in args: 105 placeholder = array_ops.placeholder( 106 dtypes.as_dtype(arg.dtype), list(arg.shape)) 107 placeholders.append(placeholder) 108 feeds[placeholder] = arg 109 110 compiled_op = CompiledKernel( 111 fn, *placeholders, name=name, noinline=noinline) 112 direct_op = fn(*placeholders) 113 114 run_metadata = config_pb2.RunMetadata() 115 compiled = test_utils.RunWithWarmup( 116 sess, compiled_op, feeds, 117 config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE), 118 run_metadata) 119 print("Compiled Result {}".format(compiled)) 120 121 if require_kernel_launch: 122 self.assert_(MetadataHasXlaRunOp(run_metadata)) 123 124 direct = sess.run(direct_op, feeds) 125 print("Direct Result {}".format(direct)) 126 127 if (isinstance(compiled, (tuple, list)) and 128 (isinstance(direct, (tuple, list)))): 129 for (x, y) in zip(compiled, direct): 130 self.assertAllClose(x, y, rtol=1e-1) 131 else: 132 self.assertAllClose(compiled, direct, rtol=1e-2) 133 134 def testNoOutputs(self): 135 with session_lib.Session() as sess: 136 137 # Check that calling the result as a compiled kernel doesn't crash. 138 @function.Defun(compiled=True) 139 def KernelWithNoOutputs(): 140 a = constant_op.constant(100) # pylint: disable=unused-variable 141 142 call = KernelWithNoOutputs() # pylint: disable=assignment-from-no-return 143 test_utils.RunWithWarmup(sess, call, {}) 144 145 def testAliasing(self): 146 """Regression test for compiled functions that return an aliased buffer. 147 148 XLA returns aliased buffers if outputs are identical. Tests that 149 we handle that case. 150 """ 151 152 def AddOnceReturnTwice(x): 153 y = math_ops.add(x, x) 154 return y, y 155 156 # Exercises compiling a function (say, Foo) which calls another function 157 # (say, Bar) which is not inlined. When the compiler compiles Foo, it needs 158 # to symbolically execute Bar correctly regardless of whether Bar is inlined 159 # or not. 160 161 # Tests compiled=True and noinline=True. 162 self._compare( 163 AddOnceReturnTwice, [np.array([[[0.5, -1.0]]], dtype=np.float32)], 164 name="AddOnceReturnTwice_inline", 165 noinline=True) 166 167 # Tests compiled=True and noinline=False. 168 self._compare( 169 AddOnceReturnTwice, [np.array([[[0.5, -1.0]]], dtype=np.float32)], 170 name="AddOnceReturnTwice_noinline", 171 noinline=False) 172 173 def testOneConstOutput(self): 174 """Test consisting of a single constant return value.""" 175 176 def OneConstOutput(): 177 return constant_op.constant([-3, 44, 99]) 178 179 self._compare(OneConstOutput, [], require_kernel_launch=False) 180 181 def testConstZeroElementOutput(self): 182 """Test consisting of a constant zero element return value.""" 183 184 def ConstZeroElementOutput(): 185 return array_ops.fill([7, 0], 3.0) 186 187 self._compare(ConstZeroElementOutput, [], require_kernel_launch=False) 188 189 def testSomeConstOutputs(self): 190 """Test kernels that return a mixture of const and non-const outputs.""" 191 192 def SomeConstOutputs(x): 193 return constant_op.constant( 194 [-2, 7]), array_ops.identity(x), constant_op.constant(3.5) 195 196 self._compare( 197 SomeConstOutputs, [np.array( 198 [[1, 2, 3], [4, 5, 6]], dtype=np.float32)]) 199 200 def testInt32Input(self): 201 """Test an int32-typed input. 202 203 On a GPU, int32 tensors will be placed in host memory. 204 """ 205 206 def AddToSelf(x): 207 return math_ops.add(x, x) 208 209 self._compare(AddToSelf, [np.array([7, 1, 3], dtype=np.int32)]) 210 211 def testMandatoryConstantInput(self): 212 """Tests an operator that has a mandatory-constant shape input.""" 213 214 def FillWithFloat(x): 215 return array_ops.fill(x, 9.5) 216 217 self._compare(FillWithFloat, [np.array([3, 2], dtype=np.int32)]) 218 219 def testMnistForwardFunc(self): 220 """Compute inference function from MNIST beginners tutorial.""" 221 batch_size = 16 222 image_size = 28 * 28 223 num_classes = 10 224 225 # Define a TensorFlow function to compute the forward pass. 226 def MnistForward(w, b, x): 227 return nn_ops.softmax(math_ops.matmul(x, w) + b) 228 229 w = np.random.random_sample((image_size, num_classes)).astype(np.float32) 230 b = np.random.random_sample((num_classes)).astype(np.float32) 231 x = np.random.random_sample((batch_size, image_size)).astype(np.float32) 232 self._compare(MnistForward, [w, b, x]) 233 234 def testExplicitMarking(self): 235 """Test explicit marking of operators to compile.""" 236 batch_size = 16 237 image_size = 28 * 28 238 num_classes = 10 239 240 with ops.Graph().as_default(): 241 x = array_ops.placeholder(dtypes.float32) 242 w = array_ops.placeholder(dtypes.float32) 243 b = array_ops.placeholder(dtypes.float32) 244 with jit_scope(): 245 y1 = math_ops.matmul(x, w) 246 y2 = math_ops.add(y1, b) 247 with jit_scope(): 248 y = math_ops.square(y2) 249 250 dw = np.random.random_sample((image_size, num_classes)).astype(np.float32) 251 db = np.random.random_sample((num_classes)).astype(np.float32) 252 dx = np.random.random_sample((batch_size, image_size)).astype(np.float32) 253 with session_lib.Session() as sess: 254 run_metadata = config_pb2.RunMetadata() 255 output = test_utils.RunWithWarmup( 256 sess, 257 y, { 258 x: dx, 259 w: dw, 260 b: db 261 }, 262 run_metadata=run_metadata, 263 options=config_pb2.RunOptions( 264 trace_level=config_pb2.RunOptions.FULL_TRACE)) 265 266 # TODO(phawkins): really we would like to test that there were exactly 267 # two kernel launches. However, we have no reliable way to determine 268 # that. 269 self.assert_(MetadataHasXlaRunOp(run_metadata)) 270 271 expected = np.square(np.dot(dx, dw) + db) 272 self.assertAllClose(expected, output, rtol=1e-1) 273 274 275class XlaCompilationTest(test.TestCase): 276 """Tests for auto-compilation on CPU/GPU devices.""" 277 278 def testReshape(self): 279 """Tests an operator with compile-time constant and non-constant inputs.""" 280 281 with self.session(config=NoRewriteSessionConfig()) as sess: 282 x = array_ops.placeholder(dtypes.float32) 283 y = array_ops.placeholder(dtypes.int32) 284 with jit_scope(): 285 # Reshape's first argument is non-constant in the JIT, but its second 286 # (shape) argument will be treated as a compile-time constant for 287 # each JIT compilation. 288 # We do not use a tf.const() argument since we want to ensure the 289 # shape is still a run-time argument to the JIT, and not 290 # statically known as part of the JIT compilation's input graph. 291 z = array_ops.reshape(x, y) 292 run_metadata = config_pb2.RunMetadata() 293 out = test_utils.RunWithWarmup( 294 sess, 295 z, { 296 x: np.array([1, 2, 3, 4, 5, 6], np.float32), 297 y: [-1, 3] 298 }, 299 run_metadata=run_metadata, 300 options=config_pb2.RunOptions( 301 trace_level=config_pb2.RunOptions.FULL_TRACE)) 302 self.assert_(MetadataHasXlaRunOp(run_metadata)) 303 self.assertAllClose(np.array([[1, 2, 3], [4, 5, 6]], np.float32), out) 304 305 def testIgnoredArguments(self): 306 """Tests that JIT computations can ignore formal parameters.""" 307 308 with self.session(config=NoRewriteSessionConfig()) as sess: 309 x = array_ops.placeholder(dtypes.int32) 310 y = array_ops.placeholder(dtypes.int32) 311 with jit_scope(): 312 z = math_ops.add(x, x) 313 w = math_ops.add(y, y) 314 # Pulls 'w' into the same compilation via control dependencies. 315 with ops.control_dependencies([w]): 316 n = control_flow_ops.no_op() 317 with ops.control_dependencies([n]): 318 t = math_ops.add(z, z) 319 320 run_metadata = config_pb2.RunMetadata() 321 out = test_utils.RunWithWarmup( 322 sess, 323 t, { 324 x: np.int32(7), 325 y: np.int32(404) 326 }, 327 run_metadata=run_metadata, 328 options=config_pb2.RunOptions( 329 trace_level=config_pb2.RunOptions.FULL_TRACE)) 330 self.assert_(MetadataHasXlaRunOp(run_metadata)) 331 self.assertAllClose(28, out) 332 333 def testLoops(self): 334 """Tests that compilation accepts computations containing loops.""" 335 336 with self.session(config=NoRewriteSessionConfig()) as session: 337 x = array_ops.placeholder(dtypes.float32) 338 with jit_scope(): 339 c = lambda i, _: math_ops.less(i, 5) 340 b = lambda i, x: (i + 1, x * 2.0 + 1.0) 341 _, y = control_flow_ops.while_loop(c, b, (constant_op.constant(0), x)) 342 343 run_metadata = config_pb2.RunMetadata() 344 result = session.run(y, {x: np.float32(2)}, 345 run_metadata=run_metadata, 346 options=config_pb2.RunOptions( 347 trace_level=config_pb2.RunOptions.FULL_TRACE)) 348 self.assert_(MetadataHasXlaRunOp(run_metadata)) 349 self.assertAllClose(result, np.float32(95), rtol=1e-1) 350 351 def testCond(self): 352 """Tests that compilation handles switch operators.""" 353 354 with self.session(config=NoRewriteSessionConfig()) as session: 355 x = array_ops.placeholder(dtypes.float32) 356 y = array_ops.placeholder(dtypes.float32) 357 c = array_ops.placeholder(dtypes.bool) 358 with jit_scope(): 359 z = x + 1.0 360 w = control_flow_ops.cond(c, lambda: z, lambda: y) 361 t = math_ops.add(z, w) 362 363 # If JIT compilation chooses to cluster z and t, then execution will 364 # deadlock. 365 366 run_metadata = config_pb2.RunMetadata() 367 result = test_utils.RunWithWarmup( 368 session, 369 t, { 370 x: np.float32(2), 371 y: np.float32(4), 372 c: True 373 }, 374 run_metadata=run_metadata, 375 options=config_pb2.RunOptions( 376 trace_level=config_pb2.RunOptions.FULL_TRACE)) 377 self.assert_(MetadataHasXlaRunOp(run_metadata)) 378 self.assertAllClose(result, np.float32(6), rtol=1e-1) 379 380 def testNestedFunction(self): 381 g = ops.Graph() 382 with g.as_default(): 383 384 @function.Defun(compiled=True) 385 def Bar(x, y): 386 return x + 2 * y 387 388 @function.Defun(compiled=True) 389 def Foo(x): 390 return Bar(x * x, x * x * x) 391 392 @function.Defun() 393 def Entry(x): 394 return Foo(x) 395 396 inp = array_ops.placeholder(dtypes.float32) 397 out = Entry(inp) 398 399 with self.session( 400 config=NoRewriteSessionConfig(), graph=g, use_gpu=True) as sess: 401 run_metadata = config_pb2.RunMetadata() 402 val = sess.run(out, 403 feed_dict={inp: [2., 10.]}, 404 run_metadata=run_metadata, 405 options=config_pb2.RunOptions( 406 trace_level=config_pb2.RunOptions.FULL_TRACE)) 407 self.assertAllClose(val, [20., 2100.]) 408 409 def testLoopDeadlock(self): 410 """Regression test for bug that caused deadlocks in graphs with loops.""" 411 412 with self.session(config=NoRewriteSessionConfig()) as session: 413 x = array_ops.placeholder(dtypes.float32) 414 with jit_scope(): 415 y = x + 1.0 416 c = lambda i, _x, _y: math_ops.less(i, 5) 417 b = lambda i, x, _y: (i + 1, x * 2.0 + 1.0, x - 3.0) 418 _, _, w = control_flow_ops.while_loop(c, b, 419 (constant_op.constant(0), y, x)) 420 u = w + y 421 result = session.run(u, {x: np.float32(2)}) 422 self.assertAllClose(result, np.float32(63), rtol=1e-1) 423 424 def testGradient(self): 425 """Tests that the backprop function is properly compiled.""" 426 427 def _Run(compiled): 428 429 @function.Defun(compiled=compiled) 430 def Forward(x): 431 return math_ops.log(x) 432 433 g = ops.Graph() 434 with g.as_default(): 435 x = array_ops.placeholder(dtypes.float32) 436 y = Forward(x) 437 dx, = gradients_impl.gradients(y, [x], 1.0) 438 439 cfg = NoRewriteSessionConfig() 440 cfg.graph_options.optimizer_options.opt_level = ( 441 config_pb2.OptimizerOptions.L1) 442 cfg.graph_options.optimizer_options.do_function_inlining = True 443 with session_lib.Session(graph=g, config=cfg) as sess: 444 run_metadata = config_pb2.RunMetadata() 445 dx_val = test_utils.RunWithWarmup( 446 sess, 447 dx, 448 feed_dict={x: 100.}, 449 run_metadata=run_metadata, 450 options=config_pb2.RunOptions( 451 trace_level=config_pb2.RunOptions.FULL_TRACE)) 452 self.assertAllClose(dx_val, 0.01) 453 return RunMetadataLabels(run_metadata) 454 455 # SymGrad[f=log(x)](x, dy) = 1/x * dy 456 # 457 # Note: we don't need to compute log(x) for dx due to graph pruning. 458 459 # Do not compile the backprop. We should see one Reciprocal and one Mul. 460 labels = _Run(compiled=False) 461 self.assertFalse(InLabels(labels, "Log")) 462 self.assertTrue(InLabels(labels, "Reciprocal")) 463 self.assertTrue(InLabels(labels, "Mul")) 464 self.assertFalse(InLabels(labels, "XlaCompile")) 465 self.assertFalse(InLabels(labels, "XlaRun")) 466 467 # Compile the backprop. One XlaCompile/XlaRun pair. 468 labels = _Run(compiled=True) 469 self.assertFalse(InLabels(labels, "Log")) 470 self.assertFalse(InLabels(labels, "Reciprocal")) 471 self.assertFalse(InLabels(labels, "Mul")) 472 self.assertTrue(InLabels(labels, "XlaCompile")) 473 self.assertTrue(InLabels(labels, "XlaRun")) 474 475 476class ElementWiseFusionTest(test.TestCase): 477 478 # Runs a simple test with the input jit_level and fusion_only flag. 479 def simpleTest(self, arg0, arg1, global_jit_level): 480 config = config_pb2.ConfigProto() 481 config.graph_options.optimizer_options.global_jit_level = global_jit_level 482 483 with session_lib.Session(config=config) as sess: 484 a1 = array_ops.placeholder(dtypes.float32, [2, 2], name="a1") 485 a2 = array_ops.placeholder(dtypes.float32, [2, 2], name="a2") 486 # Two element-wise ops. We need at least two ops since single 487 # element clusters are not passed to XLA in fusion_only mode. 488 a3 = a1 * a2 489 a4 = a3 + a1 490 # A matmul to break XLA clustering. 491 a5 = math_ops.matmul(a4, a1) 492 # Two more element-wise ops. 493 a6 = a5 - a4 494 a7 = a6 + a2 495 496 run_metadata = config_pb2.RunMetadata() 497 output = test_utils.RunWithWarmup( 498 sess, 499 a7, { 500 a1: arg0, 501 a2: arg1 502 }, 503 run_metadata=run_metadata, 504 options=config_pb2.RunOptions( 505 trace_level=config_pb2.RunOptions.FULL_TRACE)) 506 507 labels = RunMetadataLabels(run_metadata) 508 509 xla_compile_count = sum("XlaCompile(" in x for x in labels) 510 xla_run_count = sum("XlaRun(" in x for x in labels) 511 self.assertEqual(xla_compile_count, xla_run_count) 512 513 return output, xla_run_count 514 515 516class LazyCompilationTest(test.TestCase): 517 518 def testLazyCompilation(self): 519 520 @function.Defun(compiled=True) 521 def CompiledFunction(x): 522 return math_ops.log(x) 523 524 with session_lib.Session(config=NoRewriteSessionConfig()) as sess: 525 x = array_ops.placeholder(dtypes.float32) 526 y = CompiledFunction(x) 527 528 # The very first run of the cluster is always compiled (non-lazily). 529 run_metadata_for_first_run = config_pb2.RunMetadata() 530 sess.run( 531 y, 532 feed_dict={x: [2., 10., 19., 77., 100.]}, 533 run_metadata=run_metadata_for_first_run, 534 options=config_pb2.RunOptions( 535 trace_level=config_pb2.RunOptions.FULL_TRACE)) 536 self.assertTrue( 537 InLabels( 538 RunMetadataLabels(run_metadata_for_first_run), "_XlaCompile")) 539 self.assertTrue( 540 InLabels(RunMetadataLabels(run_metadata_for_first_run), "_XlaRun")) 541 542 run_metadata_before_warmup = config_pb2.RunMetadata() 543 sess.run( 544 y, 545 feed_dict={x: [2., 10.]}, 546 run_metadata=run_metadata_before_warmup, 547 options=config_pb2.RunOptions( 548 trace_level=config_pb2.RunOptions.FULL_TRACE)) 549 self.assertTrue( 550 InLabels( 551 RunMetadataLabels(run_metadata_before_warmup), "_XlaCompile")) 552 self.assertFalse( 553 InLabels(RunMetadataLabels(run_metadata_before_warmup), "_XlaRun")) 554 555 # We compile when we see the same shape a second time. 556 557 run_metadata_after_warmup = config_pb2.RunMetadata() 558 sess.run( 559 y, 560 feed_dict={x: [2., 10.]}, 561 run_metadata=run_metadata_after_warmup, 562 options=config_pb2.RunOptions( 563 trace_level=config_pb2.RunOptions.FULL_TRACE)) 564 self.assertTrue( 565 InLabels(RunMetadataLabels(run_metadata_after_warmup), "_XlaCompile")) 566 self.assertTrue( 567 InLabels(RunMetadataLabels(run_metadata_after_warmup), "_XlaRun")) 568 569 run_metadata_for_new_shape = config_pb2.RunMetadata() 570 sess.run( 571 y, 572 feed_dict={x: [2., 10., 12.]}, 573 run_metadata=run_metadata_for_new_shape, 574 options=config_pb2.RunOptions( 575 trace_level=config_pb2.RunOptions.FULL_TRACE)) 576 self.assertTrue( 577 InLabels( 578 RunMetadataLabels(run_metadata_for_new_shape), "_XlaCompile")) 579 self.assertFalse( 580 InLabels(RunMetadataLabels(run_metadata_for_new_shape), "_XlaRun")) 581 582 def testIsNotMegamorphic(self): 583 584 @function.Defun(compiled=True) 585 def CompiledFunction(x): 586 return math_ops.log(x) 587 588 with session_lib.Session(config=NoRewriteSessionConfig()) as sess: 589 x = array_ops.placeholder(dtypes.float32) 590 y = CompiledFunction(x) 591 592 # Run the cluster with lots of shape signatures, but in a way that it 593 # isn't megamorphic (i.e. each shape signature sees a lot of executions). 594 # Then check that the cluster has not been marked as megamorphic. 595 596 for shape in range(10, 50): 597 for _ in range(0, 1000): 598 sess.run(y, feed_dict={x: [0.] * shape}) 599 600 for _ in range(0, 10): 601 sess.run(y, feed_dict={x: [0.] * 60}) 602 603 run_metadata = config_pb2.RunMetadata() 604 sess.run( 605 y, 606 feed_dict={x: [0.] * 60}, 607 run_metadata=run_metadata, 608 options=config_pb2.RunOptions( 609 trace_level=config_pb2.RunOptions.FULL_TRACE)) 610 self.assertTrue(InLabels(RunMetadataLabels(run_metadata), "_XlaCompile")) 611 self.assertTrue(InLabels(RunMetadataLabels(run_metadata), "_XlaRun")) 612 613 614if __name__ == "__main__": 615 os.environ["TF_XLA_FLAGS"] = ("--tf_xla_enable_lazy_compilation=true " + 616 os.environ.get("TF_XLA_FLAGS", "")) 617 # This test is using Tensorflow sessions which are not compatible with eager 618 # mode. 619 ops.disable_eager_execution() 620 test.main() 621