1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 2*da0073e9SAndroid Build Coastguard Workerimport unittest 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo 5*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.test_minifier_common import MinifierTestBase 6*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import skipIfNNModuleInlined 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerrequires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda") 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerclass MinifierTests(MinifierTestBase): 13*da0073e9SAndroid Build Coastguard Worker # Test that compile, runtime, and accuracy errors after dynamo can be repro'd (both CPU and CUDA) 14*da0073e9SAndroid Build Coastguard Worker def _test_after_dynamo(self, device, backend, expected_error): 15*da0073e9SAndroid Build Coastguard Worker run_code = f"""\ 16*da0073e9SAndroid Build Coastguard Worker@torch._dynamo.optimize({backend!r}) 17*da0073e9SAndroid Build Coastguard Workerdef inner(x): 18*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 19*da0073e9SAndroid Build Coastguard Worker x = torch.sin(x) 20*da0073e9SAndroid Build Coastguard Worker x = torch.relu(x) 21*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 22*da0073e9SAndroid Build Coastguard Worker x = torch.cos(x) 23*da0073e9SAndroid Build Coastguard Worker return x 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Workerinner(torch.randn(20, 20).to("{device}")) 26*da0073e9SAndroid Build Coastguard Worker""" 27*da0073e9SAndroid Build Coastguard Worker self._run_full_test(run_code, "dynamo", expected_error, isolate=False) 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker def test_after_dynamo_cpu_compile_error(self): 30*da0073e9SAndroid Build Coastguard Worker self._test_after_dynamo( 31*da0073e9SAndroid Build Coastguard Worker "cpu", "relu_compile_error_TESTING_ONLY", "ReluCompileError" 32*da0073e9SAndroid Build Coastguard Worker ) 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker def test_after_dynamo_cpu_runtime_error(self): 35*da0073e9SAndroid Build Coastguard Worker self._test_after_dynamo( 36*da0073e9SAndroid Build Coastguard Worker "cpu", "relu_runtime_error_TESTING_ONLY", "ReluRuntimeError" 37*da0073e9SAndroid Build Coastguard Worker ) 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker def test_after_dynamo_cpu_accuracy_error(self): 40*da0073e9SAndroid Build Coastguard Worker self._test_after_dynamo( 41*da0073e9SAndroid Build Coastguard Worker "cpu", "relu_accuracy_error_TESTING_ONLY", "AccuracyError" 42*da0073e9SAndroid Build Coastguard Worker ) 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker @requires_cuda 45*da0073e9SAndroid Build Coastguard Worker def test_after_dynamo_cuda_compile_error(self): 46*da0073e9SAndroid Build Coastguard Worker self._test_after_dynamo( 47*da0073e9SAndroid Build Coastguard Worker "cuda", "relu_compile_error_TESTING_ONLY", "ReluCompileError" 48*da0073e9SAndroid Build Coastguard Worker ) 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker @requires_cuda 51*da0073e9SAndroid Build Coastguard Worker def test_after_dynamo_cuda_runtime_error(self): 52*da0073e9SAndroid Build Coastguard Worker self._test_after_dynamo( 53*da0073e9SAndroid Build Coastguard Worker "cuda", "relu_runtime_error_TESTING_ONLY", "ReluRuntimeError" 54*da0073e9SAndroid Build Coastguard Worker ) 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker @requires_cuda 57*da0073e9SAndroid Build Coastguard Worker def test_after_dynamo_cuda_accuracy_error(self): 58*da0073e9SAndroid Build Coastguard Worker self._test_after_dynamo( 59*da0073e9SAndroid Build Coastguard Worker "cuda", "relu_accuracy_error_TESTING_ONLY", "AccuracyError" 60*da0073e9SAndroid Build Coastguard Worker ) 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker def test_after_dynamo_non_leaf_compile_error(self): 63*da0073e9SAndroid Build Coastguard Worker run_code = """\ 64*da0073e9SAndroid Build Coastguard Worker@torch._dynamo.optimize("non_leaf_compile_error_TESTING_ONLY") 65*da0073e9SAndroid Build Coastguard Workerdef inner(x): 66*da0073e9SAndroid Build Coastguard Worker return x + 1 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Workerinner(torch.randn(20, 20, requires_grad=True) + 1) 69*da0073e9SAndroid Build Coastguard Worker""" 70*da0073e9SAndroid Build Coastguard Worker self._run_full_test( 71*da0073e9SAndroid Build Coastguard Worker run_code, "dynamo", "TestingOnlyCompileError", isolate=False 72*da0073e9SAndroid Build Coastguard Worker ) 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker # Ensure that the testing backends pass when relu is not present. 75*da0073e9SAndroid Build Coastguard Worker def _test_after_dynamo_backend_passes(self, device, backend): 76*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize(backend) 77*da0073e9SAndroid Build Coastguard Worker def inner(x): 78*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 79*da0073e9SAndroid Build Coastguard Worker x = torch.sin(x) 80*da0073e9SAndroid Build Coastguard Worker for _ in range(10): 81*da0073e9SAndroid Build Coastguard Worker x = torch.cos(x) 82*da0073e9SAndroid Build Coastguard Worker return x 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker inner(torch.randn(20, 20).to(device)) 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker def test_after_dynamo_cpu_compile_backend_passes(self): 87*da0073e9SAndroid Build Coastguard Worker self._test_after_dynamo_backend_passes("cpu", "relu_compile_error_TESTING_ONLY") 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker def test_after_dynamo_cpu_runtime_backend_passes(self): 90*da0073e9SAndroid Build Coastguard Worker self._test_after_dynamo_backend_passes("cpu", "relu_runtime_error_TESTING_ONLY") 91*da0073e9SAndroid Build Coastguard Worker 92*da0073e9SAndroid Build Coastguard Worker def test_after_dynamo_cpu_accuracy_backend_passes(self): 93*da0073e9SAndroid Build Coastguard Worker self._test_after_dynamo_backend_passes( 94*da0073e9SAndroid Build Coastguard Worker "cpu", "relu_accuracy_error_TESTING_ONLY" 95*da0073e9SAndroid Build Coastguard Worker ) 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker @requires_cuda 98*da0073e9SAndroid Build Coastguard Worker def test_after_dynamo_cuda_compile_backend_passes(self): 99*da0073e9SAndroid Build Coastguard Worker self._test_after_dynamo_backend_passes( 100*da0073e9SAndroid Build Coastguard Worker "cuda", "relu_compile_error_TESTING_ONLY" 101*da0073e9SAndroid Build Coastguard Worker ) 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker @requires_cuda 104*da0073e9SAndroid Build Coastguard Worker def test_after_dynamo_cuda_runtime_backend_passes(self): 105*da0073e9SAndroid Build Coastguard Worker self._test_after_dynamo_backend_passes( 106*da0073e9SAndroid Build Coastguard Worker "cuda", "relu_runtime_error_TESTING_ONLY" 107*da0073e9SAndroid Build Coastguard Worker ) 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker @requires_cuda 110*da0073e9SAndroid Build Coastguard Worker def test_after_dynamo_cuda_accuracy_backend_passes(self): 111*da0073e9SAndroid Build Coastguard Worker self._test_after_dynamo_backend_passes( 112*da0073e9SAndroid Build Coastguard Worker "cuda", "relu_accuracy_error_TESTING_ONLY" 113*da0073e9SAndroid Build Coastguard Worker ) 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker # Test that a module with mixed cpu/cuda parts with an error after dynamo can be repro'd 116*da0073e9SAndroid Build Coastguard Worker @skipIfNNModuleInlined() 117*da0073e9SAndroid Build Coastguard Worker @requires_cuda 118*da0073e9SAndroid Build Coastguard Worker def test_cpu_cuda_module_after_dynamo(self): 119*da0073e9SAndroid Build Coastguard Worker backend_name = "relu_compile_error_TESTING_ONLY" 120*da0073e9SAndroid Build Coastguard Worker run_code = f"""\ 121*da0073e9SAndroid Build Coastguard Workerclass CpuCudaModule(torch.nn.Module): 122*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 123*da0073e9SAndroid Build Coastguard Worker super().__init__() 124*da0073e9SAndroid Build Coastguard Worker self.m_x = torch.nn.Linear(20, 20).cuda() 125*da0073e9SAndroid Build Coastguard Worker self.m_y = torch.nn.Linear(20, 20) 126*da0073e9SAndroid Build Coastguard Worker self.p_x = torch.nn.Parameter(torch.randn(20, 20).cuda()) 127*da0073e9SAndroid Build Coastguard Worker self.p_y = torch.nn.Parameter(torch.randn(20, 20)) 128*da0073e9SAndroid Build Coastguard Worker self.b_x = torch.nn.Buffer(torch.ones(20, 20).cuda()) 129*da0073e9SAndroid Build Coastguard Worker self.b_y = torch.nn.Buffer(torch.ones(20, 20)) 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker def forward(self, x, y): 132*da0073e9SAndroid Build Coastguard Worker return self.m_x(x) + self.p_x + self.b_x, self.m_y(y) + self.p_y + self.b_y 133*da0073e9SAndroid Build Coastguard Worker 134*da0073e9SAndroid Build Coastguard Workermod = CpuCudaModule() 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker@torch._dynamo.optimize({backend_name!r}) 137*da0073e9SAndroid Build Coastguard Workerdef inner(x1, y1): 138*da0073e9SAndroid Build Coastguard Worker x2 = torch.randn(20, 20).cuda() 139*da0073e9SAndroid Build Coastguard Worker y2 = torch.randn(20, 20) 140*da0073e9SAndroid Build Coastguard Worker x3, y3 = mod(x1 + x2, y1 + y2) 141*da0073e9SAndroid Build Coastguard Worker return torch.relu(x3.cpu() + y3) 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Workerinner(torch.randn(20, 20).cuda(), torch.randn(20, 20)) 144*da0073e9SAndroid Build Coastguard Worker""" 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker res = self._run_full_test(run_code, "dynamo", "ReluCompileError", isolate=False) 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 149*da0073e9SAndroid Build Coastguard Worker res.minifier_module(), 150*da0073e9SAndroid Build Coastguard Worker """\ 151*da0073e9SAndroid Build Coastguard Workerclass Repro(torch.nn.Module): 152*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 153*da0073e9SAndroid Build Coastguard Worker super().__init__() 154*da0073e9SAndroid Build Coastguard Worker self.G__mod___m_x = Linear(in_features=20, out_features=20, bias=True).cuda() 155*da0073e9SAndroid Build Coastguard Worker self.G__mod___m_y = Linear(in_features=20, out_features=20, bias=True) 156*da0073e9SAndroid Build Coastguard Worker self.register_buffer('G__mod___b_x', torch.randn([20, 20], dtype=torch.float32).cuda()) 157*da0073e9SAndroid Build Coastguard Worker self.register_buffer('G__mod___b_y', torch.randn([20, 20], dtype=torch.float32)) 158*da0073e9SAndroid Build Coastguard Worker self.G__mod___p_x = torch.nn.Parameter(torch.randn([20, 20], dtype=torch.float32, device="cuda")) 159*da0073e9SAndroid Build Coastguard Worker self.G__mod___p_y = torch.nn.Parameter(torch.randn([20, 20], dtype=torch.float32)) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker def forward(self, L_x1_ : torch.Tensor, L_y1_ : torch.Tensor): 162*da0073e9SAndroid Build Coastguard Worker l_x1_ = L_x1_ 163*da0073e9SAndroid Build Coastguard Worker l_y1_ = L_y1_ 164*da0073e9SAndroid Build Coastguard Worker randn = torch.randn(20, 20) 165*da0073e9SAndroid Build Coastguard Worker x2 = randn.cuda(); randn = None 166*da0073e9SAndroid Build Coastguard Worker y2 = torch.randn(20, 20) 167*da0073e9SAndroid Build Coastguard Worker add = l_x1_ + x2; l_x1_ = x2 = None 168*da0073e9SAndroid Build Coastguard Worker add_1 = l_y1_ + y2; l_y1_ = y2 = None 169*da0073e9SAndroid Build Coastguard Worker g__mod___m_x = self.G__mod___m_x(add); add = None 170*da0073e9SAndroid Build Coastguard Worker g__mod___p_x = self.G__mod___p_x 171*da0073e9SAndroid Build Coastguard Worker add_2 = g__mod___m_x + g__mod___p_x; g__mod___m_x = g__mod___p_x = None 172*da0073e9SAndroid Build Coastguard Worker g__mod___b_x = self.G__mod___b_x 173*da0073e9SAndroid Build Coastguard Worker x3 = add_2 + g__mod___b_x; add_2 = g__mod___b_x = None 174*da0073e9SAndroid Build Coastguard Worker g__mod___m_y = self.G__mod___m_y(add_1); add_1 = None 175*da0073e9SAndroid Build Coastguard Worker g__mod___p_y = self.G__mod___p_y 176*da0073e9SAndroid Build Coastguard Worker add_4 = g__mod___m_y + g__mod___p_y; g__mod___m_y = g__mod___p_y = None 177*da0073e9SAndroid Build Coastguard Worker g__mod___b_y = self.G__mod___b_y 178*da0073e9SAndroid Build Coastguard Worker y3 = add_4 + g__mod___b_y; add_4 = g__mod___b_y = None 179*da0073e9SAndroid Build Coastguard Worker cpu = x3.cpu(); x3 = None 180*da0073e9SAndroid Build Coastguard Worker add_6 = cpu + y3; cpu = y3 = None 181*da0073e9SAndroid Build Coastguard Worker relu = torch.relu(add_6); add_6 = None 182*da0073e9SAndroid Build Coastguard Worker return (relu,)""", 183*da0073e9SAndroid Build Coastguard Worker ) 184*da0073e9SAndroid Build Coastguard Worker 185*da0073e9SAndroid Build Coastguard Worker # Test if we can actually get a minified graph 186*da0073e9SAndroid Build Coastguard Worker def test_if_graph_minified(self): 187*da0073e9SAndroid Build Coastguard Worker backend_name = "relu_compile_error_TESTING_ONLY" 188*da0073e9SAndroid Build Coastguard Worker run_code = f"""\ 189*da0073e9SAndroid Build Coastguard Worker@torch._dynamo.optimize({backend_name!r}) 190*da0073e9SAndroid Build Coastguard Workerdef inner(x): 191*da0073e9SAndroid Build Coastguard Worker for _ in range(20): 192*da0073e9SAndroid Build Coastguard Worker x = torch.sin(x) 193*da0073e9SAndroid Build Coastguard Worker x = torch.relu(x) 194*da0073e9SAndroid Build Coastguard Worker for _ in range(20): 195*da0073e9SAndroid Build Coastguard Worker x = torch.cos(x) 196*da0073e9SAndroid Build Coastguard Worker return x 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Workerinner(torch.randn(20, 20)) 199*da0073e9SAndroid Build Coastguard Worker""" 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker res = self._run_full_test(run_code, "dynamo", "ReluCompileError", isolate=False) 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 204*da0073e9SAndroid Build Coastguard Worker res.repro_module(), 205*da0073e9SAndroid Build Coastguard Worker """\ 206*da0073e9SAndroid Build Coastguard Workerclass Repro(torch.nn.Module): 207*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 208*da0073e9SAndroid Build Coastguard Worker super().__init__() 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker def forward(self, x_19): 211*da0073e9SAndroid Build Coastguard Worker x_20 = torch.relu(x_19); x_19 = None 212*da0073e9SAndroid Build Coastguard Worker return (x_20,)""", 213*da0073e9SAndroid Build Coastguard Worker ) 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 217*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 218*da0073e9SAndroid Build Coastguard Worker 219*da0073e9SAndroid Build Coastguard Worker run_tests() 220