1# mypy: ignore-errors 2 3# Torch 4from torch.autograd import Variable 5from torch.autograd.function import _nested_map 6from torch.jit.annotations import BroadcastingList2, BroadcastingList3 # noqa: F401 7 8from torch.onnx import OperatorExportTypes 9import torch 10import torch.cuda 11import torch.jit 12import torch.jit._logging 13import torch.jit.frontend 14import torch.jit.quantized 15import zipfile 16import functools 17 18# Testing utils 19from torch.testing import FileCheck 20from torch.testing._internal.common_utils import IS_WINDOWS, \ 21 freeze_rng_state, enable_profiling_mode_for_profiling_tests, ProfilingMode, TEST_BAILOUTS, \ 22 is_iterable_of_tensors 23from torch.testing._internal.common_jit import JitCommonTestCase 24from torch.testing._internal.common_utils import enable_profiling_mode # noqa: F401 25 26# Standard library 27from contextlib import contextmanager 28from functools import reduce 29from io import StringIO 30from collections import defaultdict 31 32import importlib.util 33import inspect 34import io 35import math 36import os 37import pickle 38import sys 39import tempfile 40import textwrap 41from importlib.abc import Loader 42from typing import Any, Dict, List, Tuple, Union 43 44RUN_CUDA = torch.cuda.is_available() 45RUN_CUDA_MULTI_GPU = RUN_CUDA and torch.cuda.device_count() > 1 46RUN_CUDA_HALF = RUN_CUDA 47# HIP supports half, no version check necessary 48if torch.cuda.is_available() and not torch.version.hip: 49 CUDA_VERSION = torch._C._cuda_getCompiledVersion() 50 for d in range(torch.cuda.device_count()): 51 major = torch.cuda.get_device_capability(d)[0] 52 if (major < 6): 53 RUN_CUDA_HALF = False 54 55def execWrapper(code, glob, loc): 56 exec(code, glob, loc) 57 58def do_input_map(fn, input): 59 return _nested_map(lambda t: isinstance(t, torch.Tensor), fn)(input) 60 61def clear_class_registry(): 62 torch._C._jit_clear_class_registry() 63 torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore() 64 torch.jit._state._clear_class_state() 65 66def get_execution_plan(graph_executor_state): 67 execution_plans = list(graph_executor_state.execution_plans.values()) 68 num_plans = len(execution_plans) 69 if num_plans != 1: 70 raise RuntimeError('This test assumes this GraphExecutor should ' 71 f'only have one execution plan, got: {num_plans}') 72 return execution_plans[0] 73 74class _AssertRaisesRegexWithHighlightContext: 75 """ 76 A context manager that is useful for checking that error messages highlight 77 the correct part of the source code. 78 """ 79 80 def __init__(self, test_case, exception, regex, highlight): 81 self.test_case = test_case 82 self.exception_type = exception 83 self.regex = regex 84 self.highlight = highlight 85 86 def __enter__(self): 87 return self 88 89 def __exit__(self, type, value, traceback): 90 with self.test_case.assertRaisesRegex(self.exception_type, self.regex): 91 if type: 92 raise value 93 94 if self.highlight: 95 FileCheck().check_source_highlighted(self.highlight).run(str(value)) 96 97 return True 98 99FUSION_GROUP = "prim::TensorExprGroup" 100 101class JitTestCase(JitCommonTestCase): 102 _do_cuda_memory_leak_check = True 103 _restored_warnings = False 104 105 class capture_stdout(list): 106 """ 107 Replace sys.stdout with a temporary StringIO 108 """ 109 def __enter__(self): 110 self.sys_stdout = sys.stdout 111 self.stringio = StringIO() 112 sys.stdout = self.stringio 113 return self 114 115 def __exit__(self, *args): 116 self.append(str(self.stringio.getvalue())) 117 del self.stringio 118 sys.stdout = self.sys_stdout 119 120 class capture_stderr(list): 121 """ 122 Replace sys.stderr with a temporary StringIO 123 """ 124 def __enter__(self): 125 self.sys_stderr = sys.stderr 126 self.stringio = StringIO() 127 sys.stderr = self.stringio 128 return self 129 130 def __exit__(self, *args): 131 self.append(str(self.stringio.getvalue())) 132 del self.stringio 133 sys.stderr = self.sys_stderr 134 135 def setHooks(self): 136 torch._C._jit_set_emit_hooks(self.emitModuleHook, self.emitFunctionHook) 137 138 def clearHooks(self): 139 torch._C._jit_set_emit_hooks(None, None) 140 141 def setUp(self): 142 super().setUp() 143 # unittest overrides all warning filters and forces all of them to show up 144 # after we install our own to silence those coming from inside PyTorch. 145 # This will ensure that our filter still takes precedence. 146 if not JitTestCase._restored_warnings: 147 torch.jit.TracerWarning.ignore_lib_warnings() 148 JitTestCase._restored_warnings = True 149 self.setHooks() 150 151 def tearDown(self): 152 super().tearDown() 153 # needs to be cleared because python might be unloaded before 154 # the callback gets destructed 155 self.clearHooks() 156 clear_class_registry() 157 158 def assertAllFused(self, graph, except_for=()): 159 160 # note this helper collects nodes on 'fast path' only 161 # i.e. the true blocks of specialized checks 162 def get_nodes_and_parents_recursively(block, kind, acc): 163 for node in block.nodes(): 164 if node.kind() == kind: 165 acc[block].append(node) 166 elif node.kind() == 'prim::DifferentiableGraph': 167 get_nodes_and_parents_recursively(node.g('Subgraph'), kind, acc) 168 elif node.kind() == 'prim::If' and (node.inputs().__next__().node().kind() == 'aten::all' or 169 node.inputs().__next__().node().kind() == 'prim::TypeCheck' or 170 node.inputs().__next__().node().kind() == 'prim::RequiresGradCheck'): 171 get_nodes_and_parents_recursively(node.blocks().__next__(), kind, acc) 172 else: 173 for inner_block in node.blocks(): 174 get_nodes_and_parents_recursively(inner_block, kind, acc) 175 176 allowed_nodes = {'prim::Constant', FUSION_GROUP, 'prim::BailoutTemplate', 177 'prim::TupleConstruct', 'prim::If', 'prim::TypeCheck', 'prim::RequiresGradCheck'} | set(except_for) 178 179 fusion_groups : Dict[torch._C.Block, List[torch._C.Node]] = defaultdict(list) 180 get_nodes_and_parents_recursively(graph, FUSION_GROUP, fusion_groups) 181 self.assertTrue(len(fusion_groups) == 1, f'got {graph}') 182 (graph, fusion_nodes) = next(iter(fusion_groups.items())) 183 # the block contains one FUSION_GROUP and the rest of nodes are `allowed_nodes` 184 self.assertTrue(len(fusion_nodes) == 1, f'got {graph}') 185 self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()), 186 f'got {graph}') 187 188 def _isHookExceptionOk(self, e): 189 se = str(e) 190 allowed = ("Could not export Python function", 191 "closures are not exportable") 192 for a in allowed: 193 if a in se: 194 return True 195 return False 196 197 def _compared_saved_loaded(self, m): 198 def extract_files(buffer): 199 # crack open the zip format to get at the main module code 200 archive = zipfile.ZipFile(buffer) 201 # check that we have no duplicate names 202 self.assertEqual(len(set(archive.namelist())), len(archive.namelist())) 203 files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist())) 204 # unwrap all the code files into strings 205 code_files_str = filter(lambda x: x.endswith('.py'), files) 206 code_files_stream = (archive.open(f) for f in code_files_str) 207 code_files = ("".join([line.decode() for line in file]) for file in code_files_stream) 208 209 # unpickled all the debug files 210 debug_files_str = filter(lambda f: f.endswith('.debug_pkl'), files) 211 debug_files_stream = (archive.open(f) for f in debug_files_str) 212 debug_files = (pickle.load(f) for f in debug_files_stream) 213 return code_files, debug_files 214 215 # disable the hook while we parse code, otherwise we will re-enter the hook 216 with torch._jit_internal._disable_emit_hooks(): 217 try: 218 # short-circuit if this is an empty function or module 219 if len(m.code) == 0: 220 return 221 if isinstance(m, torch._C.ScriptModule): 222 if len(m._method_names()) == 0: 223 return 224 225 # save the module to a buffer 226 buffer = io.BytesIO() 227 torch.jit.save(m, buffer) 228 # copy the data in the buffer so we can restore it later. This 229 # is because py2 and py3 have different semantics with zipfile 230 # and it's easier to just work with a fresh copy each time. 231 buffer_copy = buffer.getvalue() 232 233 code_files, debug_files = extract_files(buffer) 234 235 except RuntimeError as e: 236 if not self._isHookExceptionOk(e): 237 raise 238 else: 239 return 240 241 # import the model again (from a the copy we made of the original) 242 buffer2 = io.BytesIO(buffer_copy) 243 imported = torch.jit.load(buffer2) 244 245 # save it again 246 saved_module_buffer_2 = io.BytesIO() 247 torch.jit.save(imported, saved_module_buffer_2) 248 249 saved_module_buffer_2.seek(0) 250 code_files_2, debug_files_2 = extract_files(saved_module_buffer_2) 251 252 for a, b in zip(code_files, code_files_2): 253 self.assertMultiLineEqual(a, b) 254 255 if isinstance(m, torch._C.ScriptModule): 256 self.assertTrue(torch._C._ivalue_tags_match(m, imported._c)) 257 258 259 def emitFunctionHook(self, func): 260 # func has invalid names for export, skip the jitter check 261 if func.name == "<lambda>" or "aten::" in func.name: 262 return 263 self._compared_saved_loaded(func) 264 265 def emitModuleHook(self, module): 266 self._compared_saved_loaded(module) 267 268 269 def getExportImportCopyWithPacking(self, m, also_test_file=True, map_location=None): 270 buffer = io.BytesIO() 271 m.apply(lambda s: s._pack() if s._c._has_method('_pack') else None) 272 torch.jit.save(m, buffer) 273 m.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None) 274 buffer.seek(0) 275 imported = torch.jit.load(buffer, map_location=map_location) 276 imported.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None) 277 278 if not also_test_file: 279 return imported 280 281 # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile 282 # opens the file, and it cannot be opened multiple times in Windows. To support Windows, 283 # close the file after creation and try to remove it manually 284 f = tempfile.NamedTemporaryFile(delete=False) 285 try: 286 f.close() 287 imported.save(f.name) 288 result = torch.jit.load(f.name, map_location=map_location) 289 finally: 290 os.unlink(f.name) 291 292 result.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None) 293 return result 294 295 def assertGraphContains(self, graph, kind, consider_subgraphs=False): 296 297 if consider_subgraphs: 298 strgraph = str(graph) 299 count = strgraph.count(kind) - strgraph.count(f'with {kind}') 300 self.assertTrue(count > 0) 301 return 302 303 def nodes(block): 304 out = [] 305 for node in block.nodes(): 306 if node.kind() == kind: 307 out.append(node) 308 for block in node.blocks(): 309 out += nodes(block) 310 return out 311 312 out_nodes = nodes(graph) 313 self.assertTrue(len(out_nodes) > 0) 314 315 def assertGraphContainsExactly(self, graph, kind, num_kind_nodes, consider_subgraphs=False): 316 def perform_assert(graph, kind, actual, expected, consider_subgraphs): 317 if actual == expected: 318 return 319 subgraph = 'including' if consider_subgraphs else 'excluding' 320 raise AssertionError( 321 f'{graph}\nError: graph contains {actual} {kind} nodes ({subgraph} subgraphs) but expected {expected}') 322 323 if consider_subgraphs: 324 strgraph = str(graph) 325 count = strgraph.count(kind) - strgraph.count(f'with {kind}') 326 perform_assert(graph, kind, count, num_kind_nodes, 327 consider_subgraphs) 328 return 329 330 def nodes(block): 331 out = [] 332 for node in block.nodes(): 333 if node.kind() == kind: 334 out.append(node) 335 for block in node.blocks(): 336 out += nodes(block) 337 return out 338 339 out_nodes = nodes(graph) 340 perform_assert(graph, kind, len(out_nodes), num_kind_nodes, 341 consider_subgraphs) 342 343 def assertExpectedONNXGraph(self, g, *args, **kwargs): 344 g = torch.onnx._optimize_trace(g, operator_export_type=OperatorExportTypes.ONNX) 345 self.assertExpectedGraph(g, *args, **kwargs) 346 347 def assertExpectedGraph(self, trace, *args, **kwargs): 348 if isinstance(trace, torch._C.Graph): 349 graph = trace 350 else: 351 graph = trace.graph() 352 353 torch._C._jit_pass_lint(graph) 354 torch._C._jit_pass_dce(graph) 355 torch._C._jit_pass_lint(graph) 356 graph = torch._C._jit_pass_canonicalize(graph) 357 torch._C._jit_pass_lint(graph) 358 self.assertExpected(str(graph), *args, **kwargs) 359 360 def run_pass(self, name, trace): 361 if isinstance(trace, torch._C.Graph): 362 graph = trace 363 set_graph = False 364 else: 365 set_graph = True 366 graph = trace.graph() 367 368 torch._C._jit_pass_lint(graph) 369 result = getattr(torch._C, '_jit_pass_' + name)(graph) 370 if result is not None and not isinstance(result, bool): 371 graph = result 372 torch._C._jit_pass_lint(graph) 373 374 if set_graph: 375 trace.set_graph(graph) 376 return graph 377 378 def get_frame_vars(self, frames_up): 379 frame = inspect.currentframe() 380 if not frame: 381 raise RuntimeError("failed to inspect frame") 382 i = 0 383 while i < frames_up + 1: 384 frame = frame.f_back 385 if not frame: 386 raise RuntimeError("failed to get frame") 387 i += 1 388 defined_vars: Dict[str, Any] = {} 389 defined_vars.update(frame.f_locals) 390 defined_vars.update(frame.f_globals) 391 return defined_vars 392 393 def assertRaisesRegexWithHighlight(self, exception, regex, highlight): 394 return _AssertRaisesRegexWithHighlightContext(self, exception, regex, highlight) 395 396 def checkScriptRaisesRegex(self, script, inputs, exception, regex, 397 name=None, outputs=None, capture_output=False, 398 frames_up=1, profiling=ProfilingMode.PROFILING): 399 """ 400 Checks that a given function will throw the correct exception, 401 when executed with normal python, the string frontend, and the 402 AST frontend. Logic taken from `checkScript` (see comments there 403 for details) 404 """ 405 with enable_profiling_mode_for_profiling_tests(): 406 # Normal Python 407 with self.assertRaisesRegex(exception, regex): 408 if isinstance(script, str): 409 frame = self.get_frame_vars(frames_up) 410 the_locals: Dict[str, Any] = {} 411 execWrapper(script, glob=frame, loc=the_locals) 412 frame.update(the_locals) 413 414 python_fn = frame[name] 415 else: 416 python_fn = script 417 418 python_fn(*inputs) 419 420 # String frontend 421 with self.assertRaisesRegex(exception, regex): 422 if isinstance(script, str): 423 cu = torch.jit.CompilationUnit(script, _frames_up=frames_up) 424 string_frontend = getattr(cu, name) 425 else: 426 source = textwrap.dedent(inspect.getsource(script)) 427 cu = torch.jit.CompilationUnit(source, _frames_up=frames_up) 428 string_frontend = getattr(cu, script.__name__) 429 430 string_frontend(*inputs) 431 432 # Python AST frontend 433 if not isinstance(script, str): 434 with self.assertRaisesRegex(exception, regex): 435 ge = torch.jit.script(python_fn) 436 ge(*inputs) 437 438 def checkBailouts(self, model, inputs, expected): 439 state = model.get_debug_state() 440 plan = get_execution_plan(state) 441 num_bailouts = plan.code.num_bailouts() 442 for i in range(0, num_bailouts): 443 plan.code.request_bailout(i) 444 bailout_outputs = model(*inputs) 445 self.assertEqual(bailout_outputs, expected) 446 447 def checkScript(self, 448 script, 449 inputs, 450 name='func', 451 optimize=True, 452 inputs_requires_grad=False, 453 capture_output=False, 454 frames_up=1, 455 profiling=ProfilingMode.PROFILING, 456 atol=None, 457 rtol=None): 458 """ 459 Checks that a given script generates the same output as the Python 460 version using the given inputs. 461 """ 462 with torch.jit.optimized_execution(optimize): 463 with enable_profiling_mode_for_profiling_tests(): 464 extra_profile_runs = any(isinstance(x, torch.Tensor) and x.requires_grad for x in inputs) 465 if isinstance(script, str): 466 # Compile the string to a Script function 467 # with enable_profiling_mode(): 468 cu = torch.jit.CompilationUnit(script, _frames_up=frames_up) 469 470 # Execute the Python function so we can run it later and get its 471 # outputs 472 473 frame = self.get_frame_vars(frames_up) 474 the_locals: Dict[str, Any] = {} 475 execWrapper(script, glob=frame, loc=the_locals) 476 frame.update(the_locals) 477 478 python_fn = frame[name] 479 scripted_fn = getattr(cu, name) 480 else: 481 482 # Check the string frontend first 483 source = textwrap.dedent(inspect.getsource(script)) 484 self.checkScript( 485 source, 486 inputs, 487 script.__name__, 488 optimize=optimize, 489 inputs_requires_grad=inputs_requires_grad, 490 capture_output=capture_output, 491 profiling=profiling, 492 frames_up=2) 493 494 # Continue checking the Python frontend 495 scripted_fn = torch.jit.script(script, _frames_up=1) 496 python_fn = script 497 498 if inputs_requires_grad: 499 recording_inputs = do_input_map(lambda t: t.detach().requires_grad_(), inputs) 500 else: 501 recording_inputs = inputs 502 503 if capture_output: 504 with self.capture_stdout() as script_stdout: 505 script_outputs = scripted_fn(*recording_inputs) 506 with self.capture_stdout() as opt_script_stdout: 507 opt_script_outputs = scripted_fn(*recording_inputs) 508 with self.capture_stdout() as _python_stdout: 509 python_outputs = python_fn(*inputs) 510 if not IS_WINDOWS: 511 self.assertExpected(script_stdout[0], subname='stdout') 512 self.assertEqual(python_outputs, opt_script_outputs, atol=atol, rtol=rtol) 513 else: 514 # profiling run 515 script_outputs = scripted_fn(*recording_inputs) 516 if inputs_requires_grad or extra_profile_runs: 517 opt_script_outputs = scripted_fn(*recording_inputs) 518 # optimized run 519 opt_script_outputs = scripted_fn(*recording_inputs) 520 if TEST_BAILOUTS: 521 self.checkBailouts(scripted_fn, inputs, opt_script_outputs) 522 python_outputs = python_fn(*inputs) 523 self.assertEqual(python_outputs, script_outputs, atol=atol, rtol=rtol) 524 self.assertEqual(script_outputs, opt_script_outputs, atol=atol, rtol=rtol) 525 return scripted_fn 526 527 def checkTrace(self, func, reference_tensors, input_tensors=None, 528 drop=None, allow_unused=False, verbose=False, 529 inputs_require_grads=True, check_tolerance=1e-5, export_import=True, 530 _force_outplace=False, grad_atol=None, grad_rtol=None): 531 532 # TODO: check gradients for parameters, not just inputs 533 def allSum(vs): 534 # drop allows us to remove some values from ever being used 535 # to test unused outputs 536 if drop is not None: 537 vs = vs[:-drop] 538 # we don't want all the grad for all the outputs to be the same 539 # so we multiply each by a constant 540 return sum(math.log(i + 2) * v.sum() for i, v in enumerate(vs) if v is not None) 541 if input_tensors is None: 542 input_tensors = reference_tensors 543 544 def flatten_inputs(inputs): 545 def input_reduce(input, fn, acc): 546 if isinstance(input, torch.Tensor): 547 fn(input, acc) 548 elif isinstance(input, dict): 549 reduce(lambda acc, key: input_reduce(input[key], fn, acc), input, acc) 550 else: 551 reduce(lambda acc, val: input_reduce(val, fn, acc), input, acc) 552 return acc 553 return tuple(input_reduce(recording_inputs, lambda t, acc: acc.append(t), [])) 554 555 nograd_inputs = reference_tensors 556 if inputs_require_grads: 557 recording_inputs = do_input_map(lambda t: t.clone().requires_grad_(), reference_tensors) 558 flattened_recording_inputs = flatten_inputs(recording_inputs) 559 else: 560 recording_inputs = reference_tensors 561 562 # `check_trace` is set to False because check_trace is run with @no_grad 563 # Also, `checkTrace` already does all the checks 564 # against python function 565 ge = torch.jit.trace(func, input_tensors, check_tolerance=check_tolerance, 566 _force_outplace=_force_outplace, check_trace=False) 567 568 if export_import: 569 ge = self.getExportImportCopy(ge) 570 571 if verbose: 572 print(ge.graph) 573 574 # test no gradients case 575 outputs = func(*nograd_inputs) 576 outputs_ge = ge(*nograd_inputs) 577 self.assertEqual(outputs, outputs_ge) 578 579 # test gradients case 580 outputs = func(*recording_inputs) 581 if inputs_require_grads: 582 grads = torch.autograd.grad(allSum(outputs), flattened_recording_inputs, 583 allow_unused=allow_unused) 584 585 outputs_ge = ge(*recording_inputs) 586 if inputs_require_grads: 587 grads_ge = torch.autograd.grad(allSum(outputs_ge), flattened_recording_inputs, 588 allow_unused=allow_unused) 589 self.assertEqual(outputs, outputs_ge) 590 if inputs_require_grads: 591 self.assertEqual(grads, grads_ge, atol=grad_atol, rtol=grad_rtol) 592 593 # test the grad grad case 594 outputs = func(*recording_inputs) 595 l1 = allSum(outputs) 596 if inputs_require_grads: 597 grads = torch.autograd.grad(l1, flattened_recording_inputs, create_graph=True, 598 allow_unused=allow_unused) 599 if inputs_require_grads: 600 l2 = (allSum(grads) * l1) 601 grads2 = torch.autograd.grad(l2, flattened_recording_inputs, allow_unused=allow_unused) 602 603 if inputs_require_grads: 604 recording_inputs = do_input_map(lambda t: Variable(t, requires_grad=True), reference_tensors) 605 flattened_recording_inputs = flatten_inputs(recording_inputs) 606 607 outputs_ge = ge(*recording_inputs) 608 l1_ge = allSum(outputs_ge) 609 if inputs_require_grads: 610 grads_ge = torch.autograd.grad( 611 l1_ge, flattened_recording_inputs, create_graph=True, allow_unused=allow_unused) 612 613 if inputs_require_grads: 614 l2_ge = (allSum(grads_ge) * l1_ge) 615 grads2_ge = torch.autograd.grad(l2_ge, flattened_recording_inputs, allow_unused=allow_unused) 616 617 self.assertEqual(outputs, outputs_ge) 618 if inputs_require_grads: 619 self.assertEqual(grads, grads_ge, atol=grad_atol, rtol=grad_rtol) 620 for g2, g2_ge in zip(grads2, grads2_ge): 621 if g2 is None and g2_ge is None: 622 continue 623 self.assertEqual(g2, g2_ge, atol=8e-4, rtol=8e-4) 624 625 return ge 626 627 def checkModule(self, nn_module, args): 628 """ 629 Check that a nn.Module's results in Script mode match eager and that it 630 can be exported 631 """ 632 sm = torch.jit.script(nn_module) 633 634 with freeze_rng_state(): 635 eager_out = nn_module(*args) 636 637 with freeze_rng_state(): 638 script_out = sm(*args) 639 640 self.assertEqual(eager_out, script_out) 641 self.assertExportImportModule(sm, args) 642 643 return sm 644 645class NoTracerWarnContextManager: 646 def __enter__(self): 647 self.prev = torch._C._jit_get_tracer_state_warn() 648 torch._C._jit_set_tracer_state_warn(False) 649 650 def __exit__(self, *args): 651 torch._C._jit_set_tracer_state_warn(self.prev) 652 653@contextmanager 654def inline_everything_mode(should_inline): 655 old = torch._C._jit_get_inline_everything_mode() 656 torch._C._jit_set_inline_everything_mode(should_inline) 657 try: 658 yield 659 finally: 660 torch._C._jit_set_inline_everything_mode(old) 661 662@contextmanager 663def set_fusion_group_inlining(inlining): 664 old = torch._C._debug_get_fusion_group_inlining() 665 torch._C._debug_set_fusion_group_inlining(inlining) 666 try: 667 yield 668 finally: 669 torch._C._debug_set_fusion_group_inlining(old) 670 671# note: not re-entrant, use unnested only 672@contextmanager 673def disable_autodiff_subgraph_inlining(enabled=True): 674 torch._C._debug_set_autodiff_subgraph_inlining(not enabled) 675 try: 676 yield 677 finally: 678 torch._C._debug_set_autodiff_subgraph_inlining(True) 679 680def _inline_everything(fn): 681 @functools.wraps(fn) 682 def wrapper(*args, **kwargs): 683 with inline_everything_mode(True): 684 fn(*args, **kwargs) 685 return wrapper 686 687# this exists for forward compatibility reasons temporarily. 688# TODO(suo) remove 689def _tmp_donotuse_dont_inline_everything(fn): 690 @functools.wraps(fn) 691 def wrapper(*args, **kwargs): 692 with inline_everything_mode(False): 693 fn(*args, **kwargs) 694 return wrapper 695 696# make it easy to quicky define/trace a function for these tests 697def _trace(*args, **kwargs): 698 def wrapper(func): 699 return torch.jit.trace(func, args, **kwargs) 700 return wrapper 701 702 703def enable_cpu_fuser(fn): 704 def wrapper(*args, **kwargs): 705 torch._C._jit_override_can_fuse_on_cpu_legacy(True) 706 torch._C._jit_override_can_fuse_on_cpu(True) 707 torch._C._jit_set_te_must_use_llvm_cpu(False) 708 try: 709 fn(*args, **kwargs) 710 finally: 711 torch._C._jit_override_can_fuse_on_cpu_legacy(False) 712 torch._C._jit_override_can_fuse_on_cpu(False) 713 torch._C._jit_set_te_must_use_llvm_cpu(True) 714 return wrapper 715 716 717def enable_cpu_fuser_if(cond): 718 if cond: 719 return enable_cpu_fuser 720 else: 721 def noop_fuser(fn): 722 def wrapper(*args, **kwargs): 723 return fn(*args, **kwargs) 724 return wrapper 725 return noop_fuser 726 727def get_forward(c): 728 return c._get_method('forward') 729 730def get_forward_graph(c): 731 return c._get_method('forward').graph 732 733def get_module_method(m, module, method): 734 return m._c.getattr(module)._get_method(method) 735 736def attrs_with_prefix(module, prefix): 737 return [x for x, _ in module._modules._c.items() 738 if x.startswith(prefix)] 739 740def warmup_backward(f, *args): 741 profiling_count = 3 742 results = [] 743 for i in range(profiling_count): 744 if len(args) > 0: 745 r = torch.autograd.grad(f, *args) 746 results.append(r) 747 else: 748 f.backward(retain_graph=True) 749 750 return results 751 752# TODO: Remove me once https://bugs.python.org/issue42666 is resolved 753def make_global(*args): 754 for arg in args: 755 setattr(sys.modules[arg.__module__], arg.__name__, arg) 756 757# Helper function to eval Python3 code without causing a syntax error for 758# this file under py2 759def _get_py3_code(code, fn_name): 760 with tempfile.TemporaryDirectory() as tmp_dir: 761 script_path = os.path.join(tmp_dir, 'script.py') 762 with open(script_path, 'w') as f: 763 f.write(code) 764 spec = importlib.util.spec_from_file_location(fn_name, script_path) 765 module = importlib.util.module_from_spec(spec) 766 loader = spec.loader 767 assert isinstance(loader, Loader) # Assert type to meet MyPy requirement 768 loader.exec_module(module) 769 fn = getattr(module, fn_name) 770 return fn 771 772class TensorExprTestOptions: 773 def __init__(self) -> None: 774 self.old_profiling_executor = torch._C._jit_set_profiling_executor(True) 775 self.old_profiling_mode = torch._C._get_graph_executor_optimize(True) 776 777 self.old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu() 778 self.old_gpu_fuser_state = torch._C._jit_can_fuse_on_gpu() 779 torch._C._jit_override_can_fuse_on_cpu(True) 780 torch._C._jit_override_can_fuse_on_gpu(True) 781 self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled() 782 torch._C._jit_set_texpr_fuser_enabled(True) 783 self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining() 784 torch._C._debug_set_fusion_group_inlining(False) 785 self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu() 786 torch._C._jit_set_te_must_use_llvm_cpu(False) 787 788 def restore(self): 789 torch._C._jit_set_profiling_executor(self.old_profiling_executor) 790 torch._C._get_graph_executor_optimize(self.old_profiling_mode) 791 792 torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state) 793 torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state) 794 torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state) 795 torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining) 796 torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu) 797 798def clone_inputs(args): 799 inputs: List[Union[torch.Tensor, List[torch.Tensor]]] = [] 800 801 for arg in args: 802 if isinstance(arg, torch.Tensor): 803 inputs.append(arg.detach().clone()) 804 elif is_iterable_of_tensors(arg): 805 inputs.append([t.detach().clone() for t in arg]) 806 else: 807 inputs.append(arg) 808 809 return inputs 810 811def get_traced_sample_variant_pairs(device, dtype, op): 812 # tuples of (variant, sample) 813 outputs: List[Tuple[Any, Any]] = [] 814 815 samples = op.sample_inputs(device, dtype) 816 817 # Acquires variants to test 818 func = op.get_op() 819 method = op.get_method() 820 variants = { 821 # TODO: inplace tests currently fail, fix and add inplace variant 822 'function': func, 'method': method, 823 } 824 825 # TODO: find better way to standardize on op registration itself.. 826 has_fake_function = op.name in ["resize_", 'resize_as_'] 827 828 if has_fake_function: 829 variants = {'method': getattr(torch.Tensor, op.name)} 830 831 # In eager mode, these ops can take (Tensor, bool) args; but in 832 # JIT they can only take (Tensor, Scalar), and bool is not a 833 # scalar in the JIT type system. So to test these in JIT, the bool 834 # is converted to an int for the test. 835 ops_with_unsupported_bool_args = [ 836 { 837 "name": "div_floor_rounding", 838 "arg_idx": [0], 839 }, 840 { 841 "name": "div_no_rounding_mode", 842 "arg_idx": [0], 843 }, 844 { 845 "name": "div_trunc_rounding", 846 "arg_idx": [0], 847 }, 848 { 849 "name": "index_fill", 850 "arg_idx": [2], 851 }, 852 { 853 "name": "full_like", 854 "arg_idx": [0], 855 }, 856 { 857 "name": "mul", 858 "arg_idx": [0], 859 }, 860 { 861 "name": "new_full", 862 "arg_idx": [1], 863 }, 864 ] 865 866 # doesn't support tracing 867 if has_fake_function: 868 return outputs 869 870 for sample in samples: 871 for variant in variants.values(): 872 if variant is None: 873 continue 874 875 if is_lambda(variant): 876 continue 877 878 matching_ops = filter(lambda x: op.formatted_name == x["name"], ops_with_unsupported_bool_args) 879 for op_data in matching_ops: 880 for idx in op_data["arg_idx"]: 881 args = list(sample.args) 882 if len(sample.args) > idx and isinstance(sample.args[idx], bool): 883 args[idx] = int(args[idx]) 884 sample.args = tuple(args) 885 886 outputs.append((variant, sample)) 887 888 return outputs 889 890# types.LambdaType gave false positives 891def is_lambda(lamb): 892 LAMBDA = lambda: 0 # noqa: E731 893 return isinstance(lamb, type(LAMBDA)) and lamb.__name__ == LAMBDA.__name__ 894