1# Owner(s): ["module: inductor"] 2import os 3import shutil 4import sys 5import unittest 6 7import torch 8import torch._dynamo 9import torch.utils.cpp_extension 10from torch._C import FileCheck 11 12try: 13 from extension_backends.cpp.extension_codegen_backend import ( 14 ExtensionCppWrapperCodegen, 15 ExtensionScheduling, 16 ExtensionWrapperCodegen, 17 ) 18except ImportError: 19 from .extension_backends.cpp.extension_codegen_backend import ( 20 ExtensionCppWrapperCodegen, 21 ExtensionScheduling, 22 ExtensionWrapperCodegen, 23 ) 24 25import torch._inductor.config as config 26from torch._inductor import codecache, metrics 27from torch._inductor.codegen import cpp_utils 28from torch._inductor.codegen.common import ( 29 get_scheduling_for_device, 30 get_wrapper_codegen_for_device, 31 register_backend_for_device, 32) 33from torch.testing._internal.common_utils import IS_FBCODE, IS_MACOS 34 35try: 36 try: 37 from . import test_torchinductor 38 except ImportError: 39 import test_torchinductor 40except unittest.SkipTest: 41 if __name__ == "__main__": 42 sys.exit(0) 43 raise 44 45 46run_and_get_cpp_code = test_torchinductor.run_and_get_cpp_code 47TestCase = test_torchinductor.TestCase 48 49 50def remove_build_path(): 51 if sys.platform == "win32": 52 # Not wiping extensions build folder because Windows 53 return 54 default_build_root = torch.utils.cpp_extension.get_default_build_root() 55 if os.path.exists(default_build_root): 56 shutil.rmtree(default_build_root, ignore_errors=True) 57 58 59@unittest.skipIf(IS_FBCODE, "cpp_extension doesn't work in fbcode right now") 60class ExtensionBackendTests(TestCase): 61 module = None 62 63 @classmethod 64 def setUpClass(cls): 65 super().setUpClass() 66 67 # Build Extension 68 remove_build_path() 69 source_file_path = os.path.dirname(os.path.abspath(__file__)) 70 source_file = os.path.join( 71 source_file_path, "extension_backends/cpp/extension_device.cpp" 72 ) 73 cls.module = torch.utils.cpp_extension.load( 74 name="extension_device", 75 sources=[ 76 str(source_file), 77 ], 78 extra_cflags=["-g"], 79 verbose=True, 80 ) 81 82 @classmethod 83 def tearDownClass(cls): 84 cls._stack.close() 85 super().tearDownClass() 86 87 remove_build_path() 88 89 def setUp(self): 90 torch._dynamo.reset() 91 super().setUp() 92 93 # cpp extensions use relative paths. Those paths are relative to 94 # this file, so we'll change the working directory temporarily 95 self.old_working_dir = os.getcwd() 96 os.chdir(os.path.dirname(os.path.abspath(__file__))) 97 assert self.module is not None 98 99 def tearDown(self): 100 super().tearDown() 101 torch._dynamo.reset() 102 103 # return the working directory (see setUp) 104 os.chdir(self.old_working_dir) 105 106 def test_open_device_registration(self): 107 torch.utils.rename_privateuse1_backend("extension_device") 108 torch._register_device_module("extension_device", self.module) 109 110 register_backend_for_device( 111 "extension_device", 112 ExtensionScheduling, 113 ExtensionWrapperCodegen, 114 ExtensionCppWrapperCodegen, 115 ) 116 self.assertTrue( 117 get_scheduling_for_device("extension_device") == ExtensionScheduling 118 ) 119 self.assertTrue( 120 get_wrapper_codegen_for_device("extension_device") 121 == ExtensionWrapperCodegen 122 ) 123 self.assertTrue( 124 get_wrapper_codegen_for_device("extension_device", True) 125 == ExtensionCppWrapperCodegen 126 ) 127 128 self.assertFalse(self.module.custom_op_called()) 129 device = self.module.custom_device() 130 x = torch.empty(2, 16).to(device=device).fill_(1) 131 self.assertTrue(self.module.custom_op_called()) 132 y = torch.empty(2, 16).to(device=device).fill_(2) 133 z = torch.empty(2, 16).to(device=device).fill_(3) 134 ref = torch.empty(2, 16).fill_(5) 135 136 self.assertTrue(x.device == device) 137 self.assertTrue(y.device == device) 138 self.assertTrue(z.device == device) 139 140 def fn(a, b, c): 141 return a * b + c 142 143 cpp_utils.DEVICE_TO_ATEN["extension_device"] = "at::kPrivateUse1" 144 for cpp_wrapper_flag in [True, False]: 145 with config.patch({"cpp_wrapper": cpp_wrapper_flag}): 146 metrics.reset() 147 opt_fn = torch.compile()(fn) 148 _, code = run_and_get_cpp_code(opt_fn, x, y, z) 149 if codecache.valid_vec_isa_list(): 150 load_expr = "loadu" 151 else: 152 load_expr = " = in_ptr0[static_cast<long>(i0)];" 153 FileCheck().check("void").check(load_expr).check( 154 "extension_device" 155 ).run(code) 156 opt_fn(x, y, z) 157 res = opt_fn(x, y, z) 158 self.assertEqual(ref, res.to(device="cpu")) 159 160 161if __name__ == "__main__": 162 from torch._inductor.test_case import run_tests 163 from torch.testing._internal.inductor_utils import HAS_CPU 164 165 # cpp_extension doesn't work in fbcode right now 166 if HAS_CPU and not IS_MACOS and not IS_FBCODE: 167 run_tests(needs="filelock") 168