xref: /aosp_15_r20/external/pytorch/torchgen/api/translate.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom typing import NoReturn, Sequence
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import (
6*da0073e9SAndroid Build Coastguard Worker    ArrayRefCType,
7*da0073e9SAndroid Build Coastguard Worker    BaseCType,
8*da0073e9SAndroid Build Coastguard Worker    Binding,
9*da0073e9SAndroid Build Coastguard Worker    boolT,
10*da0073e9SAndroid Build Coastguard Worker    ConstRefCType,
11*da0073e9SAndroid Build Coastguard Worker    deviceT,
12*da0073e9SAndroid Build Coastguard Worker    Expr,
13*da0073e9SAndroid Build Coastguard Worker    intArrayRefT,
14*da0073e9SAndroid Build Coastguard Worker    iOptTensorListRefT,
15*da0073e9SAndroid Build Coastguard Worker    layoutT,
16*da0073e9SAndroid Build Coastguard Worker    ListCType,
17*da0073e9SAndroid Build Coastguard Worker    longT,
18*da0073e9SAndroid Build Coastguard Worker    memoryFormatT,
19*da0073e9SAndroid Build Coastguard Worker    MutRefCType,
20*da0073e9SAndroid Build Coastguard Worker    NamedCType,
21*da0073e9SAndroid Build Coastguard Worker    opmath_t,
22*da0073e9SAndroid Build Coastguard Worker    OptionalCType,
23*da0073e9SAndroid Build Coastguard Worker    optionalIntArrayRefT,
24*da0073e9SAndroid Build Coastguard Worker    optionalScalarRefT,
25*da0073e9SAndroid Build Coastguard Worker    optionalSymIntArrayRefT,
26*da0073e9SAndroid Build Coastguard Worker    optionalTensorRefT,
27*da0073e9SAndroid Build Coastguard Worker    scalar_t,
28*da0073e9SAndroid Build Coastguard Worker    scalarT,
29*da0073e9SAndroid Build Coastguard Worker    scalarTypeT,
30*da0073e9SAndroid Build Coastguard Worker    SpecialArgName,
31*da0073e9SAndroid Build Coastguard Worker    symIntArrayRefT,
32*da0073e9SAndroid Build Coastguard Worker    SymIntT,
33*da0073e9SAndroid Build Coastguard Worker    tensorOptionsT,
34*da0073e9SAndroid Build Coastguard Worker    tensorT,
35*da0073e9SAndroid Build Coastguard Worker    VectorCType,
36*da0073e9SAndroid Build Coastguard Worker)
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker# This file implements a small program synthesis engine that implements
40*da0073e9SAndroid Build Coastguard Worker# conversions between one API to another.
41*da0073e9SAndroid Build Coastguard Worker#
42*da0073e9SAndroid Build Coastguard Worker# The key data type in this file in NamedCType, short for Named C++ semantic type.  A NamedCType
43*da0073e9SAndroid Build Coastguard Worker# represents a C++ type, plus semantic information about what it represents.
44*da0073e9SAndroid Build Coastguard Worker# For example, consider the argument "bool pin_memory"; its normal C++ type is
45*da0073e9SAndroid Build Coastguard Worker# "bool", but its C++ semantic type also keeps track that this represents a
46*da0073e9SAndroid Build Coastguard Worker# "pin_memory"; you can't just use a random other boolean in a context where you
47*da0073e9SAndroid Build Coastguard Worker# need a "pin_memory"!
48*da0073e9SAndroid Build Coastguard Worker#
49*da0073e9SAndroid Build Coastguard Worker# The translator takes a list of needed NamedCTypes, and then figures out how
50*da0073e9SAndroid Build Coastguard Worker# to construct expressions with these NamedCTypes from the given bindings.  Many
51*da0073e9SAndroid Build Coastguard Worker# of these expressions are trivial (I need a Tensor other; there's a Tensor
52*da0073e9SAndroid Build Coastguard Worker# other scope); others are more nontrivial and may require packing/unpacking.
53*da0073e9SAndroid Build Coastguard Worker# Some examples of non-trivial action:
54*da0073e9SAndroid Build Coastguard Worker#
55*da0073e9SAndroid Build Coastguard Worker#   - Need the "dtype" binding?  Well, maybe "dtype" isn't available
56*da0073e9SAndroid Build Coastguard Worker#     in the context, instead, "options" is, and you need to extract
57*da0073e9SAndroid Build Coastguard Worker#     it from there.  (Gather)
58*da0073e9SAndroid Build Coastguard Worker#
59*da0073e9SAndroid Build Coastguard Worker#   - Need the "context" binding?  Well, maybe "context" isn't available
60*da0073e9SAndroid Build Coastguard Worker#     in the context, and you need to construct it from "dtype", "device",
61*da0073e9SAndroid Build Coastguard Worker#     etc.  (Scatter)
62*da0073e9SAndroid Build Coastguard Worker#
63*da0073e9SAndroid Build Coastguard Worker#   - Need the "memory_format" binding?  Well, actually, it's available
64*da0073e9SAndroid Build Coastguard Worker#     from both "memory_format" and "options", so you had better make sure
65*da0073e9SAndroid Build Coastguard Worker#     they are consistent.  (Join)
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Workeroptions_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT)))
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Workerout_tensor_ctype = NamedCType("out", ConstRefCType(BaseCType(tensorT)))
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard WorkerlongVec_ctype = VectorCType(BaseCType(longT))
72*da0073e9SAndroid Build Coastguard WorkerlongSymVec_ctype = VectorCType(BaseCType(SymIntT))
73*da0073e9SAndroid Build Coastguard WorkeroptionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT)))
74*da0073e9SAndroid Build Coastguard WorkeroptionalScalar_ctype = OptionalCType(BaseCType(scalarT))
75*da0073e9SAndroid Build Coastguard WorkeroptionalTensor_ctype = OptionalCType(BaseCType(tensorT))
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Workerclass UnsatError(RuntimeError):
79*da0073e9SAndroid Build Coastguard Worker    pass
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker# Given a set of in-scope bindings and a set of target bindings, synthesize
83*da0073e9SAndroid Build Coastguard Worker# a list of expressions that uses only the in-scope bindings (bindings) that
84*da0073e9SAndroid Build Coastguard Worker# have all of the types of goals.  You may want to use this function if
85*da0073e9SAndroid Build Coastguard Worker# you're generating code for a function like:
86*da0073e9SAndroid Build Coastguard Worker#
87*da0073e9SAndroid Build Coastguard Worker#   void f({args}) {
88*da0073e9SAndroid Build Coastguard Worker#     g({exprs}); // g is a different API
89*da0073e9SAndroid Build Coastguard Worker#   }
90*da0073e9SAndroid Build Coastguard Worker#
91*da0073e9SAndroid Build Coastguard Worker# and you need to generate "exprs".
92*da0073e9SAndroid Build Coastguard Worker#
93*da0073e9SAndroid Build Coastguard Worker# Typically, a list of Bindings is convenient to get (you usually call something
94*da0073e9SAndroid Build Coastguard Worker# like arguments() to get them); but technically you only need less information:
95*da0073e9SAndroid Build Coastguard Worker# for 'bindings' an (un-ordered) list of Exprs is sufficient; similarly, for
96*da0073e9SAndroid Build Coastguard Worker# 'goals', an (ordered) list of NamedCType goals is sufficient.  If you are doing
97*da0073e9SAndroid Build Coastguard Worker# something more complicated, e.g., tracking the set of bindings in a context,
98*da0073e9SAndroid Build Coastguard Worker# you may find using these smaller types more convenient.
99*da0073e9SAndroid Build Coastguard Workerdef translate(
100*da0073e9SAndroid Build Coastguard Worker    bindings: Sequence[Expr | Binding],
101*da0073e9SAndroid Build Coastguard Worker    goals: Sequence[NamedCType | Binding],
102*da0073e9SAndroid Build Coastguard Worker    *,
103*da0073e9SAndroid Build Coastguard Worker    method: bool = False,
104*da0073e9SAndroid Build Coastguard Worker    allow_expensive_conversions: bool = False,
105*da0073e9SAndroid Build Coastguard Worker) -> list[Expr]:
106*da0073e9SAndroid Build Coastguard Worker    binding_exprs: list[Expr] = []
107*da0073e9SAndroid Build Coastguard Worker    for b in bindings:
108*da0073e9SAndroid Build Coastguard Worker        if isinstance(b, Binding):
109*da0073e9SAndroid Build Coastguard Worker            binding_exprs.append(
110*da0073e9SAndroid Build Coastguard Worker                Expr(
111*da0073e9SAndroid Build Coastguard Worker                    expr=b.name,
112*da0073e9SAndroid Build Coastguard Worker                    type=b.nctype,
113*da0073e9SAndroid Build Coastguard Worker                )
114*da0073e9SAndroid Build Coastguard Worker            )
115*da0073e9SAndroid Build Coastguard Worker        else:
116*da0073e9SAndroid Build Coastguard Worker            binding_exprs.append(b)
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker    goal_ctypes: list[NamedCType] = []
119*da0073e9SAndroid Build Coastguard Worker    for g in goals:
120*da0073e9SAndroid Build Coastguard Worker        if isinstance(g, Binding):
121*da0073e9SAndroid Build Coastguard Worker            goal_ctypes.append(g.nctype)
122*da0073e9SAndroid Build Coastguard Worker        else:
123*da0073e9SAndroid Build Coastguard Worker            goal_ctypes.append(g)
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker    # Add all the bindings to the context
126*da0073e9SAndroid Build Coastguard Worker    ctx: dict[NamedCType, str] = {}
127*da0073e9SAndroid Build Coastguard Worker    for b in binding_exprs:
128*da0073e9SAndroid Build Coastguard Worker        ctx[b.type] = b.expr
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker        # While we're at it, do some simple forward inference, looking through
131*da0073e9SAndroid Build Coastguard Worker        # constructors.
132*da0073e9SAndroid Build Coastguard Worker        #
133*da0073e9SAndroid Build Coastguard Worker        # NB: When should you do forward inference versus backward inference?
134*da0073e9SAndroid Build Coastguard Worker        # The general idea:
135*da0073e9SAndroid Build Coastguard Worker        #
136*da0073e9SAndroid Build Coastguard Worker        #   - Backward inference WHEN the goal gets smaller
137*da0073e9SAndroid Build Coastguard Worker        #   - Forward inference WHEN the hypothesis gets smaller
138*da0073e9SAndroid Build Coastguard Worker        #
139*da0073e9SAndroid Build Coastguard Worker        # This helps ensure termination: backward inference starts with a goal
140*da0073e9SAndroid Build Coastguard Worker        # and tries to make it simpler and simpler until it's trivial; if the
141*da0073e9SAndroid Build Coastguard Worker        # goal can grow in size, we blow up to a really huge goal size.
142*da0073e9SAndroid Build Coastguard Worker        # Similarly, with forward inference we take hypotheses and decompose
143*da0073e9SAndroid Build Coastguard Worker        # them into simpler hypotheses; if hypotheses could expand in size,
144*da0073e9SAndroid Build Coastguard Worker        # we also have potential nontermination.  (In the code below, forward
145*da0073e9SAndroid Build Coastguard Worker        # inference is only ever carried out at a single step, but you could
146*da0073e9SAndroid Build Coastguard Worker        # imagine repeated application of forward inference being profitable.)
147*da0073e9SAndroid Build Coastguard Worker        #
148*da0073e9SAndroid Build Coastguard Worker        # A good starting point in the literature for exploring more about proof
149*da0073e9SAndroid Build Coastguard Worker        # search are these lecture notes
150*da0073e9SAndroid Build Coastguard Worker        # https://www.cs.cmu.edu/~fp/courses/oregon-m10/04-focusing.pdf
151*da0073e9SAndroid Build Coastguard Worker        #
152*da0073e9SAndroid Build Coastguard Worker        # TODO: My kingdom for a pattern matcher
153*da0073e9SAndroid Build Coastguard Worker        # https://www.python.org/dev/peps/pep-0634/
154*da0073e9SAndroid Build Coastguard Worker        #
155*da0073e9SAndroid Build Coastguard Worker        # TODO: This could get us in recomputation trouble if b.expr is nontrivial.
156*da0073e9SAndroid Build Coastguard Worker        # Fix this by implementing some sort of sharing so that if multiple
157*da0073e9SAndroid Build Coastguard Worker        # goals share the same expression, we only compute it once.  This seems
158*da0073e9SAndroid Build Coastguard Worker        # to matter in practice as compiler is often unwilling to CSE nontrivial
159*da0073e9SAndroid Build Coastguard Worker        # expressions like scalar.to<scalar_t>()
160*da0073e9SAndroid Build Coastguard Worker        t = b.type
161*da0073e9SAndroid Build Coastguard Worker        if (
162*da0073e9SAndroid Build Coastguard Worker            isinstance(t, ConstRefCType)
163*da0073e9SAndroid Build Coastguard Worker            and isinstance(t.elem, OptionalCType)
164*da0073e9SAndroid Build Coastguard Worker            and isinstance(t.elem.elem, BaseCType)
165*da0073e9SAndroid Build Coastguard Worker            and str(t.elem.elem.type) == "at::Tensor"
166*da0073e9SAndroid Build Coastguard Worker        ):
167*da0073e9SAndroid Build Coastguard Worker            ctx[
168*da0073e9SAndroid Build Coastguard Worker                NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT)))
169*da0073e9SAndroid Build Coastguard Worker            ] = f"({b.expr}.has_value() ? *{b.expr} : at::Tensor())"
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker        if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))):
172*da0073e9SAndroid Build Coastguard Worker            ctx[
173*da0073e9SAndroid Build Coastguard Worker                NamedCType(t.name, BaseCType(optionalTensorRefT))
174*da0073e9SAndroid Build Coastguard Worker            ] = f"(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())"
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker        if t.type == ConstRefCType(BaseCType(scalarT)):
177*da0073e9SAndroid Build Coastguard Worker            ctx[NamedCType(t.name, BaseCType(opmath_t))] = f"({b.expr}).to<opmath_t>()"
178*da0073e9SAndroid Build Coastguard Worker
179*da0073e9SAndroid Build Coastguard Worker        if t.type == ConstRefCType(OptionalCType(BaseCType(scalarT))):
180*da0073e9SAndroid Build Coastguard Worker            ctx[
181*da0073e9SAndroid Build Coastguard Worker                NamedCType(t.name, BaseCType(optionalScalarRefT))
182*da0073e9SAndroid Build Coastguard Worker            ] = f"({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())"
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Worker        if t.type == BaseCType(scalar_t):
185*da0073e9SAndroid Build Coastguard Worker            ctx[
186*da0073e9SAndroid Build Coastguard Worker                NamedCType(t.name, BaseCType(opmath_t))
187*da0073e9SAndroid Build Coastguard Worker            ] = f"static_cast<opmath_t>({b.expr})"
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker        # [Note: IOptTensorListRef]
190*da0073e9SAndroid Build Coastguard Worker        if t.type == ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))):
191*da0073e9SAndroid Build Coastguard Worker            ctx[
192*da0073e9SAndroid Build Coastguard Worker                NamedCType(t.name, BaseCType(iOptTensorListRefT))
193*da0073e9SAndroid Build Coastguard Worker            ] = f"at::IOptTensorListRef({b.expr})"
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker    # Add implicit bindings if the generated code is inside a Tensor method
196*da0073e9SAndroid Build Coastguard Worker    if method:
197*da0073e9SAndroid Build Coastguard Worker        ctx[
198*da0073e9SAndroid Build Coastguard Worker            NamedCType("self", MutRefCType(BaseCType(tensorT)))
199*da0073e9SAndroid Build Coastguard Worker        ] = "const_cast<Tensor&>(*this)"
200*da0073e9SAndroid Build Coastguard Worker        ctx[
201*da0073e9SAndroid Build Coastguard Worker            NamedCType("self", ConstRefCType(BaseCType(tensorT)))
202*da0073e9SAndroid Build Coastguard Worker        ] = "const_cast<Tensor&>(*this)"
203*da0073e9SAndroid Build Coastguard Worker        # This is better!  Byte-for-byte compat
204*da0073e9SAndroid Build Coastguard Worker        # ctx[NamedCType("self", ConstRefCType(BaseCType(tensorT)))] = "*this"
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Worker    def unsat(goal: NamedCType) -> NoReturn:
207*da0073e9SAndroid Build Coastguard Worker        ctx_desc = "\n".join(
208*da0073e9SAndroid Build Coastguard Worker            f"  {t.cpp_type()} {t.name}; // {e}" for t, e in ctx.items()
209*da0073e9SAndroid Build Coastguard Worker        )
210*da0073e9SAndroid Build Coastguard Worker        raise UnsatError(
211*da0073e9SAndroid Build Coastguard Worker            f"""
212*da0073e9SAndroid Build Coastguard WorkerFailed to synthesize the expression "{goal.cpp_type()} {goal.name}".
213*da0073e9SAndroid Build Coastguard WorkerWhen I failed, the following bindings were available in the context:
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker{ctx_desc}
216*da0073e9SAndroid Build Coastguard Worker
217*da0073e9SAndroid Build Coastguard WorkerThis probably means there is a missing rule in the rules of torchgen.api.translate.
218*da0073e9SAndroid Build Coastguard WorkerCheck this module for more information.
219*da0073e9SAndroid Build Coastguard Worker"""
220*da0073e9SAndroid Build Coastguard Worker        )
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker    # A shitty backtracking search implementation.  It's shitty because it
223*da0073e9SAndroid Build Coastguard Worker    # does backtracking via stack (bad idea!) and for the most part tries to
224*da0073e9SAndroid Build Coastguard Worker    # avoid backtracking.  In particular, if
225*da0073e9SAndroid Build Coastguard Worker    # direct=True, we won't try to do any fancy synthesis, just trivial
226*da0073e9SAndroid Build Coastguard Worker    # conversions (e.g., "T a" is OK for "const T& a").  So all of the
227*da0073e9SAndroid Build Coastguard Worker    # existing rules in this function simply try to solve immediately,
228*da0073e9SAndroid Build Coastguard Worker    # and bail if things don't work out.
229*da0073e9SAndroid Build Coastguard Worker    def solve(goal: NamedCType, *, direct: bool) -> str:
230*da0073e9SAndroid Build Coastguard Worker        def direct_solve(goal: NamedCType) -> str:
231*da0073e9SAndroid Build Coastguard Worker            return solve(goal, direct=True)
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker        if goal in ctx:
234*da0073e9SAndroid Build Coastguard Worker            # Trivial
235*da0073e9SAndroid Build Coastguard Worker            return ctx[goal]
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker        # const & is satisfied with mutable &
238*da0073e9SAndroid Build Coastguard Worker        if isinstance(goal.type, ConstRefCType):
239*da0073e9SAndroid Build Coastguard Worker            try:
240*da0073e9SAndroid Build Coastguard Worker                # WARNING: not strictly decreasing; be careful not
241*da0073e9SAndroid Build Coastguard Worker                # to add a direct conversion that goes satisfies
242*da0073e9SAndroid Build Coastguard Worker                # mutable& with const&
243*da0073e9SAndroid Build Coastguard Worker                return solve(
244*da0073e9SAndroid Build Coastguard Worker                    NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct
245*da0073e9SAndroid Build Coastguard Worker                )
246*da0073e9SAndroid Build Coastguard Worker            except UnsatError:
247*da0073e9SAndroid Build Coastguard Worker                pass
248*da0073e9SAndroid Build Coastguard Worker
249*da0073e9SAndroid Build Coastguard Worker        # mutable & is satisfied with value
250*da0073e9SAndroid Build Coastguard Worker        if isinstance(goal.type, MutRefCType):
251*da0073e9SAndroid Build Coastguard Worker            try:
252*da0073e9SAndroid Build Coastguard Worker                return solve(NamedCType(goal.name, goal.type.elem), direct=direct)
253*da0073e9SAndroid Build Coastguard Worker            except UnsatError:
254*da0073e9SAndroid Build Coastguard Worker                pass
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker        # TODO: These are referentially equal, shouldn't have to do this;
257*da0073e9SAndroid Build Coastguard Worker        # ensuring we don't use type synonym IntArrayRef in codegen would
258*da0073e9SAndroid Build Coastguard Worker        # help
259*da0073e9SAndroid Build Coastguard Worker        if goal.type == ArrayRefCType(BaseCType(longT)):
260*da0073e9SAndroid Build Coastguard Worker            return solve(NamedCType(goal.name, BaseCType(intArrayRefT)), direct=direct)
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker        if direct:
263*da0073e9SAndroid Build Coastguard Worker            unsat(goal)
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker        # For now, all of these rules are mutually exclusive.
266*da0073e9SAndroid Build Coastguard Worker        if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))):
267*da0073e9SAndroid Build Coastguard Worker            memory_format = direct_solve(
268*da0073e9SAndroid Build Coastguard Worker                NamedCType(
269*da0073e9SAndroid Build Coastguard Worker                    SpecialArgName.possibly_redundant_memory_format,
270*da0073e9SAndroid Build Coastguard Worker                    OptionalCType(BaseCType(memoryFormatT)),
271*da0073e9SAndroid Build Coastguard Worker                )
272*da0073e9SAndroid Build Coastguard Worker            )
273*da0073e9SAndroid Build Coastguard Worker            # No need to join "memory_format" and "options" if the target API takes "options" directly.
274*da0073e9SAndroid Build Coastguard Worker            # Otherwise it will cause the redundant memory_format error.
275*da0073e9SAndroid Build Coastguard Worker            if options_ctype in goal_ctypes:
276*da0073e9SAndroid Build Coastguard Worker                return memory_format
277*da0073e9SAndroid Build Coastguard Worker            try:
278*da0073e9SAndroid Build Coastguard Worker                options = direct_solve(options_ctype)
279*da0073e9SAndroid Build Coastguard Worker                return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})"
280*da0073e9SAndroid Build Coastguard Worker            except UnsatError:
281*da0073e9SAndroid Build Coastguard Worker                return memory_format
282*da0073e9SAndroid Build Coastguard Worker        elif goal == NamedCType("options", BaseCType(tensorOptionsT)):
283*da0073e9SAndroid Build Coastguard Worker            dtype = direct_solve(
284*da0073e9SAndroid Build Coastguard Worker                NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT)))
285*da0073e9SAndroid Build Coastguard Worker            )
286*da0073e9SAndroid Build Coastguard Worker            pin_memory = direct_solve(
287*da0073e9SAndroid Build Coastguard Worker                NamedCType("pin_memory", OptionalCType(BaseCType(boolT)))
288*da0073e9SAndroid Build Coastguard Worker            )
289*da0073e9SAndroid Build Coastguard Worker            device = direct_solve(
290*da0073e9SAndroid Build Coastguard Worker                NamedCType("device", OptionalCType(BaseCType(deviceT)))
291*da0073e9SAndroid Build Coastguard Worker            )
292*da0073e9SAndroid Build Coastguard Worker            layout = direct_solve(
293*da0073e9SAndroid Build Coastguard Worker                NamedCType("layout", OptionalCType(BaseCType(layoutT)))
294*da0073e9SAndroid Build Coastguard Worker            )
295*da0073e9SAndroid Build Coastguard Worker            return f"TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})"
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker        elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))):
298*da0073e9SAndroid Build Coastguard Worker            try:
299*da0073e9SAndroid Build Coastguard Worker                options = direct_solve(options_ctype)
300*da0073e9SAndroid Build Coastguard Worker                return f"c10::optTypeMetaToScalarType({options}.dtype_opt())"
301*da0073e9SAndroid Build Coastguard Worker            except UnsatError:
302*da0073e9SAndroid Build Coastguard Worker                out_tensor = direct_solve(out_tensor_ctype)
303*da0073e9SAndroid Build Coastguard Worker                return f"{out_tensor}.scalar_type()"
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker        elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))):
306*da0073e9SAndroid Build Coastguard Worker            try:
307*da0073e9SAndroid Build Coastguard Worker                options = direct_solve(options_ctype)
308*da0073e9SAndroid Build Coastguard Worker                return f"{options}.layout_opt()"
309*da0073e9SAndroid Build Coastguard Worker            except UnsatError:
310*da0073e9SAndroid Build Coastguard Worker                out_tensor = direct_solve(out_tensor_ctype)
311*da0073e9SAndroid Build Coastguard Worker                return f"{out_tensor}.layout()"
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker        elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))):
314*da0073e9SAndroid Build Coastguard Worker            try:
315*da0073e9SAndroid Build Coastguard Worker                options = direct_solve(options_ctype)
316*da0073e9SAndroid Build Coastguard Worker                return f"{options}.device_opt()"
317*da0073e9SAndroid Build Coastguard Worker            except UnsatError:
318*da0073e9SAndroid Build Coastguard Worker                out_tensor = direct_solve(out_tensor_ctype)
319*da0073e9SAndroid Build Coastguard Worker                return f"{out_tensor}.device()"
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Worker        elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))):
322*da0073e9SAndroid Build Coastguard Worker            try:
323*da0073e9SAndroid Build Coastguard Worker                options = direct_solve(options_ctype)
324*da0073e9SAndroid Build Coastguard Worker                return f"{options}.pinned_memory_opt()"
325*da0073e9SAndroid Build Coastguard Worker            except UnsatError:
326*da0073e9SAndroid Build Coastguard Worker                # If we're calling a factory op from its out= variant,
327*da0073e9SAndroid Build Coastguard Worker                # We don't actually care about the value of pin_memory.
328*da0073e9SAndroid Build Coastguard Worker                out_tensor = direct_solve(out_tensor_ctype)
329*da0073e9SAndroid Build Coastguard Worker                return "::std::nullopt"
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker        # We can always do translations from value types to reference types, like vector<int> -> IntArrayRef
332*da0073e9SAndroid Build Coastguard Worker        elif goal.type == BaseCType(intArrayRefT):
333*da0073e9SAndroid Build Coastguard Worker            try:
334*da0073e9SAndroid Build Coastguard Worker                return direct_solve(NamedCType(goal.name, longVec_ctype))
335*da0073e9SAndroid Build Coastguard Worker            except UnsatError:
336*da0073e9SAndroid Build Coastguard Worker                # We can also go SymIntArrayRef -> IntArrayRef
337*da0073e9SAndroid Build Coastguard Worker                symIntArrayRef_type = direct_solve(
338*da0073e9SAndroid Build Coastguard Worker                    NamedCType(goal.name, BaseCType(symIntArrayRefT))
339*da0073e9SAndroid Build Coastguard Worker                )
340*da0073e9SAndroid Build Coastguard Worker                return f"C10_AS_INTARRAYREF_SLOW({symIntArrayRef_type})"
341*da0073e9SAndroid Build Coastguard Worker        elif goal.type == BaseCType(symIntArrayRefT):
342*da0073e9SAndroid Build Coastguard Worker            try:
343*da0073e9SAndroid Build Coastguard Worker                r = direct_solve(NamedCType(goal.name, BaseCType(intArrayRefT)))
344*da0073e9SAndroid Build Coastguard Worker                return f"c10::fromIntArrayRefSlow({r})"
345*da0073e9SAndroid Build Coastguard Worker            except UnsatError:
346*da0073e9SAndroid Build Coastguard Worker                return direct_solve(NamedCType(goal.name, longSymVec_ctype))
347*da0073e9SAndroid Build Coastguard Worker        elif goal.type == BaseCType(SymIntT):
348*da0073e9SAndroid Build Coastguard Worker            return direct_solve(NamedCType(goal.name, BaseCType(longT)))
349*da0073e9SAndroid Build Coastguard Worker        elif goal.type == OptionalCType(BaseCType(SymIntT)):
350*da0073e9SAndroid Build Coastguard Worker            argname = direct_solve(
351*da0073e9SAndroid Build Coastguard Worker                NamedCType(goal.name, OptionalCType(BaseCType(longT)))
352*da0073e9SAndroid Build Coastguard Worker            )
353*da0073e9SAndroid Build Coastguard Worker            return f"{argname}.has_value() ? ::std::make_optional(c10::SymInt(*{argname})) : ::std::nullopt"
354*da0073e9SAndroid Build Coastguard Worker        elif goal.type == BaseCType(longT):
355*da0073e9SAndroid Build Coastguard Worker            symInt_type = direct_solve(NamedCType(goal.name, BaseCType(SymIntT)))
356*da0073e9SAndroid Build Coastguard Worker            return f"{symInt_type}.guard_int(__FILE__, __LINE__)"
357*da0073e9SAndroid Build Coastguard Worker        elif goal.type == OptionalCType(BaseCType(longT)):
358*da0073e9SAndroid Build Coastguard Worker            argname = direct_solve(
359*da0073e9SAndroid Build Coastguard Worker                NamedCType(goal.name, OptionalCType(BaseCType(SymIntT)))
360*da0073e9SAndroid Build Coastguard Worker            )
361*da0073e9SAndroid Build Coastguard Worker            return f"{argname}.has_value() ? ::std::make_optional({argname}->guard_int(__FILE__, __LINE__)) : ::std::nullopt"
362*da0073e9SAndroid Build Coastguard Worker        elif goal.type == BaseCType(optionalIntArrayRefT):
363*da0073e9SAndroid Build Coastguard Worker            try:
364*da0073e9SAndroid Build Coastguard Worker                return direct_solve(NamedCType(goal.name, optionalLongVec_ctype))
365*da0073e9SAndroid Build Coastguard Worker            except UnsatError:
366*da0073e9SAndroid Build Coastguard Worker                argname = direct_solve(
367*da0073e9SAndroid Build Coastguard Worker                    NamedCType(goal.name, BaseCType(optionalSymIntArrayRefT))
368*da0073e9SAndroid Build Coastguard Worker                )
369*da0073e9SAndroid Build Coastguard Worker                return f"{argname}.has_value() ? ::std::make_optional(C10_AS_INTARRAYREF_SLOW(*{argname})) : ::std::nullopt"
370*da0073e9SAndroid Build Coastguard Worker        elif goal.type == BaseCType(optionalSymIntArrayRefT):
371*da0073e9SAndroid Build Coastguard Worker            # TODO: You might also want to solve this from longSymVec_ctype or
372*da0073e9SAndroid Build Coastguard Worker            # an optional version of it
373*da0073e9SAndroid Build Coastguard Worker            argname = direct_solve(
374*da0073e9SAndroid Build Coastguard Worker                NamedCType(goal.name, BaseCType(optionalIntArrayRefT))
375*da0073e9SAndroid Build Coastguard Worker            )
376*da0073e9SAndroid Build Coastguard Worker            return f"{argname}.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*{argname})) : ::std::nullopt"
377*da0073e9SAndroid Build Coastguard Worker        elif goal.type == BaseCType(optionalScalarRefT):
378*da0073e9SAndroid Build Coastguard Worker            return direct_solve(NamedCType(goal.name, optionalScalar_ctype))
379*da0073e9SAndroid Build Coastguard Worker        elif goal.type == BaseCType(optionalTensorRefT):
380*da0073e9SAndroid Build Coastguard Worker            return direct_solve(NamedCType(goal.name, optionalTensor_ctype))
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker        # Note [translation from C++ reference to value types]
383*da0073e9SAndroid Build Coastguard Worker        # The below cases are all for when we have an argument with a reference type,
384*da0073e9SAndroid Build Coastguard Worker        # and a corresponding goal with a value type.
385*da0073e9SAndroid Build Coastguard Worker        # These are needed when we populate the inputs to a lambda capture and we need
386*da0073e9SAndroid Build Coastguard Worker        # to guarantee the lifetime of each captured argument.
387*da0073e9SAndroid Build Coastguard Worker        # We guard it with an explicit kwarg because converting to a value type is expensive
388*da0073e9SAndroid Build Coastguard Worker        # (O(n)) to convert from IntArrayRef to vector<int>),
389*da0073e9SAndroid Build Coastguard Worker        # so the caller of translate() should be explicit that they need it.
390*da0073e9SAndroid Build Coastguard Worker        if allow_expensive_conversions:
391*da0073e9SAndroid Build Coastguard Worker            if goal.type == VectorCType(BaseCType(longT)):
392*da0073e9SAndroid Build Coastguard Worker                intArrayRef_ctype = NamedCType(goal.name, BaseCType(intArrayRefT))
393*da0073e9SAndroid Build Coastguard Worker                argname = direct_solve(intArrayRef_ctype)
394*da0073e9SAndroid Build Coastguard Worker                return f"{argname}.vec()"
395*da0073e9SAndroid Build Coastguard Worker            if goal.type == VectorCType(BaseCType(SymIntT)):
396*da0073e9SAndroid Build Coastguard Worker                symIntArrayRef_ctype = NamedCType(goal.name, BaseCType(symIntArrayRefT))
397*da0073e9SAndroid Build Coastguard Worker                argname = direct_solve(symIntArrayRef_ctype)
398*da0073e9SAndroid Build Coastguard Worker                return f"{argname}.vec()"
399*da0073e9SAndroid Build Coastguard Worker            elif goal.type == OptionalCType(VectorCType(BaseCType(longT))):
400*da0073e9SAndroid Build Coastguard Worker                optionalIntArrayRef_ctype = NamedCType(
401*da0073e9SAndroid Build Coastguard Worker                    goal.name, BaseCType(optionalIntArrayRefT)
402*da0073e9SAndroid Build Coastguard Worker                )
403*da0073e9SAndroid Build Coastguard Worker                argname = direct_solve(optionalIntArrayRef_ctype)
404*da0073e9SAndroid Build Coastguard Worker                return f"{argname}.has_value() ? ::std::make_optional({argname}->vec()) : ::std::nullopt"
405*da0073e9SAndroid Build Coastguard Worker            elif goal.type == OptionalCType(BaseCType(scalarT)):
406*da0073e9SAndroid Build Coastguard Worker                optionalScalarRef_ctype = NamedCType(
407*da0073e9SAndroid Build Coastguard Worker                    goal.name, BaseCType(optionalScalarRefT)
408*da0073e9SAndroid Build Coastguard Worker                )
409*da0073e9SAndroid Build Coastguard Worker                argname = direct_solve(optionalScalarRef_ctype)
410*da0073e9SAndroid Build Coastguard Worker                return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt"
411*da0073e9SAndroid Build Coastguard Worker            elif goal.type == OptionalCType(BaseCType(scalarT)):
412*da0073e9SAndroid Build Coastguard Worker                optionalTensorRef_ctype = NamedCType(
413*da0073e9SAndroid Build Coastguard Worker                    goal.name, BaseCType(optionalTensorRefT)
414*da0073e9SAndroid Build Coastguard Worker                )
415*da0073e9SAndroid Build Coastguard Worker                argname = direct_solve(optionalTensorRef_ctype)
416*da0073e9SAndroid Build Coastguard Worker                return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt"
417*da0073e9SAndroid Build Coastguard Worker            # Technically, we also need to handle cases of C++ containers holding reference types.
418*da0073e9SAndroid Build Coastguard Worker            # But there currently aren't any ops that require lambda capture codegen
419*da0073e9SAndroid Build Coastguard Worker            # With arguments like ::std::vector<IntArrayRef>.
420*da0073e9SAndroid Build Coastguard Worker            # If that changes, we'll have to add the translation here.
421*da0073e9SAndroid Build Coastguard Worker
422*da0073e9SAndroid Build Coastguard Worker        # We allow const casting on tensors, since const-correctness is a bit broken for at::Tensor.
423*da0073e9SAndroid Build Coastguard Worker        # We could probably generalize this to non-tensor types too.
424*da0073e9SAndroid Build Coastguard Worker        if goal.type == MutRefCType(BaseCType(tensorT)):
425*da0073e9SAndroid Build Coastguard Worker            const_ref_tensor_ctype = NamedCType(
426*da0073e9SAndroid Build Coastguard Worker                goal.name, ConstRefCType(BaseCType(tensorT))
427*da0073e9SAndroid Build Coastguard Worker            )
428*da0073e9SAndroid Build Coastguard Worker            argname = direct_solve(const_ref_tensor_ctype)
429*da0073e9SAndroid Build Coastguard Worker            return f"const_cast<Tensor&>({argname})"
430*da0073e9SAndroid Build Coastguard Worker
431*da0073e9SAndroid Build Coastguard Worker        unsat(goal)
432*da0073e9SAndroid Build Coastguard Worker
433*da0073e9SAndroid Build Coastguard Worker    return [Expr(solve(g, direct=False), g) for g in goal_ctypes]
434