xref: /aosp_15_r20/external/pytorch/torchgen/dest/ufunc.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom dataclasses import dataclass
4*da0073e9SAndroid Build Coastguard Workerfrom typing import Sequence, TYPE_CHECKING
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torchgen.api.ufunc as ufunc
7*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.translate import translate
8*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import (
9*da0073e9SAndroid Build Coastguard Worker    BaseCType,
10*da0073e9SAndroid Build Coastguard Worker    Binding,
11*da0073e9SAndroid Build Coastguard Worker    CType,
12*da0073e9SAndroid Build Coastguard Worker    Expr,
13*da0073e9SAndroid Build Coastguard Worker    NamedCType,
14*da0073e9SAndroid Build Coastguard Worker    opmath_t,
15*da0073e9SAndroid Build Coastguard Worker    scalar_t,
16*da0073e9SAndroid Build Coastguard Worker    StructuredImplSignature,
17*da0073e9SAndroid Build Coastguard Worker    VectorizedCType,
18*da0073e9SAndroid Build Coastguard Worker)
19*da0073e9SAndroid Build Coastguard Workerfrom torchgen.context import with_native_function
20*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import (
21*da0073e9SAndroid Build Coastguard Worker    Argument,
22*da0073e9SAndroid Build Coastguard Worker    BaseTy,
23*da0073e9SAndroid Build Coastguard Worker    BaseType,
24*da0073e9SAndroid Build Coastguard Worker    DispatchKey,
25*da0073e9SAndroid Build Coastguard Worker    NativeFunctionsGroup,
26*da0073e9SAndroid Build Coastguard Worker    ScalarType,
27*da0073e9SAndroid Build Coastguard Worker    UfuncKey,
28*da0073e9SAndroid Build Coastguard Worker)
29*da0073e9SAndroid Build Coastguard Workerfrom torchgen.utils import OrderedSet
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Workerif TYPE_CHECKING:
33*da0073e9SAndroid Build Coastguard Worker    from torchgen.api.ufunc import UfunctorBindings
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
37*da0073e9SAndroid Build Coastguard Worker#
38*da0073e9SAndroid Build Coastguard Worker#                                  CUDA STUFF
39*da0073e9SAndroid Build Coastguard Worker#
40*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker# NB: not bothering to generate dispatch stub forward declaration in header,
43*da0073e9SAndroid Build Coastguard Worker# we can just paste it whereever necessary
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker# TODO: use BackendIndex
46*da0073e9SAndroid Build Coastguard Worker# dispatch_key: DispatchKey  # only CPU/CUDA right now
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker# Represents functors for implementing CUDA ufuncs.
50*da0073e9SAndroid Build Coastguard Worker# Functors are templated by scalar_t because when USERS instantiate functors
51*da0073e9SAndroid Build Coastguard Worker# they are templated.  A functor looks something like this:
52*da0073e9SAndroid Build Coastguard Worker#
53*da0073e9SAndroid Build Coastguard Worker#   template <typename scalar_t>
54*da0073e9SAndroid Build Coastguard Worker#   struct CUDAFunctorOnSelf_add {
55*da0073e9SAndroid Build Coastguard Worker#     using opmath_t = at::opmath_type<scalar_t>;
56*da0073e9SAndroid Build Coastguard Worker#     opmath_t other_;
57*da0073e9SAndroid Build Coastguard Worker#     opmath_t alpha_;
58*da0073e9SAndroid Build Coastguard Worker#     CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha)
59*da0073e9SAndroid Build Coastguard Worker#         : other_(other), alpha_(alpha) {}
60*da0073e9SAndroid Build Coastguard Worker#     __device__ scalar_t operator()(scalar_t self) {
61*da0073e9SAndroid Build Coastguard Worker#       return ufunc::add(static_cast<opmath_t>(self), other_, alpha_);
62*da0073e9SAndroid Build Coastguard Worker#     }
63*da0073e9SAndroid Build Coastguard Worker#   };
64*da0073e9SAndroid Build Coastguard Worker#
65*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
66*da0073e9SAndroid Build Coastguard Workerclass UfunctorSignature:
67*da0073e9SAndroid Build Coastguard Worker    g: NativeFunctionsGroup
68*da0073e9SAndroid Build Coastguard Worker    scalar_tensor_idx: int | None
69*da0073e9SAndroid Build Coastguard Worker    name: str
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker    def arguments(self) -> UfunctorBindings:
72*da0073e9SAndroid Build Coastguard Worker        return ufunc.ufunctor_arguments(
73*da0073e9SAndroid Build Coastguard Worker            self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t
74*da0073e9SAndroid Build Coastguard Worker        )
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker    def fields(self) -> list[Binding]:
77*da0073e9SAndroid Build Coastguard Worker        # fields are renamed to have a trailing underscore, as is conventional
78*da0073e9SAndroid Build Coastguard Worker        return [b.rename(f"{b.name}_") for b in self.arguments().ctor]
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker    def returns_type(self) -> CType:
81*da0073e9SAndroid Build Coastguard Worker        # TODO: don't hardcode; return type will be inferred based on tags on
82*da0073e9SAndroid Build Coastguard Worker        # the native function
83*da0073e9SAndroid Build Coastguard Worker        return BaseCType(scalar_t)
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker    def decl_fields(self) -> str:
86*da0073e9SAndroid Build Coastguard Worker        return "\n".join(f"{f.type} {f.name};" for f in self.fields())
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker    def inline_defn_ctor(self) -> str:
89*da0073e9SAndroid Build Coastguard Worker        args_str = ", ".join(a.decl() for a in self.arguments().ctor)
90*da0073e9SAndroid Build Coastguard Worker        # NB: hypothetically could do this with translate but the
91*da0073e9SAndroid Build Coastguard Worker        # transition here is very regular
92*da0073e9SAndroid Build Coastguard Worker        init_str = ", ".join(f"{a.name}_({a.name})" for a in self.arguments().ctor)
93*da0073e9SAndroid Build Coastguard Worker        return f"{self.name}({args_str}) : {init_str} {{}}"
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker    def decl_apply(self) -> str:
96*da0073e9SAndroid Build Coastguard Worker        args_str = ", ".join(a.decl() for a in self.arguments().apply)
97*da0073e9SAndroid Build Coastguard Worker        return f"{self.returns_type().cpp_type()} operator()({args_str}) const"
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
101*da0073e9SAndroid Build Coastguard Workerclass UfuncSignature:
102*da0073e9SAndroid Build Coastguard Worker    g: NativeFunctionsGroup
103*da0073e9SAndroid Build Coastguard Worker    name: str
104*da0073e9SAndroid Build Coastguard Worker    compute_t: CType
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker    def arguments(self) -> list[Binding]:
107*da0073e9SAndroid Build Coastguard Worker        return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t)
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker    def call(self, ctx: Sequence[Binding | Expr]) -> str:
110*da0073e9SAndroid Build Coastguard Worker        return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})"
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker# steps:
114*da0073e9SAndroid Build Coastguard Worker#   1. take the functional signature
115*da0073e9SAndroid Build Coastguard Worker#   2. use api.ufunc to convert it to template signature.  this establishes
116*da0073e9SAndroid Build Coastguard Worker#      the type of the template function
117*da0073e9SAndroid Build Coastguard Worker#   3. use api.ufunc (II) to generate a split struct / operator() signature.
118*da0073e9SAndroid Build Coastguard Worker#      this establish context in which we call the template signature
119*da0073e9SAndroid Build Coastguard Worker#
120*da0073e9SAndroid Build Coastguard Worker# StructuredImplSignature context
121*da0073e9SAndroid Build Coastguard Worker#   ~> functor constructor sig
122*da0073e9SAndroid Build Coastguard Worker#
123*da0073e9SAndroid Build Coastguard Worker# Functor constructor context
124*da0073e9SAndroid Build Coastguard Worker#   ~> functor fields sig
125*da0073e9SAndroid Build Coastguard Worker#
126*da0073e9SAndroid Build Coastguard Worker# Functor apply context (functor fields + functor apply sig)
127*da0073e9SAndroid Build Coastguard Worker#   ~> template sig
128*da0073e9SAndroid Build Coastguard Worker#
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Workerdef eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool:
132*da0073e9SAndroid Build Coastguard Worker    num_tensors = sum(
133*da0073e9SAndroid Build Coastguard Worker        1 for a in g.functional.func.arguments.flat_non_out if a.type.is_tensor_like()
134*da0073e9SAndroid Build Coastguard Worker    )
135*da0073e9SAndroid Build Coastguard Worker    return num_tensors == 2
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Workerdef compute_ufunc_cuda_functors(
139*da0073e9SAndroid Build Coastguard Worker    g: NativeFunctionsGroup,
140*da0073e9SAndroid Build Coastguard Worker) -> tuple[dict[ScalarType, dict[UfuncKey, UfunctorSignature]], str]:
141*da0073e9SAndroid Build Coastguard Worker    # First, build the functors.
142*da0073e9SAndroid Build Coastguard Worker    ufunctor_sigs: dict[ScalarType, dict[UfuncKey, UfunctorSignature]] = {}
143*da0073e9SAndroid Build Coastguard Worker    ufunctors: list[str] = []
144*da0073e9SAndroid Build Coastguard Worker    loops = g.out.ufunc_inner_loop
145*da0073e9SAndroid Build Coastguard Worker    scalar_tensor_idx_lookup = {
146*da0073e9SAndroid Build Coastguard Worker        UfuncKey.CUDAFunctorOnSelf: 1,
147*da0073e9SAndroid Build Coastguard Worker        UfuncKey.CUDAFunctorOnOther: 0,
148*da0073e9SAndroid Build Coastguard Worker        UfuncKey.CUDAFunctor: None,
149*da0073e9SAndroid Build Coastguard Worker    }
150*da0073e9SAndroid Build Coastguard Worker    if eligible_for_binary_scalar_specialization(g):
151*da0073e9SAndroid Build Coastguard Worker        keys = [
152*da0073e9SAndroid Build Coastguard Worker            UfuncKey.CUDAFunctorOnSelf,
153*da0073e9SAndroid Build Coastguard Worker            UfuncKey.CUDAFunctorOnOther,
154*da0073e9SAndroid Build Coastguard Worker            UfuncKey.CUDAFunctor,
155*da0073e9SAndroid Build Coastguard Worker        ]
156*da0073e9SAndroid Build Coastguard Worker    else:
157*da0073e9SAndroid Build Coastguard Worker        keys = [UfuncKey.CUDAFunctor]
158*da0073e9SAndroid Build Coastguard Worker        for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]:
159*da0073e9SAndroid Build Coastguard Worker            assert k not in loops, f"cannot use {k} on non-binary function"
160*da0073e9SAndroid Build Coastguard Worker    for k in keys:
161*da0073e9SAndroid Build Coastguard Worker        # If the key was directly defined, skip functor codegen; we assume the
162*da0073e9SAndroid Build Coastguard Worker        # user already done it for us
163*da0073e9SAndroid Build Coastguard Worker        if k in loops:
164*da0073e9SAndroid Build Coastguard Worker            ufunctor_sig = UfunctorSignature(
165*da0073e9SAndroid Build Coastguard Worker                g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name
166*da0073e9SAndroid Build Coastguard Worker            )
167*da0073e9SAndroid Build Coastguard Worker            for dtype in loops[k].supported_dtypes:
168*da0073e9SAndroid Build Coastguard Worker                ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
169*da0073e9SAndroid Build Coastguard Worker            continue
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker        # Note [ScalarOnly and Generic must match names for CUDA]
172*da0073e9SAndroid Build Coastguard Worker        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
173*da0073e9SAndroid Build Coastguard Worker        # Otherwise, look in ANY of the generic entries.  For simplicity of
174*da0073e9SAndroid Build Coastguard Worker        # codegen, both ScalarOnly and Generic are defined, the ufunc name
175*da0073e9SAndroid Build Coastguard Worker        # must match  (if they didn't match, we'd have to generate distinct
176*da0073e9SAndroid Build Coastguard Worker        # functors per dtype, which is awful, so we're not going to do it unless
177*da0073e9SAndroid Build Coastguard Worker        # someone really forces us to)
178*da0073e9SAndroid Build Coastguard Worker        ufunc_name = None
179*da0073e9SAndroid Build Coastguard Worker        supported_dtypes: OrderedSet[ScalarType] = OrderedSet()
180*da0073e9SAndroid Build Coastguard Worker        for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]:
181*da0073e9SAndroid Build Coastguard Worker            if lk not in loops:
182*da0073e9SAndroid Build Coastguard Worker                continue
183*da0073e9SAndroid Build Coastguard Worker            if ufunc_name is None:
184*da0073e9SAndroid Build Coastguard Worker                ufunc_name = loops[lk].name
185*da0073e9SAndroid Build Coastguard Worker            else:
186*da0073e9SAndroid Build Coastguard Worker                # See Note [ScalarOnly and Generic must match names for CUDA]
187*da0073e9SAndroid Build Coastguard Worker                assert (
188*da0073e9SAndroid Build Coastguard Worker                    ufunc_name == loops[lk].name
189*da0073e9SAndroid Build Coastguard Worker                ), "ScalarOnly and Generic must have same ufunc name"
190*da0073e9SAndroid Build Coastguard Worker            supported_dtypes |= loops[lk].supported_dtypes
191*da0073e9SAndroid Build Coastguard Worker        assert ufunc_name is not None
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker        name = f"{k}_{ufunc_name}"
194*da0073e9SAndroid Build Coastguard Worker        ufunctor_sig = UfunctorSignature(
195*da0073e9SAndroid Build Coastguard Worker            g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name
196*da0073e9SAndroid Build Coastguard Worker        )
197*da0073e9SAndroid Build Coastguard Worker        for dtype in supported_dtypes:
198*da0073e9SAndroid Build Coastguard Worker            ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker        ufunc_sig = UfuncSignature(
201*da0073e9SAndroid Build Coastguard Worker            g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t)
202*da0073e9SAndroid Build Coastguard Worker        )
203*da0073e9SAndroid Build Coastguard Worker        apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply
204*da0073e9SAndroid Build Coastguard Worker        ufunctors.append(
205*da0073e9SAndroid Build Coastguard Worker            f"""
206*da0073e9SAndroid Build Coastguard Workertemplate <typename scalar_t>
207*da0073e9SAndroid Build Coastguard Workerstruct {ufunctor_sig.name} {{
208*da0073e9SAndroid Build Coastguard Worker  using opmath_t = at::opmath_type<scalar_t>;
209*da0073e9SAndroid Build Coastguard Worker  {ufunctor_sig.decl_fields()}
210*da0073e9SAndroid Build Coastguard Worker  {ufunctor_sig.inline_defn_ctor()}
211*da0073e9SAndroid Build Coastguard Worker  __device__ {ufunctor_sig.decl_apply()} {{
212*da0073e9SAndroid Build Coastguard Worker    return {ufunc_sig.call(apply_ctx)};
213*da0073e9SAndroid Build Coastguard Worker  }}
214*da0073e9SAndroid Build Coastguard Worker}};
215*da0073e9SAndroid Build Coastguard Worker"""
216*da0073e9SAndroid Build Coastguard Worker        )
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker    return ufunctor_sigs, "\n".join(ufunctors)
219*da0073e9SAndroid Build Coastguard Worker
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
222*da0073e9SAndroid Build Coastguard Workerclass BinaryScalarSpecializationConfig:
223*da0073e9SAndroid Build Coastguard Worker    scalar_idx: int
224*da0073e9SAndroid Build Coastguard Worker    ctor_tensor: str
225*da0073e9SAndroid Build Coastguard Worker    ufunc_key: UfuncKey
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker
228*da0073e9SAndroid Build Coastguard WorkerBinaryScalarSpecializationConfigs = [
229*da0073e9SAndroid Build Coastguard Worker    BinaryScalarSpecializationConfig(
230*da0073e9SAndroid Build Coastguard Worker        scalar_idx=0,
231*da0073e9SAndroid Build Coastguard Worker        ctor_tensor="self",
232*da0073e9SAndroid Build Coastguard Worker        ufunc_key=UfuncKey.CUDAFunctorOnOther,
233*da0073e9SAndroid Build Coastguard Worker    ),
234*da0073e9SAndroid Build Coastguard Worker    BinaryScalarSpecializationConfig(
235*da0073e9SAndroid Build Coastguard Worker        scalar_idx=1,
236*da0073e9SAndroid Build Coastguard Worker        ctor_tensor="other",
237*da0073e9SAndroid Build Coastguard Worker        ufunc_key=UfuncKey.CUDAFunctorOnSelf,
238*da0073e9SAndroid Build Coastguard Worker    ),
239*da0073e9SAndroid Build Coastguard Worker]
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Workerdef compute_ufunc_cuda_dtype_body(
243*da0073e9SAndroid Build Coastguard Worker    g: NativeFunctionsGroup,
244*da0073e9SAndroid Build Coastguard Worker    dtype: ScalarType,
245*da0073e9SAndroid Build Coastguard Worker    inner_loops: dict[UfuncKey, UfunctorSignature],
246*da0073e9SAndroid Build Coastguard Worker    parent_ctx: Sequence[Binding],
247*da0073e9SAndroid Build Coastguard Worker) -> str:
248*da0073e9SAndroid Build Coastguard Worker    body = "using opmath_t = at::opmath_type<scalar_t>;"
249*da0073e9SAndroid Build Coastguard Worker    body += "if (false) {}\n"  # for ease of codegen
250*da0073e9SAndroid Build Coastguard Worker    for config in BinaryScalarSpecializationConfigs:
251*da0073e9SAndroid Build Coastguard Worker        if config.ufunc_key not in inner_loops:
252*da0073e9SAndroid Build Coastguard Worker            continue
253*da0073e9SAndroid Build Coastguard Worker        ufunctor_sig = inner_loops[config.ufunc_key]
254*da0073e9SAndroid Build Coastguard Worker        scalar_idx = config.scalar_idx + 1
255*da0073e9SAndroid Build Coastguard Worker        # Make a copy and at the same time widen the type (not permissible
256*da0073e9SAndroid Build Coastguard Worker        # without copy; we don't want to mutate the input argument anyway)
257*da0073e9SAndroid Build Coastguard Worker        ctx: list[Expr | Binding] = list(parent_ctx)
258*da0073e9SAndroid Build Coastguard Worker        ctx.append(
259*da0073e9SAndroid Build Coastguard Worker            Expr(
260*da0073e9SAndroid Build Coastguard Worker                expr=f"iter.scalar_value<opmath_t>({scalar_idx})",
261*da0073e9SAndroid Build Coastguard Worker                type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)),
262*da0073e9SAndroid Build Coastguard Worker            )
263*da0073e9SAndroid Build Coastguard Worker        )
264*da0073e9SAndroid Build Coastguard Worker        ufunctor_ctor_exprs_str = ", ".join(
265*da0073e9SAndroid Build Coastguard Worker            a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor)
266*da0073e9SAndroid Build Coastguard Worker        )
267*da0073e9SAndroid Build Coastguard Worker
268*da0073e9SAndroid Build Coastguard Worker        # NB: ufunctor must be allocated before iter.remove_operand is called,
269*da0073e9SAndroid Build Coastguard Worker        # as it relies on iter
270*da0073e9SAndroid Build Coastguard Worker        body += f"""\
271*da0073e9SAndroid Build Coastguard Workerelse if (iter.is_cpu_scalar({scalar_idx})) {{
272*da0073e9SAndroid Build Coastguard Worker  {ufunctor_sig.name}<scalar_t> ufunctor({ufunctor_ctor_exprs_str});
273*da0073e9SAndroid Build Coastguard Worker  iter.remove_operand({scalar_idx});
274*da0073e9SAndroid Build Coastguard Worker  gpu_kernel(iter, ufunctor);
275*da0073e9SAndroid Build Coastguard Worker}}"""
276*da0073e9SAndroid Build Coastguard Worker
277*da0073e9SAndroid Build Coastguard Worker    ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor]
278*da0073e9SAndroid Build Coastguard Worker    ufunctor_ctor_exprs_str = ", ".join(
279*da0073e9SAndroid Build Coastguard Worker        a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor)
280*da0073e9SAndroid Build Coastguard Worker    )
281*da0073e9SAndroid Build Coastguard Worker    body += f"""
282*da0073e9SAndroid Build Coastguard Workerelse {{
283*da0073e9SAndroid Build Coastguard Worker  gpu_kernel(iter, {ufunctor_sig.name}<scalar_t>({ufunctor_ctor_exprs_str}));
284*da0073e9SAndroid Build Coastguard Worker}}
285*da0073e9SAndroid Build Coastguard Worker    """
286*da0073e9SAndroid Build Coastguard Worker    return body
287*da0073e9SAndroid Build Coastguard Worker
288*da0073e9SAndroid Build Coastguard Worker
289*da0073e9SAndroid Build Coastguard Worker@with_native_function
290*da0073e9SAndroid Build Coastguard Workerdef compute_ufunc_cuda(g: NativeFunctionsGroup) -> str:
291*da0073e9SAndroid Build Coastguard Worker    # First, build the functors, indexing them by dtype
292*da0073e9SAndroid Build Coastguard Worker    ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g)
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker    # Next, build the conditionals
295*da0073e9SAndroid Build Coastguard Worker    sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA))
296*da0073e9SAndroid Build Coastguard Worker    dtype_cases = []
297*da0073e9SAndroid Build Coastguard Worker    for dtype, inner_ufunc_sigs in ufunctor_sigs.items():
298*da0073e9SAndroid Build Coastguard Worker        dtype_cases.append(
299*da0073e9SAndroid Build Coastguard Worker            f"""
300*da0073e9SAndroid Build Coastguard WorkerAT_DISPATCH_CASE(at::ScalarType::{dtype},
301*da0073e9SAndroid Build Coastguard Worker  [&]() {{
302*da0073e9SAndroid Build Coastguard Worker    {compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunc_sigs, sig.arguments())}
303*da0073e9SAndroid Build Coastguard Worker  }}
304*da0073e9SAndroid Build Coastguard Worker)
305*da0073e9SAndroid Build Coastguard Worker"""
306*da0073e9SAndroid Build Coastguard Worker        )
307*da0073e9SAndroid Build Coastguard Worker
308*da0073e9SAndroid Build Coastguard Worker    dtype_cases_str = "\n".join(dtype_cases)
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker    stub_sig = StubSignature(g)
311*da0073e9SAndroid Build Coastguard Worker
312*da0073e9SAndroid Build Coastguard Worker    return f"""
313*da0073e9SAndroid Build Coastguard Worker{ufunctors}
314*da0073e9SAndroid Build Coastguard Worker
315*da0073e9SAndroid Build Coastguard Worker{stub_sig.type_defn()};
316*da0073e9SAndroid Build Coastguard Worker{stub_sig.dispatch_decl()};
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Worker{stub_sig.kernel_defn()} {{
319*da0073e9SAndroid Build Coastguard Worker  AT_DISPATCH_SWITCH(iter.common_dtype(), "{sig.name}",
320*da0073e9SAndroid Build Coastguard Worker    {dtype_cases_str}
321*da0073e9SAndroid Build Coastguard Worker  );
322*da0073e9SAndroid Build Coastguard Worker}}
323*da0073e9SAndroid Build Coastguard WorkerREGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker{sig.defn()} {{
326*da0073e9SAndroid Build Coastguard Worker  {stub_sig.direct_call(sig.arguments())};
327*da0073e9SAndroid Build Coastguard Worker}}
328*da0073e9SAndroid Build Coastguard Worker"""
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
332*da0073e9SAndroid Build Coastguard Worker#
333*da0073e9SAndroid Build Coastguard Worker#                                   CPU STUFF
334*da0073e9SAndroid Build Coastguard Worker#
335*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
336*da0073e9SAndroid Build Coastguard Worker
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker@dataclass(frozen=True)
339*da0073e9SAndroid Build Coastguard Workerclass StubSignature:
340*da0073e9SAndroid Build Coastguard Worker    g: NativeFunctionsGroup
341*da0073e9SAndroid Build Coastguard Worker
342*da0073e9SAndroid Build Coastguard Worker    @property
343*da0073e9SAndroid Build Coastguard Worker    def name(self) -> str:
344*da0073e9SAndroid Build Coastguard Worker        return f"{str(self.g.functional.func.name.name)}_stub"
345*da0073e9SAndroid Build Coastguard Worker
346*da0073e9SAndroid Build Coastguard Worker    @property
347*da0073e9SAndroid Build Coastguard Worker    def kernel_name(self) -> str:
348*da0073e9SAndroid Build Coastguard Worker        return f"{str(self.g.functional.func.name.name)}_kernel"
349*da0073e9SAndroid Build Coastguard Worker
350*da0073e9SAndroid Build Coastguard Worker    @property
351*da0073e9SAndroid Build Coastguard Worker    def type_name(self) -> str:
352*da0073e9SAndroid Build Coastguard Worker        return f"{str(self.g.functional.func.name.name)}_fn"
353*da0073e9SAndroid Build Coastguard Worker
354*da0073e9SAndroid Build Coastguard Worker    def arguments(self) -> list[Binding]:
355*da0073e9SAndroid Build Coastguard Worker        return ufunc.stub_arguments(self.g)
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker    def type(self) -> str:
358*da0073e9SAndroid Build Coastguard Worker        cpp_args = self.arguments()
359*da0073e9SAndroid Build Coastguard Worker        return f"void(*)(TensorIteratorBase&, {', '.join(a.type for a in cpp_args)})"
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker    def dispatch_decl(self) -> str:
362*da0073e9SAndroid Build Coastguard Worker        return f"DECLARE_DISPATCH({self.type_name}, {self.name})"
363*da0073e9SAndroid Build Coastguard Worker
364*da0073e9SAndroid Build Coastguard Worker    def dispatch_defn(self) -> str:
365*da0073e9SAndroid Build Coastguard Worker        return f"DEFINE_DISPATCH({self.name})"
366*da0073e9SAndroid Build Coastguard Worker
367*da0073e9SAndroid Build Coastguard Worker    def kernel_defn(self) -> str:
368*da0073e9SAndroid Build Coastguard Worker        return f"void {self.kernel_name}(TensorIteratorBase& iter, {', '.join(a.defn() for a in self.arguments())})"
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker    def type_defn(self) -> str:
371*da0073e9SAndroid Build Coastguard Worker        return f"using {self.type_name} = {self.type()}"
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker    # must be called from context where this is TensorIteratorBase*
374*da0073e9SAndroid Build Coastguard Worker    def call(self, ctx: Sequence[Binding]) -> str:
375*da0073e9SAndroid Build Coastguard Worker        return f"{self.name}(device_type(), *this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
376*da0073e9SAndroid Build Coastguard Worker
377*da0073e9SAndroid Build Coastguard Worker    # used in CUDA to skip the unnecessary dynamic dispatch
378*da0073e9SAndroid Build Coastguard Worker    def direct_call(self, ctx: Sequence[Binding]) -> str:
379*da0073e9SAndroid Build Coastguard Worker        return f"{self.kernel_name}(*this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker@with_native_function
383*da0073e9SAndroid Build Coastguard Workerdef compute_ufunc_cpu(g: NativeFunctionsGroup) -> str:
384*da0073e9SAndroid Build Coastguard Worker    stub_sig = StubSignature(g)
385*da0073e9SAndroid Build Coastguard Worker    sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CPU))
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Worker    return f"""
388*da0073e9SAndroid Build Coastguard Worker{stub_sig.type_defn()};
389*da0073e9SAndroid Build Coastguard Worker{stub_sig.dispatch_decl()};
390*da0073e9SAndroid Build Coastguard Worker{stub_sig.dispatch_defn()};
391*da0073e9SAndroid Build Coastguard Worker
392*da0073e9SAndroid Build Coastguard Worker{sig.defn()} {{
393*da0073e9SAndroid Build Coastguard Worker  {stub_sig.call(sig.arguments())};
394*da0073e9SAndroid Build Coastguard Worker}}
395*da0073e9SAndroid Build Coastguard Worker"""
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Worker
398*da0073e9SAndroid Build Coastguard Workerdef compute_ufunc_cpu_dtype_body(
399*da0073e9SAndroid Build Coastguard Worker    g: NativeFunctionsGroup,
400*da0073e9SAndroid Build Coastguard Worker    dtype: ScalarType,
401*da0073e9SAndroid Build Coastguard Worker    inner_loops: dict[UfuncKey, UfuncSignature],
402*da0073e9SAndroid Build Coastguard Worker    parent_ctx: Sequence[Binding],
403*da0073e9SAndroid Build Coastguard Worker) -> str:
404*da0073e9SAndroid Build Coastguard Worker    assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}"
405*da0073e9SAndroid Build Coastguard Worker    assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector}
406*da0073e9SAndroid Build Coastguard Worker    scalar_loop = inner_loops[UfuncKey.CPUScalar]
407*da0073e9SAndroid Build Coastguard Worker    vec_loop = None
408*da0073e9SAndroid Build Coastguard Worker    if UfuncKey.CPUVector in inner_loops:
409*da0073e9SAndroid Build Coastguard Worker        vec_loop = inner_loops[UfuncKey.CPUVector]
410*da0073e9SAndroid Build Coastguard Worker
411*da0073e9SAndroid Build Coastguard Worker    # NB: We DON'T use translate here, because translate is
412*da0073e9SAndroid Build Coastguard Worker    # incapable of CSE'ing the scalar accesses in case it is also
413*da0073e9SAndroid Build Coastguard Worker    # used by Vectorized; also, the unpacking here is very simple
414*da0073e9SAndroid Build Coastguard Worker    # and only affects Scalar; everything else is implicitly captured
415*da0073e9SAndroid Build Coastguard Worker    # by the lambda
416*da0073e9SAndroid Build Coastguard Worker
417*da0073e9SAndroid Build Coastguard Worker    # Setup scalar in scope
418*da0073e9SAndroid Build Coastguard Worker    body = []
419*da0073e9SAndroid Build Coastguard Worker    ctx = []
420*da0073e9SAndroid Build Coastguard Worker    for b in parent_ctx:
421*da0073e9SAndroid Build Coastguard Worker        if isinstance(b.argument, Argument) and b.argument.type != BaseType(
422*da0073e9SAndroid Build Coastguard Worker            BaseTy.Scalar
423*da0073e9SAndroid Build Coastguard Worker        ):
424*da0073e9SAndroid Build Coastguard Worker            continue
425*da0073e9SAndroid Build Coastguard Worker        body.append(f"auto _s_{b.name} = {b.name}.to<scalar_t>();")
426*da0073e9SAndroid Build Coastguard Worker        ctx.append(Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t))))
427*da0073e9SAndroid Build Coastguard Worker    if vec_loop is not None:
428*da0073e9SAndroid Build Coastguard Worker        for b in parent_ctx:
429*da0073e9SAndroid Build Coastguard Worker            if isinstance(b.argument, Argument) and b.argument.type != BaseType(
430*da0073e9SAndroid Build Coastguard Worker                BaseTy.Scalar
431*da0073e9SAndroid Build Coastguard Worker            ):
432*da0073e9SAndroid Build Coastguard Worker                continue
433*da0073e9SAndroid Build Coastguard Worker            body.append(
434*da0073e9SAndroid Build Coastguard Worker                f"auto _v_{b.name} = at::vec::Vectorized<scalar_t>(_s_{b.name});"
435*da0073e9SAndroid Build Coastguard Worker            )
436*da0073e9SAndroid Build Coastguard Worker            ctx.append(
437*da0073e9SAndroid Build Coastguard Worker                Expr(
438*da0073e9SAndroid Build Coastguard Worker                    f"_v_{b.name}",
439*da0073e9SAndroid Build Coastguard Worker                    NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))),
440*da0073e9SAndroid Build Coastguard Worker                )
441*da0073e9SAndroid Build Coastguard Worker            )
442*da0073e9SAndroid Build Coastguard Worker
443*da0073e9SAndroid Build Coastguard Worker    # Setup lambda signature
444*da0073e9SAndroid Build Coastguard Worker    # NB: simplified version of ufunctor_arguments
445*da0073e9SAndroid Build Coastguard Worker    scalar_bindings = []
446*da0073e9SAndroid Build Coastguard Worker    vec_bindings = []
447*da0073e9SAndroid Build Coastguard Worker    for a in g.functional.func.arguments.flat_non_out:
448*da0073e9SAndroid Build Coastguard Worker        if not a.type.is_tensor_like():
449*da0073e9SAndroid Build Coastguard Worker            continue
450*da0073e9SAndroid Build Coastguard Worker        assert a.type == BaseType(BaseTy.Tensor)
451*da0073e9SAndroid Build Coastguard Worker        scalar_bindings.append(
452*da0073e9SAndroid Build Coastguard Worker            Binding(
453*da0073e9SAndroid Build Coastguard Worker                name=a.name,
454*da0073e9SAndroid Build Coastguard Worker                nctype=NamedCType(a.name, BaseCType(scalar_t)),
455*da0073e9SAndroid Build Coastguard Worker                argument=a,
456*da0073e9SAndroid Build Coastguard Worker            )
457*da0073e9SAndroid Build Coastguard Worker        )
458*da0073e9SAndroid Build Coastguard Worker        if vec_loop is not None:
459*da0073e9SAndroid Build Coastguard Worker            vec_bindings.append(
460*da0073e9SAndroid Build Coastguard Worker                Binding(
461*da0073e9SAndroid Build Coastguard Worker                    name=a.name,
462*da0073e9SAndroid Build Coastguard Worker                    nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))),
463*da0073e9SAndroid Build Coastguard Worker                    argument=a,
464*da0073e9SAndroid Build Coastguard Worker                )
465*da0073e9SAndroid Build Coastguard Worker            )
466*da0073e9SAndroid Build Coastguard Worker
467*da0073e9SAndroid Build Coastguard Worker    def with_ctx(b: Sequence[Binding]) -> list[Expr | Binding]:
468*da0073e9SAndroid Build Coastguard Worker        r: list[Expr | Binding] = []
469*da0073e9SAndroid Build Coastguard Worker        r.extend(ctx)
470*da0073e9SAndroid Build Coastguard Worker        r.extend(b)
471*da0073e9SAndroid Build Coastguard Worker        return r
472*da0073e9SAndroid Build Coastguard Worker
473*da0073e9SAndroid Build Coastguard Worker    body_str = "\n".join(body)
474*da0073e9SAndroid Build Coastguard Worker    if vec_loop is not None:
475*da0073e9SAndroid Build Coastguard Worker        return f"""
476*da0073e9SAndroid Build Coastguard Worker{body_str}
477*da0073e9SAndroid Build Coastguard Workercpu_kernel_vec(iter,
478*da0073e9SAndroid Build Coastguard Worker  [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }},
479*da0073e9SAndroid Build Coastguard Worker  [=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }}
480*da0073e9SAndroid Build Coastguard Worker);
481*da0073e9SAndroid Build Coastguard Worker"""
482*da0073e9SAndroid Build Coastguard Worker    else:
483*da0073e9SAndroid Build Coastguard Worker        return f"""
484*da0073e9SAndroid Build Coastguard Worker{body_str}
485*da0073e9SAndroid Build Coastguard Workercpu_kernel(iter,
486*da0073e9SAndroid Build Coastguard Worker  [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}
487*da0073e9SAndroid Build Coastguard Worker);
488*da0073e9SAndroid Build Coastguard Worker"""
489*da0073e9SAndroid Build Coastguard Worker
490*da0073e9SAndroid Build Coastguard Worker
491*da0073e9SAndroid Build Coastguard Worker@with_native_function
492*da0073e9SAndroid Build Coastguard Workerdef compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str:
493*da0073e9SAndroid Build Coastguard Worker    stub_sig = StubSignature(g)
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Worker    # Reindex the ufunc by dtypes; processing generic/scalaronly as well
496*da0073e9SAndroid Build Coastguard Worker    loops = g.out.ufunc_inner_loop
497*da0073e9SAndroid Build Coastguard Worker    ufunc_sigs: dict[ScalarType, dict[UfuncKey, UfuncSignature]] = {}
498*da0073e9SAndroid Build Coastguard Worker    for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]:
499*da0073e9SAndroid Build Coastguard Worker        lks = []
500*da0073e9SAndroid Build Coastguard Worker        # ORDER MATTERS: this specifies overriding precedence
501*da0073e9SAndroid Build Coastguard Worker        if k in loops:  # should happen rarely
502*da0073e9SAndroid Build Coastguard Worker            lks.append(k)
503*da0073e9SAndroid Build Coastguard Worker        if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar:
504*da0073e9SAndroid Build Coastguard Worker            lks.append(UfuncKey.ScalarOnly)
505*da0073e9SAndroid Build Coastguard Worker        if UfuncKey.Generic in loops:
506*da0073e9SAndroid Build Coastguard Worker            lks.append(UfuncKey.Generic)
507*da0073e9SAndroid Build Coastguard Worker        # TODO: don't hardcode ufunc:: namespace here, should be centralized smh
508*da0073e9SAndroid Build Coastguard Worker        for lk in lks:
509*da0073e9SAndroid Build Coastguard Worker            for dtype in loops[lk].supported_dtypes:
510*da0073e9SAndroid Build Coastguard Worker                compute_t: CType
511*da0073e9SAndroid Build Coastguard Worker                if k is UfuncKey.CPUScalar:
512*da0073e9SAndroid Build Coastguard Worker                    compute_t = BaseCType(scalar_t)
513*da0073e9SAndroid Build Coastguard Worker                elif k is UfuncKey.CPUVector:
514*da0073e9SAndroid Build Coastguard Worker                    compute_t = VectorizedCType(BaseCType(scalar_t))
515*da0073e9SAndroid Build Coastguard Worker                else:
516*da0073e9SAndroid Build Coastguard Worker                    raise AssertionError
517*da0073e9SAndroid Build Coastguard Worker                inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {})
518*da0073e9SAndroid Build Coastguard Worker                if k not in inner_ufunc_sigs:
519*da0073e9SAndroid Build Coastguard Worker                    inner_ufunc_sigs[k] = UfuncSignature(
520*da0073e9SAndroid Build Coastguard Worker                        g, name=f"ufunc::{loops[lk].name}", compute_t=compute_t
521*da0073e9SAndroid Build Coastguard Worker                    )
522*da0073e9SAndroid Build Coastguard Worker
523*da0073e9SAndroid Build Coastguard Worker    # Build the conditionals
524*da0073e9SAndroid Build Coastguard Worker    dtype_cases = []
525*da0073e9SAndroid Build Coastguard Worker    for dtype, inner_ufunc_sigs in ufunc_sigs.items():
526*da0073e9SAndroid Build Coastguard Worker        dtype_cases.append(
527*da0073e9SAndroid Build Coastguard Worker            f"""
528*da0073e9SAndroid Build Coastguard WorkerAT_DISPATCH_CASE(at::ScalarType::{dtype},
529*da0073e9SAndroid Build Coastguard Worker  [&]() {{
530*da0073e9SAndroid Build Coastguard Worker    {compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())}
531*da0073e9SAndroid Build Coastguard Worker  }}
532*da0073e9SAndroid Build Coastguard Worker)
533*da0073e9SAndroid Build Coastguard Worker"""
534*da0073e9SAndroid Build Coastguard Worker        )
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Worker    dtype_cases_str = "\n".join(dtype_cases)
537*da0073e9SAndroid Build Coastguard Worker    return f"""
538*da0073e9SAndroid Build Coastguard Workernamespace {{
539*da0073e9SAndroid Build Coastguard Worker
540*da0073e9SAndroid Build Coastguard Worker{stub_sig.kernel_defn()} {{
541*da0073e9SAndroid Build Coastguard Worker  AT_DISPATCH_SWITCH(iter.common_dtype(), "{stub_sig.name}",
542*da0073e9SAndroid Build Coastguard Worker    {dtype_cases_str}
543*da0073e9SAndroid Build Coastguard Worker  );
544*da0073e9SAndroid Build Coastguard Worker}}
545*da0073e9SAndroid Build Coastguard Worker
546*da0073e9SAndroid Build Coastguard Worker}} // anonymous namespace
547*da0073e9SAndroid Build Coastguard Worker
548*da0073e9SAndroid Build Coastguard Worker{stub_sig.type_defn()};
549*da0073e9SAndroid Build Coastguard Worker{stub_sig.dispatch_decl()};
550*da0073e9SAndroid Build Coastguard WorkerREGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
551*da0073e9SAndroid Build Coastguard Worker"""
552