xref: /aosp_15_r20/external/executorch/exir/operator/convert.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerr"""
10*523fa7a6SAndroid Build Coastguard WorkerHandle the following op convertions:
11*523fa7a6SAndroid Build Coastguard Worker- convert a functional op to an out variant op
12*523fa7a6SAndroid Build Coastguard Worker- convert an out variant op to a scratch op.
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard WorkerWe assume there is already a functionalization pass being done that removes aliases and inplace variants.
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard WorkerFor the to_out_variant convertion, The functional variant will be represented
17*523fa7a6SAndroid Build Coastguard Workeras qualified op name plus the overload name. The returned out variant constains
18*523fa7a6SAndroid Build Coastguard Workerthe following information
19*523fa7a6SAndroid Build Coastguard Worker- the OpOverload for the out variant
20*523fa7a6SAndroid Build Coastguard Worker- the list of keyward arguments names that are out variables. There should be
21*523fa7a6SAndroid Build Coastguard Worker  at least one out variables. Some ops may also have multiple out variables,
22*523fa7a6SAndroid Build Coastguard Worker  e.g. aten::topk.values returns both values and indices for the topk elements.
23*523fa7a6SAndroid Build Coastguard Worker
24*523fa7a6SAndroid Build Coastguard Worker"""
25*523fa7a6SAndroid Build Coastguard Worker
26*523fa7a6SAndroid Build Coastguard Workerimport dataclasses
27*523fa7a6SAndroid Build Coastguard Workerimport logging
28*523fa7a6SAndroid Build Coastguard Workerfrom typing import Dict, Optional, Tuple
29*523fa7a6SAndroid Build Coastguard Worker
30*523fa7a6SAndroid Build Coastguard Workerimport torch
31*523fa7a6SAndroid Build Coastguard Workerfrom torch._ops import OpOverload
32*523fa7a6SAndroid Build Coastguard Workerfrom torchgen.model import FunctionSchema, SchemaKind
33*523fa7a6SAndroid Build Coastguard Worker
34*523fa7a6SAndroid Build Coastguard Worker# cache the FunctionSchema so we don't need to parse everytime>
35*523fa7a6SAndroid Build Coastguard Worker# Use OpOverload as hash key. We can not use torch._C.FunctionSchema as key since
36*523fa7a6SAndroid Build Coastguard Worker# it's not hashable.
37*523fa7a6SAndroid Build Coastguard Worker_op_overload_to_schema_cache: Dict[OpOverload, FunctionSchema] = {}
38*523fa7a6SAndroid Build Coastguard Worker
39*523fa7a6SAndroid Build Coastguard Worker# Value type is Optional so we can cache None if an op does not have
40*523fa7a6SAndroid Build Coastguard Worker# out variant/scratch op. This way, we don't need to confuse the op not
41*523fa7a6SAndroid Build Coastguard Worker# existing case with cache miss.
42*523fa7a6SAndroid Build Coastguard Worker_func_to_out_variant_map: Dict[OpOverload, Optional[OpOverload]] = {}
43*523fa7a6SAndroid Build Coastguard Worker_out_variant_to_scratch_map: Dict[OpOverload, Optional[OpOverload]] = {}
44*523fa7a6SAndroid Build Coastguard Worker_mutable_to_out_variant_map: Dict[OpOverload, Optional[OpOverload]] = {}
45*523fa7a6SAndroid Build Coastguard Worker
46*523fa7a6SAndroid Build Coastguard Worker# We've found a functional and an out variant with the same name, but their
47*523fa7a6SAndroid Build Coastguard Worker# schemas mismatch. This map collects all of these cases and provides proper
48*523fa7a6SAndroid Build Coastguard Worker# error message to user. The key is an `OpOverload` of a functional variant.
49*523fa7a6SAndroid Build Coastguard Worker_schema_mismatch_map: Dict[OpOverload, Optional[FunctionSchema]] = {}
50*523fa7a6SAndroid Build Coastguard Worker
51*523fa7a6SAndroid Build Coastguard Worker
52*523fa7a6SAndroid Build Coastguard Workerdef _pybind_schema_to_native_schema(
53*523fa7a6SAndroid Build Coastguard Worker    pybind_schema: torch._C.FunctionSchema,
54*523fa7a6SAndroid Build Coastguard Worker) -> Optional[FunctionSchema]:
55*523fa7a6SAndroid Build Coastguard Worker    """
56*523fa7a6SAndroid Build Coastguard Worker    We have 2 FunctionSchema definitions in python.
57*523fa7a6SAndroid Build Coastguard Worker    One is defined in torchgen (call it native FunctionSchema), another is a
58*523fa7a6SAndroid Build Coastguard Worker    pybind of c10::FunctionSchema (call it pybind FunctionSchema).
59*523fa7a6SAndroid Build Coastguard Worker    Because we want to leverage torchgen to handle out variant, we will
60*523fa7a6SAndroid Build Coastguard Worker    convert any pybind FunctionSchema to native FunctionSchema.
61*523fa7a6SAndroid Build Coastguard Worker    """
62*523fa7a6SAndroid Build Coastguard Worker    native_schema = None
63*523fa7a6SAndroid Build Coastguard Worker    try:
64*523fa7a6SAndroid Build Coastguard Worker        native_schema = FunctionSchema.parse(str(pybind_schema))
65*523fa7a6SAndroid Build Coastguard Worker    except (RuntimeError, AssertionError, ValueError):
66*523fa7a6SAndroid Build Coastguard Worker        # Need catch AssertionError since parsing prim ops like:
67*523fa7a6SAndroid Build Coastguard Worker        #   aten::to.prim_other(Tensor(a) self, bool non_blocking=False, bool copy=False) -> Tensor(a|b)
68*523fa7a6SAndroid Build Coastguard Worker        # cause an asertion error in torchgen when parsiong annotation 'a|b'.
69*523fa7a6SAndroid Build Coastguard Worker        # We should ignore it. Hopefully one day the C++ FunctionSchema parsing
70*523fa7a6SAndroid Build Coastguard Worker        # is 100% consistent with Python FunctionSchema parsing, then we don't need
71*523fa7a6SAndroid Build Coastguard Worker        # catch these exceptions any more.
72*523fa7a6SAndroid Build Coastguard Worker
73*523fa7a6SAndroid Build Coastguard Worker        # We also need catch ValueError for schema like:
74*523fa7a6SAndroid Build Coastguard Worker        #   aten::copy.Dict_str(Dict(str, t)(a) self) -> Dict(str, t)
75*523fa7a6SAndroid Build Coastguard Worker        # torchgen throws ValueError since it does not expect the type string
76*523fa7a6SAndroid Build Coastguard Worker        # containing commas. Ignore those schemas for now.
77*523fa7a6SAndroid Build Coastguard Worker        logging.debug(f"Fail to parse function schema: {str(pybind_schema)}")
78*523fa7a6SAndroid Build Coastguard Worker        # ignore failure and return None. There are some schemas defined as
79*523fa7a6SAndroid Build Coastguard Worker        # prim ops that can not be parsed by torchgen. E.g.:
80*523fa7a6SAndroid Build Coastguard Worker        #   https://www.fburl.com/code/1vvzhssa
81*523fa7a6SAndroid Build Coastguard Worker        # We should be safe to ignore them since PyE are not using these ops.
82*523fa7a6SAndroid Build Coastguard Worker    return native_schema
83*523fa7a6SAndroid Build Coastguard Worker
84*523fa7a6SAndroid Build Coastguard Worker
85*523fa7a6SAndroid Build Coastguard Workerdef _get_overload_schema(op_overload: OpOverload) -> Optional[FunctionSchema]:
86*523fa7a6SAndroid Build Coastguard Worker    native_schema = _op_overload_to_schema_cache.get(op_overload)
87*523fa7a6SAndroid Build Coastguard Worker    if not native_schema:
88*523fa7a6SAndroid Build Coastguard Worker        native_schema = _pybind_schema_to_native_schema(op_overload._schema)
89*523fa7a6SAndroid Build Coastguard Worker        _op_overload_to_schema_cache[op_overload] = native_schema  # pyre-ignore
90*523fa7a6SAndroid Build Coastguard Worker    return native_schema
91*523fa7a6SAndroid Build Coastguard Worker
92*523fa7a6SAndroid Build Coastguard Worker
93*523fa7a6SAndroid Build Coastguard Workerdef get_out_args_from_opoverload(op_overload: OpOverload) -> Tuple[str]:
94*523fa7a6SAndroid Build Coastguard Worker    return get_out_args_from_schema(_get_overload_schema(op_overload))  # pyre-ignore
95*523fa7a6SAndroid Build Coastguard Worker
96*523fa7a6SAndroid Build Coastguard Worker
97*523fa7a6SAndroid Build Coastguard Workerdef get_out_args_from_schema(out_var_schema: FunctionSchema) -> Tuple[str]:
98*523fa7a6SAndroid Build Coastguard Worker    """
99*523fa7a6SAndroid Build Coastguard Worker    Assume the input is the schema for an out variant.
100*523fa7a6SAndroid Build Coastguard Worker    Return the name list of the out arguments.
101*523fa7a6SAndroid Build Coastguard Worker    """
102*523fa7a6SAndroid Build Coastguard Worker    assert (
103*523fa7a6SAndroid Build Coastguard Worker        out_var_schema.is_out_fn()
104*523fa7a6SAndroid Build Coastguard Worker    ), f"Expect an out variant, but get: {out_var_schema}"
105*523fa7a6SAndroid Build Coastguard Worker    return tuple(arg.name for arg in out_var_schema.arguments.out)
106*523fa7a6SAndroid Build Coastguard Worker
107*523fa7a6SAndroid Build Coastguard Worker
108*523fa7a6SAndroid Build Coastguard Workerdef parse_qualified_opname(qualified_opname: str) -> Tuple[str, str]:
109*523fa7a6SAndroid Build Coastguard Worker    """
110*523fa7a6SAndroid Build Coastguard Worker    Given a qualified opname like aten::add, return a tuple for namespace
111*523fa7a6SAndroid Build Coastguard Worker    (aten here) and op name (add here)
112*523fa7a6SAndroid Build Coastguard Worker    """
113*523fa7a6SAndroid Build Coastguard Worker    ns_and_opname = qualified_opname.split("::")
114*523fa7a6SAndroid Build Coastguard Worker    if len(ns_and_opname) != 2:
115*523fa7a6SAndroid Build Coastguard Worker        raise RuntimeError(f"Invalid qualified_opname {qualified_opname}")
116*523fa7a6SAndroid Build Coastguard Worker    return tuple(ns_and_opname)
117*523fa7a6SAndroid Build Coastguard Worker
118*523fa7a6SAndroid Build Coastguard Worker
119*523fa7a6SAndroid Build Coastguard Workerdef get_op_overload(qualified_opname: str, overload: str) -> OpOverload:
120*523fa7a6SAndroid Build Coastguard Worker    """
121*523fa7a6SAndroid Build Coastguard Worker    Arguments:
122*523fa7a6SAndroid Build Coastguard Worker        qualified_opname: string like {namespace}::{op name}
123*523fa7a6SAndroid Build Coastguard Worker        overload: the overload string of the op
124*523fa7a6SAndroid Build Coastguard Worker    """
125*523fa7a6SAndroid Build Coastguard Worker    ns, opname = parse_qualified_opname(qualified_opname)
126*523fa7a6SAndroid Build Coastguard Worker    if not overload:
127*523fa7a6SAndroid Build Coastguard Worker        overload = "default"
128*523fa7a6SAndroid Build Coastguard Worker    return getattr(getattr(getattr(torch.ops, ns), opname), overload)
129*523fa7a6SAndroid Build Coastguard Worker
130*523fa7a6SAndroid Build Coastguard Worker
131*523fa7a6SAndroid Build Coastguard Workerdef schema_to_opoverload(schema: FunctionSchema) -> OpOverload:
132*523fa7a6SAndroid Build Coastguard Worker    qualified_name = str(schema.name.name)
133*523fa7a6SAndroid Build Coastguard Worker    overload = schema.name.overload_name
134*523fa7a6SAndroid Build Coastguard Worker    return get_op_overload(qualified_name, overload)
135*523fa7a6SAndroid Build Coastguard Worker
136*523fa7a6SAndroid Build Coastguard Worker
137*523fa7a6SAndroid Build Coastguard Workerdef set_mapping_for_op(op: OpOverload) -> None:
138*523fa7a6SAndroid Build Coastguard Worker    """
139*523fa7a6SAndroid Build Coastguard Worker    op can either be a functional op, mutable op, or out variant op.
140*523fa7a6SAndroid Build Coastguard Worker    This method is only called if
141*523fa7a6SAndroid Build Coastguard Worker    1. either op is a functional op and it's missing in the _func_to_out_variant_map cache.
142*523fa7a6SAndroid Build Coastguard Worker    2. or op is a out variant op and it's missing in the _out_variant_to_scratch_map cache.
143*523fa7a6SAndroid Build Coastguard Worker
144*523fa7a6SAndroid Build Coastguard Worker    Setup entries in _func_to_out_variant_map and _out_variant_to_scratch_map for all ops sharing the same
145*523fa7a6SAndroid Build Coastguard Worker    op name as the passed in OpOverload.
146*523fa7a6SAndroid Build Coastguard Worker    """
147*523fa7a6SAndroid Build Coastguard Worker    native_schema = _pybind_schema_to_native_schema(op._schema)
148*523fa7a6SAndroid Build Coastguard Worker    # pyre-fixme[16]: `Optional` has no attribute `kind`.
149*523fa7a6SAndroid Build Coastguard Worker    assert native_schema.kind() in (
150*523fa7a6SAndroid Build Coastguard Worker        SchemaKind.functional,
151*523fa7a6SAndroid Build Coastguard Worker        SchemaKind.out,
152*523fa7a6SAndroid Build Coastguard Worker        SchemaKind.mutable,
153*523fa7a6SAndroid Build Coastguard Worker    )
154*523fa7a6SAndroid Build Coastguard Worker    assert not (
155*523fa7a6SAndroid Build Coastguard Worker        native_schema.kind() == SchemaKind.functional and op in _func_to_out_variant_map
156*523fa7a6SAndroid Build Coastguard Worker    )
157*523fa7a6SAndroid Build Coastguard Worker    assert not (
158*523fa7a6SAndroid Build Coastguard Worker        native_schema.kind() == SchemaKind.out and op in _out_variant_to_scratch_map
159*523fa7a6SAndroid Build Coastguard Worker    )
160*523fa7a6SAndroid Build Coastguard Worker    assert not (
161*523fa7a6SAndroid Build Coastguard Worker        native_schema.kind() == SchemaKind.mutable and op in _mutable_to_out_variant_map
162*523fa7a6SAndroid Build Coastguard Worker    )
163*523fa7a6SAndroid Build Coastguard Worker    qualified_opname = str(op._schema.name)
164*523fa7a6SAndroid Build Coastguard Worker
165*523fa7a6SAndroid Build Coastguard Worker    all_schemas = [
166*523fa7a6SAndroid Build Coastguard Worker        _pybind_schema_to_native_schema(pybind_schema)
167*523fa7a6SAndroid Build Coastguard Worker        for pybind_schema in torch._C._jit_get_schemas_for_operator(qualified_opname)
168*523fa7a6SAndroid Build Coastguard Worker    ]
169*523fa7a6SAndroid Build Coastguard Worker
170*523fa7a6SAndroid Build Coastguard Worker    # skip the schema that we can not be parsed by torchgen
171*523fa7a6SAndroid Build Coastguard Worker    all_schemas = [schema for schema in all_schemas if schema is not None]
172*523fa7a6SAndroid Build Coastguard Worker
173*523fa7a6SAndroid Build Coastguard Worker    group_by_signature: Dict[str, Dict[SchemaKind, FunctionSchema]] = {}
174*523fa7a6SAndroid Build Coastguard Worker
175*523fa7a6SAndroid Build Coastguard Worker    for schema in all_schemas:
176*523fa7a6SAndroid Build Coastguard Worker        signature = schema.signature()
177*523fa7a6SAndroid Build Coastguard Worker        # override the return type to an empty tuple. Otherwise,  for ops like
178*523fa7a6SAndroid Build Coastguard Worker        # aten.slice.Tensor_out that returns a Tensor list,
179*523fa7a6SAndroid Build Coastguard Worker        # the signature of the schema does not match the one for the functional
180*523fa7a6SAndroid Build Coastguard Worker        # op aten.slice.Tensor because of different return type.
181*523fa7a6SAndroid Build Coastguard Worker        # Schema for aten.slice.Tensor_out:
182*523fa7a6SAndroid Build Coastguard Worker        #   split.Tensor_out(Tensor(a -> *) self, int split_size, int dim=0, *, Tensor(a!)[] out) -> ()
183*523fa7a6SAndroid Build Coastguard Worker        # Schema for aten.slice.Tensor
184*523fa7a6SAndroid Build Coastguard Worker        #   split.Tensor(Tensor(a -> *) self, int split_size, int dim=0) -> Tensor(a)[]
185*523fa7a6SAndroid Build Coastguard Worker        # The reason of the above inconsistency is explained in: https://github.com/pytorch/pytorch/pull/76049
186*523fa7a6SAndroid Build Coastguard Worker        signature = dataclasses.replace(signature, returns=())
187*523fa7a6SAndroid Build Coastguard Worker
188*523fa7a6SAndroid Build Coastguard Worker        kind = schema.kind()
189*523fa7a6SAndroid Build Coastguard Worker        # pyre-fixme[6]: For 1st argument expected `str` but got `FunctionSchema`.
190*523fa7a6SAndroid Build Coastguard Worker        group_by_kind = group_by_signature.setdefault(signature, {})
191*523fa7a6SAndroid Build Coastguard Worker        assert (
192*523fa7a6SAndroid Build Coastguard Worker            kind not in group_by_kind
193*523fa7a6SAndroid Build Coastguard Worker        ), f"Schema of kind {kind} already exist for {schema}"
194*523fa7a6SAndroid Build Coastguard Worker        group_by_kind[kind] = schema
195*523fa7a6SAndroid Build Coastguard Worker
196*523fa7a6SAndroid Build Coastguard Worker    # add all the functional op -> out variant op pairs to the cache
197*523fa7a6SAndroid Build Coastguard Worker    for group_by_kind in group_by_signature.values():
198*523fa7a6SAndroid Build Coastguard Worker        func_op_schema = group_by_kind.get(SchemaKind.functional)
199*523fa7a6SAndroid Build Coastguard Worker        out_var_schema = group_by_kind.get(SchemaKind.out)
200*523fa7a6SAndroid Build Coastguard Worker        mutable_op_schema = group_by_kind.get(SchemaKind.mutable)
201*523fa7a6SAndroid Build Coastguard Worker        scratch_schema = group_by_kind.get(SchemaKind.scratch)
202*523fa7a6SAndroid Build Coastguard Worker
203*523fa7a6SAndroid Build Coastguard Worker        # update the map even if out_var_schema is None to cache the negative
204*523fa7a6SAndroid Build Coastguard Worker        # case
205*523fa7a6SAndroid Build Coastguard Worker        if func_op_schema:
206*523fa7a6SAndroid Build Coastguard Worker            _func_to_out_variant_map[schema_to_opoverload(func_op_schema)] = (
207*523fa7a6SAndroid Build Coastguard Worker                schema_to_opoverload(out_var_schema) if out_var_schema else None
208*523fa7a6SAndroid Build Coastguard Worker            )
209*523fa7a6SAndroid Build Coastguard Worker            # out variant schema missing from group_by_kind
210*523fa7a6SAndroid Build Coastguard Worker            if out_var_schema is None:
211*523fa7a6SAndroid Build Coastguard Worker                # find the out variant with a schema different than the functional variant
212*523fa7a6SAndroid Build Coastguard Worker                mismatched_out_schema: Optional[FunctionSchema] = next(
213*523fa7a6SAndroid Build Coastguard Worker                    (s for s in all_schemas if s.kind() == SchemaKind.out), None
214*523fa7a6SAndroid Build Coastguard Worker                )
215*523fa7a6SAndroid Build Coastguard Worker                _schema_mismatch_map[schema_to_opoverload(func_op_schema)] = (
216*523fa7a6SAndroid Build Coastguard Worker                    mismatched_out_schema
217*523fa7a6SAndroid Build Coastguard Worker                )
218*523fa7a6SAndroid Build Coastguard Worker
219*523fa7a6SAndroid Build Coastguard Worker        # update hte map even if scratch_schema is None to cache the negative
220*523fa7a6SAndroid Build Coastguard Worker        # case
221*523fa7a6SAndroid Build Coastguard Worker        if out_var_schema:
222*523fa7a6SAndroid Build Coastguard Worker            _out_variant_to_scratch_map[schema_to_opoverload(out_var_schema)] = (
223*523fa7a6SAndroid Build Coastguard Worker                schema_to_opoverload(scratch_schema) if scratch_schema else None
224*523fa7a6SAndroid Build Coastguard Worker            )
225*523fa7a6SAndroid Build Coastguard Worker        if mutable_op_schema:
226*523fa7a6SAndroid Build Coastguard Worker            _mutable_to_out_variant_map[schema_to_opoverload(mutable_op_schema)] = (
227*523fa7a6SAndroid Build Coastguard Worker                schema_to_opoverload(out_var_schema) if out_var_schema else None
228*523fa7a6SAndroid Build Coastguard Worker            )
229*523fa7a6SAndroid Build Coastguard Worker
230*523fa7a6SAndroid Build Coastguard Worker
231*523fa7a6SAndroid Build Coastguard Workerdef to_out_variant(op_overload: OpOverload) -> Tuple[OpOverload, Tuple[str]]:
232*523fa7a6SAndroid Build Coastguard Worker    r"""
233*523fa7a6SAndroid Build Coastguard Worker    Convert the passed in OpOverload to its out variant. Raise an exception if
234*523fa7a6SAndroid Build Coastguard Worker    on return the op_overload is not guaranteed to be an out variant.
235*523fa7a6SAndroid Build Coastguard Worker
236*523fa7a6SAndroid Build Coastguard Worker    If a conversion is found, return the out variant OpOverload alongwith the name of out
237*523fa7a6SAndroid Build Coastguard Worker    arguments.
238*523fa7a6SAndroid Build Coastguard Worker    """
239*523fa7a6SAndroid Build Coastguard Worker    schema = _get_overload_schema(op_overload)
240*523fa7a6SAndroid Build Coastguard Worker    if schema.is_out_fn():  # pyre-ignore
241*523fa7a6SAndroid Build Coastguard Worker        return op_overload, get_out_args_from_schema(schema)  # pyre-ignore[6]
242*523fa7a6SAndroid Build Coastguard Worker
243*523fa7a6SAndroid Build Coastguard Worker    # should be a functionalish op here
244*523fa7a6SAndroid Build Coastguard Worker    assert (
245*523fa7a6SAndroid Build Coastguard Worker        schema.kind() == SchemaKind.functional  # pyre-ignore[16]
246*523fa7a6SAndroid Build Coastguard Worker        or schema.kind() == SchemaKind.mutable
247*523fa7a6SAndroid Build Coastguard Worker    ), f"Expect a functionalish op, but get {schema.kind()} {schema}"
248*523fa7a6SAndroid Build Coastguard Worker
249*523fa7a6SAndroid Build Coastguard Worker    if (
250*523fa7a6SAndroid Build Coastguard Worker        op_overload not in _func_to_out_variant_map
251*523fa7a6SAndroid Build Coastguard Worker        and op_overload not in _mutable_to_out_variant_map
252*523fa7a6SAndroid Build Coastguard Worker    ):
253*523fa7a6SAndroid Build Coastguard Worker        # setup out_var
254*523fa7a6SAndroid Build Coastguard Worker        set_mapping_for_op(op_overload)
255*523fa7a6SAndroid Build Coastguard Worker
256*523fa7a6SAndroid Build Coastguard Worker    if op_overload in _mutable_to_out_variant_map:
257*523fa7a6SAndroid Build Coastguard Worker        out_var = _mutable_to_out_variant_map[op_overload]
258*523fa7a6SAndroid Build Coastguard Worker    else:
259*523fa7a6SAndroid Build Coastguard Worker        out_var = _func_to_out_variant_map.get(op_overload)
260*523fa7a6SAndroid Build Coastguard Worker
261*523fa7a6SAndroid Build Coastguard Worker    if not out_var:
262*523fa7a6SAndroid Build Coastguard Worker        msg = f"Missing out variant for functional op: {schema} . Make sure you have loaded your custom operator library for compiler. E.g., custom_ops_generated_lib"
263*523fa7a6SAndroid Build Coastguard Worker        if op_overload in _schema_mismatch_map:
264*523fa7a6SAndroid Build Coastguard Worker            if _schema_mismatch_map[op_overload]:
265*523fa7a6SAndroid Build Coastguard Worker                msg += (
266*523fa7a6SAndroid Build Coastguard Worker                    f"\nFound an out variant for operator name {op_overload.name()} but its schema mismatched with functional op."
267*523fa7a6SAndroid Build Coastguard Worker                    f"\nfunctional op schema:\t{schema}"
268*523fa7a6SAndroid Build Coastguard Worker                    f"\nout variant op schema:\t{_schema_mismatch_map[op_overload]}"
269*523fa7a6SAndroid Build Coastguard Worker                )
270*523fa7a6SAndroid Build Coastguard Worker        raise RuntimeError(msg)
271*523fa7a6SAndroid Build Coastguard Worker
272*523fa7a6SAndroid Build Coastguard Worker    return out_var, get_out_args_from_opoverload(out_var)
273*523fa7a6SAndroid Build Coastguard Worker
274*523fa7a6SAndroid Build Coastguard Worker
275*523fa7a6SAndroid Build Coastguard Workerdef to_scratch_op(op_overload: OpOverload) -> Optional[OpOverload]:
276*523fa7a6SAndroid Build Coastguard Worker    schema = _get_overload_schema(op_overload)
277*523fa7a6SAndroid Build Coastguard Worker
278*523fa7a6SAndroid Build Coastguard Worker    # If the op is not an out variant, then we must have ignored some failure in to_out_var
279*523fa7a6SAndroid Build Coastguard Worker    # pass. Return immediately rather than throwing an exception since the user must have ignores
280*523fa7a6SAndroid Build Coastguard Worker    # errors for some reason (e.g. desigin some special unit tests, or unblock new
281*523fa7a6SAndroid Build Coastguard Worker    # use cases).
282*523fa7a6SAndroid Build Coastguard Worker    if schema.kind() != SchemaKind.out:  # pyre-ignore
283*523fa7a6SAndroid Build Coastguard Worker        logging.debug(f"Expect an out variant op as input, got: {schema.kind()}")
284*523fa7a6SAndroid Build Coastguard Worker        return None
285*523fa7a6SAndroid Build Coastguard Worker
286*523fa7a6SAndroid Build Coastguard Worker    if op_overload not in _out_variant_to_scratch_map:
287*523fa7a6SAndroid Build Coastguard Worker        set_mapping_for_op(op_overload)
288*523fa7a6SAndroid Build Coastguard Worker    scratch_op = _out_variant_to_scratch_map.get(op_overload)
289*523fa7a6SAndroid Build Coastguard Worker
290*523fa7a6SAndroid Build Coastguard Worker    # scratch_op can be None
291*523fa7a6SAndroid Build Coastguard Worker    return scratch_op
292*523fa7a6SAndroid Build Coastguard Worker
293*523fa7a6SAndroid Build Coastguard Worker
294*523fa7a6SAndroid Build Coastguard Workerdef is_out_variant(qualified_opname: str, overload: str) -> bool:
295*523fa7a6SAndroid Build Coastguard Worker    op_overload = get_op_overload(qualified_opname, overload)
296*523fa7a6SAndroid Build Coastguard Worker    schema = _get_overload_schema(op_overload)
297*523fa7a6SAndroid Build Coastguard Worker    if schema is None:
298*523fa7a6SAndroid Build Coastguard Worker        return False
299*523fa7a6SAndroid Build Coastguard Worker    return schema.is_out_fn()
300*523fa7a6SAndroid Build Coastguard Worker
301*523fa7a6SAndroid Build Coastguard Worker
302*523fa7a6SAndroid Build Coastguard Workerdef is_inplace_variant(qualified_opname: str, overload: str) -> bool:
303*523fa7a6SAndroid Build Coastguard Worker    op_overload = get_op_overload(qualified_opname, overload)
304*523fa7a6SAndroid Build Coastguard Worker    schema = _get_overload_schema(op_overload)
305*523fa7a6SAndroid Build Coastguard Worker    if schema is None:
306*523fa7a6SAndroid Build Coastguard Worker        return False
307*523fa7a6SAndroid Build Coastguard Worker    return schema.kind() == SchemaKind.inplace
308