xref: /aosp_15_r20/external/pytorch/test/dynamo/test_debug_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2
3import unittest
4
5import torch
6from functorch import make_fx
7from torch._dynamo import debug_utils
8from torch._dynamo.debug_utils import aot_graph_input_parser
9from torch._dynamo.test_case import TestCase
10from torch.testing._internal.inductor_utils import HAS_CUDA
11
12
13requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
14
15f32 = torch.float32
16i64 = torch.int64
17i32 = torch.int32
18
19
20class TestDebugUtils(TestCase):
21    def test_cast_model_to_fp64_dtype_args(self):
22        # Test that dtype arguments are converted to fp64
23
24        def fn(x):
25            return (
26                torch.ops.prims.convert_element_type(x, torch.float16),
27                x.to(torch.float16),
28                torch.full(x.shape, 2, dtype=torch.float32, device=x.device),
29                x.new_empty(x.shape),
30            )
31
32        x = torch.randn(32, device="cpu")
33        decomps = torch._decomp.core_aten_decompositions()
34        fx = make_fx(fn, decomposition_table=decomps)(x)
35
36        self.assertExpectedInline(
37            fx.code.lstrip(),
38            """\
39def forward(self, x_1):
40    convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float16)
41    _to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float16);  x_1 = None
42    full = torch.ops.aten.full.default([32], 2, dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
43    empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
44    return (convert_element_type, _to_copy, full, empty)
45    """,  # NOQA: B950
46        )
47
48        fp64_model, fp64_examples = debug_utils.cast_to_fp64(fx, (x,))
49        self.assertEqual(fp64_examples, (x.to(torch.float64),))
50
51        self.assertExpectedInline(
52            fx.code.lstrip(),
53            """\
54def forward(self, x_1):
55    convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float64)
56    _to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float64);  x_1 = None
57    full = torch.ops.aten.full.default([32], 2, dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
58    empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float64, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
59    return (convert_element_type, _to_copy, full, empty)
60    """,  # NOQA: B950
61        )
62
63    @requires_cuda
64    def test_aot_graph_parser(self):
65        from torch import device
66
67        def forward(
68            self,
69            primals_1: "f32[1001, 6]",
70            primals_2: "f32[1001]",
71            primals_3: "f32[1001, 64]",
72            primals_4: "f32[4190]",
73            primals_5: "f32[4190]",
74            primals_6: "f32[1739, 4190]",
75            primals_48: "f32[6144, 4191]",
76        ):
77            _tensor_constant0: "i64[4190]" = self._tensor_constant0
78            lift_fresh_copy: "i64[4190]" = torch.ops.aten.lift_fresh_copy.default(
79                _tensor_constant0
80            )
81            _tensor_constant0 = None
82            index: "f32[6144, 4190]" = torch.ops.aten.index.Tensor(
83                primals_48, [None, lift_fresh_copy]
84            )
85            lift_fresh_copy = None
86
87            _tensor_constant1: "i64[6]" = self._tensor_constant1
88            lift_fresh_copy_1: "i64[6]" = torch.ops.aten.lift_fresh_copy.default(
89                _tensor_constant1
90            )
91            _tensor_constant1 = None
92            index_1: "f32[6144, 6]" = torch.ops.aten.index.Tensor(
93                primals_48, [None, lift_fresh_copy_1]
94            )
95            primals_48 = lift_fresh_copy_1 = None
96            permute: "f32[6, 1001]" = torch.ops.aten.permute.default(primals_1, [1, 0])
97            primals_1 = None
98            addmm: "f32[6144, 1001]" = torch.ops.aten.addmm.default(
99                primals_2, index_1, permute
100            )
101            primals_2 = permute = None
102            amax: "f32[6144, 1]" = torch.ops.aten.amax.default(addmm, [-1], True)
103            sub: "f32[6144, 1001]" = torch.ops.aten.sub.Tensor(addmm, amax)
104            exp: "f32[6144, 1001]" = torch.ops.aten.exp.default(sub)
105            sub = None
106            sum_1: "f32[6144, 1]" = torch.ops.aten.sum.dim_IntList(exp, [-1], True)
107            div: "f32[6144, 1001]" = torch.ops.aten.div.Tensor(exp, sum_1)
108            exp = None
109
110            full_default: "i32[6144, 1001]" = torch.ops.aten.full.default(
111                [6144, 1001],
112                1,
113                dtype=torch.int32,
114                layout=torch.strided,
115                device=device(type="cuda", index=0),
116                pin_memory=False,
117            )
118
119            iota: "i32[1001]" = torch.ops.prims.iota.default(
120                1001,
121                start=0,
122                step=1,
123                dtype=torch.int32,
124                device=device(type="cuda"),
125                requires_grad=False,
126            )
127
128            mul: "i32[6144, 1001]" = torch.ops.aten.mul.Tensor(full_default, iota)
129            full_default = iota = None
130
131            iota_1: "i32[6144]" = torch.ops.prims.iota.default(
132                6144,
133                start=0,
134                step=1001,
135                dtype=torch.int32,
136                device=device(type="cuda", index=0),
137                requires_grad=False,
138            )
139            view: "i32[6150144]" = torch.ops.aten.reshape.default(mul, [-1])
140            mul = None
141            view_1: "f32[6150144]" = torch.ops.aten.reshape.default(div, [-1])
142            div = None
143            _embedding_bag = torch.ops.aten._embedding_bag.default(
144                primals_3, view, iota_1, False, 0, False, view_1
145            )
146
147            return _embedding_bag
148
149        kwargs = aot_graph_input_parser(forward, device="cuda")
150        # runs successfully
151        forward(**kwargs)
152
153    @requires_cuda
154    def test_sym_aot_graph_parser(self):
155        def forward(
156            self,
157            primals_1: "f32[1001, 6]",  # noqa: F821
158            primals_2: "f32[s0]",  # noqa: F821
159            primals_3: "Sym(s0)",  # noqa: F821,
160            primals_4: "f32[s1]",  # noqa: F821,
161            primals_5: "Sym(s1)",  # noqa: F821,
162        ):
163            _tensor_constant0: "i64[4190]" = self._tensor_constant0
164
165        kwargs = aot_graph_input_parser(
166            forward, device="cuda", sym_shapes={"s0": 10}, default_sym_shape=5
167        )
168
169        self.assertEqual(list(kwargs["primals_2"].shape), [10])
170        self.assertEqual(kwargs["primals_3"], 10)
171
172        self.assertEqual(list(kwargs["primals_4"].shape), [5])
173        self.assertEqual(kwargs["primals_5"], 5)
174
175
176if __name__ == "__main__":
177    from torch._dynamo.test_case import run_tests
178
179    run_tests()
180