1# Owner(s): ["module: cpp-extensions"] 2 3import os 4import re 5import unittest 6from itertools import repeat 7from typing import get_args, get_origin, Union 8 9import torch 10import torch.backends.cudnn 11import torch.testing._internal.common_utils as common 12import torch.utils.cpp_extension 13from torch.testing._internal.common_cuda import TEST_CUDA 14from torch.testing._internal.common_utils import IS_WINDOWS, skipIfTorchDynamo 15 16 17try: 18 import pytest 19 20 HAS_PYTEST = True 21except ImportError as e: 22 HAS_PYTEST = False 23 24# TODO: Rewrite these tests so that they can be collected via pytest without 25# using run_test.py 26try: 27 if HAS_PYTEST: 28 cpp_extension = pytest.importorskip("torch_test_cpp_extension.cpp") 29 maia_extension = pytest.importorskip("torch_test_cpp_extension.maia") 30 rng_extension = pytest.importorskip("torch_test_cpp_extension.rng") 31 else: 32 import torch_test_cpp_extension.cpp as cpp_extension 33 import torch_test_cpp_extension.maia as maia_extension 34 import torch_test_cpp_extension.rng as rng_extension 35except ImportError as e: 36 raise RuntimeError( 37 "test_cpp_extensions_aot.py cannot be invoked directly. Run " 38 "`python run_test.py -i test_cpp_extensions_aot_ninja` instead." 39 ) from e 40 41 42@torch.testing._internal.common_utils.markDynamoStrictTest 43class TestCppExtensionAOT(common.TestCase): 44 """Tests ahead-of-time cpp extensions 45 46 NOTE: run_test.py's test_cpp_extensions_aot_ninja target 47 also runs this test case, but with ninja enabled. If you are debugging 48 a test failure here from the CI, check the logs for which target 49 (test_cpp_extensions_aot_no_ninja vs test_cpp_extensions_aot_ninja) 50 failed. 51 """ 52 53 def test_extension_function(self): 54 x = torch.randn(4, 4) 55 y = torch.randn(4, 4) 56 z = cpp_extension.sigmoid_add(x, y) 57 self.assertEqual(z, x.sigmoid() + y.sigmoid()) 58 # test pybind support torch.dtype cast. 59 self.assertEqual( 60 str(torch.float32), str(cpp_extension.get_math_type(torch.half)) 61 ) 62 63 def test_extension_module(self): 64 mm = cpp_extension.MatrixMultiplier(4, 8) 65 weights = torch.rand(8, 4, dtype=torch.double) 66 expected = mm.get().mm(weights) 67 result = mm.forward(weights) 68 self.assertEqual(expected, result) 69 70 def test_backward(self): 71 mm = cpp_extension.MatrixMultiplier(4, 8) 72 weights = torch.rand(8, 4, dtype=torch.double, requires_grad=True) 73 result = mm.forward(weights) 74 result.sum().backward() 75 tensor = mm.get() 76 77 expected_weights_grad = tensor.t().mm(torch.ones([4, 4], dtype=torch.double)) 78 self.assertEqual(weights.grad, expected_weights_grad) 79 80 expected_tensor_grad = torch.ones([4, 4], dtype=torch.double).mm(weights.t()) 81 self.assertEqual(tensor.grad, expected_tensor_grad) 82 83 @unittest.skipIf(not TEST_CUDA, "CUDA not found") 84 def test_cuda_extension(self): 85 import torch_test_cpp_extension.cuda as cuda_extension 86 87 x = torch.zeros(100, device="cuda", dtype=torch.float32) 88 y = torch.zeros(100, device="cuda", dtype=torch.float32) 89 90 z = cuda_extension.sigmoid_add(x, y).cpu() 91 92 # 2 * sigmoid(0) = 2 * 0.5 = 1 93 self.assertEqual(z, torch.ones_like(z)) 94 95 @unittest.skipIf(not torch.backends.mps.is_available(), "MPS not found") 96 def test_mps_extension(self): 97 import torch_test_cpp_extension.mps as mps_extension 98 99 tensor_length = 100000 100 x = torch.randn(tensor_length, device="cpu", dtype=torch.float32) 101 y = torch.randn(tensor_length, device="cpu", dtype=torch.float32) 102 103 cpu_output = mps_extension.get_cpu_add_output(x, y) 104 mps_output = mps_extension.get_mps_add_output(x.to("mps"), y.to("mps")) 105 106 self.assertEqual(cpu_output, mps_output.to("cpu")) 107 108 @common.skipIfRocm 109 @unittest.skipIf(common.IS_WINDOWS, "Windows not supported") 110 @unittest.skipIf(not TEST_CUDA, "CUDA not found") 111 def test_cublas_extension(self): 112 from torch_test_cpp_extension import cublas_extension 113 114 x = torch.zeros(100, device="cuda", dtype=torch.float32) 115 z = cublas_extension.noop_cublas_function(x) 116 self.assertEqual(z, x) 117 118 @common.skipIfRocm 119 @unittest.skipIf(common.IS_WINDOWS, "Windows not supported") 120 @unittest.skipIf(not TEST_CUDA, "CUDA not found") 121 def test_cusolver_extension(self): 122 from torch_test_cpp_extension import cusolver_extension 123 124 x = torch.zeros(100, device="cuda", dtype=torch.float32) 125 z = cusolver_extension.noop_cusolver_function(x) 126 self.assertEqual(z, x) 127 128 @unittest.skipIf(IS_WINDOWS, "Not available on Windows") 129 def test_no_python_abi_suffix_sets_the_correct_library_name(self): 130 # For this test, run_test.py will call `python setup.py install` in the 131 # cpp_extensions/no_python_abi_suffix_test folder, where the 132 # `BuildExtension` class has a `no_python_abi_suffix` option set to 133 # `True`. This *should* mean that on Python 3, the produced shared 134 # library does not have an ABI suffix like 135 # "cpython-37m-x86_64-linux-gnu" before the library suffix, e.g. "so". 136 root = os.path.join("cpp_extensions", "no_python_abi_suffix_test", "build") 137 matches = [f for _, _, fs in os.walk(root) for f in fs if f.endswith("so")] 138 self.assertEqual(len(matches), 1, msg=str(matches)) 139 self.assertEqual(matches[0], "no_python_abi_suffix_test.so", msg=str(matches)) 140 141 def test_optional(self): 142 has_value = cpp_extension.function_taking_optional(torch.ones(5)) 143 self.assertTrue(has_value) 144 has_value = cpp_extension.function_taking_optional(None) 145 self.assertFalse(has_value) 146 147 @common.skipIfRocm 148 @unittest.skipIf(common.IS_WINDOWS, "Windows not supported") 149 @unittest.skipIf(not TEST_CUDA, "CUDA not found") 150 @unittest.skipIf( 151 os.getenv("USE_NINJA", "0") == "0", 152 "cuda extension with dlink requires ninja to build", 153 ) 154 def test_cuda_dlink_libs(self): 155 from torch_test_cpp_extension import cuda_dlink 156 157 a = torch.randn(8, dtype=torch.float, device="cuda") 158 b = torch.randn(8, dtype=torch.float, device="cuda") 159 ref = a + b 160 test = cuda_dlink.add(a, b) 161 self.assertEqual(test, ref) 162 163 164@torch.testing._internal.common_utils.markDynamoStrictTest 165class TestPybindTypeCasters(common.TestCase): 166 """Pybind tests for ahead-of-time cpp extensions 167 168 These tests verify the types returned from cpp code using custom type 169 casters. By exercising pybind, we also verify that the type casters work 170 properly. 171 172 For each type caster in `torch/csrc/utils/pybind.h` we create a pybind 173 function that takes no arguments and returns the type_caster type. The 174 second argument to `PYBIND11_TYPE_CASTER` should be the type we expect to 175 receive in python, in these tests we verify this at run-time. 176 """ 177 178 @staticmethod 179 def expected_return_type(func): 180 """ 181 Our Pybind functions have a signature of the form `() -> return_type`. 182 """ 183 # Imports needed for the `eval` below. 184 from typing import List, Tuple # noqa: F401 185 186 return eval(re.search("-> (.*)\n", func.__doc__).group(1)) 187 188 def check(self, func): 189 val = func() 190 expected = self.expected_return_type(func) 191 origin = get_origin(expected) 192 if origin is list: 193 self.check_list(val, expected) 194 elif origin is tuple: 195 self.check_tuple(val, expected) 196 else: 197 self.assertIsInstance(val, expected) 198 199 def check_list(self, vals, expected): 200 self.assertIsInstance(vals, list) 201 list_type = get_args(expected)[0] 202 for val in vals: 203 self.assertIsInstance(val, list_type) 204 205 def check_tuple(self, vals, expected): 206 self.assertIsInstance(vals, tuple) 207 tuple_types = get_args(expected) 208 if tuple_types[1] is ...: 209 tuple_types = repeat(tuple_types[0]) 210 for val, tuple_type in zip(vals, tuple_types): 211 self.assertIsInstance(val, tuple_type) 212 213 def check_union(self, funcs): 214 """Special handling for Union type casters. 215 216 A single cpp type can sometimes be cast to different types in python. 217 In these cases we expect to get exactly one function per python type. 218 """ 219 # Verify that all functions have the same return type. 220 union_type = {self.expected_return_type(f) for f in funcs} 221 assert len(union_type) == 1 222 union_type = union_type.pop() 223 self.assertIs(Union, get_origin(union_type)) 224 # SymInt is inconvenient to test, so don't require it 225 expected_types = set(get_args(union_type)) - {torch.SymInt} 226 for func in funcs: 227 val = func() 228 for tp in expected_types: 229 if isinstance(val, tp): 230 expected_types.remove(tp) 231 break 232 else: 233 raise AssertionError(f"{val} is not an instance of {expected_types}") 234 self.assertFalse( 235 expected_types, f"Missing functions for types {expected_types}" 236 ) 237 238 def test_pybind_return_types(self): 239 functions = [ 240 cpp_extension.get_complex, 241 cpp_extension.get_device, 242 cpp_extension.get_generator, 243 cpp_extension.get_intarrayref, 244 cpp_extension.get_memory_format, 245 cpp_extension.get_storage, 246 cpp_extension.get_symfloat, 247 cpp_extension.get_symintarrayref, 248 cpp_extension.get_tensor, 249 ] 250 union_functions = [ 251 [cpp_extension.get_symint], 252 ] 253 for func in functions: 254 with self.subTest(msg=f"check {func.__name__}"): 255 self.check(func) 256 for funcs in union_functions: 257 with self.subTest(msg=f"check {[f.__name__ for f in funcs]}"): 258 self.check_union(funcs) 259 260 261@torch.testing._internal.common_utils.markDynamoStrictTest 262class TestMAIATensor(common.TestCase): 263 def test_unregistered(self): 264 a = torch.arange(0, 10, device="cpu") 265 with self.assertRaisesRegex(RuntimeError, "Could not run"): 266 b = torch.arange(0, 10, device="maia") 267 268 @skipIfTorchDynamo("dynamo cannot model maia device") 269 def test_zeros(self): 270 a = torch.empty(5, 5, device="cpu") 271 self.assertEqual(a.device, torch.device("cpu")) 272 273 b = torch.empty(5, 5, device="maia") 274 self.assertEqual(b.device, torch.device("maia", 0)) 275 self.assertEqual(maia_extension.get_test_int(), 0) 276 self.assertEqual(torch.get_default_dtype(), b.dtype) 277 278 c = torch.empty((5, 5), dtype=torch.int64, device="maia") 279 self.assertEqual(maia_extension.get_test_int(), 0) 280 self.assertEqual(torch.int64, c.dtype) 281 282 def test_add(self): 283 a = torch.empty(5, 5, device="maia", requires_grad=True) 284 self.assertEqual(maia_extension.get_test_int(), 0) 285 286 b = torch.empty(5, 5, device="maia") 287 self.assertEqual(maia_extension.get_test_int(), 0) 288 289 c = a + b 290 self.assertEqual(maia_extension.get_test_int(), 1) 291 292 def test_conv_backend_override(self): 293 # To simplify tests, we use 4d input here to avoid doing view4d( which 294 # needs more overrides) in _convolution. 295 input = torch.empty(2, 4, 10, 2, device="maia", requires_grad=True) 296 weight = torch.empty(6, 4, 2, 2, device="maia", requires_grad=True) 297 bias = torch.empty(6, device="maia") 298 299 # Make sure forward is overriden 300 out = torch.nn.functional.conv2d(input, weight, bias, 2, 0, 1, 1) 301 self.assertEqual(maia_extension.get_test_int(), 2) 302 self.assertEqual(out.shape[0], input.shape[0]) 303 self.assertEqual(out.shape[1], weight.shape[0]) 304 305 # Make sure backward is overriden 306 # Double backward is dispatched to _convolution_double_backward. 307 # It is not tested here as it involves more computation/overrides. 308 grad = torch.autograd.grad(out, input, out, create_graph=True) 309 self.assertEqual(maia_extension.get_test_int(), 3) 310 self.assertEqual(grad[0].shape, input.shape) 311 312 313@torch.testing._internal.common_utils.markDynamoStrictTest 314class TestRNGExtension(common.TestCase): 315 def setUp(self): 316 super().setUp() 317 318 @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991") 319 def test_rng(self): 320 fourty_two = torch.full((10,), 42, dtype=torch.int64) 321 322 t = torch.empty(10, dtype=torch.int64).random_() 323 self.assertNotEqual(t, fourty_two) 324 325 gen = torch.Generator(device="cpu") 326 t = torch.empty(10, dtype=torch.int64).random_(generator=gen) 327 self.assertNotEqual(t, fourty_two) 328 329 self.assertEqual(rng_extension.getInstanceCount(), 0) 330 gen = rng_extension.createTestCPUGenerator(42) 331 self.assertEqual(rng_extension.getInstanceCount(), 1) 332 copy = gen 333 self.assertEqual(rng_extension.getInstanceCount(), 1) 334 self.assertEqual(gen, copy) 335 copy2 = rng_extension.identity(copy) 336 self.assertEqual(rng_extension.getInstanceCount(), 1) 337 self.assertEqual(gen, copy2) 338 t = torch.empty(10, dtype=torch.int64).random_(generator=gen) 339 self.assertEqual(rng_extension.getInstanceCount(), 1) 340 self.assertEqual(t, fourty_two) 341 del gen 342 self.assertEqual(rng_extension.getInstanceCount(), 1) 343 del copy 344 self.assertEqual(rng_extension.getInstanceCount(), 1) 345 del copy2 346 self.assertEqual(rng_extension.getInstanceCount(), 0) 347 348 349@torch.testing._internal.common_utils.markDynamoStrictTest 350@unittest.skipIf(not TEST_CUDA, "CUDA not found") 351class TestTorchLibrary(common.TestCase): 352 def test_torch_library(self): 353 import torch_test_cpp_extension.torch_library # noqa: F401 354 355 def f(a: bool, b: bool): 356 return torch.ops.torch_library.logical_and(a, b) 357 358 self.assertTrue(f(True, True)) 359 self.assertFalse(f(True, False)) 360 self.assertFalse(f(False, True)) 361 self.assertFalse(f(False, False)) 362 s = torch.jit.script(f) 363 self.assertTrue(s(True, True)) 364 self.assertFalse(s(True, False)) 365 self.assertFalse(s(False, True)) 366 self.assertFalse(s(False, False)) 367 self.assertIn("torch_library::logical_and", str(s.graph)) 368 369 370if __name__ == "__main__": 371 common.run_tests() 372