1# mypy: allow-untyped-defs 2import dataclasses 3import inspect 4import sys 5from typing import Any, Callable, Dict, Iterable, Tuple, Union 6 7import torch 8from torch import _C, _utils_internal 9from torch._ops import OpOverload 10 11 12@dataclasses.dataclass 13class Kernel: 14 """Models a (function, source location)""" 15 16 func: Callable 17 source: str 18 19 def __call__(self, *args, **kwargs): 20 return self.func(*args, **kwargs) 21 22 23class RegistrationHandle: 24 """Does something when someone calls .destroy() on it""" 25 26 def __init__(self, on_destroy: Callable): 27 self._on_destroy = on_destroy 28 29 def destroy(self) -> None: 30 self._on_destroy() 31 32 33def get_source(stacklevel: int) -> str: 34 """Get a string that represents the caller. 35 36 Example: "/path/to/foo.py:42" 37 38 Use stacklevel=1 to get the caller's source 39 Use stacklevel=2 to get the caller's caller's source 40 etc. 41 """ 42 frame = inspect.getframeinfo(sys._getframe(stacklevel)) 43 source = f"{frame.filename}:{frame.lineno}" 44 return source 45 46 47def parse_namespace(qualname: str) -> Tuple[str, str]: 48 splits = qualname.split("::") 49 if len(splits) != 2: 50 raise ValueError( 51 f"Expected `qualname` to be of the form " 52 f'"namespace::name", but got {qualname}. ' 53 f"The qualname passed to the torch.library APIs must consist " 54 f"of a namespace and a name, e.g. aten::sin" 55 ) 56 return splits[0], splits[1] 57 58 59def lookup_op(qualname: str) -> OpOverload: 60 namespace, name = parse_namespace(qualname) 61 if "." in name: 62 name, overload = name.split(".") 63 else: 64 overload = "default" 65 ns = getattr(torch.ops, namespace) 66 packet = getattr(ns, name) 67 return getattr(packet, overload) 68 69 70def is_builtin(op: OpOverload) -> bool: 71 assert isinstance(op, OpOverload) 72 return op.namespace in {"aten", "prim", "prims"} 73 74 75def is_functional_schema(schema: Any) -> bool: 76 """Check if the schema is functional. 77 78 An operator is functional if: 79 - it does not mutate any of its inputs 80 - it does not return a view on any of its inputs 81 - it has at least one return 82 """ 83 84 def is_functional(schema): 85 if schema.is_mutable: 86 return False 87 rets = schema.returns 88 is_non_mutating_view = len(rets) > 0 and any( 89 r.alias_info is not None and not r.alias_info.is_write for r in rets 90 ) 91 if is_non_mutating_view: 92 return False 93 if not schema.returns: 94 return False 95 return True 96 97 if isinstance(schema, torch._C.FunctionSchema): 98 return is_functional(schema) 99 100 # Lazy import because not all PyTorch builds have torchgen 101 from torchgen.model import FunctionSchema 102 103 if isinstance(schema, str): 104 schema = FunctionSchema.parse(schema) 105 assert isinstance(schema, FunctionSchema) 106 return is_functional(schema) 107 108 109# should be torch._C.JitType but that annotation is busted 110def is_tensorlist_like_type(typ: Any) -> bool: 111 return ( 112 typ == _C.ListType(_C.TensorType.get()) 113 or typ == _C.ListType(_C.OptionalType(_C.TensorType.get())) 114 or typ == _C.OptionalType(_C.ListType(_C.TensorType.get())) 115 or typ == _C.OptionalType(_C.ListType(_C.OptionalType(_C.TensorType.get()))) 116 ) 117 118 119# should be torch._C.JitType but that annotation is busted 120def is_tensor_like_type(typ: Any) -> bool: 121 return typ == _C.TensorType.get() or typ == _C.OptionalType(_C.TensorType.get()) 122 123 124def mutates_and_returns_first_arg(op: OpOverload): 125 """Check if an op is an inplace aten op, i.e. it mutates and returns the first arg. 126 127 TODO: torchgen/model.py's FunctionSchema.parse is the source of truth for this, 128 but not all PyTorch builds have torchgen (due to the yaml dependency being weird). 129 Figure this out. 130 131 Example: add_(Tensor(a!) x, Tensor y) -> Tensor(a) 132 """ 133 if op.namespace != "aten": 134 return False 135 schema = op._schema 136 if not len(schema.returns) == 1: 137 return False 138 if schema.returns[0].alias_info is None: 139 return False 140 alias_set = schema.returns[0].alias_info.after_set 141 if len(alias_set) != 1: 142 return False 143 loc = next(iter(alias_set)) 144 if len(schema.arguments) < 1: 145 return False 146 first_arg = schema.arguments[0] 147 if first_arg.alias_info is None: 148 return False 149 if not first_arg.alias_info.is_write: 150 return False 151 alias_set = first_arg.alias_info.after_set 152 if len(alias_set) != 1: 153 return False 154 if loc != next(iter(alias_set)): 155 return False 156 for arg in schema.arguments[1:]: 157 if arg.alias_info is not None: 158 return False 159 return True 160 161 162def fill_defaults(schema, args, kwargs): 163 new_args = [] 164 new_kwargs = {} 165 for i in range(len(schema.arguments)): 166 info = schema.arguments[i] 167 if info.kwarg_only: 168 if info.name in kwargs: 169 new_kwargs[info.name] = kwargs[info.name] 170 else: 171 new_kwargs[info.name] = info.default_value 172 else: 173 if i < len(args): 174 new_args.append(args[i]) 175 else: 176 new_args.append(info.default_value) 177 return tuple(new_args), new_kwargs 178 179 180def zip_schema( 181 schema: _C.FunctionSchema, args: Tuple[Any, ...], kwargs: Dict[str, Any] 182) -> Iterable[Tuple[_C.Argument, Any]]: 183 """zips schema.arguments and (args, kwargs) together. 184 185 Assumes that (args, kwargs) were the inputs to some torch._ops.OpOverload: 186 that is, kwargs must be keyword-only arguments and default values may be omitted. 187 """ 188 assert len(schema.arguments) >= len(args) + len(kwargs) 189 for i in range(len(schema.arguments)): 190 info = schema.arguments[i] 191 if info.kwarg_only: 192 if info.name in kwargs: 193 yield info, kwargs[info.name] 194 continue 195 if i >= len(args): 196 # args that are equal to their default values are not populated 197 # if they are followed by args that are equal to their defaults. 198 # Skip these. 199 continue 200 yield info, args[i] 201 return 202 203 204def hop_schema_from_fx_node(node): 205 from torchgen.gen_schema_utils import FunctionSchemaGen 206 207 hop = node.target 208 if not isinstance(hop, torch._ops.HigherOrderOperator): 209 raise RuntimeError("fx_node's target must be a hop.") 210 211 def _collect_example_val(node): 212 meta_val = node.meta.get("val", None) 213 if meta_val is None: 214 assert node.op == "get_attr" 215 meta_val = getattr(node.graph.owning_module, node.target) 216 return meta_val 217 218 example_inputs = [] 219 for arg in node.args: 220 if isinstance(arg, (torch.fx.Node, torch.fx.node.Node)): 221 example_inputs.append(_collect_example_val(arg)) 222 elif isinstance( 223 arg, (torch.fx.immutable_collections.immutable_list, list, tuple) 224 ): 225 example_inputs.append([_collect_example_val(x) for x in arg]) 226 else: 227 raise RuntimeError(f"Unsupported arg type {type(arg)}") 228 229 # Bound the arguments to make sure number of inputs are correct 230 bound_args: inspect.BoundArguments = inspect.signature(hop.__call__).bind( 231 *example_inputs 232 ) 233 234 # We treat example_output as a single value in return. This is to differentiate 1. return a single val 235 # vs 2. return a tuple with one element. 236 example_output = _collect_example_val(node) 237 return FunctionSchemaGen.from_example( 238 hop._name, tuple(bound_args.arguments.items()), (list(example_output),) 239 ) 240 241 242def can_generate_trivial_fake_impl(op: OpOverload) -> bool: 243 assert isinstance(op, OpOverload) 244 if is_builtin(op): 245 # We control the built-ins. These may (in rare cases) 246 # do input metadata mutation (which we have banned on custom ops) 247 return False 248 schema = op._schema 249 # It's suspicious if the op is not mutable but returns nothing, so we return False out of an abundance of caution 250 if not schema.is_mutable: 251 return False 252 if len(schema.returns) > 0: 253 return False 254 # If the op returns nothing, then it has a trivial fake impl. 255 return True 256 257 258def requires_set_python_module() -> bool: 259 """If an op was defined in C++ and extended from Python using the 260 torch.library APIs, returns if we require that there have been a 261 m.set_python_module("mylib.ops") call from C++ that associates 262 the C++ op with a python module. 263 """ 264 return getattr(_utils_internal, "REQUIRES_SET_PYTHON_MODULE", True) 265 266 267def handle_dispatch_mode(curr_mode, op_overload, *args, **kwargs): 268 assert isinstance(curr_mode, torch.utils._python_dispatch.TorchDispatchMode) 269 overload_types = [] 270 args_flattened, _ = torch.utils._pytree.tree_flatten((args, kwargs.values())) 271 for a in args_flattened: 272 # TODO: need to double check the semantics of the "types" argument to torch_dispatch. 273 # It's generated in PyInterpreter.cpp, but seems to be generated in two places, 274 # where in one case we only include tensors with the python key, and in another 275 # we include **all** tensors. 276 if isinstance(a, torch.Tensor) and torch._C._dispatch_keys(a).has( 277 torch._C.DispatchKey.Python 278 ): 279 overload_types.append(type(a)) 280 # TODO: check that I got these args correct (in C++, we pass in "0000"??) 281 282 return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs) 283 284 285def has_kwarg_only_args(schema: _C.FunctionSchema): 286 return any(a.kwarg_only for a in schema.arguments) 287 288 289def has_kwarg_only_tensors(schema: _C.FunctionSchema): 290 for a in schema.arguments: 291 if not (is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type)): 292 continue 293 if not a.kwarg_only: 294 continue 295 return True 296 return False 297 298 299def has_tensor_arg(schema: _C.FunctionSchema) -> bool: 300 """ 301 Given a schema, returns True if the schema has a Tensor arg. 302 A Tensor arg is any arg with a type annotation that might involve Tensor. 303 """ 304 return any( 305 (is_tensor_like_type(a.type) or is_tensorlist_like_type(a.type)) 306 for a in schema.arguments 307 ) 308 309 310def get_device_arg_index(schema: _C.FunctionSchema) -> Union[int, None]: 311 """ 312 Given a schema, returns the id of the `device: torch.device` argument. 313 If it does not exist, returns None. 314 """ 315 for index, arg in enumerate(schema.arguments): 316 if arg.type is _C.DeviceObjType.get() and arg.name == "device": 317 return index 318 return None 319