xref: /aosp_15_r20/external/pytorch/test/test_cpp_extensions_aot.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: cpp-extensions"]
2
3import os
4import re
5import unittest
6from itertools import repeat
7from typing import get_args, get_origin, Union
8
9import torch
10import torch.backends.cudnn
11import torch.testing._internal.common_utils as common
12import torch.utils.cpp_extension
13from torch.testing._internal.common_cuda import TEST_CUDA
14from torch.testing._internal.common_utils import IS_WINDOWS, skipIfTorchDynamo
15
16
17try:
18    import pytest
19
20    HAS_PYTEST = True
21except ImportError as e:
22    HAS_PYTEST = False
23
24# TODO: Rewrite these tests so that they can be collected via pytest without
25# using run_test.py
26try:
27    if HAS_PYTEST:
28        cpp_extension = pytest.importorskip("torch_test_cpp_extension.cpp")
29        maia_extension = pytest.importorskip("torch_test_cpp_extension.maia")
30        rng_extension = pytest.importorskip("torch_test_cpp_extension.rng")
31    else:
32        import torch_test_cpp_extension.cpp as cpp_extension
33        import torch_test_cpp_extension.maia as maia_extension
34        import torch_test_cpp_extension.rng as rng_extension
35except ImportError as e:
36    raise RuntimeError(
37        "test_cpp_extensions_aot.py cannot be invoked directly. Run "
38        "`python run_test.py -i test_cpp_extensions_aot_ninja` instead."
39    ) from e
40
41
42@torch.testing._internal.common_utils.markDynamoStrictTest
43class TestCppExtensionAOT(common.TestCase):
44    """Tests ahead-of-time cpp extensions
45
46    NOTE: run_test.py's test_cpp_extensions_aot_ninja target
47    also runs this test case, but with ninja enabled. If you are debugging
48    a test failure here from the CI, check the logs for which target
49    (test_cpp_extensions_aot_no_ninja vs test_cpp_extensions_aot_ninja)
50    failed.
51    """
52
53    def test_extension_function(self):
54        x = torch.randn(4, 4)
55        y = torch.randn(4, 4)
56        z = cpp_extension.sigmoid_add(x, y)
57        self.assertEqual(z, x.sigmoid() + y.sigmoid())
58        # test pybind support torch.dtype cast.
59        self.assertEqual(
60            str(torch.float32), str(cpp_extension.get_math_type(torch.half))
61        )
62
63    def test_extension_module(self):
64        mm = cpp_extension.MatrixMultiplier(4, 8)
65        weights = torch.rand(8, 4, dtype=torch.double)
66        expected = mm.get().mm(weights)
67        result = mm.forward(weights)
68        self.assertEqual(expected, result)
69
70    def test_backward(self):
71        mm = cpp_extension.MatrixMultiplier(4, 8)
72        weights = torch.rand(8, 4, dtype=torch.double, requires_grad=True)
73        result = mm.forward(weights)
74        result.sum().backward()
75        tensor = mm.get()
76
77        expected_weights_grad = tensor.t().mm(torch.ones([4, 4], dtype=torch.double))
78        self.assertEqual(weights.grad, expected_weights_grad)
79
80        expected_tensor_grad = torch.ones([4, 4], dtype=torch.double).mm(weights.t())
81        self.assertEqual(tensor.grad, expected_tensor_grad)
82
83    @unittest.skipIf(not TEST_CUDA, "CUDA not found")
84    def test_cuda_extension(self):
85        import torch_test_cpp_extension.cuda as cuda_extension
86
87        x = torch.zeros(100, device="cuda", dtype=torch.float32)
88        y = torch.zeros(100, device="cuda", dtype=torch.float32)
89
90        z = cuda_extension.sigmoid_add(x, y).cpu()
91
92        # 2 * sigmoid(0) = 2 * 0.5 = 1
93        self.assertEqual(z, torch.ones_like(z))
94
95    @unittest.skipIf(not torch.backends.mps.is_available(), "MPS not found")
96    def test_mps_extension(self):
97        import torch_test_cpp_extension.mps as mps_extension
98
99        tensor_length = 100000
100        x = torch.randn(tensor_length, device="cpu", dtype=torch.float32)
101        y = torch.randn(tensor_length, device="cpu", dtype=torch.float32)
102
103        cpu_output = mps_extension.get_cpu_add_output(x, y)
104        mps_output = mps_extension.get_mps_add_output(x.to("mps"), y.to("mps"))
105
106        self.assertEqual(cpu_output, mps_output.to("cpu"))
107
108    @common.skipIfRocm
109    @unittest.skipIf(common.IS_WINDOWS, "Windows not supported")
110    @unittest.skipIf(not TEST_CUDA, "CUDA not found")
111    def test_cublas_extension(self):
112        from torch_test_cpp_extension import cublas_extension
113
114        x = torch.zeros(100, device="cuda", dtype=torch.float32)
115        z = cublas_extension.noop_cublas_function(x)
116        self.assertEqual(z, x)
117
118    @common.skipIfRocm
119    @unittest.skipIf(common.IS_WINDOWS, "Windows not supported")
120    @unittest.skipIf(not TEST_CUDA, "CUDA not found")
121    def test_cusolver_extension(self):
122        from torch_test_cpp_extension import cusolver_extension
123
124        x = torch.zeros(100, device="cuda", dtype=torch.float32)
125        z = cusolver_extension.noop_cusolver_function(x)
126        self.assertEqual(z, x)
127
128    @unittest.skipIf(IS_WINDOWS, "Not available on Windows")
129    def test_no_python_abi_suffix_sets_the_correct_library_name(self):
130        # For this test, run_test.py will call `python setup.py install` in the
131        # cpp_extensions/no_python_abi_suffix_test folder, where the
132        # `BuildExtension` class has a `no_python_abi_suffix` option set to
133        # `True`. This *should* mean that on Python 3, the produced shared
134        # library does not have an ABI suffix like
135        # "cpython-37m-x86_64-linux-gnu" before the library suffix, e.g. "so".
136        root = os.path.join("cpp_extensions", "no_python_abi_suffix_test", "build")
137        matches = [f for _, _, fs in os.walk(root) for f in fs if f.endswith("so")]
138        self.assertEqual(len(matches), 1, msg=str(matches))
139        self.assertEqual(matches[0], "no_python_abi_suffix_test.so", msg=str(matches))
140
141    def test_optional(self):
142        has_value = cpp_extension.function_taking_optional(torch.ones(5))
143        self.assertTrue(has_value)
144        has_value = cpp_extension.function_taking_optional(None)
145        self.assertFalse(has_value)
146
147    @common.skipIfRocm
148    @unittest.skipIf(common.IS_WINDOWS, "Windows not supported")
149    @unittest.skipIf(not TEST_CUDA, "CUDA not found")
150    @unittest.skipIf(
151        os.getenv("USE_NINJA", "0") == "0",
152        "cuda extension with dlink requires ninja to build",
153    )
154    def test_cuda_dlink_libs(self):
155        from torch_test_cpp_extension import cuda_dlink
156
157        a = torch.randn(8, dtype=torch.float, device="cuda")
158        b = torch.randn(8, dtype=torch.float, device="cuda")
159        ref = a + b
160        test = cuda_dlink.add(a, b)
161        self.assertEqual(test, ref)
162
163
164@torch.testing._internal.common_utils.markDynamoStrictTest
165class TestPybindTypeCasters(common.TestCase):
166    """Pybind tests for ahead-of-time cpp extensions
167
168    These tests verify the types returned from cpp code using custom type
169    casters. By exercising pybind, we also verify that the type casters work
170    properly.
171
172    For each type caster in `torch/csrc/utils/pybind.h` we create a pybind
173    function that takes no arguments and returns the type_caster type. The
174    second argument to `PYBIND11_TYPE_CASTER` should be the type we expect to
175    receive in python, in these tests we verify this at run-time.
176    """
177
178    @staticmethod
179    def expected_return_type(func):
180        """
181        Our Pybind functions have a signature of the form `() -> return_type`.
182        """
183        # Imports needed for the `eval` below.
184        from typing import List, Tuple  # noqa: F401
185
186        return eval(re.search("-> (.*)\n", func.__doc__).group(1))
187
188    def check(self, func):
189        val = func()
190        expected = self.expected_return_type(func)
191        origin = get_origin(expected)
192        if origin is list:
193            self.check_list(val, expected)
194        elif origin is tuple:
195            self.check_tuple(val, expected)
196        else:
197            self.assertIsInstance(val, expected)
198
199    def check_list(self, vals, expected):
200        self.assertIsInstance(vals, list)
201        list_type = get_args(expected)[0]
202        for val in vals:
203            self.assertIsInstance(val, list_type)
204
205    def check_tuple(self, vals, expected):
206        self.assertIsInstance(vals, tuple)
207        tuple_types = get_args(expected)
208        if tuple_types[1] is ...:
209            tuple_types = repeat(tuple_types[0])
210        for val, tuple_type in zip(vals, tuple_types):
211            self.assertIsInstance(val, tuple_type)
212
213    def check_union(self, funcs):
214        """Special handling for Union type casters.
215
216        A single cpp type can sometimes be cast to different types in python.
217        In these cases we expect to get exactly one function per python type.
218        """
219        # Verify that all functions have the same return type.
220        union_type = {self.expected_return_type(f) for f in funcs}
221        assert len(union_type) == 1
222        union_type = union_type.pop()
223        self.assertIs(Union, get_origin(union_type))
224        # SymInt is inconvenient to test, so don't require it
225        expected_types = set(get_args(union_type)) - {torch.SymInt}
226        for func in funcs:
227            val = func()
228            for tp in expected_types:
229                if isinstance(val, tp):
230                    expected_types.remove(tp)
231                    break
232            else:
233                raise AssertionError(f"{val} is not an instance of {expected_types}")
234        self.assertFalse(
235            expected_types, f"Missing functions for types {expected_types}"
236        )
237
238    def test_pybind_return_types(self):
239        functions = [
240            cpp_extension.get_complex,
241            cpp_extension.get_device,
242            cpp_extension.get_generator,
243            cpp_extension.get_intarrayref,
244            cpp_extension.get_memory_format,
245            cpp_extension.get_storage,
246            cpp_extension.get_symfloat,
247            cpp_extension.get_symintarrayref,
248            cpp_extension.get_tensor,
249        ]
250        union_functions = [
251            [cpp_extension.get_symint],
252        ]
253        for func in functions:
254            with self.subTest(msg=f"check {func.__name__}"):
255                self.check(func)
256        for funcs in union_functions:
257            with self.subTest(msg=f"check {[f.__name__ for f in funcs]}"):
258                self.check_union(funcs)
259
260
261@torch.testing._internal.common_utils.markDynamoStrictTest
262class TestMAIATensor(common.TestCase):
263    def test_unregistered(self):
264        a = torch.arange(0, 10, device="cpu")
265        with self.assertRaisesRegex(RuntimeError, "Could not run"):
266            b = torch.arange(0, 10, device="maia")
267
268    @skipIfTorchDynamo("dynamo cannot model maia device")
269    def test_zeros(self):
270        a = torch.empty(5, 5, device="cpu")
271        self.assertEqual(a.device, torch.device("cpu"))
272
273        b = torch.empty(5, 5, device="maia")
274        self.assertEqual(b.device, torch.device("maia", 0))
275        self.assertEqual(maia_extension.get_test_int(), 0)
276        self.assertEqual(torch.get_default_dtype(), b.dtype)
277
278        c = torch.empty((5, 5), dtype=torch.int64, device="maia")
279        self.assertEqual(maia_extension.get_test_int(), 0)
280        self.assertEqual(torch.int64, c.dtype)
281
282    def test_add(self):
283        a = torch.empty(5, 5, device="maia", requires_grad=True)
284        self.assertEqual(maia_extension.get_test_int(), 0)
285
286        b = torch.empty(5, 5, device="maia")
287        self.assertEqual(maia_extension.get_test_int(), 0)
288
289        c = a + b
290        self.assertEqual(maia_extension.get_test_int(), 1)
291
292    def test_conv_backend_override(self):
293        # To simplify tests, we use 4d input here to avoid doing view4d( which
294        # needs more overrides) in _convolution.
295        input = torch.empty(2, 4, 10, 2, device="maia", requires_grad=True)
296        weight = torch.empty(6, 4, 2, 2, device="maia", requires_grad=True)
297        bias = torch.empty(6, device="maia")
298
299        # Make sure forward is overriden
300        out = torch.nn.functional.conv2d(input, weight, bias, 2, 0, 1, 1)
301        self.assertEqual(maia_extension.get_test_int(), 2)
302        self.assertEqual(out.shape[0], input.shape[0])
303        self.assertEqual(out.shape[1], weight.shape[0])
304
305        # Make sure backward is overriden
306        # Double backward is dispatched to _convolution_double_backward.
307        # It is not tested here as it involves more computation/overrides.
308        grad = torch.autograd.grad(out, input, out, create_graph=True)
309        self.assertEqual(maia_extension.get_test_int(), 3)
310        self.assertEqual(grad[0].shape, input.shape)
311
312
313@torch.testing._internal.common_utils.markDynamoStrictTest
314class TestRNGExtension(common.TestCase):
315    def setUp(self):
316        super().setUp()
317
318    @skipIfTorchDynamo("https://github.com/pytorch/torchdynamo/issues/1991")
319    def test_rng(self):
320        fourty_two = torch.full((10,), 42, dtype=torch.int64)
321
322        t = torch.empty(10, dtype=torch.int64).random_()
323        self.assertNotEqual(t, fourty_two)
324
325        gen = torch.Generator(device="cpu")
326        t = torch.empty(10, dtype=torch.int64).random_(generator=gen)
327        self.assertNotEqual(t, fourty_two)
328
329        self.assertEqual(rng_extension.getInstanceCount(), 0)
330        gen = rng_extension.createTestCPUGenerator(42)
331        self.assertEqual(rng_extension.getInstanceCount(), 1)
332        copy = gen
333        self.assertEqual(rng_extension.getInstanceCount(), 1)
334        self.assertEqual(gen, copy)
335        copy2 = rng_extension.identity(copy)
336        self.assertEqual(rng_extension.getInstanceCount(), 1)
337        self.assertEqual(gen, copy2)
338        t = torch.empty(10, dtype=torch.int64).random_(generator=gen)
339        self.assertEqual(rng_extension.getInstanceCount(), 1)
340        self.assertEqual(t, fourty_two)
341        del gen
342        self.assertEqual(rng_extension.getInstanceCount(), 1)
343        del copy
344        self.assertEqual(rng_extension.getInstanceCount(), 1)
345        del copy2
346        self.assertEqual(rng_extension.getInstanceCount(), 0)
347
348
349@torch.testing._internal.common_utils.markDynamoStrictTest
350@unittest.skipIf(not TEST_CUDA, "CUDA not found")
351class TestTorchLibrary(common.TestCase):
352    def test_torch_library(self):
353        import torch_test_cpp_extension.torch_library  # noqa: F401
354
355        def f(a: bool, b: bool):
356            return torch.ops.torch_library.logical_and(a, b)
357
358        self.assertTrue(f(True, True))
359        self.assertFalse(f(True, False))
360        self.assertFalse(f(False, True))
361        self.assertFalse(f(False, False))
362        s = torch.jit.script(f)
363        self.assertTrue(s(True, True))
364        self.assertFalse(s(True, False))
365        self.assertFalse(s(False, True))
366        self.assertFalse(s(False, False))
367        self.assertIn("torch_library::logical_and", str(s.graph))
368
369
370if __name__ == "__main__":
371    common.run_tests()
372