1# Owner(s): ["module: mtia"] 2 3import os 4import shutil 5import sys 6import tempfile 7import unittest 8 9import torch 10import torch.testing._internal.common_utils as common 11import torch.utils.cpp_extension 12from torch.testing._internal.common_utils import ( 13 IS_ARM64, 14 IS_LINUX, 15 skipIfTorchDynamo, 16 TEST_CUDA, 17 TEST_PRIVATEUSE1, 18 TEST_XPU, 19) 20from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME 21 22 23# define TEST_ROCM before changing TEST_CUDA 24TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None 25TEST_CUDA = TEST_CUDA and CUDA_HOME is not None 26 27 28def remove_build_path(): 29 if sys.platform == "win32": 30 # Not wiping extensions build folder because Windows 31 return 32 default_build_root = torch.utils.cpp_extension.get_default_build_root() 33 if os.path.exists(default_build_root): 34 shutil.rmtree(default_build_root, ignore_errors=True) 35 36 37@unittest.skipIf( 38 IS_ARM64 or not IS_LINUX or TEST_CUDA or TEST_PRIVATEUSE1 or TEST_ROCM or TEST_XPU, 39 "Only on linux platform and mutual exclusive to other backends", 40) 41@torch.testing._internal.common_utils.markDynamoStrictTest 42class TestCppExtensionMTIABackend(common.TestCase): 43 """Tests MTIA backend with C++ extensions.""" 44 45 module = None 46 47 def setUp(self): 48 super().setUp() 49 # cpp extensions use relative paths. Those paths are relative to 50 # this file, so we'll change the working directory temporarily 51 self.old_working_dir = os.getcwd() 52 os.chdir(os.path.dirname(os.path.abspath(__file__))) 53 54 def tearDown(self): 55 super().tearDown() 56 # return the working directory (see setUp) 57 os.chdir(self.old_working_dir) 58 59 @classmethod 60 def tearDownClass(cls): 61 remove_build_path() 62 63 @classmethod 64 def setUpClass(cls): 65 remove_build_path() 66 build_dir = tempfile.mkdtemp() 67 # Load the fake device guard impl. 68 cls.module = torch.utils.cpp_extension.load( 69 name="mtia_extension", 70 sources=["cpp_extensions/mtia_extension.cpp"], 71 build_directory=build_dir, 72 extra_include_paths=[ 73 "cpp_extensions", 74 "path / with spaces in it", 75 "path with quote'", 76 ], 77 is_python_module=False, 78 verbose=True, 79 ) 80 81 @skipIfTorchDynamo("Not a TorchDynamo suitable test") 82 def test_get_device_module(self): 83 device = torch.device("mtia:0") 84 default_stream = torch.get_device_module(device).current_stream() 85 self.assertEqual( 86 default_stream.device_type, int(torch._C._autograd.DeviceType.MTIA) 87 ) 88 print(torch._C.Stream.__mro__) 89 print(torch.cuda.Stream.__mro__) 90 91 @skipIfTorchDynamo("Not a TorchDynamo suitable test") 92 def test_stream_basic(self): 93 default_stream = torch.mtia.current_stream() 94 user_stream = torch.mtia.Stream() 95 self.assertEqual(torch.mtia.current_stream(), default_stream) 96 self.assertNotEqual(default_stream, user_stream) 97 # Check mtia_extension.cpp, default stream id starts from 0. 98 self.assertEqual(default_stream.stream_id, 0) 99 self.assertNotEqual(user_stream.stream_id, 0) 100 with torch.mtia.stream(user_stream): 101 self.assertEqual(torch.mtia.current_stream(), user_stream) 102 self.assertTrue(user_stream.query()) 103 default_stream.synchronize() 104 self.assertTrue(default_stream.query()) 105 106 @skipIfTorchDynamo("Not a TorchDynamo suitable test") 107 def test_stream_context(self): 108 mtia_stream_0 = torch.mtia.Stream(device="mtia:0") 109 mtia_stream_1 = torch.mtia.Stream(device="mtia:0") 110 print(mtia_stream_0) 111 print(mtia_stream_1) 112 with torch.mtia.stream(mtia_stream_0): 113 current_stream = torch.mtia.current_stream() 114 msg = f"current_stream {current_stream} should be {mtia_stream_0}" 115 self.assertTrue(current_stream == mtia_stream_0, msg=msg) 116 117 with torch.mtia.stream(mtia_stream_1): 118 current_stream = torch.mtia.current_stream() 119 msg = f"current_stream {current_stream} should be {mtia_stream_1}" 120 self.assertTrue(current_stream == mtia_stream_1, msg=msg) 121 122 @skipIfTorchDynamo("Not a TorchDynamo suitable test") 123 def test_stream_context_different_device(self): 124 device_0 = torch.device("mtia:0") 125 device_1 = torch.device("mtia:1") 126 mtia_stream_0 = torch.mtia.Stream(device=device_0) 127 mtia_stream_1 = torch.mtia.Stream(device=device_1) 128 print(mtia_stream_0) 129 print(mtia_stream_1) 130 orig_current_device = torch.mtia.current_device() 131 with torch.mtia.stream(mtia_stream_0): 132 current_stream = torch.mtia.current_stream() 133 self.assertTrue(torch.mtia.current_device() == device_0.index) 134 msg = f"current_stream {current_stream} should be {mtia_stream_0}" 135 self.assertTrue(current_stream == mtia_stream_0, msg=msg) 136 self.assertTrue(torch.mtia.current_device() == orig_current_device) 137 with torch.mtia.stream(mtia_stream_1): 138 current_stream = torch.mtia.current_stream() 139 self.assertTrue(torch.mtia.current_device() == device_1.index) 140 msg = f"current_stream {current_stream} should be {mtia_stream_1}" 141 self.assertTrue(current_stream == mtia_stream_1, msg=msg) 142 self.assertTrue(torch.mtia.current_device() == orig_current_device) 143 144 @skipIfTorchDynamo("Not a TorchDynamo suitable test") 145 def test_device_context(self): 146 device_0 = torch.device("mtia:0") 147 device_1 = torch.device("mtia:1") 148 with torch.mtia.device(device_0): 149 self.assertTrue(torch.mtia.current_device() == device_0.index) 150 151 with torch.mtia.device(device_1): 152 self.assertTrue(torch.mtia.current_device() == device_1.index) 153 154 155if __name__ == "__main__": 156 common.run_tests() 157