1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass 4*da0073e9SAndroid Build Coastguard Workerfrom typing import Sequence, TYPE_CHECKING 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torchgen.api.ufunc as ufunc 7*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.translate import translate 8*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import ( 9*da0073e9SAndroid Build Coastguard Worker BaseCType, 10*da0073e9SAndroid Build Coastguard Worker Binding, 11*da0073e9SAndroid Build Coastguard Worker CType, 12*da0073e9SAndroid Build Coastguard Worker Expr, 13*da0073e9SAndroid Build Coastguard Worker NamedCType, 14*da0073e9SAndroid Build Coastguard Worker opmath_t, 15*da0073e9SAndroid Build Coastguard Worker scalar_t, 16*da0073e9SAndroid Build Coastguard Worker StructuredImplSignature, 17*da0073e9SAndroid Build Coastguard Worker VectorizedCType, 18*da0073e9SAndroid Build Coastguard Worker) 19*da0073e9SAndroid Build Coastguard Workerfrom torchgen.context import with_native_function 20*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import ( 21*da0073e9SAndroid Build Coastguard Worker Argument, 22*da0073e9SAndroid Build Coastguard Worker BaseTy, 23*da0073e9SAndroid Build Coastguard Worker BaseType, 24*da0073e9SAndroid Build Coastguard Worker DispatchKey, 25*da0073e9SAndroid Build Coastguard Worker NativeFunctionsGroup, 26*da0073e9SAndroid Build Coastguard Worker ScalarType, 27*da0073e9SAndroid Build Coastguard Worker UfuncKey, 28*da0073e9SAndroid Build Coastguard Worker) 29*da0073e9SAndroid Build Coastguard Workerfrom torchgen.utils import OrderedSet 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Workerif TYPE_CHECKING: 33*da0073e9SAndroid Build Coastguard Worker from torchgen.api.ufunc import UfunctorBindings 34*da0073e9SAndroid Build Coastguard Worker 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 37*da0073e9SAndroid Build Coastguard Worker# 38*da0073e9SAndroid Build Coastguard Worker# CUDA STUFF 39*da0073e9SAndroid Build Coastguard Worker# 40*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker# NB: not bothering to generate dispatch stub forward declaration in header, 43*da0073e9SAndroid Build Coastguard Worker# we can just paste it whereever necessary 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker# TODO: use BackendIndex 46*da0073e9SAndroid Build Coastguard Worker# dispatch_key: DispatchKey # only CPU/CUDA right now 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker# Represents functors for implementing CUDA ufuncs. 50*da0073e9SAndroid Build Coastguard Worker# Functors are templated by scalar_t because when USERS instantiate functors 51*da0073e9SAndroid Build Coastguard Worker# they are templated. A functor looks something like this: 52*da0073e9SAndroid Build Coastguard Worker# 53*da0073e9SAndroid Build Coastguard Worker# template <typename scalar_t> 54*da0073e9SAndroid Build Coastguard Worker# struct CUDAFunctorOnSelf_add { 55*da0073e9SAndroid Build Coastguard Worker# using opmath_t = at::opmath_type<scalar_t>; 56*da0073e9SAndroid Build Coastguard Worker# opmath_t other_; 57*da0073e9SAndroid Build Coastguard Worker# opmath_t alpha_; 58*da0073e9SAndroid Build Coastguard Worker# CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) 59*da0073e9SAndroid Build Coastguard Worker# : other_(other), alpha_(alpha) {} 60*da0073e9SAndroid Build Coastguard Worker# __device__ scalar_t operator()(scalar_t self) { 61*da0073e9SAndroid Build Coastguard Worker# return ufunc::add(static_cast<opmath_t>(self), other_, alpha_); 62*da0073e9SAndroid Build Coastguard Worker# } 63*da0073e9SAndroid Build Coastguard Worker# }; 64*da0073e9SAndroid Build Coastguard Worker# 65*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 66*da0073e9SAndroid Build Coastguard Workerclass UfunctorSignature: 67*da0073e9SAndroid Build Coastguard Worker g: NativeFunctionsGroup 68*da0073e9SAndroid Build Coastguard Worker scalar_tensor_idx: int | None 69*da0073e9SAndroid Build Coastguard Worker name: str 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker def arguments(self) -> UfunctorBindings: 72*da0073e9SAndroid Build Coastguard Worker return ufunc.ufunctor_arguments( 73*da0073e9SAndroid Build Coastguard Worker self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t 74*da0073e9SAndroid Build Coastguard Worker ) 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker def fields(self) -> list[Binding]: 77*da0073e9SAndroid Build Coastguard Worker # fields are renamed to have a trailing underscore, as is conventional 78*da0073e9SAndroid Build Coastguard Worker return [b.rename(f"{b.name}_") for b in self.arguments().ctor] 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker def returns_type(self) -> CType: 81*da0073e9SAndroid Build Coastguard Worker # TODO: don't hardcode; return type will be inferred based on tags on 82*da0073e9SAndroid Build Coastguard Worker # the native function 83*da0073e9SAndroid Build Coastguard Worker return BaseCType(scalar_t) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker def decl_fields(self) -> str: 86*da0073e9SAndroid Build Coastguard Worker return "\n".join(f"{f.type} {f.name};" for f in self.fields()) 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker def inline_defn_ctor(self) -> str: 89*da0073e9SAndroid Build Coastguard Worker args_str = ", ".join(a.decl() for a in self.arguments().ctor) 90*da0073e9SAndroid Build Coastguard Worker # NB: hypothetically could do this with translate but the 91*da0073e9SAndroid Build Coastguard Worker # transition here is very regular 92*da0073e9SAndroid Build Coastguard Worker init_str = ", ".join(f"{a.name}_({a.name})" for a in self.arguments().ctor) 93*da0073e9SAndroid Build Coastguard Worker return f"{self.name}({args_str}) : {init_str} {{}}" 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker def decl_apply(self) -> str: 96*da0073e9SAndroid Build Coastguard Worker args_str = ", ".join(a.decl() for a in self.arguments().apply) 97*da0073e9SAndroid Build Coastguard Worker return f"{self.returns_type().cpp_type()} operator()({args_str}) const" 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 101*da0073e9SAndroid Build Coastguard Workerclass UfuncSignature: 102*da0073e9SAndroid Build Coastguard Worker g: NativeFunctionsGroup 103*da0073e9SAndroid Build Coastguard Worker name: str 104*da0073e9SAndroid Build Coastguard Worker compute_t: CType 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker def arguments(self) -> list[Binding]: 107*da0073e9SAndroid Build Coastguard Worker return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t) 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker def call(self, ctx: Sequence[Binding | Expr]) -> str: 110*da0073e9SAndroid Build Coastguard Worker return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})" 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker# steps: 114*da0073e9SAndroid Build Coastguard Worker# 1. take the functional signature 115*da0073e9SAndroid Build Coastguard Worker# 2. use api.ufunc to convert it to template signature. this establishes 116*da0073e9SAndroid Build Coastguard Worker# the type of the template function 117*da0073e9SAndroid Build Coastguard Worker# 3. use api.ufunc (II) to generate a split struct / operator() signature. 118*da0073e9SAndroid Build Coastguard Worker# this establish context in which we call the template signature 119*da0073e9SAndroid Build Coastguard Worker# 120*da0073e9SAndroid Build Coastguard Worker# StructuredImplSignature context 121*da0073e9SAndroid Build Coastguard Worker# ~> functor constructor sig 122*da0073e9SAndroid Build Coastguard Worker# 123*da0073e9SAndroid Build Coastguard Worker# Functor constructor context 124*da0073e9SAndroid Build Coastguard Worker# ~> functor fields sig 125*da0073e9SAndroid Build Coastguard Worker# 126*da0073e9SAndroid Build Coastguard Worker# Functor apply context (functor fields + functor apply sig) 127*da0073e9SAndroid Build Coastguard Worker# ~> template sig 128*da0073e9SAndroid Build Coastguard Worker# 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Workerdef eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool: 132*da0073e9SAndroid Build Coastguard Worker num_tensors = sum( 133*da0073e9SAndroid Build Coastguard Worker 1 for a in g.functional.func.arguments.flat_non_out if a.type.is_tensor_like() 134*da0073e9SAndroid Build Coastguard Worker ) 135*da0073e9SAndroid Build Coastguard Worker return num_tensors == 2 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Workerdef compute_ufunc_cuda_functors( 139*da0073e9SAndroid Build Coastguard Worker g: NativeFunctionsGroup, 140*da0073e9SAndroid Build Coastguard Worker) -> tuple[dict[ScalarType, dict[UfuncKey, UfunctorSignature]], str]: 141*da0073e9SAndroid Build Coastguard Worker # First, build the functors. 142*da0073e9SAndroid Build Coastguard Worker ufunctor_sigs: dict[ScalarType, dict[UfuncKey, UfunctorSignature]] = {} 143*da0073e9SAndroid Build Coastguard Worker ufunctors: list[str] = [] 144*da0073e9SAndroid Build Coastguard Worker loops = g.out.ufunc_inner_loop 145*da0073e9SAndroid Build Coastguard Worker scalar_tensor_idx_lookup = { 146*da0073e9SAndroid Build Coastguard Worker UfuncKey.CUDAFunctorOnSelf: 1, 147*da0073e9SAndroid Build Coastguard Worker UfuncKey.CUDAFunctorOnOther: 0, 148*da0073e9SAndroid Build Coastguard Worker UfuncKey.CUDAFunctor: None, 149*da0073e9SAndroid Build Coastguard Worker } 150*da0073e9SAndroid Build Coastguard Worker if eligible_for_binary_scalar_specialization(g): 151*da0073e9SAndroid Build Coastguard Worker keys = [ 152*da0073e9SAndroid Build Coastguard Worker UfuncKey.CUDAFunctorOnSelf, 153*da0073e9SAndroid Build Coastguard Worker UfuncKey.CUDAFunctorOnOther, 154*da0073e9SAndroid Build Coastguard Worker UfuncKey.CUDAFunctor, 155*da0073e9SAndroid Build Coastguard Worker ] 156*da0073e9SAndroid Build Coastguard Worker else: 157*da0073e9SAndroid Build Coastguard Worker keys = [UfuncKey.CUDAFunctor] 158*da0073e9SAndroid Build Coastguard Worker for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]: 159*da0073e9SAndroid Build Coastguard Worker assert k not in loops, f"cannot use {k} on non-binary function" 160*da0073e9SAndroid Build Coastguard Worker for k in keys: 161*da0073e9SAndroid Build Coastguard Worker # If the key was directly defined, skip functor codegen; we assume the 162*da0073e9SAndroid Build Coastguard Worker # user already done it for us 163*da0073e9SAndroid Build Coastguard Worker if k in loops: 164*da0073e9SAndroid Build Coastguard Worker ufunctor_sig = UfunctorSignature( 165*da0073e9SAndroid Build Coastguard Worker g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name 166*da0073e9SAndroid Build Coastguard Worker ) 167*da0073e9SAndroid Build Coastguard Worker for dtype in loops[k].supported_dtypes: 168*da0073e9SAndroid Build Coastguard Worker ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig 169*da0073e9SAndroid Build Coastguard Worker continue 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker # Note [ScalarOnly and Generic must match names for CUDA] 172*da0073e9SAndroid Build Coastguard Worker # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 173*da0073e9SAndroid Build Coastguard Worker # Otherwise, look in ANY of the generic entries. For simplicity of 174*da0073e9SAndroid Build Coastguard Worker # codegen, both ScalarOnly and Generic are defined, the ufunc name 175*da0073e9SAndroid Build Coastguard Worker # must match (if they didn't match, we'd have to generate distinct 176*da0073e9SAndroid Build Coastguard Worker # functors per dtype, which is awful, so we're not going to do it unless 177*da0073e9SAndroid Build Coastguard Worker # someone really forces us to) 178*da0073e9SAndroid Build Coastguard Worker ufunc_name = None 179*da0073e9SAndroid Build Coastguard Worker supported_dtypes: OrderedSet[ScalarType] = OrderedSet() 180*da0073e9SAndroid Build Coastguard Worker for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]: 181*da0073e9SAndroid Build Coastguard Worker if lk not in loops: 182*da0073e9SAndroid Build Coastguard Worker continue 183*da0073e9SAndroid Build Coastguard Worker if ufunc_name is None: 184*da0073e9SAndroid Build Coastguard Worker ufunc_name = loops[lk].name 185*da0073e9SAndroid Build Coastguard Worker else: 186*da0073e9SAndroid Build Coastguard Worker # See Note [ScalarOnly and Generic must match names for CUDA] 187*da0073e9SAndroid Build Coastguard Worker assert ( 188*da0073e9SAndroid Build Coastguard Worker ufunc_name == loops[lk].name 189*da0073e9SAndroid Build Coastguard Worker ), "ScalarOnly and Generic must have same ufunc name" 190*da0073e9SAndroid Build Coastguard Worker supported_dtypes |= loops[lk].supported_dtypes 191*da0073e9SAndroid Build Coastguard Worker assert ufunc_name is not None 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker name = f"{k}_{ufunc_name}" 194*da0073e9SAndroid Build Coastguard Worker ufunctor_sig = UfunctorSignature( 195*da0073e9SAndroid Build Coastguard Worker g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name 196*da0073e9SAndroid Build Coastguard Worker ) 197*da0073e9SAndroid Build Coastguard Worker for dtype in supported_dtypes: 198*da0073e9SAndroid Build Coastguard Worker ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker ufunc_sig = UfuncSignature( 201*da0073e9SAndroid Build Coastguard Worker g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t) 202*da0073e9SAndroid Build Coastguard Worker ) 203*da0073e9SAndroid Build Coastguard Worker apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply 204*da0073e9SAndroid Build Coastguard Worker ufunctors.append( 205*da0073e9SAndroid Build Coastguard Worker f""" 206*da0073e9SAndroid Build Coastguard Workertemplate <typename scalar_t> 207*da0073e9SAndroid Build Coastguard Workerstruct {ufunctor_sig.name} {{ 208*da0073e9SAndroid Build Coastguard Worker using opmath_t = at::opmath_type<scalar_t>; 209*da0073e9SAndroid Build Coastguard Worker {ufunctor_sig.decl_fields()} 210*da0073e9SAndroid Build Coastguard Worker {ufunctor_sig.inline_defn_ctor()} 211*da0073e9SAndroid Build Coastguard Worker __device__ {ufunctor_sig.decl_apply()} {{ 212*da0073e9SAndroid Build Coastguard Worker return {ufunc_sig.call(apply_ctx)}; 213*da0073e9SAndroid Build Coastguard Worker }} 214*da0073e9SAndroid Build Coastguard Worker}}; 215*da0073e9SAndroid Build Coastguard Worker""" 216*da0073e9SAndroid Build Coastguard Worker ) 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker return ufunctor_sigs, "\n".join(ufunctors) 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 222*da0073e9SAndroid Build Coastguard Workerclass BinaryScalarSpecializationConfig: 223*da0073e9SAndroid Build Coastguard Worker scalar_idx: int 224*da0073e9SAndroid Build Coastguard Worker ctor_tensor: str 225*da0073e9SAndroid Build Coastguard Worker ufunc_key: UfuncKey 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard WorkerBinaryScalarSpecializationConfigs = [ 229*da0073e9SAndroid Build Coastguard Worker BinaryScalarSpecializationConfig( 230*da0073e9SAndroid Build Coastguard Worker scalar_idx=0, 231*da0073e9SAndroid Build Coastguard Worker ctor_tensor="self", 232*da0073e9SAndroid Build Coastguard Worker ufunc_key=UfuncKey.CUDAFunctorOnOther, 233*da0073e9SAndroid Build Coastguard Worker ), 234*da0073e9SAndroid Build Coastguard Worker BinaryScalarSpecializationConfig( 235*da0073e9SAndroid Build Coastguard Worker scalar_idx=1, 236*da0073e9SAndroid Build Coastguard Worker ctor_tensor="other", 237*da0073e9SAndroid Build Coastguard Worker ufunc_key=UfuncKey.CUDAFunctorOnSelf, 238*da0073e9SAndroid Build Coastguard Worker ), 239*da0073e9SAndroid Build Coastguard Worker] 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Workerdef compute_ufunc_cuda_dtype_body( 243*da0073e9SAndroid Build Coastguard Worker g: NativeFunctionsGroup, 244*da0073e9SAndroid Build Coastguard Worker dtype: ScalarType, 245*da0073e9SAndroid Build Coastguard Worker inner_loops: dict[UfuncKey, UfunctorSignature], 246*da0073e9SAndroid Build Coastguard Worker parent_ctx: Sequence[Binding], 247*da0073e9SAndroid Build Coastguard Worker) -> str: 248*da0073e9SAndroid Build Coastguard Worker body = "using opmath_t = at::opmath_type<scalar_t>;" 249*da0073e9SAndroid Build Coastguard Worker body += "if (false) {}\n" # for ease of codegen 250*da0073e9SAndroid Build Coastguard Worker for config in BinaryScalarSpecializationConfigs: 251*da0073e9SAndroid Build Coastguard Worker if config.ufunc_key not in inner_loops: 252*da0073e9SAndroid Build Coastguard Worker continue 253*da0073e9SAndroid Build Coastguard Worker ufunctor_sig = inner_loops[config.ufunc_key] 254*da0073e9SAndroid Build Coastguard Worker scalar_idx = config.scalar_idx + 1 255*da0073e9SAndroid Build Coastguard Worker # Make a copy and at the same time widen the type (not permissible 256*da0073e9SAndroid Build Coastguard Worker # without copy; we don't want to mutate the input argument anyway) 257*da0073e9SAndroid Build Coastguard Worker ctx: list[Expr | Binding] = list(parent_ctx) 258*da0073e9SAndroid Build Coastguard Worker ctx.append( 259*da0073e9SAndroid Build Coastguard Worker Expr( 260*da0073e9SAndroid Build Coastguard Worker expr=f"iter.scalar_value<opmath_t>({scalar_idx})", 261*da0073e9SAndroid Build Coastguard Worker type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)), 262*da0073e9SAndroid Build Coastguard Worker ) 263*da0073e9SAndroid Build Coastguard Worker ) 264*da0073e9SAndroid Build Coastguard Worker ufunctor_ctor_exprs_str = ", ".join( 265*da0073e9SAndroid Build Coastguard Worker a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor) 266*da0073e9SAndroid Build Coastguard Worker ) 267*da0073e9SAndroid Build Coastguard Worker 268*da0073e9SAndroid Build Coastguard Worker # NB: ufunctor must be allocated before iter.remove_operand is called, 269*da0073e9SAndroid Build Coastguard Worker # as it relies on iter 270*da0073e9SAndroid Build Coastguard Worker body += f"""\ 271*da0073e9SAndroid Build Coastguard Workerelse if (iter.is_cpu_scalar({scalar_idx})) {{ 272*da0073e9SAndroid Build Coastguard Worker {ufunctor_sig.name}<scalar_t> ufunctor({ufunctor_ctor_exprs_str}); 273*da0073e9SAndroid Build Coastguard Worker iter.remove_operand({scalar_idx}); 274*da0073e9SAndroid Build Coastguard Worker gpu_kernel(iter, ufunctor); 275*da0073e9SAndroid Build Coastguard Worker}}""" 276*da0073e9SAndroid Build Coastguard Worker 277*da0073e9SAndroid Build Coastguard Worker ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor] 278*da0073e9SAndroid Build Coastguard Worker ufunctor_ctor_exprs_str = ", ".join( 279*da0073e9SAndroid Build Coastguard Worker a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor) 280*da0073e9SAndroid Build Coastguard Worker ) 281*da0073e9SAndroid Build Coastguard Worker body += f""" 282*da0073e9SAndroid Build Coastguard Workerelse {{ 283*da0073e9SAndroid Build Coastguard Worker gpu_kernel(iter, {ufunctor_sig.name}<scalar_t>({ufunctor_ctor_exprs_str})); 284*da0073e9SAndroid Build Coastguard Worker}} 285*da0073e9SAndroid Build Coastguard Worker """ 286*da0073e9SAndroid Build Coastguard Worker return body 287*da0073e9SAndroid Build Coastguard Worker 288*da0073e9SAndroid Build Coastguard Worker 289*da0073e9SAndroid Build Coastguard Worker@with_native_function 290*da0073e9SAndroid Build Coastguard Workerdef compute_ufunc_cuda(g: NativeFunctionsGroup) -> str: 291*da0073e9SAndroid Build Coastguard Worker # First, build the functors, indexing them by dtype 292*da0073e9SAndroid Build Coastguard Worker ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g) 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker # Next, build the conditionals 295*da0073e9SAndroid Build Coastguard Worker sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA)) 296*da0073e9SAndroid Build Coastguard Worker dtype_cases = [] 297*da0073e9SAndroid Build Coastguard Worker for dtype, inner_ufunc_sigs in ufunctor_sigs.items(): 298*da0073e9SAndroid Build Coastguard Worker dtype_cases.append( 299*da0073e9SAndroid Build Coastguard Worker f""" 300*da0073e9SAndroid Build Coastguard WorkerAT_DISPATCH_CASE(at::ScalarType::{dtype}, 301*da0073e9SAndroid Build Coastguard Worker [&]() {{ 302*da0073e9SAndroid Build Coastguard Worker {compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunc_sigs, sig.arguments())} 303*da0073e9SAndroid Build Coastguard Worker }} 304*da0073e9SAndroid Build Coastguard Worker) 305*da0073e9SAndroid Build Coastguard Worker""" 306*da0073e9SAndroid Build Coastguard Worker ) 307*da0073e9SAndroid Build Coastguard Worker 308*da0073e9SAndroid Build Coastguard Worker dtype_cases_str = "\n".join(dtype_cases) 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker stub_sig = StubSignature(g) 311*da0073e9SAndroid Build Coastguard Worker 312*da0073e9SAndroid Build Coastguard Worker return f""" 313*da0073e9SAndroid Build Coastguard Worker{ufunctors} 314*da0073e9SAndroid Build Coastguard Worker 315*da0073e9SAndroid Build Coastguard Worker{stub_sig.type_defn()}; 316*da0073e9SAndroid Build Coastguard Worker{stub_sig.dispatch_decl()}; 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Worker{stub_sig.kernel_defn()} {{ 319*da0073e9SAndroid Build Coastguard Worker AT_DISPATCH_SWITCH(iter.common_dtype(), "{sig.name}", 320*da0073e9SAndroid Build Coastguard Worker {dtype_cases_str} 321*da0073e9SAndroid Build Coastguard Worker ); 322*da0073e9SAndroid Build Coastguard Worker}} 323*da0073e9SAndroid Build Coastguard WorkerREGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name}); 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker{sig.defn()} {{ 326*da0073e9SAndroid Build Coastguard Worker {stub_sig.direct_call(sig.arguments())}; 327*da0073e9SAndroid Build Coastguard Worker}} 328*da0073e9SAndroid Build Coastguard Worker""" 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Worker 331*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 332*da0073e9SAndroid Build Coastguard Worker# 333*da0073e9SAndroid Build Coastguard Worker# CPU STUFF 334*da0073e9SAndroid Build Coastguard Worker# 335*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # 336*da0073e9SAndroid Build Coastguard Worker 337*da0073e9SAndroid Build Coastguard Worker 338*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True) 339*da0073e9SAndroid Build Coastguard Workerclass StubSignature: 340*da0073e9SAndroid Build Coastguard Worker g: NativeFunctionsGroup 341*da0073e9SAndroid Build Coastguard Worker 342*da0073e9SAndroid Build Coastguard Worker @property 343*da0073e9SAndroid Build Coastguard Worker def name(self) -> str: 344*da0073e9SAndroid Build Coastguard Worker return f"{str(self.g.functional.func.name.name)}_stub" 345*da0073e9SAndroid Build Coastguard Worker 346*da0073e9SAndroid Build Coastguard Worker @property 347*da0073e9SAndroid Build Coastguard Worker def kernel_name(self) -> str: 348*da0073e9SAndroid Build Coastguard Worker return f"{str(self.g.functional.func.name.name)}_kernel" 349*da0073e9SAndroid Build Coastguard Worker 350*da0073e9SAndroid Build Coastguard Worker @property 351*da0073e9SAndroid Build Coastguard Worker def type_name(self) -> str: 352*da0073e9SAndroid Build Coastguard Worker return f"{str(self.g.functional.func.name.name)}_fn" 353*da0073e9SAndroid Build Coastguard Worker 354*da0073e9SAndroid Build Coastguard Worker def arguments(self) -> list[Binding]: 355*da0073e9SAndroid Build Coastguard Worker return ufunc.stub_arguments(self.g) 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker def type(self) -> str: 358*da0073e9SAndroid Build Coastguard Worker cpp_args = self.arguments() 359*da0073e9SAndroid Build Coastguard Worker return f"void(*)(TensorIteratorBase&, {', '.join(a.type for a in cpp_args)})" 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker def dispatch_decl(self) -> str: 362*da0073e9SAndroid Build Coastguard Worker return f"DECLARE_DISPATCH({self.type_name}, {self.name})" 363*da0073e9SAndroid Build Coastguard Worker 364*da0073e9SAndroid Build Coastguard Worker def dispatch_defn(self) -> str: 365*da0073e9SAndroid Build Coastguard Worker return f"DEFINE_DISPATCH({self.name})" 366*da0073e9SAndroid Build Coastguard Worker 367*da0073e9SAndroid Build Coastguard Worker def kernel_defn(self) -> str: 368*da0073e9SAndroid Build Coastguard Worker return f"void {self.kernel_name}(TensorIteratorBase& iter, {', '.join(a.defn() for a in self.arguments())})" 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker def type_defn(self) -> str: 371*da0073e9SAndroid Build Coastguard Worker return f"using {self.type_name} = {self.type()}" 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker # must be called from context where this is TensorIteratorBase* 374*da0073e9SAndroid Build Coastguard Worker def call(self, ctx: Sequence[Binding]) -> str: 375*da0073e9SAndroid Build Coastguard Worker return f"{self.name}(device_type(), *this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})" 376*da0073e9SAndroid Build Coastguard Worker 377*da0073e9SAndroid Build Coastguard Worker # used in CUDA to skip the unnecessary dynamic dispatch 378*da0073e9SAndroid Build Coastguard Worker def direct_call(self, ctx: Sequence[Binding]) -> str: 379*da0073e9SAndroid Build Coastguard Worker return f"{self.kernel_name}(*this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})" 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Worker@with_native_function 383*da0073e9SAndroid Build Coastguard Workerdef compute_ufunc_cpu(g: NativeFunctionsGroup) -> str: 384*da0073e9SAndroid Build Coastguard Worker stub_sig = StubSignature(g) 385*da0073e9SAndroid Build Coastguard Worker sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CPU)) 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Worker return f""" 388*da0073e9SAndroid Build Coastguard Worker{stub_sig.type_defn()}; 389*da0073e9SAndroid Build Coastguard Worker{stub_sig.dispatch_decl()}; 390*da0073e9SAndroid Build Coastguard Worker{stub_sig.dispatch_defn()}; 391*da0073e9SAndroid Build Coastguard Worker 392*da0073e9SAndroid Build Coastguard Worker{sig.defn()} {{ 393*da0073e9SAndroid Build Coastguard Worker {stub_sig.call(sig.arguments())}; 394*da0073e9SAndroid Build Coastguard Worker}} 395*da0073e9SAndroid Build Coastguard Worker""" 396*da0073e9SAndroid Build Coastguard Worker 397*da0073e9SAndroid Build Coastguard Worker 398*da0073e9SAndroid Build Coastguard Workerdef compute_ufunc_cpu_dtype_body( 399*da0073e9SAndroid Build Coastguard Worker g: NativeFunctionsGroup, 400*da0073e9SAndroid Build Coastguard Worker dtype: ScalarType, 401*da0073e9SAndroid Build Coastguard Worker inner_loops: dict[UfuncKey, UfuncSignature], 402*da0073e9SAndroid Build Coastguard Worker parent_ctx: Sequence[Binding], 403*da0073e9SAndroid Build Coastguard Worker) -> str: 404*da0073e9SAndroid Build Coastguard Worker assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}" 405*da0073e9SAndroid Build Coastguard Worker assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector} 406*da0073e9SAndroid Build Coastguard Worker scalar_loop = inner_loops[UfuncKey.CPUScalar] 407*da0073e9SAndroid Build Coastguard Worker vec_loop = None 408*da0073e9SAndroid Build Coastguard Worker if UfuncKey.CPUVector in inner_loops: 409*da0073e9SAndroid Build Coastguard Worker vec_loop = inner_loops[UfuncKey.CPUVector] 410*da0073e9SAndroid Build Coastguard Worker 411*da0073e9SAndroid Build Coastguard Worker # NB: We DON'T use translate here, because translate is 412*da0073e9SAndroid Build Coastguard Worker # incapable of CSE'ing the scalar accesses in case it is also 413*da0073e9SAndroid Build Coastguard Worker # used by Vectorized; also, the unpacking here is very simple 414*da0073e9SAndroid Build Coastguard Worker # and only affects Scalar; everything else is implicitly captured 415*da0073e9SAndroid Build Coastguard Worker # by the lambda 416*da0073e9SAndroid Build Coastguard Worker 417*da0073e9SAndroid Build Coastguard Worker # Setup scalar in scope 418*da0073e9SAndroid Build Coastguard Worker body = [] 419*da0073e9SAndroid Build Coastguard Worker ctx = [] 420*da0073e9SAndroid Build Coastguard Worker for b in parent_ctx: 421*da0073e9SAndroid Build Coastguard Worker if isinstance(b.argument, Argument) and b.argument.type != BaseType( 422*da0073e9SAndroid Build Coastguard Worker BaseTy.Scalar 423*da0073e9SAndroid Build Coastguard Worker ): 424*da0073e9SAndroid Build Coastguard Worker continue 425*da0073e9SAndroid Build Coastguard Worker body.append(f"auto _s_{b.name} = {b.name}.to<scalar_t>();") 426*da0073e9SAndroid Build Coastguard Worker ctx.append(Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t)))) 427*da0073e9SAndroid Build Coastguard Worker if vec_loop is not None: 428*da0073e9SAndroid Build Coastguard Worker for b in parent_ctx: 429*da0073e9SAndroid Build Coastguard Worker if isinstance(b.argument, Argument) and b.argument.type != BaseType( 430*da0073e9SAndroid Build Coastguard Worker BaseTy.Scalar 431*da0073e9SAndroid Build Coastguard Worker ): 432*da0073e9SAndroid Build Coastguard Worker continue 433*da0073e9SAndroid Build Coastguard Worker body.append( 434*da0073e9SAndroid Build Coastguard Worker f"auto _v_{b.name} = at::vec::Vectorized<scalar_t>(_s_{b.name});" 435*da0073e9SAndroid Build Coastguard Worker ) 436*da0073e9SAndroid Build Coastguard Worker ctx.append( 437*da0073e9SAndroid Build Coastguard Worker Expr( 438*da0073e9SAndroid Build Coastguard Worker f"_v_{b.name}", 439*da0073e9SAndroid Build Coastguard Worker NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))), 440*da0073e9SAndroid Build Coastguard Worker ) 441*da0073e9SAndroid Build Coastguard Worker ) 442*da0073e9SAndroid Build Coastguard Worker 443*da0073e9SAndroid Build Coastguard Worker # Setup lambda signature 444*da0073e9SAndroid Build Coastguard Worker # NB: simplified version of ufunctor_arguments 445*da0073e9SAndroid Build Coastguard Worker scalar_bindings = [] 446*da0073e9SAndroid Build Coastguard Worker vec_bindings = [] 447*da0073e9SAndroid Build Coastguard Worker for a in g.functional.func.arguments.flat_non_out: 448*da0073e9SAndroid Build Coastguard Worker if not a.type.is_tensor_like(): 449*da0073e9SAndroid Build Coastguard Worker continue 450*da0073e9SAndroid Build Coastguard Worker assert a.type == BaseType(BaseTy.Tensor) 451*da0073e9SAndroid Build Coastguard Worker scalar_bindings.append( 452*da0073e9SAndroid Build Coastguard Worker Binding( 453*da0073e9SAndroid Build Coastguard Worker name=a.name, 454*da0073e9SAndroid Build Coastguard Worker nctype=NamedCType(a.name, BaseCType(scalar_t)), 455*da0073e9SAndroid Build Coastguard Worker argument=a, 456*da0073e9SAndroid Build Coastguard Worker ) 457*da0073e9SAndroid Build Coastguard Worker ) 458*da0073e9SAndroid Build Coastguard Worker if vec_loop is not None: 459*da0073e9SAndroid Build Coastguard Worker vec_bindings.append( 460*da0073e9SAndroid Build Coastguard Worker Binding( 461*da0073e9SAndroid Build Coastguard Worker name=a.name, 462*da0073e9SAndroid Build Coastguard Worker nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))), 463*da0073e9SAndroid Build Coastguard Worker argument=a, 464*da0073e9SAndroid Build Coastguard Worker ) 465*da0073e9SAndroid Build Coastguard Worker ) 466*da0073e9SAndroid Build Coastguard Worker 467*da0073e9SAndroid Build Coastguard Worker def with_ctx(b: Sequence[Binding]) -> list[Expr | Binding]: 468*da0073e9SAndroid Build Coastguard Worker r: list[Expr | Binding] = [] 469*da0073e9SAndroid Build Coastguard Worker r.extend(ctx) 470*da0073e9SAndroid Build Coastguard Worker r.extend(b) 471*da0073e9SAndroid Build Coastguard Worker return r 472*da0073e9SAndroid Build Coastguard Worker 473*da0073e9SAndroid Build Coastguard Worker body_str = "\n".join(body) 474*da0073e9SAndroid Build Coastguard Worker if vec_loop is not None: 475*da0073e9SAndroid Build Coastguard Worker return f""" 476*da0073e9SAndroid Build Coastguard Worker{body_str} 477*da0073e9SAndroid Build Coastguard Workercpu_kernel_vec(iter, 478*da0073e9SAndroid Build Coastguard Worker [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}, 479*da0073e9SAndroid Build Coastguard Worker [=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }} 480*da0073e9SAndroid Build Coastguard Worker); 481*da0073e9SAndroid Build Coastguard Worker""" 482*da0073e9SAndroid Build Coastguard Worker else: 483*da0073e9SAndroid Build Coastguard Worker return f""" 484*da0073e9SAndroid Build Coastguard Worker{body_str} 485*da0073e9SAndroid Build Coastguard Workercpu_kernel(iter, 486*da0073e9SAndroid Build Coastguard Worker [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }} 487*da0073e9SAndroid Build Coastguard Worker); 488*da0073e9SAndroid Build Coastguard Worker""" 489*da0073e9SAndroid Build Coastguard Worker 490*da0073e9SAndroid Build Coastguard Worker 491*da0073e9SAndroid Build Coastguard Worker@with_native_function 492*da0073e9SAndroid Build Coastguard Workerdef compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str: 493*da0073e9SAndroid Build Coastguard Worker stub_sig = StubSignature(g) 494*da0073e9SAndroid Build Coastguard Worker 495*da0073e9SAndroid Build Coastguard Worker # Reindex the ufunc by dtypes; processing generic/scalaronly as well 496*da0073e9SAndroid Build Coastguard Worker loops = g.out.ufunc_inner_loop 497*da0073e9SAndroid Build Coastguard Worker ufunc_sigs: dict[ScalarType, dict[UfuncKey, UfuncSignature]] = {} 498*da0073e9SAndroid Build Coastguard Worker for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]: 499*da0073e9SAndroid Build Coastguard Worker lks = [] 500*da0073e9SAndroid Build Coastguard Worker # ORDER MATTERS: this specifies overriding precedence 501*da0073e9SAndroid Build Coastguard Worker if k in loops: # should happen rarely 502*da0073e9SAndroid Build Coastguard Worker lks.append(k) 503*da0073e9SAndroid Build Coastguard Worker if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar: 504*da0073e9SAndroid Build Coastguard Worker lks.append(UfuncKey.ScalarOnly) 505*da0073e9SAndroid Build Coastguard Worker if UfuncKey.Generic in loops: 506*da0073e9SAndroid Build Coastguard Worker lks.append(UfuncKey.Generic) 507*da0073e9SAndroid Build Coastguard Worker # TODO: don't hardcode ufunc:: namespace here, should be centralized smh 508*da0073e9SAndroid Build Coastguard Worker for lk in lks: 509*da0073e9SAndroid Build Coastguard Worker for dtype in loops[lk].supported_dtypes: 510*da0073e9SAndroid Build Coastguard Worker compute_t: CType 511*da0073e9SAndroid Build Coastguard Worker if k is UfuncKey.CPUScalar: 512*da0073e9SAndroid Build Coastguard Worker compute_t = BaseCType(scalar_t) 513*da0073e9SAndroid Build Coastguard Worker elif k is UfuncKey.CPUVector: 514*da0073e9SAndroid Build Coastguard Worker compute_t = VectorizedCType(BaseCType(scalar_t)) 515*da0073e9SAndroid Build Coastguard Worker else: 516*da0073e9SAndroid Build Coastguard Worker raise AssertionError 517*da0073e9SAndroid Build Coastguard Worker inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {}) 518*da0073e9SAndroid Build Coastguard Worker if k not in inner_ufunc_sigs: 519*da0073e9SAndroid Build Coastguard Worker inner_ufunc_sigs[k] = UfuncSignature( 520*da0073e9SAndroid Build Coastguard Worker g, name=f"ufunc::{loops[lk].name}", compute_t=compute_t 521*da0073e9SAndroid Build Coastguard Worker ) 522*da0073e9SAndroid Build Coastguard Worker 523*da0073e9SAndroid Build Coastguard Worker # Build the conditionals 524*da0073e9SAndroid Build Coastguard Worker dtype_cases = [] 525*da0073e9SAndroid Build Coastguard Worker for dtype, inner_ufunc_sigs in ufunc_sigs.items(): 526*da0073e9SAndroid Build Coastguard Worker dtype_cases.append( 527*da0073e9SAndroid Build Coastguard Worker f""" 528*da0073e9SAndroid Build Coastguard WorkerAT_DISPATCH_CASE(at::ScalarType::{dtype}, 529*da0073e9SAndroid Build Coastguard Worker [&]() {{ 530*da0073e9SAndroid Build Coastguard Worker {compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())} 531*da0073e9SAndroid Build Coastguard Worker }} 532*da0073e9SAndroid Build Coastguard Worker) 533*da0073e9SAndroid Build Coastguard Worker""" 534*da0073e9SAndroid Build Coastguard Worker ) 535*da0073e9SAndroid Build Coastguard Worker 536*da0073e9SAndroid Build Coastguard Worker dtype_cases_str = "\n".join(dtype_cases) 537*da0073e9SAndroid Build Coastguard Worker return f""" 538*da0073e9SAndroid Build Coastguard Workernamespace {{ 539*da0073e9SAndroid Build Coastguard Worker 540*da0073e9SAndroid Build Coastguard Worker{stub_sig.kernel_defn()} {{ 541*da0073e9SAndroid Build Coastguard Worker AT_DISPATCH_SWITCH(iter.common_dtype(), "{stub_sig.name}", 542*da0073e9SAndroid Build Coastguard Worker {dtype_cases_str} 543*da0073e9SAndroid Build Coastguard Worker ); 544*da0073e9SAndroid Build Coastguard Worker}} 545*da0073e9SAndroid Build Coastguard Worker 546*da0073e9SAndroid Build Coastguard Worker}} // anonymous namespace 547*da0073e9SAndroid Build Coastguard Worker 548*da0073e9SAndroid Build Coastguard Worker{stub_sig.type_defn()}; 549*da0073e9SAndroid Build Coastguard Worker{stub_sig.dispatch_decl()}; 550*da0073e9SAndroid Build Coastguard WorkerREGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name}); 551*da0073e9SAndroid Build Coastguard Worker""" 552