xref: /aosp_15_r20/external/pytorch/test/inductor/test_extension_backend.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: inductor"]
2import os
3import shutil
4import sys
5import unittest
6
7import torch
8import torch._dynamo
9import torch.utils.cpp_extension
10from torch._C import FileCheck
11
12try:
13    from extension_backends.cpp.extension_codegen_backend import (
14        ExtensionCppWrapperCodegen,
15        ExtensionScheduling,
16        ExtensionWrapperCodegen,
17    )
18except ImportError:
19    from .extension_backends.cpp.extension_codegen_backend import (
20        ExtensionCppWrapperCodegen,
21        ExtensionScheduling,
22        ExtensionWrapperCodegen,
23    )
24
25import torch._inductor.config as config
26from torch._inductor import codecache, metrics
27from torch._inductor.codegen import cpp_utils
28from torch._inductor.codegen.common import (
29    get_scheduling_for_device,
30    get_wrapper_codegen_for_device,
31    register_backend_for_device,
32)
33from torch.testing._internal.common_utils import IS_FBCODE, IS_MACOS
34
35try:
36    try:
37        from . import test_torchinductor
38    except ImportError:
39        import test_torchinductor
40except unittest.SkipTest:
41    if __name__ == "__main__":
42        sys.exit(0)
43    raise
44
45
46run_and_get_cpp_code = test_torchinductor.run_and_get_cpp_code
47TestCase = test_torchinductor.TestCase
48
49
50def remove_build_path():
51    if sys.platform == "win32":
52        # Not wiping extensions build folder because Windows
53        return
54    default_build_root = torch.utils.cpp_extension.get_default_build_root()
55    if os.path.exists(default_build_root):
56        shutil.rmtree(default_build_root, ignore_errors=True)
57
58
59@unittest.skipIf(IS_FBCODE, "cpp_extension doesn't work in fbcode right now")
60class ExtensionBackendTests(TestCase):
61    module = None
62
63    @classmethod
64    def setUpClass(cls):
65        super().setUpClass()
66
67        # Build Extension
68        remove_build_path()
69        source_file_path = os.path.dirname(os.path.abspath(__file__))
70        source_file = os.path.join(
71            source_file_path, "extension_backends/cpp/extension_device.cpp"
72        )
73        cls.module = torch.utils.cpp_extension.load(
74            name="extension_device",
75            sources=[
76                str(source_file),
77            ],
78            extra_cflags=["-g"],
79            verbose=True,
80        )
81
82    @classmethod
83    def tearDownClass(cls):
84        cls._stack.close()
85        super().tearDownClass()
86
87        remove_build_path()
88
89    def setUp(self):
90        torch._dynamo.reset()
91        super().setUp()
92
93        # cpp extensions use relative paths. Those paths are relative to
94        # this file, so we'll change the working directory temporarily
95        self.old_working_dir = os.getcwd()
96        os.chdir(os.path.dirname(os.path.abspath(__file__)))
97        assert self.module is not None
98
99    def tearDown(self):
100        super().tearDown()
101        torch._dynamo.reset()
102
103        # return the working directory (see setUp)
104        os.chdir(self.old_working_dir)
105
106    def test_open_device_registration(self):
107        torch.utils.rename_privateuse1_backend("extension_device")
108        torch._register_device_module("extension_device", self.module)
109
110        register_backend_for_device(
111            "extension_device",
112            ExtensionScheduling,
113            ExtensionWrapperCodegen,
114            ExtensionCppWrapperCodegen,
115        )
116        self.assertTrue(
117            get_scheduling_for_device("extension_device") == ExtensionScheduling
118        )
119        self.assertTrue(
120            get_wrapper_codegen_for_device("extension_device")
121            == ExtensionWrapperCodegen
122        )
123        self.assertTrue(
124            get_wrapper_codegen_for_device("extension_device", True)
125            == ExtensionCppWrapperCodegen
126        )
127
128        self.assertFalse(self.module.custom_op_called())
129        device = self.module.custom_device()
130        x = torch.empty(2, 16).to(device=device).fill_(1)
131        self.assertTrue(self.module.custom_op_called())
132        y = torch.empty(2, 16).to(device=device).fill_(2)
133        z = torch.empty(2, 16).to(device=device).fill_(3)
134        ref = torch.empty(2, 16).fill_(5)
135
136        self.assertTrue(x.device == device)
137        self.assertTrue(y.device == device)
138        self.assertTrue(z.device == device)
139
140        def fn(a, b, c):
141            return a * b + c
142
143        cpp_utils.DEVICE_TO_ATEN["extension_device"] = "at::kPrivateUse1"
144        for cpp_wrapper_flag in [True, False]:
145            with config.patch({"cpp_wrapper": cpp_wrapper_flag}):
146                metrics.reset()
147                opt_fn = torch.compile()(fn)
148                _, code = run_and_get_cpp_code(opt_fn, x, y, z)
149                if codecache.valid_vec_isa_list():
150                    load_expr = "loadu"
151                else:
152                    load_expr = " = in_ptr0[static_cast<long>(i0)];"
153                FileCheck().check("void").check(load_expr).check(
154                    "extension_device"
155                ).run(code)
156                opt_fn(x, y, z)
157                res = opt_fn(x, y, z)
158                self.assertEqual(ref, res.to(device="cpu"))
159
160
161if __name__ == "__main__":
162    from torch._inductor.test_case import run_tests
163    from torch.testing._internal.inductor_utils import HAS_CPU
164
165    # cpp_extension doesn't work in fbcode right now
166    if HAS_CPU and not IS_MACOS and not IS_FBCODE:
167        run_tests(needs="filelock")
168