xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/jit_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3# Torch
4from torch.autograd import Variable
5from torch.autograd.function import _nested_map
6from torch.jit.annotations import BroadcastingList2, BroadcastingList3  # noqa: F401
7
8from torch.onnx import OperatorExportTypes
9import torch
10import torch.cuda
11import torch.jit
12import torch.jit._logging
13import torch.jit.frontend
14import torch.jit.quantized
15import zipfile
16import functools
17
18# Testing utils
19from torch.testing import FileCheck
20from torch.testing._internal.common_utils import IS_WINDOWS, \
21    freeze_rng_state, enable_profiling_mode_for_profiling_tests, ProfilingMode, TEST_BAILOUTS, \
22    is_iterable_of_tensors
23from torch.testing._internal.common_jit import JitCommonTestCase
24from torch.testing._internal.common_utils import enable_profiling_mode  # noqa: F401
25
26# Standard library
27from contextlib import contextmanager
28from functools import reduce
29from io import StringIO
30from collections import defaultdict
31
32import importlib.util
33import inspect
34import io
35import math
36import os
37import pickle
38import sys
39import tempfile
40import textwrap
41from importlib.abc import Loader
42from typing import Any, Dict, List, Tuple, Union
43
44RUN_CUDA = torch.cuda.is_available()
45RUN_CUDA_MULTI_GPU = RUN_CUDA and torch.cuda.device_count() > 1
46RUN_CUDA_HALF = RUN_CUDA
47# HIP supports half, no version check necessary
48if torch.cuda.is_available() and not torch.version.hip:
49    CUDA_VERSION = torch._C._cuda_getCompiledVersion()
50    for d in range(torch.cuda.device_count()):
51        major = torch.cuda.get_device_capability(d)[0]
52        if (major < 6):
53            RUN_CUDA_HALF = False
54
55def execWrapper(code, glob, loc):
56    exec(code, glob, loc)
57
58def do_input_map(fn, input):
59    return _nested_map(lambda t: isinstance(t, torch.Tensor), fn)(input)
60
61def clear_class_registry():
62    torch._C._jit_clear_class_registry()
63    torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
64    torch.jit._state._clear_class_state()
65
66def get_execution_plan(graph_executor_state):
67    execution_plans = list(graph_executor_state.execution_plans.values())
68    num_plans = len(execution_plans)
69    if num_plans != 1:
70        raise RuntimeError('This test assumes this GraphExecutor should '
71                           f'only have one execution plan, got: {num_plans}')
72    return execution_plans[0]
73
74class _AssertRaisesRegexWithHighlightContext:
75    """
76    A context manager that is useful for checking that error messages highlight
77    the correct part of the source code.
78    """
79
80    def __init__(self, test_case, exception, regex, highlight):
81        self.test_case = test_case
82        self.exception_type = exception
83        self.regex = regex
84        self.highlight = highlight
85
86    def __enter__(self):
87        return self
88
89    def __exit__(self, type, value, traceback):
90        with self.test_case.assertRaisesRegex(self.exception_type, self.regex):
91            if type:
92                raise value
93
94        if self.highlight:
95            FileCheck().check_source_highlighted(self.highlight).run(str(value))
96
97        return True
98
99FUSION_GROUP = "prim::TensorExprGroup"
100
101class JitTestCase(JitCommonTestCase):
102    _do_cuda_memory_leak_check = True
103    _restored_warnings = False
104
105    class capture_stdout(list):
106        """
107        Replace sys.stdout with a temporary StringIO
108        """
109        def __enter__(self):
110            self.sys_stdout = sys.stdout
111            self.stringio = StringIO()
112            sys.stdout = self.stringio
113            return self
114
115        def __exit__(self, *args):
116            self.append(str(self.stringio.getvalue()))
117            del self.stringio
118            sys.stdout = self.sys_stdout
119
120    class capture_stderr(list):
121        """
122        Replace sys.stderr with a temporary StringIO
123        """
124        def __enter__(self):
125            self.sys_stderr = sys.stderr
126            self.stringio = StringIO()
127            sys.stderr = self.stringio
128            return self
129
130        def __exit__(self, *args):
131            self.append(str(self.stringio.getvalue()))
132            del self.stringio
133            sys.stderr = self.sys_stderr
134
135    def setHooks(self):
136        torch._C._jit_set_emit_hooks(self.emitModuleHook, self.emitFunctionHook)
137
138    def clearHooks(self):
139        torch._C._jit_set_emit_hooks(None, None)
140
141    def setUp(self):
142        super().setUp()
143        # unittest overrides all warning filters and forces all of them to show up
144        # after we install our own to silence those coming from inside PyTorch.
145        # This will ensure that our filter still takes precedence.
146        if not JitTestCase._restored_warnings:
147            torch.jit.TracerWarning.ignore_lib_warnings()
148            JitTestCase._restored_warnings = True
149        self.setHooks()
150
151    def tearDown(self):
152        super().tearDown()
153        # needs to be cleared because python might be unloaded before
154        # the callback gets destructed
155        self.clearHooks()
156        clear_class_registry()
157
158    def assertAllFused(self, graph, except_for=()):
159
160        # note this helper collects nodes on 'fast path' only
161        # i.e. the true blocks of specialized checks
162        def get_nodes_and_parents_recursively(block, kind, acc):
163            for node in block.nodes():
164                if node.kind() == kind:
165                    acc[block].append(node)
166                elif node.kind() == 'prim::DifferentiableGraph':
167                    get_nodes_and_parents_recursively(node.g('Subgraph'), kind, acc)
168                elif node.kind() == 'prim::If' and (node.inputs().__next__().node().kind() == 'aten::all' or
169                                                    node.inputs().__next__().node().kind() == 'prim::TypeCheck' or
170                                                    node.inputs().__next__().node().kind() == 'prim::RequiresGradCheck'):
171                    get_nodes_and_parents_recursively(node.blocks().__next__(), kind, acc)
172                else:
173                    for inner_block in node.blocks():
174                        get_nodes_and_parents_recursively(inner_block, kind, acc)
175
176        allowed_nodes = {'prim::Constant', FUSION_GROUP, 'prim::BailoutTemplate',
177                         'prim::TupleConstruct', 'prim::If', 'prim::TypeCheck', 'prim::RequiresGradCheck'} | set(except_for)
178
179        fusion_groups : Dict[torch._C.Block, List[torch._C.Node]] = defaultdict(list)
180        get_nodes_and_parents_recursively(graph, FUSION_GROUP, fusion_groups)
181        self.assertTrue(len(fusion_groups) == 1, f'got {graph}')
182        (graph, fusion_nodes) = next(iter(fusion_groups.items()))
183        # the block contains one FUSION_GROUP and the rest of nodes are `allowed_nodes`
184        self.assertTrue(len(fusion_nodes) == 1, f'got {graph}')
185        self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()),
186                        f'got {graph}')
187
188    def _isHookExceptionOk(self, e):
189        se = str(e)
190        allowed = ("Could not export Python function",
191                   "closures are not exportable")
192        for a in allowed:
193            if a in se:
194                return True
195        return False
196
197    def _compared_saved_loaded(self, m):
198        def extract_files(buffer):
199            # crack open the zip format to get at the main module code
200            archive = zipfile.ZipFile(buffer)
201            # check that we have no duplicate names
202            self.assertEqual(len(set(archive.namelist())), len(archive.namelist()))
203            files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist()))
204            # unwrap all the code files into strings
205            code_files_str = filter(lambda x: x.endswith('.py'), files)
206            code_files_stream = (archive.open(f) for f in code_files_str)
207            code_files = ("".join([line.decode() for line in file]) for file in code_files_stream)
208
209            # unpickled all the debug files
210            debug_files_str = filter(lambda f: f.endswith('.debug_pkl'), files)
211            debug_files_stream = (archive.open(f) for f in debug_files_str)
212            debug_files = (pickle.load(f) for f in debug_files_stream)
213            return code_files, debug_files
214
215        # disable the hook while we parse code, otherwise we will re-enter the hook
216        with torch._jit_internal._disable_emit_hooks():
217            try:
218                # short-circuit if this is an empty function or module
219                if len(m.code) == 0:
220                    return
221                if isinstance(m, torch._C.ScriptModule):
222                    if len(m._method_names()) == 0:
223                        return
224
225                # save the module to a buffer
226                buffer = io.BytesIO()
227                torch.jit.save(m, buffer)
228                # copy the data in the buffer so we can restore it later. This
229                # is because py2 and py3 have different semantics with zipfile
230                # and it's easier to just work with a fresh copy each time.
231                buffer_copy = buffer.getvalue()
232
233                code_files, debug_files = extract_files(buffer)
234
235            except RuntimeError as e:
236                if not self._isHookExceptionOk(e):
237                    raise
238                else:
239                    return
240
241            # import the model again (from a the copy we made of the original)
242            buffer2 = io.BytesIO(buffer_copy)
243            imported = torch.jit.load(buffer2)
244
245            # save it again
246            saved_module_buffer_2 = io.BytesIO()
247            torch.jit.save(imported, saved_module_buffer_2)
248
249            saved_module_buffer_2.seek(0)
250            code_files_2, debug_files_2 = extract_files(saved_module_buffer_2)
251
252            for a, b in zip(code_files, code_files_2):
253                self.assertMultiLineEqual(a, b)
254
255            if isinstance(m, torch._C.ScriptModule):
256                self.assertTrue(torch._C._ivalue_tags_match(m, imported._c))
257
258
259    def emitFunctionHook(self, func):
260        # func has invalid names for export, skip the jitter check
261        if func.name == "<lambda>" or "aten::" in func.name:
262            return
263        self._compared_saved_loaded(func)
264
265    def emitModuleHook(self, module):
266        self._compared_saved_loaded(module)
267
268
269    def getExportImportCopyWithPacking(self, m, also_test_file=True, map_location=None):
270        buffer = io.BytesIO()
271        m.apply(lambda s: s._pack() if s._c._has_method('_pack') else None)
272        torch.jit.save(m, buffer)
273        m.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None)
274        buffer.seek(0)
275        imported = torch.jit.load(buffer, map_location=map_location)
276        imported.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None)
277
278        if not also_test_file:
279            return imported
280
281        # Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
282        # opens the file, and it cannot be opened multiple times in Windows. To support Windows,
283        # close the file after creation and try to remove it manually
284        f = tempfile.NamedTemporaryFile(delete=False)
285        try:
286            f.close()
287            imported.save(f.name)
288            result = torch.jit.load(f.name, map_location=map_location)
289        finally:
290            os.unlink(f.name)
291
292        result.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None)
293        return result
294
295    def assertGraphContains(self, graph, kind, consider_subgraphs=False):
296
297        if consider_subgraphs:
298            strgraph = str(graph)
299            count = strgraph.count(kind) - strgraph.count(f'with {kind}')
300            self.assertTrue(count > 0)
301            return
302
303        def nodes(block):
304            out = []
305            for node in block.nodes():
306                if node.kind() == kind:
307                    out.append(node)
308                for block in node.blocks():
309                    out += nodes(block)
310            return out
311
312        out_nodes = nodes(graph)
313        self.assertTrue(len(out_nodes) > 0)
314
315    def assertGraphContainsExactly(self, graph, kind, num_kind_nodes, consider_subgraphs=False):
316        def perform_assert(graph, kind, actual, expected, consider_subgraphs):
317            if actual == expected:
318                return
319            subgraph = 'including' if consider_subgraphs else 'excluding'
320            raise AssertionError(
321                f'{graph}\nError: graph contains {actual} {kind} nodes ({subgraph} subgraphs) but expected {expected}')
322
323        if consider_subgraphs:
324            strgraph = str(graph)
325            count = strgraph.count(kind) - strgraph.count(f'with {kind}')
326            perform_assert(graph, kind, count, num_kind_nodes,
327                           consider_subgraphs)
328            return
329
330        def nodes(block):
331            out = []
332            for node in block.nodes():
333                if node.kind() == kind:
334                    out.append(node)
335                for block in node.blocks():
336                    out += nodes(block)
337            return out
338
339        out_nodes = nodes(graph)
340        perform_assert(graph, kind, len(out_nodes), num_kind_nodes,
341                       consider_subgraphs)
342
343    def assertExpectedONNXGraph(self, g, *args, **kwargs):
344        g = torch.onnx._optimize_trace(g, operator_export_type=OperatorExportTypes.ONNX)
345        self.assertExpectedGraph(g, *args, **kwargs)
346
347    def assertExpectedGraph(self, trace, *args, **kwargs):
348        if isinstance(trace, torch._C.Graph):
349            graph = trace
350        else:
351            graph = trace.graph()
352
353        torch._C._jit_pass_lint(graph)
354        torch._C._jit_pass_dce(graph)
355        torch._C._jit_pass_lint(graph)
356        graph = torch._C._jit_pass_canonicalize(graph)
357        torch._C._jit_pass_lint(graph)
358        self.assertExpected(str(graph), *args, **kwargs)
359
360    def run_pass(self, name, trace):
361        if isinstance(trace, torch._C.Graph):
362            graph = trace
363            set_graph = False
364        else:
365            set_graph = True
366            graph = trace.graph()
367
368        torch._C._jit_pass_lint(graph)
369        result = getattr(torch._C, '_jit_pass_' + name)(graph)
370        if result is not None and not isinstance(result, bool):
371            graph = result
372        torch._C._jit_pass_lint(graph)
373
374        if set_graph:
375            trace.set_graph(graph)
376        return graph
377
378    def get_frame_vars(self, frames_up):
379        frame = inspect.currentframe()
380        if not frame:
381            raise RuntimeError("failed to inspect frame")
382        i = 0
383        while i < frames_up + 1:
384            frame = frame.f_back
385            if not frame:
386                raise RuntimeError("failed to get frame")
387            i += 1
388        defined_vars: Dict[str, Any] = {}
389        defined_vars.update(frame.f_locals)
390        defined_vars.update(frame.f_globals)
391        return defined_vars
392
393    def assertRaisesRegexWithHighlight(self, exception, regex, highlight):
394        return _AssertRaisesRegexWithHighlightContext(self, exception, regex, highlight)
395
396    def checkScriptRaisesRegex(self, script, inputs, exception, regex,
397                               name=None, outputs=None, capture_output=False,
398                               frames_up=1, profiling=ProfilingMode.PROFILING):
399        """
400        Checks that a given function will throw the correct exception,
401        when executed with normal python, the string frontend, and the
402        AST frontend. Logic taken from `checkScript` (see comments there
403        for details)
404        """
405        with enable_profiling_mode_for_profiling_tests():
406            # Normal Python
407            with self.assertRaisesRegex(exception, regex):
408                if isinstance(script, str):
409                    frame = self.get_frame_vars(frames_up)
410                    the_locals: Dict[str, Any] = {}
411                    execWrapper(script, glob=frame, loc=the_locals)
412                    frame.update(the_locals)
413
414                    python_fn = frame[name]
415                else:
416                    python_fn = script
417
418                python_fn(*inputs)
419
420            # String frontend
421            with self.assertRaisesRegex(exception, regex):
422                if isinstance(script, str):
423                    cu = torch.jit.CompilationUnit(script, _frames_up=frames_up)
424                    string_frontend = getattr(cu, name)
425                else:
426                    source = textwrap.dedent(inspect.getsource(script))
427                    cu = torch.jit.CompilationUnit(source, _frames_up=frames_up)
428                    string_frontend = getattr(cu, script.__name__)
429
430                string_frontend(*inputs)
431
432            # Python AST frontend
433            if not isinstance(script, str):
434                with self.assertRaisesRegex(exception, regex):
435                    ge = torch.jit.script(python_fn)
436                    ge(*inputs)
437
438    def checkBailouts(self, model, inputs, expected):
439        state = model.get_debug_state()
440        plan = get_execution_plan(state)
441        num_bailouts = plan.code.num_bailouts()
442        for i in range(0, num_bailouts):
443            plan.code.request_bailout(i)
444            bailout_outputs = model(*inputs)
445            self.assertEqual(bailout_outputs, expected)
446
447    def checkScript(self,
448                    script,
449                    inputs,
450                    name='func',
451                    optimize=True,
452                    inputs_requires_grad=False,
453                    capture_output=False,
454                    frames_up=1,
455                    profiling=ProfilingMode.PROFILING,
456                    atol=None,
457                    rtol=None):
458        """
459        Checks that a given script generates the same output as the Python
460        version using the given inputs.
461        """
462        with torch.jit.optimized_execution(optimize):
463            with enable_profiling_mode_for_profiling_tests():
464                extra_profile_runs = any(isinstance(x, torch.Tensor) and x.requires_grad for x in inputs)
465                if isinstance(script, str):
466                    # Compile the string to a Script function
467                    # with enable_profiling_mode():
468                    cu = torch.jit.CompilationUnit(script, _frames_up=frames_up)
469
470                    # Execute the Python function so we can run it later and get its
471                    # outputs
472
473                    frame = self.get_frame_vars(frames_up)
474                    the_locals: Dict[str, Any] = {}
475                    execWrapper(script, glob=frame, loc=the_locals)
476                    frame.update(the_locals)
477
478                    python_fn = frame[name]
479                    scripted_fn = getattr(cu, name)
480                else:
481
482                    # Check the string frontend first
483                    source = textwrap.dedent(inspect.getsource(script))
484                    self.checkScript(
485                        source,
486                        inputs,
487                        script.__name__,
488                        optimize=optimize,
489                        inputs_requires_grad=inputs_requires_grad,
490                        capture_output=capture_output,
491                        profiling=profiling,
492                        frames_up=2)
493
494                    # Continue checking the Python frontend
495                    scripted_fn = torch.jit.script(script, _frames_up=1)
496                    python_fn = script
497
498                if inputs_requires_grad:
499                    recording_inputs = do_input_map(lambda t: t.detach().requires_grad_(), inputs)
500                else:
501                    recording_inputs = inputs
502
503                if capture_output:
504                    with self.capture_stdout() as script_stdout:
505                        script_outputs = scripted_fn(*recording_inputs)
506                    with self.capture_stdout() as opt_script_stdout:
507                        opt_script_outputs = scripted_fn(*recording_inputs)
508                    with self.capture_stdout() as _python_stdout:
509                        python_outputs = python_fn(*inputs)
510                    if not IS_WINDOWS:
511                        self.assertExpected(script_stdout[0], subname='stdout')
512                    self.assertEqual(python_outputs, opt_script_outputs, atol=atol, rtol=rtol)
513                else:
514                    # profiling run
515                    script_outputs = scripted_fn(*recording_inputs)
516                    if inputs_requires_grad or extra_profile_runs:
517                        opt_script_outputs = scripted_fn(*recording_inputs)
518                    # optimized run
519                    opt_script_outputs = scripted_fn(*recording_inputs)
520                    if TEST_BAILOUTS:
521                        self.checkBailouts(scripted_fn, inputs, opt_script_outputs)
522                    python_outputs = python_fn(*inputs)
523                self.assertEqual(python_outputs, script_outputs, atol=atol, rtol=rtol)
524                self.assertEqual(script_outputs, opt_script_outputs, atol=atol, rtol=rtol)
525                return scripted_fn
526
527    def checkTrace(self, func, reference_tensors, input_tensors=None,
528                   drop=None, allow_unused=False, verbose=False,
529                   inputs_require_grads=True, check_tolerance=1e-5, export_import=True,
530                   _force_outplace=False, grad_atol=None, grad_rtol=None):
531
532        # TODO: check gradients for parameters, not just inputs
533        def allSum(vs):
534            # drop allows us to remove some values from ever being used
535            # to test unused outputs
536            if drop is not None:
537                vs = vs[:-drop]
538            # we don't want all the grad for all the outputs to be the same
539            # so we multiply each by a constant
540            return sum(math.log(i + 2) * v.sum() for i, v in enumerate(vs) if v is not None)
541        if input_tensors is None:
542            input_tensors = reference_tensors
543
544        def flatten_inputs(inputs):
545            def input_reduce(input, fn, acc):
546                if isinstance(input, torch.Tensor):
547                    fn(input, acc)
548                elif isinstance(input, dict):
549                    reduce(lambda acc, key: input_reduce(input[key], fn, acc), input, acc)
550                else:
551                    reduce(lambda acc, val: input_reduce(val, fn, acc), input, acc)
552                return acc
553            return tuple(input_reduce(recording_inputs, lambda t, acc: acc.append(t), []))
554
555        nograd_inputs = reference_tensors
556        if inputs_require_grads:
557            recording_inputs = do_input_map(lambda t: t.clone().requires_grad_(), reference_tensors)
558            flattened_recording_inputs = flatten_inputs(recording_inputs)
559        else:
560            recording_inputs = reference_tensors
561
562        # `check_trace` is set to False because check_trace is run with @no_grad
563        # Also, `checkTrace` already does all the checks
564        # against python function
565        ge = torch.jit.trace(func, input_tensors, check_tolerance=check_tolerance,
566                             _force_outplace=_force_outplace, check_trace=False)
567
568        if export_import:
569            ge = self.getExportImportCopy(ge)
570
571        if verbose:
572            print(ge.graph)
573
574        # test no gradients case
575        outputs = func(*nograd_inputs)
576        outputs_ge = ge(*nograd_inputs)
577        self.assertEqual(outputs, outputs_ge)
578
579        # test gradients case
580        outputs = func(*recording_inputs)
581        if inputs_require_grads:
582            grads = torch.autograd.grad(allSum(outputs), flattened_recording_inputs,
583                                        allow_unused=allow_unused)
584
585        outputs_ge = ge(*recording_inputs)
586        if inputs_require_grads:
587            grads_ge = torch.autograd.grad(allSum(outputs_ge), flattened_recording_inputs,
588                                           allow_unused=allow_unused)
589        self.assertEqual(outputs, outputs_ge)
590        if inputs_require_grads:
591            self.assertEqual(grads, grads_ge, atol=grad_atol, rtol=grad_rtol)
592
593        # test the grad grad case
594        outputs = func(*recording_inputs)
595        l1 = allSum(outputs)
596        if inputs_require_grads:
597            grads = torch.autograd.grad(l1, flattened_recording_inputs, create_graph=True,
598                                        allow_unused=allow_unused)
599        if inputs_require_grads:
600            l2 = (allSum(grads) * l1)
601            grads2 = torch.autograd.grad(l2, flattened_recording_inputs, allow_unused=allow_unused)
602
603        if inputs_require_grads:
604            recording_inputs = do_input_map(lambda t: Variable(t, requires_grad=True), reference_tensors)
605            flattened_recording_inputs = flatten_inputs(recording_inputs)
606
607        outputs_ge = ge(*recording_inputs)
608        l1_ge = allSum(outputs_ge)
609        if inputs_require_grads:
610            grads_ge = torch.autograd.grad(
611                l1_ge, flattened_recording_inputs, create_graph=True, allow_unused=allow_unused)
612
613        if inputs_require_grads:
614            l2_ge = (allSum(grads_ge) * l1_ge)
615            grads2_ge = torch.autograd.grad(l2_ge, flattened_recording_inputs, allow_unused=allow_unused)
616
617        self.assertEqual(outputs, outputs_ge)
618        if inputs_require_grads:
619            self.assertEqual(grads, grads_ge, atol=grad_atol, rtol=grad_rtol)
620            for g2, g2_ge in zip(grads2, grads2_ge):
621                if g2 is None and g2_ge is None:
622                    continue
623                self.assertEqual(g2, g2_ge, atol=8e-4, rtol=8e-4)
624
625        return ge
626
627    def checkModule(self, nn_module, args):
628        """
629        Check that a nn.Module's results in Script mode match eager and that it
630        can be exported
631        """
632        sm = torch.jit.script(nn_module)
633
634        with freeze_rng_state():
635            eager_out = nn_module(*args)
636
637        with freeze_rng_state():
638            script_out = sm(*args)
639
640        self.assertEqual(eager_out, script_out)
641        self.assertExportImportModule(sm, args)
642
643        return sm
644
645class NoTracerWarnContextManager:
646    def __enter__(self):
647        self.prev = torch._C._jit_get_tracer_state_warn()
648        torch._C._jit_set_tracer_state_warn(False)
649
650    def __exit__(self, *args):
651        torch._C._jit_set_tracer_state_warn(self.prev)
652
653@contextmanager
654def inline_everything_mode(should_inline):
655    old = torch._C._jit_get_inline_everything_mode()
656    torch._C._jit_set_inline_everything_mode(should_inline)
657    try:
658        yield
659    finally:
660        torch._C._jit_set_inline_everything_mode(old)
661
662@contextmanager
663def set_fusion_group_inlining(inlining):
664    old = torch._C._debug_get_fusion_group_inlining()
665    torch._C._debug_set_fusion_group_inlining(inlining)
666    try:
667        yield
668    finally:
669        torch._C._debug_set_fusion_group_inlining(old)
670
671# note: not re-entrant, use unnested only
672@contextmanager
673def disable_autodiff_subgraph_inlining(enabled=True):
674    torch._C._debug_set_autodiff_subgraph_inlining(not enabled)
675    try:
676        yield
677    finally:
678        torch._C._debug_set_autodiff_subgraph_inlining(True)
679
680def _inline_everything(fn):
681    @functools.wraps(fn)
682    def wrapper(*args, **kwargs):
683        with inline_everything_mode(True):
684            fn(*args, **kwargs)
685    return wrapper
686
687# this exists for forward compatibility reasons temporarily.
688# TODO(suo) remove
689def _tmp_donotuse_dont_inline_everything(fn):
690    @functools.wraps(fn)
691    def wrapper(*args, **kwargs):
692        with inline_everything_mode(False):
693            fn(*args, **kwargs)
694    return wrapper
695
696# make it easy to quicky define/trace a function for these tests
697def _trace(*args, **kwargs):
698    def wrapper(func):
699        return torch.jit.trace(func, args, **kwargs)
700    return wrapper
701
702
703def enable_cpu_fuser(fn):
704    def wrapper(*args, **kwargs):
705        torch._C._jit_override_can_fuse_on_cpu_legacy(True)
706        torch._C._jit_override_can_fuse_on_cpu(True)
707        torch._C._jit_set_te_must_use_llvm_cpu(False)
708        try:
709            fn(*args, **kwargs)
710        finally:
711            torch._C._jit_override_can_fuse_on_cpu_legacy(False)
712            torch._C._jit_override_can_fuse_on_cpu(False)
713            torch._C._jit_set_te_must_use_llvm_cpu(True)
714    return wrapper
715
716
717def enable_cpu_fuser_if(cond):
718    if cond:
719        return enable_cpu_fuser
720    else:
721        def noop_fuser(fn):
722            def wrapper(*args, **kwargs):
723                return fn(*args, **kwargs)
724            return wrapper
725        return noop_fuser
726
727def get_forward(c):
728    return c._get_method('forward')
729
730def get_forward_graph(c):
731    return c._get_method('forward').graph
732
733def get_module_method(m, module, method):
734    return m._c.getattr(module)._get_method(method)
735
736def attrs_with_prefix(module, prefix):
737    return [x for x, _ in module._modules._c.items()
738            if x.startswith(prefix)]
739
740def warmup_backward(f, *args):
741    profiling_count = 3
742    results = []
743    for i in range(profiling_count):
744        if len(args) > 0:
745            r = torch.autograd.grad(f, *args)
746            results.append(r)
747        else:
748            f.backward(retain_graph=True)
749
750    return results
751
752# TODO: Remove me once https://bugs.python.org/issue42666 is resolved
753def make_global(*args):
754    for arg in args:
755        setattr(sys.modules[arg.__module__], arg.__name__, arg)
756
757# Helper function to eval Python3 code without causing a syntax error for
758# this file under py2
759def _get_py3_code(code, fn_name):
760    with tempfile.TemporaryDirectory() as tmp_dir:
761        script_path = os.path.join(tmp_dir, 'script.py')
762        with open(script_path, 'w') as f:
763            f.write(code)
764        spec = importlib.util.spec_from_file_location(fn_name, script_path)
765        module = importlib.util.module_from_spec(spec)
766        loader = spec.loader
767        assert isinstance(loader, Loader)  # Assert type to meet MyPy requirement
768        loader.exec_module(module)
769        fn = getattr(module, fn_name)
770        return fn
771
772class TensorExprTestOptions:
773    def __init__(self) -> None:
774        self.old_profiling_executor = torch._C._jit_set_profiling_executor(True)
775        self.old_profiling_mode = torch._C._get_graph_executor_optimize(True)
776
777        self.old_cpu_fuser_state = torch._C._jit_can_fuse_on_cpu()
778        self.old_gpu_fuser_state = torch._C._jit_can_fuse_on_gpu()
779        torch._C._jit_override_can_fuse_on_cpu(True)
780        torch._C._jit_override_can_fuse_on_gpu(True)
781        self.texpr_fuser_state = torch._C._jit_texpr_fuser_enabled()
782        torch._C._jit_set_texpr_fuser_enabled(True)
783        self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
784        torch._C._debug_set_fusion_group_inlining(False)
785        self.old_te_must_use_llvm_cpu = torch._C._jit_get_te_must_use_llvm_cpu()
786        torch._C._jit_set_te_must_use_llvm_cpu(False)
787
788    def restore(self):
789        torch._C._jit_set_profiling_executor(self.old_profiling_executor)
790        torch._C._get_graph_executor_optimize(self.old_profiling_mode)
791
792        torch._C._jit_set_texpr_fuser_enabled(self.texpr_fuser_state)
793        torch._C._jit_override_can_fuse_on_gpu(self.old_gpu_fuser_state)
794        torch._C._jit_override_can_fuse_on_cpu(self.old_cpu_fuser_state)
795        torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
796        torch._C._jit_set_te_must_use_llvm_cpu(self.old_te_must_use_llvm_cpu)
797
798def clone_inputs(args):
799    inputs: List[Union[torch.Tensor, List[torch.Tensor]]] = []
800
801    for arg in args:
802        if isinstance(arg, torch.Tensor):
803            inputs.append(arg.detach().clone())
804        elif is_iterable_of_tensors(arg):
805            inputs.append([t.detach().clone() for t in arg])
806        else:
807            inputs.append(arg)
808
809    return inputs
810
811def get_traced_sample_variant_pairs(device, dtype, op):
812    # tuples of (variant, sample)
813    outputs: List[Tuple[Any, Any]] = []
814
815    samples = op.sample_inputs(device, dtype)
816
817    # Acquires variants to test
818    func = op.get_op()
819    method = op.get_method()
820    variants = {
821        # TODO: inplace tests currently fail, fix and add inplace variant
822        'function': func, 'method': method,
823    }
824
825    # TODO: find better way to standardize on op registration itself..
826    has_fake_function = op.name in ["resize_", 'resize_as_']
827
828    if has_fake_function:
829        variants = {'method': getattr(torch.Tensor, op.name)}
830
831    # In eager mode, these ops can take (Tensor, bool) args; but in
832    # JIT they can only take (Tensor, Scalar), and bool is not a
833    # scalar in the JIT type system. So to test these in JIT, the bool
834    # is converted to an int for the test.
835    ops_with_unsupported_bool_args = [
836        {
837            "name": "div_floor_rounding",
838            "arg_idx": [0],
839        },
840        {
841            "name": "div_no_rounding_mode",
842            "arg_idx": [0],
843        },
844        {
845            "name": "div_trunc_rounding",
846            "arg_idx": [0],
847        },
848        {
849            "name": "index_fill",
850            "arg_idx": [2],
851        },
852        {
853            "name": "full_like",
854            "arg_idx": [0],
855        },
856        {
857            "name": "mul",
858            "arg_idx": [0],
859        },
860        {
861            "name": "new_full",
862            "arg_idx": [1],
863        },
864    ]
865
866    # doesn't support tracing
867    if has_fake_function:
868        return outputs
869
870    for sample in samples:
871        for variant in variants.values():
872            if variant is None:
873                continue
874
875            if is_lambda(variant):
876                continue
877
878            matching_ops = filter(lambda x: op.formatted_name == x["name"], ops_with_unsupported_bool_args)
879            for op_data in matching_ops:
880                for idx in op_data["arg_idx"]:
881                    args = list(sample.args)
882                    if len(sample.args) > idx and isinstance(sample.args[idx], bool):
883                        args[idx] = int(args[idx])
884                    sample.args = tuple(args)
885
886            outputs.append((variant, sample))
887
888    return outputs
889
890# types.LambdaType gave false positives
891def is_lambda(lamb):
892    LAMBDA = lambda: 0  # noqa: E731
893    return isinstance(lamb, type(LAMBDA)) and lamb.__name__ == LAMBDA.__name__
894