1# Owner(s): ["oncall: pt2"] 2import functools 3import itertools 4import os 5import sys 6import textwrap 7import unittest 8 9import torch 10import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools 11from torch._inductor import config 12from torch._inductor.codecache import HalideCodeCache 13from torch._inductor.runtime.hints import HalideInputSpec, HalideMeta 14from torch._inductor.test_case import run_tests, TestCase 15from torch._inductor.utils import parallel_num_threads 16from torch.testing._internal.common_utils import IS_CI, IS_MACOS, IS_WINDOWS 17from torch.testing._internal.inductor_utils import HAS_CPU 18from torch.utils._triton import has_triton 19 20 21if IS_WINDOWS and IS_CI: 22 sys.stderr.write( 23 "Windows CI does not have necessary dependencies for test_torchinductor_dynamic_shapes yet\n" 24 ) 25 if __name__ == "__main__": 26 sys.exit(0) 27 raise unittest.SkipTest("requires sympy/functorch/filelock") 28 29try: 30 import halide 31 32 HAS_HALIDE = halide is not None 33except ImportError: 34 HAS_HALIDE = False 35 36 37try: 38 from . import test_torchinductor 39except ImportError: 40 import test_torchinductor 41 42 43make_halide = config.patch( 44 { 45 "halide.scan_kernels": True, 46 "cpu_backend": "halide", 47 "cuda_backend": "halide", 48 } 49) 50 51 52@unittest.skipUnless(HAS_HALIDE, "requires halide") 53class HalideTests(TestCase): 54 def test_codecache(self): 55 fn = HalideCodeCache.generate_halide( 56 HalideMeta( 57 argtypes=[ 58 HalideInputSpec( 59 ctype="float*", 60 name="in_ptr0", 61 shape=["1024L"], 62 stride=["1L"], 63 offset="0", 64 ), 65 HalideInputSpec( 66 ctype="float*", 67 name="in_ptr1", 68 shape=["1024L"], 69 stride=["1L"], 70 offset="0", 71 ), 72 HalideInputSpec( 73 ctype="float*", 74 name="out_ptr0", 75 shape=["1024L"], 76 stride=["1L"], 77 offset="0", 78 ), 79 ], 80 target="host-no_runtime", 81 scheduler="Mullapudi2016", 82 scheduler_flags={ 83 "parallelism": parallel_num_threads(), 84 }, 85 ), 86 textwrap.dedent( 87 """ 88 import halide as hl 89 90 @hl.generator(name="kernel") 91 class Kernel: 92 in_ptr0 = hl.InputBuffer(hl.Float(32), 1) 93 in_ptr1 = hl.InputBuffer(hl.Float(32), 1) 94 out_ptr0 = hl.OutputBuffer(hl.Float(32), 1) 95 96 def generate(g): 97 in_ptr0 = g.in_ptr0 98 in_ptr1 = g.in_ptr1 99 out_ptr0 = g.out_ptr0 100 xindex = hl.Var('xindex') 101 x0 = xindex 102 tmp0 = hl.Func() 103 tmp0[xindex] = in_ptr0[x0] 104 tmp1 = hl.Func() 105 tmp1[xindex] = in_ptr1[x0] 106 tmp2 = hl.Func() 107 tmp2[xindex] = tmp0[xindex] + tmp1[xindex] 108 out_ptr0[x0] = tmp2[xindex] 109 110 assert g.using_autoscheduler() 111 in_ptr0.set_estimates([hl.Range(1024, 1024)]) 112 in_ptr1.set_estimates([hl.Range(1024, 1024)]) 113 out_ptr0.set_estimates([hl.Range(1024, 1024)]) 114 115 __name__ == '__main__' and hl.main() 116 """ 117 ), 118 ) 119 a = torch.randn(1024) 120 b = torch.randn(1024) 121 c = torch.randn(1024) 122 fn(a, b, c) 123 self.assertEqual(c, a + b) 124 125 def test_manual_schedule(self): 126 fn = HalideCodeCache.generate_halide( 127 HalideMeta( 128 argtypes=[ 129 HalideInputSpec( 130 ctype="float*", 131 name="in_ptr0", 132 shape=["1024L"], 133 stride=["1L"], 134 offset="0", 135 ), 136 HalideInputSpec( 137 ctype="float*", 138 name="in_ptr1", 139 shape=["1024L"], 140 stride=["1L"], 141 offset="0", 142 ), 143 HalideInputSpec( 144 ctype="float*", 145 name="out_ptr0", 146 shape=["1024L"], 147 stride=["1L"], 148 offset="0", 149 ), 150 ], 151 target="host-no_runtime", 152 scheduler=None, 153 ), 154 textwrap.dedent( 155 """ 156 import halide as hl 157 158 @hl.generator(name="kernel") 159 class Kernel: 160 in_ptr0 = hl.InputBuffer(hl.Float(32), 1) 161 in_ptr1 = hl.InputBuffer(hl.Float(32), 1) 162 out_ptr0 = hl.OutputBuffer(hl.Float(32), 1) 163 164 def generate(g): 165 in_ptr0 = g.in_ptr0 166 in_ptr1 = g.in_ptr1 167 out_ptr0 = g.out_ptr0 168 xindex = hl.Var('xindex') 169 x0 = xindex 170 tmp0 = hl.Func() 171 tmp0[xindex] = in_ptr0[x0] 172 tmp1 = hl.Func() 173 tmp1[xindex] = in_ptr1[x0] 174 tmp2 = hl.Func() 175 tmp2[xindex] = tmp0[xindex] + tmp1[xindex] 176 out_ptr0[x0] = tmp2[xindex] 177 178 assert not g.using_autoscheduler() 179 i = hl.Var() 180 j = hl.Var() 181 out_ptr0.compute_root() 182 out_ptr0.split(xindex, i, j, 32) 183 out_ptr0.parallel(i) 184 out_ptr0.vectorize(j) 185 tmp2.compute_at(out_ptr0, i) 186 tmp2.store_at(out_ptr0, i) 187 tmp1.compute_inline() 188 189 __name__ == '__main__' and hl.main() 190 """ 191 ), 192 ) 193 a = torch.randn(1024) 194 b = torch.randn(1024) 195 c = torch.randn(1024) 196 fn(a, b, c) 197 self.assertEqual(c, a + b) 198 199 @unittest.skipUnless(has_triton(), "requires triton") 200 def test_random_consistency(self): 201 seed = 1234 202 shape = (3, 3) 203 dtype = torch.float32 204 205 for (rand_fn,) in itertools.product( 206 ( 207 functools.partial(torch.rand, shape, dtype=dtype, device="cuda"), 208 functools.partial(torch.randn, shape, dtype=dtype, device="cuda"), 209 functools.partial( 210 torch.randint, 211 -1000, 212 1000, 213 size=shape, 214 dtype=torch.int64, 215 device="cuda", 216 ), 217 ) 218 ): 219 220 @torch.compile(backend="inductor", options={"cuda_backend": "halide"}) 221 def get_rand_halide(): 222 return rand_fn() 223 224 @torch.compile(backend="inductor", options={"cuda_backend": "triton"}) 225 def get_rand_triton(): 226 return rand_fn() 227 228 torch.manual_seed(seed) 229 halide_output = get_rand_halide() 230 torch.manual_seed(seed) 231 triton_output = get_rand_triton() 232 233 self.assertEqual(halide_output, triton_output) 234 235 236if test_torchinductor.HAS_CPU and HAS_HALIDE: 237 SweepInputsCpuHalideTest = make_halide(test_torchinductor.SweepInputsCpuTest) 238 CpuHalideTests = make_halide(test_torchinductor.CpuTests) 239 240if ( 241 test_torchinductor.HAS_GPU 242 and HAS_HALIDE 243 and os.environ.get("TEST_HALIDE_GPU") == "1" 244): 245 SweepInputsGPUHalideTest = make_halide(test_torchinductor.SweepInputsGPUTest) 246 GPUHalideTests = make_halide(test_torchinductor.GPUTests) 247 248if __name__ == "__main__": 249 if HAS_CPU and not IS_MACOS and HAS_HALIDE: 250 run_tests(needs="filelock") 251