1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerr""" 10*523fa7a6SAndroid Build Coastguard WorkerHandle the following op convertions: 11*523fa7a6SAndroid Build Coastguard Worker- convert a functional op to an out variant op 12*523fa7a6SAndroid Build Coastguard Worker- convert an out variant op to a scratch op. 13*523fa7a6SAndroid Build Coastguard Worker 14*523fa7a6SAndroid Build Coastguard WorkerWe assume there is already a functionalization pass being done that removes aliases and inplace variants. 15*523fa7a6SAndroid Build Coastguard Worker 16*523fa7a6SAndroid Build Coastguard WorkerFor the to_out_variant convertion, The functional variant will be represented 17*523fa7a6SAndroid Build Coastguard Workeras qualified op name plus the overload name. The returned out variant constains 18*523fa7a6SAndroid Build Coastguard Workerthe following information 19*523fa7a6SAndroid Build Coastguard Worker- the OpOverload for the out variant 20*523fa7a6SAndroid Build Coastguard Worker- the list of keyward arguments names that are out variables. There should be 21*523fa7a6SAndroid Build Coastguard Worker at least one out variables. Some ops may also have multiple out variables, 22*523fa7a6SAndroid Build Coastguard Worker e.g. aten::topk.values returns both values and indices for the topk elements. 23*523fa7a6SAndroid Build Coastguard Worker 24*523fa7a6SAndroid Build Coastguard Worker""" 25*523fa7a6SAndroid Build Coastguard Worker 26*523fa7a6SAndroid Build Coastguard Workerimport dataclasses 27*523fa7a6SAndroid Build Coastguard Workerimport logging 28*523fa7a6SAndroid Build Coastguard Workerfrom typing import Dict, Optional, Tuple 29*523fa7a6SAndroid Build Coastguard Worker 30*523fa7a6SAndroid Build Coastguard Workerimport torch 31*523fa7a6SAndroid Build Coastguard Workerfrom torch._ops import OpOverload 32*523fa7a6SAndroid Build Coastguard Workerfrom torchgen.model import FunctionSchema, SchemaKind 33*523fa7a6SAndroid Build Coastguard Worker 34*523fa7a6SAndroid Build Coastguard Worker# cache the FunctionSchema so we don't need to parse everytime> 35*523fa7a6SAndroid Build Coastguard Worker# Use OpOverload as hash key. We can not use torch._C.FunctionSchema as key since 36*523fa7a6SAndroid Build Coastguard Worker# it's not hashable. 37*523fa7a6SAndroid Build Coastguard Worker_op_overload_to_schema_cache: Dict[OpOverload, FunctionSchema] = {} 38*523fa7a6SAndroid Build Coastguard Worker 39*523fa7a6SAndroid Build Coastguard Worker# Value type is Optional so we can cache None if an op does not have 40*523fa7a6SAndroid Build Coastguard Worker# out variant/scratch op. This way, we don't need to confuse the op not 41*523fa7a6SAndroid Build Coastguard Worker# existing case with cache miss. 42*523fa7a6SAndroid Build Coastguard Worker_func_to_out_variant_map: Dict[OpOverload, Optional[OpOverload]] = {} 43*523fa7a6SAndroid Build Coastguard Worker_out_variant_to_scratch_map: Dict[OpOverload, Optional[OpOverload]] = {} 44*523fa7a6SAndroid Build Coastguard Worker_mutable_to_out_variant_map: Dict[OpOverload, Optional[OpOverload]] = {} 45*523fa7a6SAndroid Build Coastguard Worker 46*523fa7a6SAndroid Build Coastguard Worker# We've found a functional and an out variant with the same name, but their 47*523fa7a6SAndroid Build Coastguard Worker# schemas mismatch. This map collects all of these cases and provides proper 48*523fa7a6SAndroid Build Coastguard Worker# error message to user. The key is an `OpOverload` of a functional variant. 49*523fa7a6SAndroid Build Coastguard Worker_schema_mismatch_map: Dict[OpOverload, Optional[FunctionSchema]] = {} 50*523fa7a6SAndroid Build Coastguard Worker 51*523fa7a6SAndroid Build Coastguard Worker 52*523fa7a6SAndroid Build Coastguard Workerdef _pybind_schema_to_native_schema( 53*523fa7a6SAndroid Build Coastguard Worker pybind_schema: torch._C.FunctionSchema, 54*523fa7a6SAndroid Build Coastguard Worker) -> Optional[FunctionSchema]: 55*523fa7a6SAndroid Build Coastguard Worker """ 56*523fa7a6SAndroid Build Coastguard Worker We have 2 FunctionSchema definitions in python. 57*523fa7a6SAndroid Build Coastguard Worker One is defined in torchgen (call it native FunctionSchema), another is a 58*523fa7a6SAndroid Build Coastguard Worker pybind of c10::FunctionSchema (call it pybind FunctionSchema). 59*523fa7a6SAndroid Build Coastguard Worker Because we want to leverage torchgen to handle out variant, we will 60*523fa7a6SAndroid Build Coastguard Worker convert any pybind FunctionSchema to native FunctionSchema. 61*523fa7a6SAndroid Build Coastguard Worker """ 62*523fa7a6SAndroid Build Coastguard Worker native_schema = None 63*523fa7a6SAndroid Build Coastguard Worker try: 64*523fa7a6SAndroid Build Coastguard Worker native_schema = FunctionSchema.parse(str(pybind_schema)) 65*523fa7a6SAndroid Build Coastguard Worker except (RuntimeError, AssertionError, ValueError): 66*523fa7a6SAndroid Build Coastguard Worker # Need catch AssertionError since parsing prim ops like: 67*523fa7a6SAndroid Build Coastguard Worker # aten::to.prim_other(Tensor(a) self, bool non_blocking=False, bool copy=False) -> Tensor(a|b) 68*523fa7a6SAndroid Build Coastguard Worker # cause an asertion error in torchgen when parsiong annotation 'a|b'. 69*523fa7a6SAndroid Build Coastguard Worker # We should ignore it. Hopefully one day the C++ FunctionSchema parsing 70*523fa7a6SAndroid Build Coastguard Worker # is 100% consistent with Python FunctionSchema parsing, then we don't need 71*523fa7a6SAndroid Build Coastguard Worker # catch these exceptions any more. 72*523fa7a6SAndroid Build Coastguard Worker 73*523fa7a6SAndroid Build Coastguard Worker # We also need catch ValueError for schema like: 74*523fa7a6SAndroid Build Coastguard Worker # aten::copy.Dict_str(Dict(str, t)(a) self) -> Dict(str, t) 75*523fa7a6SAndroid Build Coastguard Worker # torchgen throws ValueError since it does not expect the type string 76*523fa7a6SAndroid Build Coastguard Worker # containing commas. Ignore those schemas for now. 77*523fa7a6SAndroid Build Coastguard Worker logging.debug(f"Fail to parse function schema: {str(pybind_schema)}") 78*523fa7a6SAndroid Build Coastguard Worker # ignore failure and return None. There are some schemas defined as 79*523fa7a6SAndroid Build Coastguard Worker # prim ops that can not be parsed by torchgen. E.g.: 80*523fa7a6SAndroid Build Coastguard Worker # https://www.fburl.com/code/1vvzhssa 81*523fa7a6SAndroid Build Coastguard Worker # We should be safe to ignore them since PyE are not using these ops. 82*523fa7a6SAndroid Build Coastguard Worker return native_schema 83*523fa7a6SAndroid Build Coastguard Worker 84*523fa7a6SAndroid Build Coastguard Worker 85*523fa7a6SAndroid Build Coastguard Workerdef _get_overload_schema(op_overload: OpOverload) -> Optional[FunctionSchema]: 86*523fa7a6SAndroid Build Coastguard Worker native_schema = _op_overload_to_schema_cache.get(op_overload) 87*523fa7a6SAndroid Build Coastguard Worker if not native_schema: 88*523fa7a6SAndroid Build Coastguard Worker native_schema = _pybind_schema_to_native_schema(op_overload._schema) 89*523fa7a6SAndroid Build Coastguard Worker _op_overload_to_schema_cache[op_overload] = native_schema # pyre-ignore 90*523fa7a6SAndroid Build Coastguard Worker return native_schema 91*523fa7a6SAndroid Build Coastguard Worker 92*523fa7a6SAndroid Build Coastguard Worker 93*523fa7a6SAndroid Build Coastguard Workerdef get_out_args_from_opoverload(op_overload: OpOverload) -> Tuple[str]: 94*523fa7a6SAndroid Build Coastguard Worker return get_out_args_from_schema(_get_overload_schema(op_overload)) # pyre-ignore 95*523fa7a6SAndroid Build Coastguard Worker 96*523fa7a6SAndroid Build Coastguard Worker 97*523fa7a6SAndroid Build Coastguard Workerdef get_out_args_from_schema(out_var_schema: FunctionSchema) -> Tuple[str]: 98*523fa7a6SAndroid Build Coastguard Worker """ 99*523fa7a6SAndroid Build Coastguard Worker Assume the input is the schema for an out variant. 100*523fa7a6SAndroid Build Coastguard Worker Return the name list of the out arguments. 101*523fa7a6SAndroid Build Coastguard Worker """ 102*523fa7a6SAndroid Build Coastguard Worker assert ( 103*523fa7a6SAndroid Build Coastguard Worker out_var_schema.is_out_fn() 104*523fa7a6SAndroid Build Coastguard Worker ), f"Expect an out variant, but get: {out_var_schema}" 105*523fa7a6SAndroid Build Coastguard Worker return tuple(arg.name for arg in out_var_schema.arguments.out) 106*523fa7a6SAndroid Build Coastguard Worker 107*523fa7a6SAndroid Build Coastguard Worker 108*523fa7a6SAndroid Build Coastguard Workerdef parse_qualified_opname(qualified_opname: str) -> Tuple[str, str]: 109*523fa7a6SAndroid Build Coastguard Worker """ 110*523fa7a6SAndroid Build Coastguard Worker Given a qualified opname like aten::add, return a tuple for namespace 111*523fa7a6SAndroid Build Coastguard Worker (aten here) and op name (add here) 112*523fa7a6SAndroid Build Coastguard Worker """ 113*523fa7a6SAndroid Build Coastguard Worker ns_and_opname = qualified_opname.split("::") 114*523fa7a6SAndroid Build Coastguard Worker if len(ns_and_opname) != 2: 115*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(f"Invalid qualified_opname {qualified_opname}") 116*523fa7a6SAndroid Build Coastguard Worker return tuple(ns_and_opname) 117*523fa7a6SAndroid Build Coastguard Worker 118*523fa7a6SAndroid Build Coastguard Worker 119*523fa7a6SAndroid Build Coastguard Workerdef get_op_overload(qualified_opname: str, overload: str) -> OpOverload: 120*523fa7a6SAndroid Build Coastguard Worker """ 121*523fa7a6SAndroid Build Coastguard Worker Arguments: 122*523fa7a6SAndroid Build Coastguard Worker qualified_opname: string like {namespace}::{op name} 123*523fa7a6SAndroid Build Coastguard Worker overload: the overload string of the op 124*523fa7a6SAndroid Build Coastguard Worker """ 125*523fa7a6SAndroid Build Coastguard Worker ns, opname = parse_qualified_opname(qualified_opname) 126*523fa7a6SAndroid Build Coastguard Worker if not overload: 127*523fa7a6SAndroid Build Coastguard Worker overload = "default" 128*523fa7a6SAndroid Build Coastguard Worker return getattr(getattr(getattr(torch.ops, ns), opname), overload) 129*523fa7a6SAndroid Build Coastguard Worker 130*523fa7a6SAndroid Build Coastguard Worker 131*523fa7a6SAndroid Build Coastguard Workerdef schema_to_opoverload(schema: FunctionSchema) -> OpOverload: 132*523fa7a6SAndroid Build Coastguard Worker qualified_name = str(schema.name.name) 133*523fa7a6SAndroid Build Coastguard Worker overload = schema.name.overload_name 134*523fa7a6SAndroid Build Coastguard Worker return get_op_overload(qualified_name, overload) 135*523fa7a6SAndroid Build Coastguard Worker 136*523fa7a6SAndroid Build Coastguard Worker 137*523fa7a6SAndroid Build Coastguard Workerdef set_mapping_for_op(op: OpOverload) -> None: 138*523fa7a6SAndroid Build Coastguard Worker """ 139*523fa7a6SAndroid Build Coastguard Worker op can either be a functional op, mutable op, or out variant op. 140*523fa7a6SAndroid Build Coastguard Worker This method is only called if 141*523fa7a6SAndroid Build Coastguard Worker 1. either op is a functional op and it's missing in the _func_to_out_variant_map cache. 142*523fa7a6SAndroid Build Coastguard Worker 2. or op is a out variant op and it's missing in the _out_variant_to_scratch_map cache. 143*523fa7a6SAndroid Build Coastguard Worker 144*523fa7a6SAndroid Build Coastguard Worker Setup entries in _func_to_out_variant_map and _out_variant_to_scratch_map for all ops sharing the same 145*523fa7a6SAndroid Build Coastguard Worker op name as the passed in OpOverload. 146*523fa7a6SAndroid Build Coastguard Worker """ 147*523fa7a6SAndroid Build Coastguard Worker native_schema = _pybind_schema_to_native_schema(op._schema) 148*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[16]: `Optional` has no attribute `kind`. 149*523fa7a6SAndroid Build Coastguard Worker assert native_schema.kind() in ( 150*523fa7a6SAndroid Build Coastguard Worker SchemaKind.functional, 151*523fa7a6SAndroid Build Coastguard Worker SchemaKind.out, 152*523fa7a6SAndroid Build Coastguard Worker SchemaKind.mutable, 153*523fa7a6SAndroid Build Coastguard Worker ) 154*523fa7a6SAndroid Build Coastguard Worker assert not ( 155*523fa7a6SAndroid Build Coastguard Worker native_schema.kind() == SchemaKind.functional and op in _func_to_out_variant_map 156*523fa7a6SAndroid Build Coastguard Worker ) 157*523fa7a6SAndroid Build Coastguard Worker assert not ( 158*523fa7a6SAndroid Build Coastguard Worker native_schema.kind() == SchemaKind.out and op in _out_variant_to_scratch_map 159*523fa7a6SAndroid Build Coastguard Worker ) 160*523fa7a6SAndroid Build Coastguard Worker assert not ( 161*523fa7a6SAndroid Build Coastguard Worker native_schema.kind() == SchemaKind.mutable and op in _mutable_to_out_variant_map 162*523fa7a6SAndroid Build Coastguard Worker ) 163*523fa7a6SAndroid Build Coastguard Worker qualified_opname = str(op._schema.name) 164*523fa7a6SAndroid Build Coastguard Worker 165*523fa7a6SAndroid Build Coastguard Worker all_schemas = [ 166*523fa7a6SAndroid Build Coastguard Worker _pybind_schema_to_native_schema(pybind_schema) 167*523fa7a6SAndroid Build Coastguard Worker for pybind_schema in torch._C._jit_get_schemas_for_operator(qualified_opname) 168*523fa7a6SAndroid Build Coastguard Worker ] 169*523fa7a6SAndroid Build Coastguard Worker 170*523fa7a6SAndroid Build Coastguard Worker # skip the schema that we can not be parsed by torchgen 171*523fa7a6SAndroid Build Coastguard Worker all_schemas = [schema for schema in all_schemas if schema is not None] 172*523fa7a6SAndroid Build Coastguard Worker 173*523fa7a6SAndroid Build Coastguard Worker group_by_signature: Dict[str, Dict[SchemaKind, FunctionSchema]] = {} 174*523fa7a6SAndroid Build Coastguard Worker 175*523fa7a6SAndroid Build Coastguard Worker for schema in all_schemas: 176*523fa7a6SAndroid Build Coastguard Worker signature = schema.signature() 177*523fa7a6SAndroid Build Coastguard Worker # override the return type to an empty tuple. Otherwise, for ops like 178*523fa7a6SAndroid Build Coastguard Worker # aten.slice.Tensor_out that returns a Tensor list, 179*523fa7a6SAndroid Build Coastguard Worker # the signature of the schema does not match the one for the functional 180*523fa7a6SAndroid Build Coastguard Worker # op aten.slice.Tensor because of different return type. 181*523fa7a6SAndroid Build Coastguard Worker # Schema for aten.slice.Tensor_out: 182*523fa7a6SAndroid Build Coastguard Worker # split.Tensor_out(Tensor(a -> *) self, int split_size, int dim=0, *, Tensor(a!)[] out) -> () 183*523fa7a6SAndroid Build Coastguard Worker # Schema for aten.slice.Tensor 184*523fa7a6SAndroid Build Coastguard Worker # split.Tensor(Tensor(a -> *) self, int split_size, int dim=0) -> Tensor(a)[] 185*523fa7a6SAndroid Build Coastguard Worker # The reason of the above inconsistency is explained in: https://github.com/pytorch/pytorch/pull/76049 186*523fa7a6SAndroid Build Coastguard Worker signature = dataclasses.replace(signature, returns=()) 187*523fa7a6SAndroid Build Coastguard Worker 188*523fa7a6SAndroid Build Coastguard Worker kind = schema.kind() 189*523fa7a6SAndroid Build Coastguard Worker # pyre-fixme[6]: For 1st argument expected `str` but got `FunctionSchema`. 190*523fa7a6SAndroid Build Coastguard Worker group_by_kind = group_by_signature.setdefault(signature, {}) 191*523fa7a6SAndroid Build Coastguard Worker assert ( 192*523fa7a6SAndroid Build Coastguard Worker kind not in group_by_kind 193*523fa7a6SAndroid Build Coastguard Worker ), f"Schema of kind {kind} already exist for {schema}" 194*523fa7a6SAndroid Build Coastguard Worker group_by_kind[kind] = schema 195*523fa7a6SAndroid Build Coastguard Worker 196*523fa7a6SAndroid Build Coastguard Worker # add all the functional op -> out variant op pairs to the cache 197*523fa7a6SAndroid Build Coastguard Worker for group_by_kind in group_by_signature.values(): 198*523fa7a6SAndroid Build Coastguard Worker func_op_schema = group_by_kind.get(SchemaKind.functional) 199*523fa7a6SAndroid Build Coastguard Worker out_var_schema = group_by_kind.get(SchemaKind.out) 200*523fa7a6SAndroid Build Coastguard Worker mutable_op_schema = group_by_kind.get(SchemaKind.mutable) 201*523fa7a6SAndroid Build Coastguard Worker scratch_schema = group_by_kind.get(SchemaKind.scratch) 202*523fa7a6SAndroid Build Coastguard Worker 203*523fa7a6SAndroid Build Coastguard Worker # update the map even if out_var_schema is None to cache the negative 204*523fa7a6SAndroid Build Coastguard Worker # case 205*523fa7a6SAndroid Build Coastguard Worker if func_op_schema: 206*523fa7a6SAndroid Build Coastguard Worker _func_to_out_variant_map[schema_to_opoverload(func_op_schema)] = ( 207*523fa7a6SAndroid Build Coastguard Worker schema_to_opoverload(out_var_schema) if out_var_schema else None 208*523fa7a6SAndroid Build Coastguard Worker ) 209*523fa7a6SAndroid Build Coastguard Worker # out variant schema missing from group_by_kind 210*523fa7a6SAndroid Build Coastguard Worker if out_var_schema is None: 211*523fa7a6SAndroid Build Coastguard Worker # find the out variant with a schema different than the functional variant 212*523fa7a6SAndroid Build Coastguard Worker mismatched_out_schema: Optional[FunctionSchema] = next( 213*523fa7a6SAndroid Build Coastguard Worker (s for s in all_schemas if s.kind() == SchemaKind.out), None 214*523fa7a6SAndroid Build Coastguard Worker ) 215*523fa7a6SAndroid Build Coastguard Worker _schema_mismatch_map[schema_to_opoverload(func_op_schema)] = ( 216*523fa7a6SAndroid Build Coastguard Worker mismatched_out_schema 217*523fa7a6SAndroid Build Coastguard Worker ) 218*523fa7a6SAndroid Build Coastguard Worker 219*523fa7a6SAndroid Build Coastguard Worker # update hte map even if scratch_schema is None to cache the negative 220*523fa7a6SAndroid Build Coastguard Worker # case 221*523fa7a6SAndroid Build Coastguard Worker if out_var_schema: 222*523fa7a6SAndroid Build Coastguard Worker _out_variant_to_scratch_map[schema_to_opoverload(out_var_schema)] = ( 223*523fa7a6SAndroid Build Coastguard Worker schema_to_opoverload(scratch_schema) if scratch_schema else None 224*523fa7a6SAndroid Build Coastguard Worker ) 225*523fa7a6SAndroid Build Coastguard Worker if mutable_op_schema: 226*523fa7a6SAndroid Build Coastguard Worker _mutable_to_out_variant_map[schema_to_opoverload(mutable_op_schema)] = ( 227*523fa7a6SAndroid Build Coastguard Worker schema_to_opoverload(out_var_schema) if out_var_schema else None 228*523fa7a6SAndroid Build Coastguard Worker ) 229*523fa7a6SAndroid Build Coastguard Worker 230*523fa7a6SAndroid Build Coastguard Worker 231*523fa7a6SAndroid Build Coastguard Workerdef to_out_variant(op_overload: OpOverload) -> Tuple[OpOverload, Tuple[str]]: 232*523fa7a6SAndroid Build Coastguard Worker r""" 233*523fa7a6SAndroid Build Coastguard Worker Convert the passed in OpOverload to its out variant. Raise an exception if 234*523fa7a6SAndroid Build Coastguard Worker on return the op_overload is not guaranteed to be an out variant. 235*523fa7a6SAndroid Build Coastguard Worker 236*523fa7a6SAndroid Build Coastguard Worker If a conversion is found, return the out variant OpOverload alongwith the name of out 237*523fa7a6SAndroid Build Coastguard Worker arguments. 238*523fa7a6SAndroid Build Coastguard Worker """ 239*523fa7a6SAndroid Build Coastguard Worker schema = _get_overload_schema(op_overload) 240*523fa7a6SAndroid Build Coastguard Worker if schema.is_out_fn(): # pyre-ignore 241*523fa7a6SAndroid Build Coastguard Worker return op_overload, get_out_args_from_schema(schema) # pyre-ignore[6] 242*523fa7a6SAndroid Build Coastguard Worker 243*523fa7a6SAndroid Build Coastguard Worker # should be a functionalish op here 244*523fa7a6SAndroid Build Coastguard Worker assert ( 245*523fa7a6SAndroid Build Coastguard Worker schema.kind() == SchemaKind.functional # pyre-ignore[16] 246*523fa7a6SAndroid Build Coastguard Worker or schema.kind() == SchemaKind.mutable 247*523fa7a6SAndroid Build Coastguard Worker ), f"Expect a functionalish op, but get {schema.kind()} {schema}" 248*523fa7a6SAndroid Build Coastguard Worker 249*523fa7a6SAndroid Build Coastguard Worker if ( 250*523fa7a6SAndroid Build Coastguard Worker op_overload not in _func_to_out_variant_map 251*523fa7a6SAndroid Build Coastguard Worker and op_overload not in _mutable_to_out_variant_map 252*523fa7a6SAndroid Build Coastguard Worker ): 253*523fa7a6SAndroid Build Coastguard Worker # setup out_var 254*523fa7a6SAndroid Build Coastguard Worker set_mapping_for_op(op_overload) 255*523fa7a6SAndroid Build Coastguard Worker 256*523fa7a6SAndroid Build Coastguard Worker if op_overload in _mutable_to_out_variant_map: 257*523fa7a6SAndroid Build Coastguard Worker out_var = _mutable_to_out_variant_map[op_overload] 258*523fa7a6SAndroid Build Coastguard Worker else: 259*523fa7a6SAndroid Build Coastguard Worker out_var = _func_to_out_variant_map.get(op_overload) 260*523fa7a6SAndroid Build Coastguard Worker 261*523fa7a6SAndroid Build Coastguard Worker if not out_var: 262*523fa7a6SAndroid Build Coastguard Worker msg = f"Missing out variant for functional op: {schema} . Make sure you have loaded your custom operator library for compiler. E.g., custom_ops_generated_lib" 263*523fa7a6SAndroid Build Coastguard Worker if op_overload in _schema_mismatch_map: 264*523fa7a6SAndroid Build Coastguard Worker if _schema_mismatch_map[op_overload]: 265*523fa7a6SAndroid Build Coastguard Worker msg += ( 266*523fa7a6SAndroid Build Coastguard Worker f"\nFound an out variant for operator name {op_overload.name()} but its schema mismatched with functional op." 267*523fa7a6SAndroid Build Coastguard Worker f"\nfunctional op schema:\t{schema}" 268*523fa7a6SAndroid Build Coastguard Worker f"\nout variant op schema:\t{_schema_mismatch_map[op_overload]}" 269*523fa7a6SAndroid Build Coastguard Worker ) 270*523fa7a6SAndroid Build Coastguard Worker raise RuntimeError(msg) 271*523fa7a6SAndroid Build Coastguard Worker 272*523fa7a6SAndroid Build Coastguard Worker return out_var, get_out_args_from_opoverload(out_var) 273*523fa7a6SAndroid Build Coastguard Worker 274*523fa7a6SAndroid Build Coastguard Worker 275*523fa7a6SAndroid Build Coastguard Workerdef to_scratch_op(op_overload: OpOverload) -> Optional[OpOverload]: 276*523fa7a6SAndroid Build Coastguard Worker schema = _get_overload_schema(op_overload) 277*523fa7a6SAndroid Build Coastguard Worker 278*523fa7a6SAndroid Build Coastguard Worker # If the op is not an out variant, then we must have ignored some failure in to_out_var 279*523fa7a6SAndroid Build Coastguard Worker # pass. Return immediately rather than throwing an exception since the user must have ignores 280*523fa7a6SAndroid Build Coastguard Worker # errors for some reason (e.g. desigin some special unit tests, or unblock new 281*523fa7a6SAndroid Build Coastguard Worker # use cases). 282*523fa7a6SAndroid Build Coastguard Worker if schema.kind() != SchemaKind.out: # pyre-ignore 283*523fa7a6SAndroid Build Coastguard Worker logging.debug(f"Expect an out variant op as input, got: {schema.kind()}") 284*523fa7a6SAndroid Build Coastguard Worker return None 285*523fa7a6SAndroid Build Coastguard Worker 286*523fa7a6SAndroid Build Coastguard Worker if op_overload not in _out_variant_to_scratch_map: 287*523fa7a6SAndroid Build Coastguard Worker set_mapping_for_op(op_overload) 288*523fa7a6SAndroid Build Coastguard Worker scratch_op = _out_variant_to_scratch_map.get(op_overload) 289*523fa7a6SAndroid Build Coastguard Worker 290*523fa7a6SAndroid Build Coastguard Worker # scratch_op can be None 291*523fa7a6SAndroid Build Coastguard Worker return scratch_op 292*523fa7a6SAndroid Build Coastguard Worker 293*523fa7a6SAndroid Build Coastguard Worker 294*523fa7a6SAndroid Build Coastguard Workerdef is_out_variant(qualified_opname: str, overload: str) -> bool: 295*523fa7a6SAndroid Build Coastguard Worker op_overload = get_op_overload(qualified_opname, overload) 296*523fa7a6SAndroid Build Coastguard Worker schema = _get_overload_schema(op_overload) 297*523fa7a6SAndroid Build Coastguard Worker if schema is None: 298*523fa7a6SAndroid Build Coastguard Worker return False 299*523fa7a6SAndroid Build Coastguard Worker return schema.is_out_fn() 300*523fa7a6SAndroid Build Coastguard Worker 301*523fa7a6SAndroid Build Coastguard Worker 302*523fa7a6SAndroid Build Coastguard Workerdef is_inplace_variant(qualified_opname: str, overload: str) -> bool: 303*523fa7a6SAndroid Build Coastguard Worker op_overload = get_op_overload(qualified_opname, overload) 304*523fa7a6SAndroid Build Coastguard Worker schema = _get_overload_schema(op_overload) 305*523fa7a6SAndroid Build Coastguard Worker if schema is None: 306*523fa7a6SAndroid Build Coastguard Worker return False 307*523fa7a6SAndroid Build Coastguard Worker return schema.kind() == SchemaKind.inplace 308