1from typing import Any, Optional, Tuple, Union 2 3from torchgen.model import ( 4 Annotation, 5 Argument, 6 Arguments, 7 BaseOperatorName, 8 BaseTy, 9 BaseType, 10 CustomClassType, 11 FunctionSchema, 12 ListType, 13 OperatorName, 14 Return, 15) 16 17 18# Note: These aren't actually used in torchgen, they're some utilities for generating a schema 19# from real arguments. For example, this is used to generate HigherOrderOperators' schema since 20# their schemas can vary for different instances of the same HOP. 21 22 23class TypeGen: 24 convert_to_base_ty = { 25 int: BaseTy.int, 26 float: BaseTy.float, 27 str: BaseTy.str, 28 bool: BaseTy.bool, 29 } 30 31 @staticmethod 32 def from_example(obj: Any) -> Union[BaseType, ListType, CustomClassType]: 33 import torch 34 35 if isinstance(obj, torch.fx.GraphModule): 36 return BaseType(BaseTy.GraphModule) 37 elif isinstance(obj, torch.Tensor): 38 return BaseType(BaseTy.Tensor) 39 elif isinstance(obj, torch.SymInt): 40 return BaseType(BaseTy.SymInt) 41 elif isinstance(obj, torch.SymBool): 42 return BaseType(BaseTy.SymBool) 43 elif isinstance(obj, torch.ScriptObject): 44 return CustomClassType(obj._type().name()) # type: ignore[attr-defined] 45 elif isinstance(obj, (list, tuple)): 46 assert len(obj) > 0 47 all_base_tys = [TypeGen.from_example(x) for x in obj] 48 if len(set(all_base_tys)) > 1: 49 raise RuntimeError( 50 f"Cannot generate schema for a seqeunce of args of heterogeneous types: {all_base_tys}. " 51 "Consider unpacking the argument and give proper names to them if possible " 52 "instead of using *args." 53 ) 54 return ListType(all_base_tys[0], len(obj)) 55 tp = type(obj) 56 if tp not in TypeGen.convert_to_base_ty: 57 raise RuntimeError(f"unsupported type {tp}") 58 return BaseType(TypeGen.convert_to_base_ty[tp]) 59 60 61class ReturnGen: 62 @staticmethod 63 def from_example( 64 name: Optional[str], obj: Any, annotation: Optional[Annotation] 65 ) -> Return: 66 return Return(name, TypeGen.from_example(obj), annotation) 67 68 69class ArgumentGen: 70 @staticmethod 71 def from_example( 72 name: str, obj: Any, default: Optional[str], annotation: Optional[Annotation] 73 ) -> Argument: 74 return Argument( 75 name, TypeGen.from_example(obj), default=default, annotation=annotation 76 ) 77 78 79class FunctionSchemaGen: 80 @staticmethod 81 def from_example( 82 op_name: str, 83 example_inputs: Tuple[Tuple[str, Any], ...], 84 example_outputs: Tuple[Any, ...], 85 ) -> FunctionSchema: 86 args = [] 87 for name, inp in example_inputs: 88 args.append(ArgumentGen.from_example(name, inp, None, None)) 89 # ignore the annotations and other attributes for now, we could add more when needed. 90 arguments = Arguments( 91 tuple(), None, tuple(args), tuple(), None, tuple(), tuple() 92 ) 93 returns = tuple( 94 ReturnGen.from_example(None, out, None) for out in example_outputs 95 ) 96 op_name = OperatorName(BaseOperatorName(op_name, False, False, False), "") 97 return FunctionSchema(op_name, arguments, returns) 98