xref: /aosp_15_r20/external/pytorch/torch/_dynamo/testing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import contextlib
3import dis
4import functools
5import logging
6import os.path
7import random
8import re
9import sys
10import types
11import unittest
12from typing import List, Optional, Sequence, Union
13from unittest.mock import patch
14
15import torch
16from torch import fx
17from torch._dynamo.output_graph import OutputGraph
18
19from . import config, eval_frame, optimize_assert, reset
20from .bytecode_transformation import (
21    create_instruction,
22    debug_checks,
23    is_generator,
24    transform_code_object,
25)
26from .guards import CheckFunctionManager, CompileId, GuardedCode
27from .utils import same
28
29
30np: Optional[types.ModuleType] = None
31try:
32    import numpy as np
33except ModuleNotFoundError:
34    np = None
35
36
37unsupported = eval_frame.unsupported
38three = 3
39
40log = logging.getLogger(__name__)
41
42
43def clone_me(x):
44    if x is None:
45        return None
46    return x.detach().clone().requires_grad_(x.requires_grad)
47
48
49def remove_optimized_module_prefix(name) -> str:
50    return re.sub(r"^_orig_mod[.]", "", name)
51
52
53def collect_results(model, prediction, loss, example_inputs):
54    results = []
55    results.append(prediction)
56    results.append(loss)
57    # if isinstance(loss, torch.Tensor) and loss.item() > 1:
58    #     log.warning(
59    #         f"High loss value alert - {loss:.2f}. Can result in unstable gradients."
60    #     )
61
62    grads = {}
63    params = {}
64    for name, param in model.named_parameters():
65        if isinstance(model, eval_frame.OptimizedModule):
66            name = remove_optimized_module_prefix(name)
67        param_copy = param
68        grad = param.grad
69        # Treat None and zero grad as same
70        if param.grad is None:
71            grad = torch.zeros_like(param)
72        grads[name + ".grad"] = grad
73        params[name] = param_copy
74    results.append(grads)
75    results.append(params)
76    buffers = {}
77    for name, buffer in model.named_buffers():
78        if isinstance(model, eval_frame.OptimizedModule):
79            name = remove_optimized_module_prefix(name)
80        buffers[name] = buffer
81    results.append(buffers)
82    for example in example_inputs:
83        if isinstance(example, (tuple, list)):
84            for inp in example:
85                if isinstance(inp, torch.Tensor):
86                    results.append(inp.grad)
87        else:
88            if isinstance(example, torch.Tensor):
89                results.append(example.grad)
90    return results
91
92
93def requires_bwd_pass(out):
94    if isinstance(out, torch.Tensor):
95        return out.requires_grad
96    elif isinstance(out, (list, tuple)):
97        return any(requires_bwd_pass(x) for x in out)
98    elif out is None:
99        return False
100    elif isinstance(out, int):
101        return False
102    raise NotImplementedError("Don't know how to reduce", type(out))
103
104
105def reduce_to_scalar_loss(out):
106    """Reduce the output of a model to get scalar loss"""
107    if isinstance(out, torch.Tensor):
108        # Mean does not work on integer tensors
109        return out.sum() / out.numel()
110    elif isinstance(out, (list, tuple)):
111        return sum(reduce_to_scalar_loss(x) for x in out) / len(out)
112    elif type(out).__name__ in (
113        "MaskedLMOutput",
114        "Seq2SeqLMOutput",
115        "CausalLMOutputWithCrossAttentions",
116    ):
117        return reduce_to_scalar_loss(out.logits)
118    elif type(out).__name__ == "SquashedNormal":
119        return out.mean.sum()
120    elif isinstance(out, dict):
121        return sum(reduce_to_scalar_loss(value) for value in out.values()) / len(
122            out.keys()
123        )
124    raise NotImplementedError("Don't know how to reduce", type(out))
125
126
127def debug_dir() -> str:
128    path = os.path.join(os.path.dirname(__file__), "../debug")
129    if not os.path.exists(path):
130        os.mkdir(path)
131    return path
132
133
134def debug_dump(name, code: types.CodeType, extra="") -> None:
135    with open(os.path.join(debug_dir(), name), "w") as fd:
136        fd.write(
137            f"{dis.Bytecode(code).info()}\n\n{dis.Bytecode(code).dis()}\n\n{extra}\n"
138        )
139
140
141def debug_insert_nops(
142    frame, cache_size, hooks, _, *, skip: int = 0
143) -> Optional[GuardedCode]:
144    """used to debug jump updates"""
145
146    def insert_nops(instructions, code_options):
147        instructions.insert(0, create_instruction("NOP"))
148        instructions.insert(0, create_instruction("NOP"))
149
150    if is_generator(frame.f_code):
151        return None
152
153    debug_checks(frame.f_code)
154    code = transform_code_object(frame.f_code, insert_nops)
155    graph = OutputGraph(
156        code_options={},
157        compiler_fn=None,
158        root_tx=None,
159        export=False,
160        export_constraints=None,
161        frame_state={"_id": 0},
162        # TODO: shouldn't this be f_locals/f_globals from frame?
163        local_scope=locals(),
164        global_scope=globals(),
165        f_code=frame.f_code,
166    )
167
168    return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0))
169
170
171class CompileCounter:
172    def __init__(self):
173        self.frame_count = 0
174        self.op_count = 0
175
176    def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
177        self.frame_count += 1
178        for node in gm.graph.nodes:
179            if "call" in node.op:
180                self.op_count += 1
181        return gm.forward
182
183    def clear(self):
184        self.frame_count = 0
185        self.op_count = 0
186
187
188class CompileCounterWithBackend:
189    def __init__(self, backend):
190        self.frame_count = 0
191        self.op_count = 0
192        self.backend = backend
193        self.graphs = []
194
195    def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
196        from .backends.registry import lookup_backend
197
198        self.frame_count += 1
199        for node in gm.graph.nodes:
200            if "call" in node.op:
201                self.op_count += 1
202        self.graphs.append(gm)
203        return lookup_backend(self.backend)(gm, example_inputs)
204
205
206# Equivalent to backend="eager", but also records graphs that
207# we can assert on
208class EagerAndRecordGraphs:
209    def __init__(self):
210        self.graphs = []
211
212    def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]):
213        self.graphs.append(gm)
214        return gm.forward
215
216
217def strip_comment(code) -> str:
218    code = str(code)
219    return re.sub(r"(?m)^ *#.*\n?", "", code)
220
221
222def remove_trailing_space(code) -> str:
223    return "\n".join([line.rstrip() for line in code.split("\n")])
224
225
226def normalize_gm(gm_str) -> str:
227    # strip comments as comments have path to files which may differ from
228    # system to system.
229    return remove_trailing_space(strip_comment(gm_str))
230
231
232def empty_line_normalizer(code: str) -> str:
233    """
234    Normalize code: remove empty lines.
235    """
236    normal_code = re.sub(r"[\r\n]+", "\n", code)
237    return normal_code
238
239
240def standard_test(
241    self,
242    fn,
243    nargs,
244    expected_ops=None,
245    expected_ops_dynamic=None,
246    expected_frame_count=1,
247):
248    if not config.assume_static_by_default and expected_ops_dynamic is not None:
249        expected_ops = expected_ops_dynamic
250
251    actual = CompileCounter()
252
253    args1 = [torch.randn(10, 10) for _ in range(nargs)]
254    args2 = [torch.randn(10, 10) for _ in range(nargs)]
255    correct1 = fn(*args1)
256    correct2 = fn(*args2)
257    reset()
258    opt_fn = optimize_assert(actual)(fn)
259    val1a = opt_fn(*args1)
260    val2a = opt_fn(*args2)
261    val1b = opt_fn(*args1)
262    val2b = opt_fn(*args2)
263    reset()
264    self.assertTrue(same(val1a, correct1))
265    self.assertTrue(same(val1b, correct1))
266    self.assertTrue(same(val2a, correct2))
267    self.assertTrue(same(val2b, correct2))
268    self.assertEqual(actual.frame_count, expected_frame_count)
269    if expected_ops is not None:
270        self.assertEqual(actual.op_count, expected_ops)
271
272
273def dummy_fx_compile(gm: fx.GraphModule, example_inputs):
274    return gm.forward
275
276
277def format_speedup(speedup, pvalue, is_correct=True, pvalue_threshold=0.1):
278    if not is_correct:
279        return "ERROR"
280    if pvalue > pvalue_threshold:
281        return f"{speedup:.3f}x SAME"
282    return f"{speedup:.3f}x p={pvalue:.2f}"
283
284
285def rand_strided(
286    size: Sequence[int],
287    stride: Sequence[int],
288    dtype: torch.dtype = torch.float32,
289    device: Union[str, torch.device] = "cpu",
290    extra_size: int = 0,
291):
292    needed_size = (
293        sum((shape - 1) * stride for shape, stride in zip(size, stride))
294        + 1
295        + extra_size
296    )
297    if dtype.is_floating_point:
298        if dtype.itemsize == 1:
299            """
300            normal distribution kernel is not implemented for fp8..
301            Workaround that by creating a fp16 tensor and then cast.
302            """
303            buffer = torch.randn(needed_size, dtype=torch.float16, device=device).to(
304                dtype=dtype
305            )
306        else:
307            buffer = torch.randn(needed_size, dtype=dtype, device=device)
308    else:
309        buffer = torch.zeros(size=[needed_size], dtype=dtype, device=device)
310    return torch.as_strided(buffer, size, stride)
311
312
313def _make_fn_with_patches(fn, *patches):
314    @functools.wraps(fn)
315    def _fn(*args, **kwargs):
316        with contextlib.ExitStack() as stack:
317            for module, attr, val in patches:
318                stack.enter_context(patch.object(module, attr, val))
319
320            return fn(*args, **kwargs)
321
322    return _fn
323
324
325def make_test_cls_with_patches(
326    cls, cls_prefix, fn_suffix, *patches, xfail_prop=None, decorator=lambda x: x
327):
328    DummyTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {})
329    DummyTestClass.__qualname__ = DummyTestClass.__name__
330
331    for name in dir(cls):
332        if name.startswith("test_"):
333            fn = getattr(cls, name)
334            if not callable(fn):
335                setattr(DummyTestClass, name, getattr(cls, name))
336                continue
337            new_name = f"{name}{fn_suffix}"
338            new_fn = _make_fn_with_patches(fn, *patches)
339            new_fn.__name__ = new_name
340            if xfail_prop is not None and hasattr(fn, xfail_prop):
341                new_fn = unittest.expectedFailure(new_fn)
342            setattr(DummyTestClass, new_name, decorator(new_fn))
343        # NB: Doesn't handle slots correctly, but whatever
344        elif not hasattr(DummyTestClass, name):
345            setattr(DummyTestClass, name, getattr(cls, name))
346
347    return DummyTestClass
348
349
350# test Python 3.11+ specific features
351def skipIfNotPy311(fn):
352    if sys.version_info >= (3, 11):
353        return fn
354    return unittest.skip(fn)
355
356
357def skipIfNotPy312(fn):
358    if sys.version_info >= (3, 12):
359        return fn
360    return unittest.skip(fn)
361
362
363def xfailIfPy312(fn):
364    if sys.version_info >= (3, 12):
365        return unittest.expectedFailure(fn)
366    return fn
367
368
369def skipIfPy312(fn):
370    if sys.version_info >= (3, 12):
371        return unittest.skip(fn)
372    return fn
373
374
375def requiresPy310(fn):
376    if sys.version_info >= (3, 10):
377        return fn
378    else:
379        unittest.skip(fn)
380
381
382# Controls tests generated in test/inductor/test_torchinductor_dynamic_shapes.py
383# and test/dynamo/test_dynamic_shapes.py
384def expectedFailureDynamic(fn):
385    fn._expected_failure_dynamic = True
386    return fn
387
388
389# Controls tests generated in test/inductor/test_torchinductor_codegen_dynamic_shapes.py
390def expectedFailureCodegenDynamic(fn):
391    fn._expected_failure_codegen_dynamic = True
392    return fn
393
394
395# Controls test generated in test/inductor/test_cpp_wrapper.py
396def expectedFailureDynamicWrapper(fn):
397    fn._expected_failure_dynamic_wrapper = True
398    return fn
399
400
401def reset_rng_state(use_xla=False):
402    torch.manual_seed(1337)
403    random.seed(1337)
404    if np:
405        np.random.seed(1337)
406    if use_xla:
407        import torch_xla.core.xla_model as xm
408
409        xm.set_rng_state(1337, str(xm.xla_device()))
410