xref: /aosp_15_r20/external/pytorch/torchgen/api/translate.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3from typing import NoReturn, Sequence
4
5from torchgen.api.types import (
6    ArrayRefCType,
7    BaseCType,
8    Binding,
9    boolT,
10    ConstRefCType,
11    deviceT,
12    Expr,
13    intArrayRefT,
14    iOptTensorListRefT,
15    layoutT,
16    ListCType,
17    longT,
18    memoryFormatT,
19    MutRefCType,
20    NamedCType,
21    opmath_t,
22    OptionalCType,
23    optionalIntArrayRefT,
24    optionalScalarRefT,
25    optionalSymIntArrayRefT,
26    optionalTensorRefT,
27    scalar_t,
28    scalarT,
29    scalarTypeT,
30    SpecialArgName,
31    symIntArrayRefT,
32    SymIntT,
33    tensorOptionsT,
34    tensorT,
35    VectorCType,
36)
37
38
39# This file implements a small program synthesis engine that implements
40# conversions between one API to another.
41#
42# The key data type in this file in NamedCType, short for Named C++ semantic type.  A NamedCType
43# represents a C++ type, plus semantic information about what it represents.
44# For example, consider the argument "bool pin_memory"; its normal C++ type is
45# "bool", but its C++ semantic type also keeps track that this represents a
46# "pin_memory"; you can't just use a random other boolean in a context where you
47# need a "pin_memory"!
48#
49# The translator takes a list of needed NamedCTypes, and then figures out how
50# to construct expressions with these NamedCTypes from the given bindings.  Many
51# of these expressions are trivial (I need a Tensor other; there's a Tensor
52# other scope); others are more nontrivial and may require packing/unpacking.
53# Some examples of non-trivial action:
54#
55#   - Need the "dtype" binding?  Well, maybe "dtype" isn't available
56#     in the context, instead, "options" is, and you need to extract
57#     it from there.  (Gather)
58#
59#   - Need the "context" binding?  Well, maybe "context" isn't available
60#     in the context, and you need to construct it from "dtype", "device",
61#     etc.  (Scatter)
62#
63#   - Need the "memory_format" binding?  Well, actually, it's available
64#     from both "memory_format" and "options", so you had better make sure
65#     they are consistent.  (Join)
66
67options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT)))
68
69out_tensor_ctype = NamedCType("out", ConstRefCType(BaseCType(tensorT)))
70
71longVec_ctype = VectorCType(BaseCType(longT))
72longSymVec_ctype = VectorCType(BaseCType(SymIntT))
73optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT)))
74optionalScalar_ctype = OptionalCType(BaseCType(scalarT))
75optionalTensor_ctype = OptionalCType(BaseCType(tensorT))
76
77
78class UnsatError(RuntimeError):
79    pass
80
81
82# Given a set of in-scope bindings and a set of target bindings, synthesize
83# a list of expressions that uses only the in-scope bindings (bindings) that
84# have all of the types of goals.  You may want to use this function if
85# you're generating code for a function like:
86#
87#   void f({args}) {
88#     g({exprs}); // g is a different API
89#   }
90#
91# and you need to generate "exprs".
92#
93# Typically, a list of Bindings is convenient to get (you usually call something
94# like arguments() to get them); but technically you only need less information:
95# for 'bindings' an (un-ordered) list of Exprs is sufficient; similarly, for
96# 'goals', an (ordered) list of NamedCType goals is sufficient.  If you are doing
97# something more complicated, e.g., tracking the set of bindings in a context,
98# you may find using these smaller types more convenient.
99def translate(
100    bindings: Sequence[Expr | Binding],
101    goals: Sequence[NamedCType | Binding],
102    *,
103    method: bool = False,
104    allow_expensive_conversions: bool = False,
105) -> list[Expr]:
106    binding_exprs: list[Expr] = []
107    for b in bindings:
108        if isinstance(b, Binding):
109            binding_exprs.append(
110                Expr(
111                    expr=b.name,
112                    type=b.nctype,
113                )
114            )
115        else:
116            binding_exprs.append(b)
117
118    goal_ctypes: list[NamedCType] = []
119    for g in goals:
120        if isinstance(g, Binding):
121            goal_ctypes.append(g.nctype)
122        else:
123            goal_ctypes.append(g)
124
125    # Add all the bindings to the context
126    ctx: dict[NamedCType, str] = {}
127    for b in binding_exprs:
128        ctx[b.type] = b.expr
129
130        # While we're at it, do some simple forward inference, looking through
131        # constructors.
132        #
133        # NB: When should you do forward inference versus backward inference?
134        # The general idea:
135        #
136        #   - Backward inference WHEN the goal gets smaller
137        #   - Forward inference WHEN the hypothesis gets smaller
138        #
139        # This helps ensure termination: backward inference starts with a goal
140        # and tries to make it simpler and simpler until it's trivial; if the
141        # goal can grow in size, we blow up to a really huge goal size.
142        # Similarly, with forward inference we take hypotheses and decompose
143        # them into simpler hypotheses; if hypotheses could expand in size,
144        # we also have potential nontermination.  (In the code below, forward
145        # inference is only ever carried out at a single step, but you could
146        # imagine repeated application of forward inference being profitable.)
147        #
148        # A good starting point in the literature for exploring more about proof
149        # search are these lecture notes
150        # https://www.cs.cmu.edu/~fp/courses/oregon-m10/04-focusing.pdf
151        #
152        # TODO: My kingdom for a pattern matcher
153        # https://www.python.org/dev/peps/pep-0634/
154        #
155        # TODO: This could get us in recomputation trouble if b.expr is nontrivial.
156        # Fix this by implementing some sort of sharing so that if multiple
157        # goals share the same expression, we only compute it once.  This seems
158        # to matter in practice as compiler is often unwilling to CSE nontrivial
159        # expressions like scalar.to<scalar_t>()
160        t = b.type
161        if (
162            isinstance(t, ConstRefCType)
163            and isinstance(t.elem, OptionalCType)
164            and isinstance(t.elem.elem, BaseCType)
165            and str(t.elem.elem.type) == "at::Tensor"
166        ):
167            ctx[
168                NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT)))
169            ] = f"({b.expr}.has_value() ? *{b.expr} : at::Tensor())"
170
171        if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))):
172            ctx[
173                NamedCType(t.name, BaseCType(optionalTensorRefT))
174            ] = f"(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())"
175
176        if t.type == ConstRefCType(BaseCType(scalarT)):
177            ctx[NamedCType(t.name, BaseCType(opmath_t))] = f"({b.expr}).to<opmath_t>()"
178
179        if t.type == ConstRefCType(OptionalCType(BaseCType(scalarT))):
180            ctx[
181                NamedCType(t.name, BaseCType(optionalScalarRefT))
182            ] = f"({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())"
183
184        if t.type == BaseCType(scalar_t):
185            ctx[
186                NamedCType(t.name, BaseCType(opmath_t))
187            ] = f"static_cast<opmath_t>({b.expr})"
188
189        # [Note: IOptTensorListRef]
190        if t.type == ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))):
191            ctx[
192                NamedCType(t.name, BaseCType(iOptTensorListRefT))
193            ] = f"at::IOptTensorListRef({b.expr})"
194
195    # Add implicit bindings if the generated code is inside a Tensor method
196    if method:
197        ctx[
198            NamedCType("self", MutRefCType(BaseCType(tensorT)))
199        ] = "const_cast<Tensor&>(*this)"
200        ctx[
201            NamedCType("self", ConstRefCType(BaseCType(tensorT)))
202        ] = "const_cast<Tensor&>(*this)"
203        # This is better!  Byte-for-byte compat
204        # ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "*this"
205
206    def unsat(goal: NamedCType) -> NoReturn:
207        ctx_desc = "\n".join(
208            f"  {t.cpp_type()} {t.name}; // {e}" for t, e in ctx.items()
209        )
210        raise UnsatError(
211            f"""
212Failed to synthesize the expression "{goal.cpp_type()} {goal.name}".
213When I failed, the following bindings were available in the context:
214
215{ctx_desc}
216
217This probably means there is a missing rule in the rules of torchgen.api.translate.
218Check this module for more information.
219"""
220        )
221
222    # A shitty backtracking search implementation.  It's shitty because it
223    # does backtracking via stack (bad idea!) and for the most part tries to
224    # avoid backtracking.  In particular, if
225    # direct=True, we won't try to do any fancy synthesis, just trivial
226    # conversions (e.g., "T a" is OK for "const T& a").  So all of the
227    # existing rules in this function simply try to solve immediately,
228    # and bail if things don't work out.
229    def solve(goal: NamedCType, *, direct: bool) -> str:
230        def direct_solve(goal: NamedCType) -> str:
231            return solve(goal, direct=True)
232
233        if goal in ctx:
234            # Trivial
235            return ctx[goal]
236
237        # const & is satisfied with mutable &
238        if isinstance(goal.type, ConstRefCType):
239            try:
240                # WARNING: not strictly decreasing; be careful not
241                # to add a direct conversion that goes satisfies
242                # mutable& with const&
243                return solve(
244                    NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct
245                )
246            except UnsatError:
247                pass
248
249        # mutable & is satisfied with value
250        if isinstance(goal.type, MutRefCType):
251            try:
252                return solve(NamedCType(goal.name, goal.type.elem), direct=direct)
253            except UnsatError:
254                pass
255
256        # TODO: These are referentially equal, shouldn't have to do this;
257        # ensuring we don't use type synonym IntArrayRef in codegen would
258        # help
259        if goal.type == ArrayRefCType(BaseCType(longT)):
260            return solve(NamedCType(goal.name, BaseCType(intArrayRefT)), direct=direct)
261
262        if direct:
263            unsat(goal)
264
265        # For now, all of these rules are mutually exclusive.
266        if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))):
267            memory_format = direct_solve(
268                NamedCType(
269                    SpecialArgName.possibly_redundant_memory_format,
270                    OptionalCType(BaseCType(memoryFormatT)),
271                )
272            )
273            # No need to join "memory_format" and "options" if the target API takes "options" directly.
274            # Otherwise it will cause the redundant memory_format error.
275            if options_ctype in goal_ctypes:
276                return memory_format
277            try:
278                options = direct_solve(options_ctype)
279                return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})"
280            except UnsatError:
281                return memory_format
282        elif goal == NamedCType("options", BaseCType(tensorOptionsT)):
283            dtype = direct_solve(
284                NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT)))
285            )
286            pin_memory = direct_solve(
287                NamedCType("pin_memory", OptionalCType(BaseCType(boolT)))
288            )
289            device = direct_solve(
290                NamedCType("device", OptionalCType(BaseCType(deviceT)))
291            )
292            layout = direct_solve(
293                NamedCType("layout", OptionalCType(BaseCType(layoutT)))
294            )
295            return f"TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})"
296
297        elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))):
298            try:
299                options = direct_solve(options_ctype)
300                return f"c10::optTypeMetaToScalarType({options}.dtype_opt())"
301            except UnsatError:
302                out_tensor = direct_solve(out_tensor_ctype)
303                return f"{out_tensor}.scalar_type()"
304
305        elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))):
306            try:
307                options = direct_solve(options_ctype)
308                return f"{options}.layout_opt()"
309            except UnsatError:
310                out_tensor = direct_solve(out_tensor_ctype)
311                return f"{out_tensor}.layout()"
312
313        elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))):
314            try:
315                options = direct_solve(options_ctype)
316                return f"{options}.device_opt()"
317            except UnsatError:
318                out_tensor = direct_solve(out_tensor_ctype)
319                return f"{out_tensor}.device()"
320
321        elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))):
322            try:
323                options = direct_solve(options_ctype)
324                return f"{options}.pinned_memory_opt()"
325            except UnsatError:
326                # If we're calling a factory op from its out= variant,
327                # We don't actually care about the value of pin_memory.
328                out_tensor = direct_solve(out_tensor_ctype)
329                return "::std::nullopt"
330
331        # We can always do translations from value types to reference types, like vector<int> -> IntArrayRef
332        elif goal.type == BaseCType(intArrayRefT):
333            try:
334                return direct_solve(NamedCType(goal.name, longVec_ctype))
335            except UnsatError:
336                # We can also go SymIntArrayRef -> IntArrayRef
337                symIntArrayRef_type = direct_solve(
338                    NamedCType(goal.name, BaseCType(symIntArrayRefT))
339                )
340                return f"C10_AS_INTARRAYREF_SLOW({symIntArrayRef_type})"
341        elif goal.type == BaseCType(symIntArrayRefT):
342            try:
343                r = direct_solve(NamedCType(goal.name, BaseCType(intArrayRefT)))
344                return f"c10::fromIntArrayRefSlow({r})"
345            except UnsatError:
346                return direct_solve(NamedCType(goal.name, longSymVec_ctype))
347        elif goal.type == BaseCType(SymIntT):
348            return direct_solve(NamedCType(goal.name, BaseCType(longT)))
349        elif goal.type == OptionalCType(BaseCType(SymIntT)):
350            argname = direct_solve(
351                NamedCType(goal.name, OptionalCType(BaseCType(longT)))
352            )
353            return f"{argname}.has_value() ? ::std::make_optional(c10::SymInt(*{argname})) : ::std::nullopt"
354        elif goal.type == BaseCType(longT):
355            symInt_type = direct_solve(NamedCType(goal.name, BaseCType(SymIntT)))
356            return f"{symInt_type}.guard_int(__FILE__, __LINE__)"
357        elif goal.type == OptionalCType(BaseCType(longT)):
358            argname = direct_solve(
359                NamedCType(goal.name, OptionalCType(BaseCType(SymIntT)))
360            )
361            return f"{argname}.has_value() ? ::std::make_optional({argname}->guard_int(__FILE__, __LINE__)) : ::std::nullopt"
362        elif goal.type == BaseCType(optionalIntArrayRefT):
363            try:
364                return direct_solve(NamedCType(goal.name, optionalLongVec_ctype))
365            except UnsatError:
366                argname = direct_solve(
367                    NamedCType(goal.name, BaseCType(optionalSymIntArrayRefT))
368                )
369                return f"{argname}.has_value() ? ::std::make_optional(C10_AS_INTARRAYREF_SLOW(*{argname})) : ::std::nullopt"
370        elif goal.type == BaseCType(optionalSymIntArrayRefT):
371            # TODO: You might also want to solve this from longSymVec_ctype or
372            # an optional version of it
373            argname = direct_solve(
374                NamedCType(goal.name, BaseCType(optionalIntArrayRefT))
375            )
376            return f"{argname}.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*{argname})) : ::std::nullopt"
377        elif goal.type == BaseCType(optionalScalarRefT):
378            return direct_solve(NamedCType(goal.name, optionalScalar_ctype))
379        elif goal.type == BaseCType(optionalTensorRefT):
380            return direct_solve(NamedCType(goal.name, optionalTensor_ctype))
381
382        # Note [translation from C++ reference to value types]
383        # The below cases are all for when we have an argument with a reference type,
384        # and a corresponding goal with a value type.
385        # These are needed when we populate the inputs to a lambda capture and we need
386        # to guarantee the lifetime of each captured argument.
387        # We guard it with an explicit kwarg because converting to a value type is expensive
388        # (O(n)) to convert from IntArrayRef to vector<int>),
389        # so the caller of translate() should be explicit that they need it.
390        if allow_expensive_conversions:
391            if goal.type == VectorCType(BaseCType(longT)):
392                intArrayRef_ctype = NamedCType(goal.name, BaseCType(intArrayRefT))
393                argname = direct_solve(intArrayRef_ctype)
394                return f"{argname}.vec()"
395            if goal.type == VectorCType(BaseCType(SymIntT)):
396                symIntArrayRef_ctype = NamedCType(goal.name, BaseCType(symIntArrayRefT))
397                argname = direct_solve(symIntArrayRef_ctype)
398                return f"{argname}.vec()"
399            elif goal.type == OptionalCType(VectorCType(BaseCType(longT))):
400                optionalIntArrayRef_ctype = NamedCType(
401                    goal.name, BaseCType(optionalIntArrayRefT)
402                )
403                argname = direct_solve(optionalIntArrayRef_ctype)
404                return f"{argname}.has_value() ? ::std::make_optional({argname}->vec()) : ::std::nullopt"
405            elif goal.type == OptionalCType(BaseCType(scalarT)):
406                optionalScalarRef_ctype = NamedCType(
407                    goal.name, BaseCType(optionalScalarRefT)
408                )
409                argname = direct_solve(optionalScalarRef_ctype)
410                return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt"
411            elif goal.type == OptionalCType(BaseCType(scalarT)):
412                optionalTensorRef_ctype = NamedCType(
413                    goal.name, BaseCType(optionalTensorRefT)
414                )
415                argname = direct_solve(optionalTensorRef_ctype)
416                return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt"
417            # Technically, we also need to handle cases of C++ containers holding reference types.
418            # But there currently aren't any ops that require lambda capture codegen
419            # With arguments like ::std::vector<IntArrayRef>.
420            # If that changes, we'll have to add the translation here.
421
422        # We allow const casting on tensors, since const-correctness is a bit broken for at::Tensor.
423        # We could probably generalize this to non-tensor types too.
424        if goal.type == MutRefCType(BaseCType(tensorT)):
425            const_ref_tensor_ctype = NamedCType(
426                goal.name, ConstRefCType(BaseCType(tensorT))
427            )
428            argname = direct_solve(const_ref_tensor_ctype)
429            return f"const_cast<Tensor&>({argname})"
430
431        unsat(goal)
432
433    return [Expr(solve(g, direct=False), g) for g in goal_ctypes]
434