1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import Any, Dict, List, Optional, Set, Union 8 9import pkg_resources 10 11import torch 12 13from executorch.exir.dialects.edge.dtype.supported import regular_tensor_str_to_dtypes 14from executorch.exir.dialects.edge.op.api import to_variant 15from executorch.exir.dialects.edge.spec.utils import get_tensor_variable_names 16 17# pyre-ignore 18from ruamel.yaml import YAML 19from torchgen.model import SchemaKind 20 21 22class AllowedDtypeSet: 23 """All legal dtypes for current type alias. 24 25 This class is a wrapper of Set[torch.dtype]. Normally it is a set of all legal types listed in 26 edge/edge.yaml file for each type alias. If one of the argument under the type alias receiving 27 its actual type, AllowedDtypeSet will be degenerated to the set of only the actual type. 28 29 TODO(gasoonjia): Prevent users from misusing. 30 31 Public Attributes: 32 types: a set of all allowed dtypes listed in edge/edge.yaml. 33 34 Private Attributes: 35 _reduced_type: the actual type this type alias currently represents. 0 means unrestricted, 36 each type in self.types is legal. 37 38 """ 39 40 def __init__(self, types: Set[torch.dtype]): 41 self.types: Set[torch.dtype] = types 42 self._reduced_type: Union[torch.dtype, int] = 0 43 44 def reduce_to(self, t: torch.dtype) -> bool: 45 """Reduce the legal dtype to given t. 46 t must be a legal type for this type alias. 47 48 return True if reduction succeed; otherwise False. 49 """ 50 if self.__contains__(t): 51 self._reduced_type = t 52 return True 53 else: 54 return False 55 56 def clear(self): 57 """Derestrict AllowedDtypeSet to all allowed dtypes in yaml.""" 58 self._reduced_type = 0 59 60 def __contains__(self, key: torch.dtype): 61 """Check if key is a legal type of this type alias""" 62 if self._reduced_type: 63 return key == self._reduced_type 64 return key in self.types 65 66 67class FunctionDtypeConstraint: 68 """Dtype constraint for each EdgeDialect ops. 69 70 Arguments: 71 essential_tensor_io_names: All names of essential tensor inputs and outputs. 72 optional_tensor_io_names: All names of optional tensor inputs. 73 type_alias: Dict of type alias name to corresponding list of dtypes. 74 type_constraint: List of dict containing dtype constraint represented in type alias for each arg name. 75 """ 76 77 def __init__( 78 self, 79 essential_tensor_io_names: List[str], 80 optional_tensor_io_names: List[str], 81 type_alias: Dict[str, List[torch.dtype]], 82 type_constraint: List[Dict[str, str]], 83 ): 84 self.essential_tensor_io_names: List[str] = essential_tensor_io_names 85 self.optional_tensor_io_names: List[str] = optional_tensor_io_names 86 self.type_alias: Dict[str, AllowedDtypeSet] = { 87 alias: AllowedDtypeSet(set(types)) for alias, types in type_alias.items() 88 } 89 self.type_constraint: List[Dict[str, str]] = type_constraint 90 # type_constraint's non return entries should include all tensor-like arguments. 91 for t_constraint in self.type_constraint: 92 type_constraint_names = set(t_constraint) 93 all_tensor_arg_names = set( 94 self.essential_tensor_io_names + self.optional_tensor_io_names 95 ) 96 if not all_tensor_arg_names.issubset(type_constraint_names): 97 raise RuntimeError( 98 "Input entries of type_constraint must contain all tensor-like arguments, " 99 + f"but get {type_constraint_names} and {all_tensor_arg_names}" 100 ) 101 102 def validate(self, types: Dict[str, Optional[torch.dtype]]) -> bool: 103 """Check if the given input type combination a legal one of current function. 104 105 Args: 106 types: A dict of arg name to its current dtype. 107 108 Returns: 109 True iff a. types are legal for current operator b. all arg name can be found 110 in current operator and c. input contains all essential tensor inputs; False otherwise. 111 112 The essential tensor inputs here mean non-optional inputs in tensor and tensor list. 113 """ 114 115 # Every arg name in `types` should be one of the tensor ios in current function. 116 for arg_name in types: 117 if not self.__contains__(arg_name): 118 return False 119 120 # Any essential tensor input should exist in current `type` input. 121 for io_name in self.essential_tensor_io_names: 122 if io_name not in types: 123 return False 124 125 valid_type = False 126 for constraint in self.type_constraint: 127 if valid_type: 128 break 129 130 valid_type = True 131 # Narrow down the type_alias based on contraint and actual input 132 for arg_name, arg_type in types.items(): 133 if arg_type is None: 134 # None means the user didn't set dtype for this argment 135 # (i.e. empty tensorlist), skipping the validation. 136 continue 137 elif arg_type in self.type_alias[constraint[arg_name]]: 138 self.type_alias[constraint[arg_name]].reduce_to(arg_type) 139 else: 140 valid_type = False 141 break 142 143 for alias in self.type_alias.values(): 144 alias.clear() 145 146 return valid_type 147 148 def __contains__(self, key: str): 149 return key in self.type_constraint[0] 150 151 def __getitem__(self, arg_name: str) -> Set[torch.dtype]: 152 """Return all legal types for given arg name. 153 Return all its legal type in a set, or an empty set if can not find 154 the arg_name in current function.""" 155 156 if arg_name not in self.type_constraint[0]: 157 return set() 158 159 valid_dtype: Set[torch.dtype] = set() 160 for constraint in self.type_constraint: 161 valid_dtype = self.type_alias[constraint[arg_name]].types | valid_dtype 162 163 return valid_dtype 164 165 166def _load_edge_dialect_info() -> Dict[str, Dict[str, Any]]: 167 # pyre-ignore 168 yaml = YAML(typ="safe") 169 edge_dialect_yaml_info = yaml.load( 170 pkg_resources.resource_string(__name__, "edge.yaml").decode("utf8") 171 ) 172 if edge_dialect_yaml_info: 173 return { 174 edge_op_yaml_info["inherits"]: edge_op_yaml_info 175 for edge_op_yaml_info in edge_dialect_yaml_info 176 } 177 else: 178 return {} 179 180 181_edge_dialect_info: Dict[str, Dict[str, Any]] = _load_edge_dialect_info() 182 183 184class EdgeDialectArgument: 185 """Argument class for EdgeDialect ops. 186 Wraps around torch._C.Argument with dtype constraints. 187 Redirects all `getattr` calls to torch._C.Argument. 188 """ 189 190 def __init__(self, argument: torch._C.Argument, allowed_types: Set[torch.dtype]): 191 self.argument = argument 192 self.allowed_types = allowed_types 193 194 def __getattr__(self, name): 195 if name == "allowed_types": # arg.allowed_types 196 return self.allowed_types 197 return getattr(self.argument, name) 198 199 200class EdgeDialectFunctionSchema: 201 """FunctionSchema class for EdgeDialect ops. 202 Wraps around torch._C.FunctionSchema with Tensor dtype constraints. 203 In constructor, walk through all Tensor arguments and returns in the original schema 204 for ATen operator, replace the argument with EdgeDialectArgument. 205 """ 206 207 def __init__( 208 self, 209 schema: torch._C.FunctionSchema, 210 ): 211 self.schema = schema 212 edge_op_full_name = schema.name + ( 213 ".{}".format(schema.overload_name) if schema.overload_name else "" 214 ) 215 216 ( 217 essential_tensor_io_names, 218 optional_tensor_io_names, 219 all_tensor_io_names, 220 ) = get_tensor_variable_names(self.schema) 221 222 if edge_op_full_name in _edge_dialect_info: 223 # Directly use the information from edge.yaml if available. 224 _edge_op_info = _edge_dialect_info[edge_op_full_name] 225 type_alias = { 226 alias: [regular_tensor_str_to_dtypes[t] for t in types] 227 for alias, types in _edge_op_info["type_alias"].items() 228 } 229 type_constraint = _edge_op_info["type_constraint"] 230 else: 231 # Not get the info from edge.yaml 232 # Create a dtype constraint for this operator that allows any dtype 233 # combinations as long as any dtype is legal in ExecuTorch. 234 type_alias = { 235 f"T{idx}": list(regular_tensor_str_to_dtypes.values()) 236 for idx in range(len(all_tensor_io_names)) 237 } 238 type_constraint = [ 239 {io_name: f"T{idx}" for idx, io_name in enumerate(all_tensor_io_names)} 240 ] 241 242 self.dtype_constraint = FunctionDtypeConstraint( 243 essential_tensor_io_names=essential_tensor_io_names, 244 optional_tensor_io_names=optional_tensor_io_names, 245 type_alias=type_alias, 246 type_constraint=type_constraint, 247 ) 248 249 arg_list: List[Union[torch._C.Argument, EdgeDialectArgument]] = [] 250 for argument in self.schema.arguments: 251 if argument.name in self.dtype_constraint: 252 arg_list.append( 253 EdgeDialectArgument( 254 argument, 255 self.dtype_constraint[argument.name], 256 ) 257 ) 258 else: 259 arg_list.append(argument) 260 self.arguments = arg_list 261 return_names = sorted( 262 n 263 for n in self.dtype_constraint.type_constraint[0].keys() 264 if n.startswith("__ret") 265 ) 266 ret_list: List[Union[torch._C.Argument, EdgeDialectArgument]] = [] 267 ret_iter = iter(return_names) 268 for ret in self.schema.returns: 269 if isinstance(ret.type, torch.TensorType): 270 name = next(ret_iter, None) 271 if name: 272 ret_list.append( 273 EdgeDialectArgument(ret, self.dtype_constraint[name]) 274 ) 275 continue 276 ret_list.append(ret) 277 self.returns = ret_list 278 279 def __getattr__(self, name): 280 if name == "arguments": 281 return self.arguments 282 if name == "returns": 283 return self.returns 284 if name == "dtype_constraint": 285 return self.dtype_constraint 286 return getattr(self.schema, name) 287 288 def __str__(self): 289 return str(self.schema) 290 291 292class EdgeOpOverload: 293 """OpOverload for edge ops. 294 Contains API to find the out variant of this operator overload. 295 """ 296 297 def __init__( 298 self, 299 op: torch._ops.OpOverload, 300 schema: EdgeDialectFunctionSchema, 301 ): 302 self._schema = schema 303 self._op = op 304 self.__name__ = f"{self.namespace}.{self._op.__name__}" 305 306 def to_out_variant(self) -> torch._ops.OpOverload: 307 """Find out the out-variant of this operator and return it. 308 TODO (larryliu): Implement execution dialect class and let this function return that. 309 This implementation assumes the out variant is available in torch.ops.*. 310 311 Raises: 312 RuntimeError: if we could't find the out variant, raise an exception. 313 TODO (larryliu): Catch this in BackendDialect and generate an operator definition 314 for missing out variant. 315 Returns: 316 torch._ops.OpOverload: The out-variant operator of self. 317 """ 318 319 # return if already found 320 if "_out_variant" in self.__dict__ and self._out_variant: 321 return self._out_variant 322 out_variant = to_variant(self._op, SchemaKind.out) 323 self._out_variant = out_variant 324 return out_variant 325 326 def __getattr__(self, name): 327 if name == "_schema": 328 return self._schema 329 else: 330 return getattr(self._op, name) 331 332 def __call__(self, *args, **kwargs): 333 return self._op(*args, **kwargs) 334 335 def __repr__(self): 336 return "<EdgeOpOverload: {}>: schema = {}".format( 337 self.__name__, self._schema.schema 338 ) 339 340 __str__ = __repr__ 341 342 343class EdgeOpOverloadPacket: 344 """OpOverloadPacket for edge ops. 345 Wraps torch._ops.OpOverloadPacket and overrides __getattr__ to return OpOverload 346 for Edge ops. The main difference between an Edge op and its corresponding ATen op 347 is that Edge op contains a different schema (see EdgeDialectFunctionSchema). 348 """ 349 350 def __init__( 351 self, 352 qualified_op_name: str, # e.g., edge::aten::add 353 op_name: str, 354 parent_overload_packet: torch._ops.OpOverloadPacket, 355 ): 356 self._parent_overload_packet = parent_overload_packet 357 self._parent_qualified_op_name = parent_overload_packet._qualified_op_name 358 self._qualified_op_name = qualified_op_name 359 self.__name__ = self._qualified_op_name.replace("::", ".") 360 self._op = parent_overload_packet._op 361 self._overload_names = parent_overload_packet._overload_names 362 self._dir = [] 363 364 def __repr__(self): 365 return "<EdgeOpOverloadPacket(op='{}', parent_op='{}')>".format( 366 self._qualified_op_name.replace("::", "."), 367 self._parent_qualified_op_name.replace("::", "."), 368 ) 369 370 def __hash__(self): 371 return hash(self._op) 372 373 def __str__(self): 374 return "{}".format(self._qualified_op_name.replace("::", ".")) 375 376 @property 377 def op(self): 378 return self._op 379 380 def __getattr__(self, key): 381 # It is not a valid op_name when __file__ is passed in 382 if key == "__file__": 383 return "exir.ops.edge" 384 try: 385 parent_overload = getattr(self._parent_overload_packet, key) 386 except AttributeError: 387 raise AttributeError( 388 "The underlying op of '{}' has no overload name '{}'".format( 389 str(self), key 390 ) 391 ) from None 392 393 edge_schema = EdgeDialectFunctionSchema( 394 parent_overload._schema, 395 ) # create a new schema based on parent op schema 396 overload = EdgeOpOverload( 397 parent_overload, 398 edge_schema, 399 ) 400 # cache the overload object 401 setattr(self, key, overload) 402 self._dir.append(key) 403 return overload 404 405 def __call__(self, *args, **kwargs): 406 return self._parent_overload_packet(*args, **kwargs or {}) 407