1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from collections import defaultdict 8from functools import partial 9from typing import Any, Dict, List, Optional, Set, Tuple 10 11import torch 12from executorch.exir.dialects.edge.arg.model import BaseArg, BaseKwarg, Return 13from executorch.exir.dialects.edge.op.sample_input import SAMPLE_INPUT 14 15from torch._ops import OpOverloadPacket 16 17 18def is_tensor_arg(arg: torch._C.Argument) -> bool: 19 """Check if a given argument is a Tensor argument.""" 20 arg_type_str: str = str(arg.type) 21 return arg_type_str in ["Tensor", "Optional[Tensor]", "List[Tensor]"] 22 23 24def is_tensor_val(arg: Any) -> bool: 25 """Check if a given value is a Tensor-like value. 26 Please make sure the legal value in this function should be same as is_tensor_arg""" 27 if isinstance(arg, torch.Tensor): 28 return True 29 if isinstance(arg, list) or isinstance(arg, tuple): 30 return len(arg) > 0 and all(isinstance(v, torch.Tensor) for v in arg) 31 return False 32 33 34def is_essential_tensor_arg(arg: torch._C.Argument) -> bool: 35 """Check if a given argument is a Tensor argument.""" 36 arg_type_str: str = str(arg.type) 37 return arg_type_str in ["Tensor", "List[Tensor]"] 38 39 40def is_optional_tensor_arg(arg: torch._C.Argument) -> bool: 41 """Check if a given argument is a Tensor argument.""" 42 arg_type_str: str = str(arg.type) 43 return arg_type_str in ["Optional[Tensor]"] 44 45 46def get_tensor_variable_names( 47 func_schema: torch._C.FunctionSchema, 48) -> Tuple[List[str], List[str], List[str]]: 49 """Get names of essential tensor variables, optional tensor 50 variables and all tensor variables from given function schema. 51 The tensor variables here include both input tensors and output tensors.""" 52 53 essential_tensor_arg_names: List[str] = [ 54 arg.name for arg in func_schema.arguments if is_essential_tensor_arg(arg) 55 ] 56 optional_tensor_arg_names: List[str] = [ 57 arg.name for arg in func_schema.arguments if is_optional_tensor_arg(arg) 58 ] 59 all_tensor_arg_names: List[str] = [ 60 arg.name for arg in func_schema.arguments if is_tensor_arg(arg) 61 ] 62 63 return_tensor_variable_names: List[str] = [] 64 65 ret_name_base = "__ret_" 66 ret_id = 0 67 for ret in func_schema.returns: 68 name = ret.name if ret.name else f"{ret_name_base}{ret_id}" 69 if is_tensor_arg(ret): 70 return_tensor_variable_names.append(name) 71 ret_id += 1 72 return ( 73 essential_tensor_arg_names + return_tensor_variable_names, 74 optional_tensor_arg_names, 75 all_tensor_arg_names + return_tensor_variable_names, 76 ) 77 78 79def get_args_rets(op_name: str) -> List[BaseArg]: 80 args_rets: List[BaseArg] = [] 81 args_rets.extend(SAMPLE_INPUT[op_name].get("args", [])) 82 args_rets.extend(SAMPLE_INPUT[op_name].get("returns", [])) 83 return args_rets 84 85 86def get_names_for_args_with_dtype( 87 op_name: str, func_schema: torch._C.FunctionSchema 88) -> List[str]: 89 """Dtype runner is returning dtypes for more arguments than edge dialect cares about. 90 This function returns a list of booleans to select the dtypes matter to edge dialect. 91 """ 92 args_rets: List[BaseArg] = get_args_rets(op_name) 93 names = [] 94 arguments, returns = func_schema.arguments, func_schema.returns 95 args, kwargs, rets = [], [], [] 96 for arg in args_rets: 97 if isinstance(arg, Return): 98 rets.append(arg) 99 elif isinstance(arg, BaseKwarg): 100 kwargs.append(arg) 101 else: 102 args.append(arg) 103 names.extend( 104 [ 105 schema.name 106 for sample, schema in zip(args, arguments) 107 if sample.type.has_dtype() 108 ] 109 ) 110 names.extend([sample.argname for sample in kwargs if sample.type.has_dtype()]) 111 ret_name_base = "__ret_" 112 for ret_id, (_, schema) in enumerate(zip(rets, returns)): 113 names.append(schema.name if schema.name else f"{ret_name_base}{ret_id}") 114 return names 115 116 117def get_torch_op_overload( 118 namespace: str, opname: str, overload: Optional[str] 119) -> torch._ops.OpOverload: 120 packet: OpOverloadPacket = getattr(getattr(torch.ops, namespace), opname) 121 if overload: 122 return getattr(packet, overload) 123 else: 124 return packet.default 125 126 127def group_by_format( 128 all_combinations: Set[Tuple[str]], 129) -> List[Tuple[int, Tuple[Tuple[str]]]]: 130 """Taking combinations that having same format of all_combinations as a group. 131 Two combinations having same format here means one and only one of their 132 corresponding input tensors is different. e.g. {Tensor(0), Tensor(0), Tensor(1)} 133 shares same format with {Tensor(0), Tensor(0), Tensor(0)}, 134 but not {Tensor(0), Tensor(1), Tensor(0)}. 135 """ 136 137 grouped_combinations: Set[Tuple[int, Tuple[Tuple[str]]]] = set() 138 139 def almost_same_except(b: Tuple[str], combination: Tuple[str], index: int): 140 """Check if a and b share same format""" 141 for i, (aa, bb) in enumerate(zip(combination, b)): 142 if (i == index and aa == bb) or (i != index and aa != bb): 143 return False 144 return True 145 146 for combination in all_combinations: 147 # filter out combinations that only differ at index 148 has_same_comb: bool = False 149 for index in range(len(combination)): 150 filtered: Set[Tuple[str]] = set() 151 filtered.add(combination) 152 combo_filter = partial( 153 almost_same_except, combination=combination, index=index 154 ) 155 filtered.update(set(filter(combo_filter, all_combinations))) 156 if len(filtered) > 1: 157 has_same_comb = True 158 grouped_combinations.add((index, tuple(sorted(filtered)))) 159 if not has_same_comb: 160 grouped_combinations.add((-1, (combination,))) 161 return list(grouped_combinations) 162 163 164def update_type_alias(type_alias: Dict[Tuple[str], int], new_key: Tuple[str]) -> None: 165 """Update type_alias with new type alias""" 166 if new_key not in type_alias: 167 type_alias[new_key] = len(type_alias) 168 169 170def gen_index_pairs_to_types_mapping( 171 type_alias: Dict[Tuple[str], int], type_constraint: List[List[int]] 172) -> Dict[Tuple[int], List[str]]: 173 """Generate mapping from index pairs to types. For example, given type_constraint [0, 0], [1, 1] 174 type_alias ('Double',): 0, ('Int',): 1, output will be {(0, 1): ['Double', 'Int', 'Double', 'Int']}. 175 """ 176 177 def gen(x: List[int]): 178 """Generate all possible pairs of elements in the list.""" 179 for i in range(len(x) - 1): 180 for j in range(i + 1, len(x)): 181 yield (x[i], x[j]) 182 183 reverse: Dict[Tuple[int], Set[str]] = defaultdict(set) 184 for constraint in type_constraint: 185 # collect indices of elements with the same value. Value is a list of indices. 186 positions: Dict[int, List[int]] = defaultdict(list) 187 for i, val in enumerate(constraint): 188 positions[val].append(i) 189 for key, val in positions.items(): 190 # key is type_alias value which is alias index 191 alias = next(k for k, v in type_alias.items() if v == key) 192 # only care about pairs for now. Key to reverse is the pair of indices where elements are the same. Value is the list of types. 193 for pair in gen(val): 194 reverse[pair].update(alias) 195 return {k: sorted(v) for k, v in reverse.items()} 196 197 198def check_new_alias_fit_constraints( 199 type_alias: Dict[Tuple[str], int], 200 type_constraint: List[List[int]], 201 new_alias: Tuple[str], 202) -> bool: 203 """Check whether new type alias fits the existing constraints. 204 For example, for existing aliases ('Float'): 0, ('Int'): 1, a new alias of ('Float, Int') and type_constraint is [[0, 0]] 205 This new alias doesn't fit because we need [[0, 0], [0, 1]] to be satisfied. 206 """ 207 constraint_set: Set[Tuple[int]] = { 208 tuple(constraint) for constraint in type_constraint 209 } 210 length = len(type_constraint[0]) 211 subset: Set[Tuple[int]] = { 212 tuple([type_alias[(type_info,)]] * length) for type_info in new_alias 213 } 214 return subset.issubset(constraint_set) 215 216 217def aggregate_if_two_types_being_the_same( 218 type_alias: Dict[Tuple[str], int], type_constraint: List[List[int]] 219) -> Tuple[List[Tuple[str]], List[Tuple[int]]]: 220 """aggregate the type constraints that has two types being the same, at the same position. 221 For example, [0, 0] and [1, 1] where ('Double',): 0, ('Int',): 1 can be aggregated into 222 [2, 2] where ('Double', 'Int'): 3. 223 """ 224 225 reverse: Dict[Tuple[int], List[str]] = gen_index_pairs_to_types_mapping( 226 type_alias, type_constraint 227 ) 228 229 idx_to_update: Set[int] = set() 230 for alias in reverse.values(): 231 alias_tuple = tuple(alias) 232 if alias_tuple in type_alias or not check_new_alias_fit_constraints( 233 type_alias, type_constraint, alias_tuple 234 ): 235 continue 236 idx_to_update.update( 237 v for k, v in type_alias.items() if {*k}.issubset({*alias_tuple}) 238 ) 239 # update type_alias to include new type alias. 240 type_alias[alias_tuple] = len(type_alias) 241 # replace indices within alias to be new alias index. 242 for i in range(len(type_constraint)): 243 for j, a in enumerate(type_constraint[i]): 244 if a in idx_to_update: 245 type_constraint[i][j] = type_alias[alias_tuple] 246 247 # remove unused aliases 248 type_alias = {k: v for k, v in type_alias.items() if v not in idx_to_update} 249 sorted_keys = sorted(type_alias.keys()) 250 # map indices back to start from 0 contiguous 251 index_map = {type_alias[sorted_keys[i]]: i for i in range(len(sorted_keys))} 252 # remove duplicate constraints 253 constraint_set: Set[Tuple[int]] = { 254 tuple(index_map[i] for i in c) for c in type_constraint 255 } 256 257 return list(sorted_keys), sorted(constraint_set) 258 259 260def aggregate_grouped_type_combinations( 261 grouped_combinations: List[Tuple[int, Tuple[Tuple[str]]]], 262) -> Tuple[Dict[Tuple[str], int], List[List[int]]]: 263 """Aggregate grouped type combinations.""" 264 type_alias: Dict[Tuple[str], int] = {} 265 type_constraint: List[List[int]] = [] 266 for distinct_id, same_format_combinations in grouped_combinations: 267 comb_iter = iter(same_format_combinations) 268 if len(same_format_combinations) == 1: 269 # can not combine with others; each type in the comb is am individual type alias. 270 comb: Tuple[str] = next(comb_iter) 271 temp_type_constraint: List[int] = [] 272 for type_str in comb: 273 update_type_alias(type_alias, (type_str,)) 274 temp_type_constraint.append(type_alias[(type_str,)]) 275 type_constraint.append(temp_type_constraint) 276 else: 277 # gather different types in each combinations together as a list 278 # make the list as a separate type alias 279 all_distinct_types: Tuple[str] = tuple( 280 sorted({sf_comb[distinct_id] for sf_comb in same_format_combinations}) 281 ) 282 283 update_type_alias(type_alias, all_distinct_types) 284 285 comb: Tuple[str] = next(comb_iter) 286 temp_type_constraint: List[int] = [] 287 # assign each type of the format to a single type alias 288 for i, type_str in enumerate(comb): 289 if i == distinct_id: 290 temp_type_constraint.append(type_alias[all_distinct_types]) 291 else: 292 update_type_alias(type_alias, (type_str,)) 293 temp_type_constraint.append(type_alias[(type_str,)]) 294 295 type_constraint.append(temp_type_constraint) 296 return type_alias, type_constraint 297 298 299def type_aggregrate( 300 allow_types: Set[Tuple[str]], 301) -> Tuple[List[Tuple[str]], List[Tuple[int]]]: 302 """ 303 This function aims to aggreate the enumerate combinations of supported types into type alias format. 304 E.g. input: [["Float", "Float", "Float"], ["Half", "Half", "Half"], ["Char", "Char", "Int"]] 305 output: [["Float", "Half"], ["Char"], ["Int"]], [[0, 0, 0], [1, 1, 2]] 306 307 for i-dx list in the type_constraint, any j in [0, len(self.tensor_variable_names)) self.tensor_variable_names[j], 308 can be in one of the types in type_alias[type_constraint[i][j]]; also self.tensor_variable_names[k] and 309 self.tensor_variable_names[l] shoule be same if type_constraint[i][k] == type_constraint[i][l]. 310 311 NOTE: This is not the optimum way to aggregate types. It generates correct but not the optimum representation. 312 TODO(gasoonjia): continue update aggregrate algorithm. 313 """ 314 315 # group combinations with the same format 316 grouped_combinations: List[Tuple[int, Tuple[Tuple[str]]]] = group_by_format( 317 allow_types 318 ) 319 320 type_alias, type_constraint = aggregate_grouped_type_combinations( 321 grouped_combinations 322 ) 323 324 sorted_type_alias, sorted_type_constraint = aggregate_if_two_types_being_the_same( 325 type_alias, type_constraint 326 ) 327 328 return sorted_type_alias, sorted_type_constraint 329