1# mypy: allow-untyped-defs 2import torch 3import inspect 4import numbers 5import types 6import typing 7import enum 8import warnings 9from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING 10from torch._jit_internal import boolean_dispatched 11from ._compatibility import compatibility 12from torch._ops import OpOverloadPacket, OpOverload 13 14if TYPE_CHECKING: 15 from .node import Argument 16 17__all__ = ["ArgsKwargsPair", "check_for_mutable_operation", "get_signature_for_torch_op", "create_type_hint", 18 "type_matches", "normalize_function", "normalize_module"] 19 20@compatibility(is_backward_compatible=False) 21class ArgsKwargsPair(NamedTuple): 22 """ 23 Simple named tuple for wrapping args/kwargs pairs. 24 """ 25 args: Tuple[Any, ...] 26 kwargs: Dict[str, Any] 27 28_manual_overrides : Dict[Callable, List[inspect.Signature]] = {} 29 30def _nonzero_schemas(): 31 signatures = [] 32 33 def nonzero(self): 34 pass 35 signatures.append(inspect.signature(nonzero)) 36 37 def nonzero(self, *, as_tuple : bool): # type: ignore[no-redef] 38 pass 39 signatures.append(inspect.signature(nonzero)) 40 41 return signatures 42 43_manual_overrides[torch.nonzero] = _nonzero_schemas() 44 45class _FakeGlobalNamespace: 46 def __getattr__(self, name): 47 if name == 'torch': 48 return torch 49 raise RuntimeError('Expected a torch namespace lookup') 50 51_type_eval_globals = {'Tensor' : torch.Tensor, 'Device' : torch.device, 'Layout' : torch.layout, 52 'number' : numbers.Number, 'Future' : torch.jit.Future, 53 'AnyEnumType' : enum.Enum, 'QScheme' : torch.qscheme, 54 '__torch__': _FakeGlobalNamespace(), 'NoneType': type(None), 55 'Storage': torch.UntypedStorage, 56 't': typing.TypeVar('t')} 57for k in dir(typing): 58 _type_eval_globals[k] = getattr(typing, k) 59 60def _torchscript_type_to_python_type(ts_type : 'torch._C.JitType') -> Any: 61 """ 62 Convert a TorchScript type to a Python type (including subtypes) via 63 eval'ing the annotation_str. _type_eval_globals sets up expressions 64 like "List" and "Future" to map to actual types (typing.List and jit.Future) 65 """ 66 return eval(ts_type.annotation_str, _type_eval_globals) 67 68def _torchscript_schema_to_signature_impl(ts_schema : torch._C.FunctionSchema) -> inspect.Signature: 69 from inspect import Parameter 70 parameters : List[Parameter] = [] 71 for arg in ts_schema.arguments: 72 arg_type = _torchscript_type_to_python_type(arg.type) 73 default = arg.default_value if arg.has_default_value() else Parameter.empty 74 # TODO: Figure out if this is safe. It seems like when generating the type signatures for 75 # PythonArgParser, we emit signatures with `input` instead of `self` as the first tensor 76 # argument name. Downstream, if someone converts that positional argument to a keyword 77 # argument, the name mismatch will break things, so here we're going to normalize the 78 # name to "input" 79 name = arg.name if arg.name != 'self' else 'input' 80 kind = Parameter.KEYWORD_ONLY if arg.kwarg_only else Parameter.POSITIONAL_OR_KEYWORD 81 # "from" is a keyword therefore it must be a POSITIONAL_ONLY argument 82 if name == "from": 83 assert kind == Parameter.POSITIONAL_OR_KEYWORD 84 # ParameterKind type is internal implementation detail to inspec package 85 # which makes it hard to do type annotation 86 kind = Parameter.POSITIONAL_ONLY # type: ignore[assignment] 87 # This renders all previous arguments to positional only 88 for idx, p in enumerate(parameters): 89 assert p.kind == Parameter.POSITIONAL_OR_KEYWORD 90 parameters[idx] = Parameter(name=p.name, kind=Parameter.POSITIONAL_ONLY, default=p.default, annotation=p.annotation) 91 parameters.append(Parameter(name=name, kind=kind, default=default, annotation=arg_type)) 92 return_types = [_torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns] 93 if len(return_types) == 0: 94 return_type = None 95 elif len(return_types) == 1: 96 return_type = return_types[0] 97 else: 98 return_type = tuple(return_types) 99 100 return inspect.Signature(parameters, return_annotation=return_type) 101 102_SCHEMA_TO_SIGNATURE_CACHE : Dict[Tuple[str, str], inspect.Signature] = {} 103 104def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> inspect.Signature: 105 # Cached as it's called in the hot path of FakeTensor dispatch 106 cache_key = ts_schema.name, ts_schema.overload_name 107 cache_val = _SCHEMA_TO_SIGNATURE_CACHE.get(cache_key) 108 if cache_val is not None: 109 return cache_val 110 111 res = _torchscript_schema_to_signature_impl(ts_schema) 112 _SCHEMA_TO_SIGNATURE_CACHE[cache_key] = res 113 return res 114 115@compatibility(is_backward_compatible=False) 116def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...], kwargs : Dict[str, 'Argument']): 117 signatures, schemas = get_signature_for_torch_op(target, return_schemas=True) 118 119 if signatures and schemas: 120 matched_schemas = [] 121 122 # Iterate through all of the schema until we find one that matches 123 # If one matches, populate `new_args_and_kwargs` with the new args/kwargs 124 # values. If none matches, `new_args_and_kwargs` will be None 125 for candidate_signature, schema in zip(signatures, schemas): 126 try: 127 candidate_signature.bind(*args, **kwargs) 128 matched_schemas.append((candidate_signature, schema)) 129 except TypeError as e: 130 continue 131 132 def throw_if_mutable(schema): 133 if schema.is_mutable: 134 raise RuntimeError(f'Tried to trace mutable operation {schema}. FX only supports functional ' 135 f'code, so operations that mutate operands in-place (e.g. via `out` arguments) ' 136 f'are not supported') 137 138 if len(matched_schemas) == 0: 139 # Did not match any schema. Cannot check for mutation 140 pass 141 elif len(matched_schemas) == 1: 142 # Matched exactly one schema, unambiguous 143 _, schema_to_check = matched_schemas[0] 144 throw_if_mutable(schema_to_check) 145 else: 146 # Ambiguous schema match. Since mutability checking is best effort, 147 # do nothing. 148 pass 149 150@compatibility(is_backward_compatible=False) 151def get_signature_for_torch_op(op : Callable, return_schemas : bool = False): 152 """ 153 Given an operator on the `torch` namespace, return a list of `inspect.Signature` 154 objects corresponding to the overloads of that op.. May return `None` if a signature 155 could not be retrieved. 156 157 Args: 158 op (Callable): An operator on the `torch` namespace to look up a signature for 159 160 Returns: 161 Optional[List[inspect.Signature]]: A list of signatures for the overloads of this 162 operator, or None if the operator signatures could not be retrieved. If 163 return_schemas=True, returns a tuple containing the optional Python signatures 164 and the optional TorchScript Function signature 165 """ 166 if isinstance(op, OpOverload): 167 schemas = [op._schema] 168 elif isinstance(op, OpOverloadPacket): 169 schemas = [getattr(op, overload)._schema for overload in op.overloads()] 170 else: 171 override = _manual_overrides.get(op) 172 if override: 173 return (override, None) if return_schemas else None 174 175 aten_fn = torch.jit._builtins._find_builtin(op) 176 177 if aten_fn is None: 178 return (None, None) if return_schemas else None 179 schemas = torch._C._jit_get_schemas_for_operator(aten_fn) 180 181 signatures = [_torchscript_schema_to_signature(schema) for schema in schemas] 182 return (signatures, schemas) if return_schemas else signatures 183 184@compatibility(is_backward_compatible=False) 185def create_type_hint(x): 186 """ 187 Produces a type hint for the given argument. 188 189 The :func:`create_type_hint` looks for a type hint compatible with the input argument `x`. 190 191 If `x` is a `list` or `tuple`, it looks for an object in the list whose type is a superclass 192 of the rest, and uses that as `base_type` for the `List` or `Tuple` to be returned. 193 If no such object is found, it defaults to `List[Any]`. 194 195 If `x` is neither a `list` nor a `tuple`, it returns `x`. 196 """ 197 try: 198 if isinstance(x, (list, tuple)): 199 # todo(chilli): Figure out the right way for mypy to handle this 200 if isinstance(x, list): 201 def ret_type(x): 202 return List[x] # type: ignore[valid-type] 203 else: 204 def ret_type(x): 205 return Tuple[x, ...] 206 if len(x) == 0: 207 return ret_type(Any) 208 base_type = x[0] 209 for t in x: 210 if issubclass(t, base_type): 211 continue 212 elif issubclass(base_type, t): 213 base_type = t 214 else: 215 return ret_type(Any) 216 return ret_type(base_type) 217 except Exception as e: 218 # We tried to create a type hint for list but failed. 219 warnings.warn(f"We were not able to successfully create type hint from the type {x}") 220 return x 221 222@compatibility(is_backward_compatible=False) 223def type_matches(signature_type : Any, argument_type : Any): 224 sig_origin_type = getattr(signature_type, '__origin__', signature_type) 225 226 if signature_type is argument_type: 227 return True 228 229 # Union types in signature. Given type needs to match one of the 230 # contained types in the Union 231 if sig_origin_type is typing.Union and signature_type != argument_type: 232 sig_contained = signature_type.__args__ 233 return any(type_matches(c, argument_type) for c in sig_contained) 234 235 if signature_type is List[int] and argument_type is int: 236 # int can be promoted to List[int] 237 return True 238 239 if getattr(signature_type, '__origin__', None) in {list, List}: 240 sig_el_type = signature_type.__args__[0] 241 if not inspect.isclass(sig_el_type): 242 warnings.warn( 243 f"Does not support nested parametric types, got {signature_type}. Please file a bug.") 244 return False 245 if getattr(argument_type, '__origin__', None) in {list, List}: 246 return issubclass(argument_type.__args__[0], sig_el_type) 247 248 def is_homogeneous_tuple(t): 249 if getattr(t, "__origin__", None) not in {tuple, Tuple}: 250 return False 251 contained = t.__args__ 252 if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason 253 return True 254 return all((c is Ellipsis) or issubclass(c, sig_el_type) for c in contained) 255 256 # Tuple[T] is accepted for List[T] parameters 257 return is_homogeneous_tuple(argument_type) 258 259 # Dtype is an int in schemas 260 if signature_type is int and argument_type is torch.dtype: 261 return True 262 263 if signature_type is numbers.Number and argument_type in {int, float}: 264 return True 265 if inspect.isclass(argument_type) and inspect.isclass(signature_type): 266 return issubclass(argument_type, signature_type) 267 268 return False 269 270@compatibility(is_backward_compatible=False) 271def normalize_function( 272 target: Callable, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, arg_types : Optional[Tuple[Any]] = None, 273 kwarg_types : Optional[Dict[str, Any]] = None, 274 normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: 275 """ 276 Returns normalized arguments to PyTorch functions. This means that 277 `args/kwargs` will be matched up to the functional's 278 signature and return exclusively kwargs in positional order if 279 `normalize_to_only_use_kwargs` is True. 280 Also populates default values. Does not support positional-only 281 parameters or varargs parameters (*args, **kwargs). Does not support modules. 282 283 May require `arg_types` and `kwarg_types` in order to disambiguate overloads. 284 285 Args: 286 target (Callable): Function that we are normalizing 287 args (Tuple[Any]): Tuple of args to the function 288 kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function 289 arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args 290 kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs 291 normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. 292 293 Returns: 294 295 Returns normalized_args_and_kwargs, or `None` if not successful. 296 """ 297 if kwargs is None: 298 kwargs = {} 299 new_args_and_kwargs = None 300 if not isinstance(target, types.BuiltinFunctionType) and not ( 301 isinstance(target, (OpOverloadPacket, OpOverload)) 302 ): 303 target_for_analysis = target 304 if target in boolean_dispatched: 305 # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have 306 # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false` 307 # branches of the dispatch have exactly the same signature. If they do, use the `true` 308 # branch signature for analysis. Otherwise, leave this un-normalized 309 assert not isinstance(target, str) 310 dispatched = boolean_dispatched[target] 311 if_true, if_false = dispatched['if_true'], dispatched['if_false'] 312 if inspect.signature(if_true).parameters != inspect.signature(if_false).parameters: 313 return None 314 target_for_analysis = if_true 315 316 assert callable(target_for_analysis) 317 sig = inspect.signature(inspect.unwrap(target_for_analysis)) 318 new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, normalize_to_only_use_kwargs) 319 else: 320 assert callable(target) 321 torch_op_schemas = get_signature_for_torch_op(target) 322 matched_schemas = [] 323 if torch_op_schemas: 324 # Iterate through all of the schema until we find one that matches 325 # If one matches, populate `new_args_and_kwargs` with the new args/kwargs 326 # values. If none matches, `new_args_and_kwargs` will be None 327 for candidate_signature in torch_op_schemas: 328 try: 329 candidate_signature.bind(*args, **kwargs) 330 matched_schemas.append(candidate_signature) 331 except TypeError as e: 332 continue 333 334 if len(matched_schemas) == 0: 335 # Did not match any schema. Cannot normalize 336 pass 337 elif len(matched_schemas) == 1: 338 # Matched exactly one schema, unambiguous 339 new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(matched_schemas[0], args, kwargs, 340 normalize_to_only_use_kwargs) 341 else: 342 if arg_types is not None or kwarg_types is not None: 343 arg_types = arg_types if arg_types else cast(Tuple[Any], ()) 344 kwarg_types = kwarg_types if kwarg_types else {} 345 for candidate_signature in torch_op_schemas: 346 sig_matches = True 347 try: 348 bound_types = candidate_signature.bind(*arg_types, **kwarg_types) 349 for arg_name, arg_type in bound_types.arguments.items(): 350 param = candidate_signature.parameters[arg_name] 351 sig_matches = sig_matches and type_matches(param.annotation, arg_type) 352 except TypeError as e: 353 sig_matches = False 354 if sig_matches: 355 new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(candidate_signature, args, kwargs, 356 normalize_to_only_use_kwargs) 357 break 358 else: 359 # Matched more than one schema. In this situation, the caller must provide the types of 360 # the arguments of the overload they expect. 361 schema_printouts = '\n'.join(str(schema) for schema in matched_schemas) 362 raise RuntimeError(f'Tried to normalize arguments to {torch.typename(target)} but ' 363 f'the schema match was ambiguous! Please provide argument types to ' 364 f'the normalize_arguments() call. Available schemas:\n{schema_printouts}') 365 366 return new_args_and_kwargs 367 368@compatibility(is_backward_compatible=False) 369def normalize_module( 370 root: torch.nn.Module, target: str, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, 371 normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: 372 """ 373 Returns normalized arguments to PyTorch modules. This means that 374 `args/kwargs` will be matched up to the functional's 375 signature and return exclusively kwargs in positional order if 376 `normalize_to_only_use_kwargs` is True. 377 Also populates default values. Does not support positional-only 378 parameters or varargs parameters (*args, **kwargs). 379 380 Args: 381 root (nn.Module): root module upon which we query modules 382 target (Callable): Function that we are normalizing 383 args (Tuple[Any]): Tuple of args to the function 384 kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function 385 normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. 386 387 Returns: 388 389 Returns normalized_args_and_kwargs, or `None` if not successful. 390 """ 391 try: 392 submod = root.get_submodule(target) 393 except AttributeError as e: 394 raise RuntimeError(f"Tried to normalize node with target {target} but root did not " 395 f"have that target!") from e 396 if hasattr(submod.__class__, '__name__'): 397 classname = submod.__class__.__name__ 398 if getattr(torch.nn, classname, None) == submod.__class__: 399 sig = inspect.signature(inspect.unwrap(submod.forward)) 400 if kwargs is None: 401 kwargs = {} 402 new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, 403 normalize_to_only_use_kwargs) 404 return new_args_and_kwargs 405 return None 406 407def _args_kwargs_to_normalized_args_kwargs(sig : inspect.Signature, args : Tuple[Any, ...], 408 kwargs : Dict[str, Any], 409 normalize_to_only_use_kwargs : bool) -> Optional[ArgsKwargsPair]: 410 """ 411 Given a call target, args, and kwargs, return the arguments normalized into 412 an ArgsKwargsPair, or None if the type signature is not supported by 413 this normalization. 414 415 Args: 416 417 sig (inspect.Signature): Signature object for the target 418 args (Tuple): Arguments that appear at the callsite for `target` 419 kwargs (Dict): Keyword arguments that appear at the callsite for `target` 420 normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. 421 422 Returns: 423 424 Optional[ArgsKwargsPair]: Normalized args and kwargs for `target`, or `None` if 425 this target is not supported. 426 """ 427 428 # Don't currently support positional-only 429 # or varargs (*args, **kwargs) signatures 430 supported_parameter_types = { 431 inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY} 432 if any(p.kind not in supported_parameter_types for p in sig.parameters.values()): 433 # Add an exception for one signature, which is common for random/uniform, i.e.: 434 # Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None 435 # `from` is Python keyword and as such functions with that signature should have 436 # positional-only args, but at the same time they could be dispatched as kwargs 437 if list(sig.parameters.keys()) != ['input', 'from', 'to', 'generator']: 438 return None 439 440 bound_args = sig.bind(*args, **kwargs) 441 bound_args.apply_defaults() 442 443 new_kwargs : Dict[str, Any] = {} 444 new_args : List[Any] = [] 445 for i, param in enumerate(sig.parameters): 446 if not normalize_to_only_use_kwargs and i < len(args): 447 new_args.append(bound_args.arguments[param]) 448 else: 449 new_kwargs[param] = bound_args.arguments[param] 450 451 return ArgsKwargsPair(tuple(new_args), new_kwargs) 452