xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/optests/make_fx.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import torch
4from torch.fx.experimental.proxy_tensor import make_fx
5from torch.testing._utils import wrapper_set_seed
6import torch.utils._pytree as pytree
7
8
9def make_fx_check(
10    func,
11    args,
12    kwargs,
13    tracing_mode,
14    assert_close=torch.testing.assert_close,
15    randomize_data=False,
16):
17    f, *new_args = handle_sizes_for_dynamic_shapes(func, args, kwargs)
18
19    def run(f, *args, **kwargs):
20        return wrapper_set_seed(f, *args, **kwargs)
21
22    traced_f = make_fx(f, tracing_mode=tracing_mode)(*new_args)
23
24    msg = (
25        "op(*args, **kwargs) and make_fx(op)(*args, **kwargs) produced different "
26        "values. This could mean that your abstract impls (meta/FakeTensor impls) "
27        "are incorrect, that your operator is not completely traceable (e.g., "
28        "it relies on some global state), or that there is a bug in make_fx. "
29        "Note that if you passed a python function (and not an operator) to "
30        "make_fx_check, it is still possible that the python function will still "
31        "work with torch.compile because it handles capturing pieces of "
32        "your python code to compile."
33    )
34
35    # Randomize the data and run the traced graph with it, to catch bugs
36    # where we may have baked in Tensor data into the trace.
37    # This is not guaranteed to succeed, because `f` might have preconditions
38    # on the values of the inputs, so we just ignore if we used
39    # random data and it fails.
40    if randomize_data:
41        new_args = randomize(new_args)
42    try:
43        expected = run(f, *new_args)
44    except Exception:
45        if randomize_data:
46            return
47        raise
48    result = run(traced_f, *new_args)
49    assert_close(result, expected, msg=msg)
50
51
52# Arguably we should make make_fx promote torch.Size() objects to symbolic shapes.
53# Absent that, here is our strategy:
54#
55# If any argument is a torch.Size(), maybe get dynamic shapes for it by:
56# - Create a temporary Tensor whose size is the torch.Size() we want. Note that
57#   we use an expanded Tensor as we cannot pass "meta" Tensors to make_fx.
58# - Pass it to make_fx such that it is is converted to a proxy Tensor
59# - Unpack the size in the wrapper to get a torch.Size with dynamic shapes (in
60#   symbolic mode, a no-op otherwise)
61def handle_sizes_for_dynamic_shapes(func, args, kwargs):
62    def f(args, kwargs, extra_args, extra_kwargs):
63        if extra_args:
64            for i, t in extra_args:
65                args[i] = t.size()
66        if extra_kwargs:
67            for k, t in extra_kwargs.items():
68                kwargs[k] = t.size()
69
70        return func(*args, **kwargs)
71
72    extra_args = []
73    extra_kwargs = {}
74    for i, arg in enumerate(args):
75        if isinstance(arg, torch.Size):
76            extra_args.append((i, torch.empty(arg, device="cpu")))
77    for key, value in kwargs.items():
78        if isinstance(value, torch.Size):
79            extra_kwargs[key] = torch.empty(value, device="cpu")
80
81    return f, args, kwargs, extra_args, extra_kwargs
82
83
84def randomize(args):
85    def transform(x):
86        if not x.dtype.is_floating_point:
87            return x
88        return x.detach().clone().uniform_(0, 1).requires_grad_(x.requires_grad)
89    return pytree.tree_map_only(torch.Tensor, transform, args)
90