xref: /aosp_15_r20/external/pytorch/test/inductor/test_halide.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: pt2"]
2import functools
3import itertools
4import os
5import sys
6import textwrap
7import unittest
8
9import torch
10import torch._inductor.async_compile  # noqa: F401 required to warm up AsyncCompile pools
11from torch._inductor import config
12from torch._inductor.codecache import HalideCodeCache
13from torch._inductor.runtime.hints import HalideInputSpec, HalideMeta
14from torch._inductor.test_case import run_tests, TestCase
15from torch._inductor.utils import parallel_num_threads
16from torch.testing._internal.common_utils import IS_CI, IS_MACOS, IS_WINDOWS
17from torch.testing._internal.inductor_utils import HAS_CPU
18from torch.utils._triton import has_triton
19
20
21if IS_WINDOWS and IS_CI:
22    sys.stderr.write(
23        "Windows CI does not have necessary dependencies for test_torchinductor_dynamic_shapes yet\n"
24    )
25    if __name__ == "__main__":
26        sys.exit(0)
27    raise unittest.SkipTest("requires sympy/functorch/filelock")
28
29try:
30    import halide
31
32    HAS_HALIDE = halide is not None
33except ImportError:
34    HAS_HALIDE = False
35
36
37try:
38    from . import test_torchinductor
39except ImportError:
40    import test_torchinductor
41
42
43make_halide = config.patch(
44    {
45        "halide.scan_kernels": True,
46        "cpu_backend": "halide",
47        "cuda_backend": "halide",
48    }
49)
50
51
52@unittest.skipUnless(HAS_HALIDE, "requires halide")
53class HalideTests(TestCase):
54    def test_codecache(self):
55        fn = HalideCodeCache.generate_halide(
56            HalideMeta(
57                argtypes=[
58                    HalideInputSpec(
59                        ctype="float*",
60                        name="in_ptr0",
61                        shape=["1024L"],
62                        stride=["1L"],
63                        offset="0",
64                    ),
65                    HalideInputSpec(
66                        ctype="float*",
67                        name="in_ptr1",
68                        shape=["1024L"],
69                        stride=["1L"],
70                        offset="0",
71                    ),
72                    HalideInputSpec(
73                        ctype="float*",
74                        name="out_ptr0",
75                        shape=["1024L"],
76                        stride=["1L"],
77                        offset="0",
78                    ),
79                ],
80                target="host-no_runtime",
81                scheduler="Mullapudi2016",
82                scheduler_flags={
83                    "parallelism": parallel_num_threads(),
84                },
85            ),
86            textwrap.dedent(
87                """
88                import halide as hl
89
90                @hl.generator(name="kernel")
91                class Kernel:
92                    in_ptr0 = hl.InputBuffer(hl.Float(32), 1)
93                    in_ptr1 = hl.InputBuffer(hl.Float(32), 1)
94                    out_ptr0 = hl.OutputBuffer(hl.Float(32), 1)
95
96                    def generate(g):
97                        in_ptr0 = g.in_ptr0
98                        in_ptr1 = g.in_ptr1
99                        out_ptr0 = g.out_ptr0
100                        xindex = hl.Var('xindex')
101                        x0 = xindex
102                        tmp0 = hl.Func()
103                        tmp0[xindex] = in_ptr0[x0]
104                        tmp1 = hl.Func()
105                        tmp1[xindex] = in_ptr1[x0]
106                        tmp2 = hl.Func()
107                        tmp2[xindex] = tmp0[xindex] + tmp1[xindex]
108                        out_ptr0[x0] = tmp2[xindex]
109
110                        assert g.using_autoscheduler()
111                        in_ptr0.set_estimates([hl.Range(1024, 1024)])
112                        in_ptr1.set_estimates([hl.Range(1024, 1024)])
113                        out_ptr0.set_estimates([hl.Range(1024, 1024)])
114
115                __name__ == '__main__' and hl.main()
116                """
117            ),
118        )
119        a = torch.randn(1024)
120        b = torch.randn(1024)
121        c = torch.randn(1024)
122        fn(a, b, c)
123        self.assertEqual(c, a + b)
124
125    def test_manual_schedule(self):
126        fn = HalideCodeCache.generate_halide(
127            HalideMeta(
128                argtypes=[
129                    HalideInputSpec(
130                        ctype="float*",
131                        name="in_ptr0",
132                        shape=["1024L"],
133                        stride=["1L"],
134                        offset="0",
135                    ),
136                    HalideInputSpec(
137                        ctype="float*",
138                        name="in_ptr1",
139                        shape=["1024L"],
140                        stride=["1L"],
141                        offset="0",
142                    ),
143                    HalideInputSpec(
144                        ctype="float*",
145                        name="out_ptr0",
146                        shape=["1024L"],
147                        stride=["1L"],
148                        offset="0",
149                    ),
150                ],
151                target="host-no_runtime",
152                scheduler=None,
153            ),
154            textwrap.dedent(
155                """
156                import halide as hl
157
158                @hl.generator(name="kernel")
159                class Kernel:
160                    in_ptr0 = hl.InputBuffer(hl.Float(32), 1)
161                    in_ptr1 = hl.InputBuffer(hl.Float(32), 1)
162                    out_ptr0 = hl.OutputBuffer(hl.Float(32), 1)
163
164                    def generate(g):
165                        in_ptr0 = g.in_ptr0
166                        in_ptr1 = g.in_ptr1
167                        out_ptr0 = g.out_ptr0
168                        xindex = hl.Var('xindex')
169                        x0 = xindex
170                        tmp0 = hl.Func()
171                        tmp0[xindex] = in_ptr0[x0]
172                        tmp1 = hl.Func()
173                        tmp1[xindex] = in_ptr1[x0]
174                        tmp2 = hl.Func()
175                        tmp2[xindex] = tmp0[xindex] + tmp1[xindex]
176                        out_ptr0[x0] = tmp2[xindex]
177
178                        assert not g.using_autoscheduler()
179                        i = hl.Var()
180                        j = hl.Var()
181                        out_ptr0.compute_root()
182                        out_ptr0.split(xindex, i, j, 32)
183                        out_ptr0.parallel(i)
184                        out_ptr0.vectorize(j)
185                        tmp2.compute_at(out_ptr0, i)
186                        tmp2.store_at(out_ptr0, i)
187                        tmp1.compute_inline()
188
189                __name__ == '__main__' and hl.main()
190                """
191            ),
192        )
193        a = torch.randn(1024)
194        b = torch.randn(1024)
195        c = torch.randn(1024)
196        fn(a, b, c)
197        self.assertEqual(c, a + b)
198
199    @unittest.skipUnless(has_triton(), "requires triton")
200    def test_random_consistency(self):
201        seed = 1234
202        shape = (3, 3)
203        dtype = torch.float32
204
205        for (rand_fn,) in itertools.product(
206            (
207                functools.partial(torch.rand, shape, dtype=dtype, device="cuda"),
208                functools.partial(torch.randn, shape, dtype=dtype, device="cuda"),
209                functools.partial(
210                    torch.randint,
211                    -1000,
212                    1000,
213                    size=shape,
214                    dtype=torch.int64,
215                    device="cuda",
216                ),
217            )
218        ):
219
220            @torch.compile(backend="inductor", options={"cuda_backend": "halide"})
221            def get_rand_halide():
222                return rand_fn()
223
224            @torch.compile(backend="inductor", options={"cuda_backend": "triton"})
225            def get_rand_triton():
226                return rand_fn()
227
228            torch.manual_seed(seed)
229            halide_output = get_rand_halide()
230            torch.manual_seed(seed)
231            triton_output = get_rand_triton()
232
233        self.assertEqual(halide_output, triton_output)
234
235
236if test_torchinductor.HAS_CPU and HAS_HALIDE:
237    SweepInputsCpuHalideTest = make_halide(test_torchinductor.SweepInputsCpuTest)
238    CpuHalideTests = make_halide(test_torchinductor.CpuTests)
239
240if (
241    test_torchinductor.HAS_GPU
242    and HAS_HALIDE
243    and os.environ.get("TEST_HALIDE_GPU") == "1"
244):
245    SweepInputsGPUHalideTest = make_halide(test_torchinductor.SweepInputsGPUTest)
246    GPUHalideTests = make_halide(test_torchinductor.GPUTests)
247
248if __name__ == "__main__":
249    if HAS_CPU and not IS_MACOS and HAS_HALIDE:
250        run_tests(needs="filelock")
251