1# Owner(s): ["module: dynamo"] 2 3import unittest 4 5import torch 6from functorch import make_fx 7from torch._dynamo import debug_utils 8from torch._dynamo.debug_utils import aot_graph_input_parser 9from torch._dynamo.test_case import TestCase 10from torch.testing._internal.inductor_utils import HAS_CUDA 11 12 13requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") 14 15f32 = torch.float32 16i64 = torch.int64 17i32 = torch.int32 18 19 20class TestDebugUtils(TestCase): 21 def test_cast_model_to_fp64_dtype_args(self): 22 # Test that dtype arguments are converted to fp64 23 24 def fn(x): 25 return ( 26 torch.ops.prims.convert_element_type(x, torch.float16), 27 x.to(torch.float16), 28 torch.full(x.shape, 2, dtype=torch.float32, device=x.device), 29 x.new_empty(x.shape), 30 ) 31 32 x = torch.randn(32, device="cpu") 33 decomps = torch._decomp.core_aten_decompositions() 34 fx = make_fx(fn, decomposition_table=decomps)(x) 35 36 self.assertExpectedInline( 37 fx.code.lstrip(), 38 """\ 39def forward(self, x_1): 40 convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float16) 41 _to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float16); x_1 = None 42 full = torch.ops.aten.full.default([32], 2, dtype = torch.float32, device = device(type='cpu'), pin_memory = False) 43 empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) 44 return (convert_element_type, _to_copy, full, empty) 45 """, # NOQA: B950 46 ) 47 48 fp64_model, fp64_examples = debug_utils.cast_to_fp64(fx, (x,)) 49 self.assertEqual(fp64_examples, (x.to(torch.float64),)) 50 51 self.assertExpectedInline( 52 fx.code.lstrip(), 53 """\ 54def forward(self, x_1): 55 convert_element_type = torch.ops.prims.convert_element_type.default(x_1, torch.float64) 56 _to_copy = torch.ops.aten._to_copy.default(x_1, dtype = torch.float64); x_1 = None 57 full = torch.ops.aten.full.default([32], 2, dtype = torch.float64, device = device(type='cpu'), pin_memory = False) 58 empty = torch.ops.aten.empty.memory_format([32], dtype = torch.float64, layout = torch.strided, device = device(type='cpu'), pin_memory = False) 59 return (convert_element_type, _to_copy, full, empty) 60 """, # NOQA: B950 61 ) 62 63 @requires_cuda 64 def test_aot_graph_parser(self): 65 from torch import device 66 67 def forward( 68 self, 69 primals_1: "f32[1001, 6]", 70 primals_2: "f32[1001]", 71 primals_3: "f32[1001, 64]", 72 primals_4: "f32[4190]", 73 primals_5: "f32[4190]", 74 primals_6: "f32[1739, 4190]", 75 primals_48: "f32[6144, 4191]", 76 ): 77 _tensor_constant0: "i64[4190]" = self._tensor_constant0 78 lift_fresh_copy: "i64[4190]" = torch.ops.aten.lift_fresh_copy.default( 79 _tensor_constant0 80 ) 81 _tensor_constant0 = None 82 index: "f32[6144, 4190]" = torch.ops.aten.index.Tensor( 83 primals_48, [None, lift_fresh_copy] 84 ) 85 lift_fresh_copy = None 86 87 _tensor_constant1: "i64[6]" = self._tensor_constant1 88 lift_fresh_copy_1: "i64[6]" = torch.ops.aten.lift_fresh_copy.default( 89 _tensor_constant1 90 ) 91 _tensor_constant1 = None 92 index_1: "f32[6144, 6]" = torch.ops.aten.index.Tensor( 93 primals_48, [None, lift_fresh_copy_1] 94 ) 95 primals_48 = lift_fresh_copy_1 = None 96 permute: "f32[6, 1001]" = torch.ops.aten.permute.default(primals_1, [1, 0]) 97 primals_1 = None 98 addmm: "f32[6144, 1001]" = torch.ops.aten.addmm.default( 99 primals_2, index_1, permute 100 ) 101 primals_2 = permute = None 102 amax: "f32[6144, 1]" = torch.ops.aten.amax.default(addmm, [-1], True) 103 sub: "f32[6144, 1001]" = torch.ops.aten.sub.Tensor(addmm, amax) 104 exp: "f32[6144, 1001]" = torch.ops.aten.exp.default(sub) 105 sub = None 106 sum_1: "f32[6144, 1]" = torch.ops.aten.sum.dim_IntList(exp, [-1], True) 107 div: "f32[6144, 1001]" = torch.ops.aten.div.Tensor(exp, sum_1) 108 exp = None 109 110 full_default: "i32[6144, 1001]" = torch.ops.aten.full.default( 111 [6144, 1001], 112 1, 113 dtype=torch.int32, 114 layout=torch.strided, 115 device=device(type="cuda", index=0), 116 pin_memory=False, 117 ) 118 119 iota: "i32[1001]" = torch.ops.prims.iota.default( 120 1001, 121 start=0, 122 step=1, 123 dtype=torch.int32, 124 device=device(type="cuda"), 125 requires_grad=False, 126 ) 127 128 mul: "i32[6144, 1001]" = torch.ops.aten.mul.Tensor(full_default, iota) 129 full_default = iota = None 130 131 iota_1: "i32[6144]" = torch.ops.prims.iota.default( 132 6144, 133 start=0, 134 step=1001, 135 dtype=torch.int32, 136 device=device(type="cuda", index=0), 137 requires_grad=False, 138 ) 139 view: "i32[6150144]" = torch.ops.aten.reshape.default(mul, [-1]) 140 mul = None 141 view_1: "f32[6150144]" = torch.ops.aten.reshape.default(div, [-1]) 142 div = None 143 _embedding_bag = torch.ops.aten._embedding_bag.default( 144 primals_3, view, iota_1, False, 0, False, view_1 145 ) 146 147 return _embedding_bag 148 149 kwargs = aot_graph_input_parser(forward, device="cuda") 150 # runs successfully 151 forward(**kwargs) 152 153 @requires_cuda 154 def test_sym_aot_graph_parser(self): 155 def forward( 156 self, 157 primals_1: "f32[1001, 6]", # noqa: F821 158 primals_2: "f32[s0]", # noqa: F821 159 primals_3: "Sym(s0)", # noqa: F821, 160 primals_4: "f32[s1]", # noqa: F821, 161 primals_5: "Sym(s1)", # noqa: F821, 162 ): 163 _tensor_constant0: "i64[4190]" = self._tensor_constant0 164 165 kwargs = aot_graph_input_parser( 166 forward, device="cuda", sym_shapes={"s0": 10}, default_sym_shape=5 167 ) 168 169 self.assertEqual(list(kwargs["primals_2"].shape), [10]) 170 self.assertEqual(kwargs["primals_3"], 10) 171 172 self.assertEqual(list(kwargs["primals_4"].shape), [5]) 173 self.assertEqual(kwargs["primals_5"], 5) 174 175 176if __name__ == "__main__": 177 from torch._dynamo.test_case import run_tests 178 179 run_tests() 180