xref: /aosp_15_r20/external/pytorch/test/test_cpp_extensions_mtia_backend.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: mtia"]
2
3import os
4import shutil
5import sys
6import tempfile
7import unittest
8
9import torch
10import torch.testing._internal.common_utils as common
11import torch.utils.cpp_extension
12from torch.testing._internal.common_utils import (
13    IS_ARM64,
14    IS_LINUX,
15    skipIfTorchDynamo,
16    TEST_CUDA,
17    TEST_PRIVATEUSE1,
18    TEST_XPU,
19)
20from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
21
22
23# define TEST_ROCM before changing TEST_CUDA
24TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None
25TEST_CUDA = TEST_CUDA and CUDA_HOME is not None
26
27
28def remove_build_path():
29    if sys.platform == "win32":
30        # Not wiping extensions build folder because Windows
31        return
32    default_build_root = torch.utils.cpp_extension.get_default_build_root()
33    if os.path.exists(default_build_root):
34        shutil.rmtree(default_build_root, ignore_errors=True)
35
36
37@unittest.skipIf(
38    IS_ARM64 or not IS_LINUX or TEST_CUDA or TEST_PRIVATEUSE1 or TEST_ROCM or TEST_XPU,
39    "Only on linux platform and mutual exclusive to other backends",
40)
41@torch.testing._internal.common_utils.markDynamoStrictTest
42class TestCppExtensionMTIABackend(common.TestCase):
43    """Tests MTIA backend with C++ extensions."""
44
45    module = None
46
47    def setUp(self):
48        super().setUp()
49        # cpp extensions use relative paths. Those paths are relative to
50        # this file, so we'll change the working directory temporarily
51        self.old_working_dir = os.getcwd()
52        os.chdir(os.path.dirname(os.path.abspath(__file__)))
53
54    def tearDown(self):
55        super().tearDown()
56        # return the working directory (see setUp)
57        os.chdir(self.old_working_dir)
58
59    @classmethod
60    def tearDownClass(cls):
61        remove_build_path()
62
63    @classmethod
64    def setUpClass(cls):
65        remove_build_path()
66        build_dir = tempfile.mkdtemp()
67        # Load the fake device guard impl.
68        cls.module = torch.utils.cpp_extension.load(
69            name="mtia_extension",
70            sources=["cpp_extensions/mtia_extension.cpp"],
71            build_directory=build_dir,
72            extra_include_paths=[
73                "cpp_extensions",
74                "path / with spaces in it",
75                "path with quote'",
76            ],
77            is_python_module=False,
78            verbose=True,
79        )
80
81    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
82    def test_get_device_module(self):
83        device = torch.device("mtia:0")
84        default_stream = torch.get_device_module(device).current_stream()
85        self.assertEqual(
86            default_stream.device_type, int(torch._C._autograd.DeviceType.MTIA)
87        )
88        print(torch._C.Stream.__mro__)
89        print(torch.cuda.Stream.__mro__)
90
91    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
92    def test_stream_basic(self):
93        default_stream = torch.mtia.current_stream()
94        user_stream = torch.mtia.Stream()
95        self.assertEqual(torch.mtia.current_stream(), default_stream)
96        self.assertNotEqual(default_stream, user_stream)
97        # Check mtia_extension.cpp, default stream id starts from 0.
98        self.assertEqual(default_stream.stream_id, 0)
99        self.assertNotEqual(user_stream.stream_id, 0)
100        with torch.mtia.stream(user_stream):
101            self.assertEqual(torch.mtia.current_stream(), user_stream)
102        self.assertTrue(user_stream.query())
103        default_stream.synchronize()
104        self.assertTrue(default_stream.query())
105
106    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
107    def test_stream_context(self):
108        mtia_stream_0 = torch.mtia.Stream(device="mtia:0")
109        mtia_stream_1 = torch.mtia.Stream(device="mtia:0")
110        print(mtia_stream_0)
111        print(mtia_stream_1)
112        with torch.mtia.stream(mtia_stream_0):
113            current_stream = torch.mtia.current_stream()
114            msg = f"current_stream {current_stream} should be {mtia_stream_0}"
115            self.assertTrue(current_stream == mtia_stream_0, msg=msg)
116
117        with torch.mtia.stream(mtia_stream_1):
118            current_stream = torch.mtia.current_stream()
119            msg = f"current_stream {current_stream} should be {mtia_stream_1}"
120            self.assertTrue(current_stream == mtia_stream_1, msg=msg)
121
122    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
123    def test_stream_context_different_device(self):
124        device_0 = torch.device("mtia:0")
125        device_1 = torch.device("mtia:1")
126        mtia_stream_0 = torch.mtia.Stream(device=device_0)
127        mtia_stream_1 = torch.mtia.Stream(device=device_1)
128        print(mtia_stream_0)
129        print(mtia_stream_1)
130        orig_current_device = torch.mtia.current_device()
131        with torch.mtia.stream(mtia_stream_0):
132            current_stream = torch.mtia.current_stream()
133            self.assertTrue(torch.mtia.current_device() == device_0.index)
134            msg = f"current_stream {current_stream} should be {mtia_stream_0}"
135            self.assertTrue(current_stream == mtia_stream_0, msg=msg)
136        self.assertTrue(torch.mtia.current_device() == orig_current_device)
137        with torch.mtia.stream(mtia_stream_1):
138            current_stream = torch.mtia.current_stream()
139            self.assertTrue(torch.mtia.current_device() == device_1.index)
140            msg = f"current_stream {current_stream} should be {mtia_stream_1}"
141            self.assertTrue(current_stream == mtia_stream_1, msg=msg)
142        self.assertTrue(torch.mtia.current_device() == orig_current_device)
143
144    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
145    def test_device_context(self):
146        device_0 = torch.device("mtia:0")
147        device_1 = torch.device("mtia:1")
148        with torch.mtia.device(device_0):
149            self.assertTrue(torch.mtia.current_device() == device_0.index)
150
151        with torch.mtia.device(device_1):
152            self.assertTrue(torch.mtia.current_device() == device_1.index)
153
154
155if __name__ == "__main__":
156    common.run_tests()
157