1#!/usr/bin/env python3 2# mypy: allow-untyped-defs 3from typing import Any, TypeVar, Optional, Tuple, List, NamedTuple, Union, Sequence, Dict, Callable 4import textwrap 5import torch 6from torch._C import TupleType, ListType 7from torch.jit._recursive import wrap_cpp_module 8 9 10T = TypeVar("T") 11 12MAX_RAW_TENSOR_SIZE = 16 13 14class InflatableArg(NamedTuple): 15 """Helper type for bundled inputs. 16 17 'value' is the compressed/deflated input that is stored in the model. Value 18 must be of the same type as the argument to the function that it is a deflated 19 input for. 20 21 'fmt' is a formatable code string that is executed to inflate the compressed data into 22 the appropriate input. It can use 'value' as an input to the format str. It must result 23 in a value of the same type as 'value'. 24 25 'fmt_fn' is a formatable function code string that is executed to inflate the compressed 26 data into the appropriate input. It must result in a value of the same type as 'value'. 27 The function name should be the formatable part of the string. 28 29 Note: Only top level InflatableArgs can be inflated. i.e. you cannot place 30 an inflatable arg inside of some other structure. You should instead create 31 an inflatable arg such that the fmt code string returns the full structure 32 of your input. 33 """ 34 35 value: Any 36 fmt: str = "{}" 37 fmt_fn: str = "" 38 39 40def bundle_inputs( 41 model: torch.jit.ScriptModule, 42 inputs: Union[Optional[Sequence[Tuple[Any, ...]]], Dict[Callable, Optional[Sequence[Tuple[Any, ...]]]]], 43 info: Optional[Union[List[str], Dict[Callable, List[str]]]] = None, 44 *, 45 _receive_inflate_expr: Optional[List[str]] = None, 46) -> torch.jit.ScriptModule: 47 """Create and return a copy of the specified model with inputs attached. 48 49 The original model is not mutated or changed in any way. 50 51 Models with bundled inputs can be invoked in a uniform manner by 52 benchmarking and code coverage tools. 53 54 If inputs is passed in as a list then the inputs will be bundled for 'forward'. 55 If inputs is instead passed in as a map then all the methods specified in the map 56 will have their corresponding inputs bundled. Info should match watchever type is 57 chosen for the inputs. 58 59 The returned model will support the following methods: 60 61 `get_all_bundled_inputs_for_<function_name>() -> List[Tuple[Any, ...]]` 62 Returns a list of tuples suitable for passing to the model like 63 `for inp in model.get_all_bundled_inputs_for_foo(): model.foo(*inp)` 64 65 `get_bundled_inputs_functions_and_info() -> Dict[str, Dict[str: List[str]]]` 66 Returns a dictionary mapping function names to a metadata dictionary. 67 This nested dictionary maps preset strings like: 68 'get_inputs_function_name' -> the name of a function attribute in this model that can be 69 run to get back a list of inputs corresponding to that function. 70 'info' -> the user provided extra information about the bundled inputs 71 72 If forward has bundled inputs then these following functions will also be defined on the returned module: 73 74 `get_all_bundled_inputs() -> List[Tuple[Any, ...]]` 75 Returns a list of tuples suitable for passing to the model like 76 `for inp in model.get_all_bundled_inputs(): model(*inp)` 77 78 `get_num_bundled_inputs() -> int` 79 Equivalent to `len(model.get_all_bundled_inputs())`, 80 but slightly easier to call from C++. 81 82 Inputs can be specified in one of two ways: 83 84 - The model can define `_generate_bundled_inputs_for_<function_name>`. 85 If the user chooses this method inputs[<function>] should map to None 86 87 - The `inputs` argument to this function can be a dictionary mapping functions to a 88 list of inputs, of the same form that will be returned by get_all_bundled_inputs_for_<function_name>. 89 Alternatively if only bundling inputs for forward the map can be omitted and a singular list of inputs 90 can be provided instead. 91 92 The type of the inputs is List[Tuple[Any, ...]]. The outer list corresponds with a 93 list of inputs, the inner tuple is the list of args that together make up one input. 94 For inputs of functions that take one arg, this will be a tuple of length one. The Any, ... 95 is the actual data that makes up the args, e.g. a tensor. 96 97 Info is an optional parameter that maps functions to a list of strings providing extra information about that 98 function's bundled inputs. Alternatively if only bundling inputs for forward the map can be omitted and 99 a singular list of information can be provided instead. This could be descriptions, expected outputs, etc. 100 - Ex: info={model.forward : ['man eating icecream', 'an airplane', 'a dog']} 101 102 This function will attempt to optimize arguments so that (e.g.) 103 arguments like `torch.zeros(1000)` will be represented compactly. 104 Only top-level arguments will be optimized. 105 Tensors in lists or tuples will not. 106 """ 107 if not isinstance(model, torch.jit.ScriptModule): 108 raise Exception("Only ScriptModule is supported.") # noqa: TRY002 109 110 ignored_methods, ignored_attrs = _get_bundled_inputs_attributes_and_methods(model) 111 clone = torch._C._hack_do_not_use_clone_module_with_class( # type: ignore[attr-defined] 112 model._c, 113 ignored_methods, 114 ignored_attrs, 115 ) 116 117 # The above cloning function returns a torch._C.scriptmodule and we need a torch.jit.scriptmodule. 118 # Fortunately theres a function in _recursive that does exactly that conversion. 119 cloned_module = wrap_cpp_module(clone) 120 if isinstance(inputs, dict): 121 assert isinstance(info, dict) or info is None 122 augment_many_model_functions_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info) 123 else: 124 assert isinstance(info, list) or info is None 125 augment_model_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info) 126 return cloned_module 127 128def augment_model_with_bundled_inputs( 129 model: torch.jit.ScriptModule, 130 inputs: Optional[Sequence[Tuple[Any, ...]]] = None, 131 _receive_inflate_expr: Optional[List[str]] = None, # For debugging. 132 info: Optional[List[str]] = None, # Optional argument to provide info about forward or its inputs 133 skip_size_check=False, 134) -> None: 135 """Add bundled sample inputs to a model for the forward function. 136 137 Models with bundled inputs can be invoked in a uniform manner by 138 benchmarking and code coverage tools. 139 140 Augmented models will support the following methods: 141 142 `get_all_bundled_inputs() -> List[Tuple[Any, ...]]` 143 Returns a list of tuples suitable for passing to the model like 144 `for inp in model.get_all_bundled_inputs(): model(*inp)` 145 146 `get_num_bundled_inputs() -> int` 147 Equivalent to `len(model.get_all_bundled_inputs())`, 148 but slightly easier to call from C++. 149 150 `get_bundled_inputs_functions_and_info() -> Dict[str, Dict[str: List[str]]]` 151 Returns a dictionary mapping function names to a metadata dictionary. 152 This nested dictionary maps preset strings like: 153 'get_inputs_function_name' -> the name of a function attribute in this model that can be 154 run to get back a list of inputs corresponding to that function. 155 'info' -> the user provided extra information about the bundled inputs 156 157 Inputs can be specified in one of two ways: 158 159 - The model can define `_generate_bundled_inputs_for_forward`. 160 If the user chooses this method inputs should be None 161 162 - `inputs` is a list of inputs of form List[Tuple[Any, ...]]. A list of tuples where the elements 163 of each tuple are the args that make up one input. 164 """ 165 if not isinstance(model, torch.jit.ScriptModule): 166 raise Exception("Only ScriptModule is supported.") # noqa: TRY002 167 168 forward: Callable = model.forward 169 170 # Sometimes forward won't have a name attached so just in case 171 if not hasattr(forward, "__name__"): 172 forward.__name__ = 'forward' 173 augment_many_model_functions_with_bundled_inputs( 174 model, 175 inputs={forward : inputs}, 176 _receive_inflate_expr=_receive_inflate_expr, 177 info={forward : info} if info else None, 178 skip_size_check=skip_size_check, 179 ) 180 181 182def augment_many_model_functions_with_bundled_inputs( 183 model: torch.jit.ScriptModule, 184 inputs: Dict[Callable, Optional[Sequence[Tuple[Any, ...]]]], 185 _receive_inflate_expr: Optional[List[str]] = None, # For debugging. 186 info: Optional[Dict[Callable, List[str]]] = None, # Optional argument to provide info about the function or its inputs 187 skip_size_check=False, 188) -> None: 189 """Add bundled sample inputs to a model for an arbitrary list of public functions. 190 191 Models with bundled inputs can be invoked in a uniform manner by 192 benchmarking and code coverage tools. 193 194 Augmented models will support the following methods: 195 196 `get_all_bundled_inputs_for_<function_name>() -> List[Tuple[Any, ...]]` 197 Returns a list of tuples suitable for passing to the model like 198 `for inp in model.get_all_bundled_inputs_for_foo(): model.foo(*inp)` 199 200 `get_bundled_inputs_functions_and_info() -> Dict[str, Dict[str: List[str]]]` 201 Returns a dictionary mapping function names to a metadata dictionary. 202 This nested dictionary maps preset strings like: 203 'get_inputs_function_name' -> the name of a function attribute in this model that can be 204 run to get back a list of inputs corresponding to that function. 205 'info' -> the user provided extra information about the bundled inputs 206 207 If forward has bundled inputs then these following functions are also defined: 208 209 `get_all_bundled_inputs() -> List[Tuple[Any, ...]]` 210 Returns a list of tuples suitable for passing to the model like 211 `for inp in model.get_all_bundled_inputs(): model(*inp)` 212 213 `get_num_bundled_inputs() -> int` 214 Equivalent to `len(model.get_all_bundled_inputs())`, 215 but slightly easier to call from C++. 216 217 Inputs can be specified in one of two ways: 218 219 - The model can define `_generate_bundled_inputs_for_<function_name>`. 220 If the user chooses this method inputs[<function>] should map to None 221 222 - The `inputs` argument to this function can be a dictionary mapping functions to a 223 list of inputs, of the same form that will be returned by get_all_bundled_inputs_for_<function_name>. 224 The type of the inputs is List[Tuple[Any, ...]]. The outer list corresponds with a 225 list of inputs, the inner tuple is the list of args that together make up one input. 226 For inputs of functions that take one arg, this will be a tuple of length one. The Any, ... 227 is the actual data that makes up the args, e.g. a tensor. 228 229 Info is an optional parameter that maps functions to a list of strings providing extra information about that 230 function's bundled inputs. This could be descriptions, expected outputs, etc. 231 - Ex: info={model.forward : ['man eating icecream', 'an airplane', 'a dog']} 232 233 This function will attempt to optimize arguments so that (e.g.) 234 arguments like `torch.zeros(1000)` will be represented compactly. 235 Only top-level arguments will be optimized. 236 Tensors in lists or tuples will not. 237 """ 238 if not isinstance(model, torch.jit.ScriptModule): 239 raise Exception("Only ScriptModule is supported.") # noqa: TRY002 240 241 if not inputs: 242 raise Exception("Please provide inputs for at least 1 function") # noqa: TRY002 243 244 if hasattr(model, "get_all_bundled_inputs") or hasattr(model, "get_bundled_inputs_functions_and_info"): 245 raise Exception( # noqa: TRY002 246 "Models can only be augmented with bundled inputs once. " 247 "This Model seems to have already been augmented with " 248 "bundled inputs. Please start afresh with one that " 249 "doesn't have bundled inputs.", 250 ) 251 252 get_bundled_inputs_functions_and_info_template = "" 253 254 for function, input_list in inputs.items(): 255 if hasattr(function, "__name__"): 256 function_name = function.__name__ 257 else: 258 if hasattr(function, "name"): 259 function_name = function.name # type: ignore[attr-defined] 260 else: 261 raise Exception( # noqa: TRY002 262 'At least one of your functions has no attribute name please ensure all have one. m.foo.name = "foo"') 263 264 265 if input_list is not None and not isinstance(input_list, Sequence): 266 raise TypeError(f"Error inputs for function {function_name} is not a Sequence") 267 268 function_arg_types = [arg.type for arg in function.schema.arguments[1:]] # type: ignore[attr-defined] 269 deflated_inputs_type: ListType = ListType(TupleType(function_arg_types)) 270 model._c._register_attribute(f"_bundled_inputs_deflated_{function_name}", deflated_inputs_type, []) 271 272 if hasattr(model, "_generate_bundled_inputs_for_" + function_name): 273 if input_list is not None: 274 raise Exception( # noqa: TRY002 275 f"inputs[{function_name}] is not None, but _generate_bundled_inputs_for_{function_name} is already defined" 276 ) 277 # Model author already defined _generate_bundled_inputs_for_<function_name>. 278 elif input_list is None or len(input_list) == 0: 279 raise Exception( # noqa: TRY002 280 f"inputs for {function_name} must be specified if " 281 f"_generate_bundled_inputs_for_{function_name} is not already defined" 282 ) 283 else: 284 # Iterate over the inputs and args in each input. 285 # Accumulate `deflated_inputs` as (possibly) compressed values 286 # and `parts` to be joined into the expression that unpacks them. 287 deflated_inputs = [] 288 parts = [] 289 for inp_idx, args in enumerate(input_list): 290 if not isinstance(args, Tuple) and not isinstance(args, List): # type: ignore[arg-type] 291 raise TypeError( 292 f"Error bundled input for function {function_name} idx: {inp_idx} is not a Tuple or a List" 293 ) 294 deflated_args = [] 295 parts.append("(") 296 for arg_idx, arg in enumerate(args): 297 inflate_helper_fn_name = _get_inflate_helper_fn_name(arg_idx, inp_idx, function_name) 298 deflated, inflater, helper_definition = _inflate_expr( 299 arg, 300 f"deflated[{inp_idx}][{arg_idx}]", 301 inflate_helper_fn_name, 302 skip_size_check=skip_size_check, 303 ) 304 deflated_args.append(deflated) 305 parts.append(f" {inflater},") 306 if helper_definition: 307 model.define(textwrap.dedent(helper_definition)) 308 deflated_inputs.append(tuple(deflated_args)) 309 parts.append("),") 310 parts.append("") 311 expr = "\n".join(parts) 312 313 # Back-channel return this expr for debugging. 314 if _receive_inflate_expr is not None: 315 _receive_inflate_expr.append(expr) 316 setattr(model, f"_bundled_inputs_deflated_{function_name}", deflated_inputs) 317 definition = textwrap.dedent(""" 318 def _generate_bundled_inputs_for_{name}(self): 319 deflated = self._bundled_inputs_deflated_{name} 320 return [ 321 {expr} 322 ] 323 """).format(expr=expr, name=function_name) 324 model.define(definition) 325 326 # Define get_all_bundled_inputs_for_<function_name> that caches the generated inputs. 327 model.define(textwrap.dedent(""" 328 def get_all_bundled_inputs_for_{name}(self): 329 all_inputs = self._generate_bundled_inputs_for_{name}() 330 assert all_inputs is not None 331 return all_inputs 332 """).format(name=function_name)) 333 334 # Add to the high level helper methods 335 inputs_info = repr(info[function]) if info and function in info else '[]' 336 get_bundled_inputs_functions_and_info_template += f""" 337 temp_dict : Dict[str,List[str]] = {{}} 338 info: List[str] = {inputs_info} 339 340 temp_dict['info'] = info 341 temp_dict['get_inputs_function_name'] = ['get_all_bundled_inputs_for_{function_name}'] 342 all_inputs['{function_name}'] = temp_dict 343 """ 344 345 # To ensure backwards compatibility and a streamlined api for forward these wrappers are provided 346 if function_name == 'forward': 347 model.define(textwrap.dedent(""" 348 def get_all_bundled_inputs(self): 349 return self.get_all_bundled_inputs_for_forward() 350 """)) 351 model.define(textwrap.dedent(""" 352 def get_num_bundled_inputs(self): 353 return len(self.get_all_bundled_inputs_for_forward()) 354 """)) 355 356 # Define some high level helper methods that act on all bundled inputs 357 model.define(textwrap.dedent(f""" 358 def get_bundled_inputs_functions_and_info(self): 359 all_inputs : Dict[str, Dict[str,List[str]]] = {{}} 360 {get_bundled_inputs_functions_and_info_template} 361 return all_inputs 362 """)) 363 364def _inflate_expr( 365 arg: T, ref: str, inflate_helper_fn_name: str, skip_size_check: bool = False 366) -> Tuple[Union[T, torch.Tensor], str, Optional[str]]: 367 # Allow custom inflation expressions any object. 368 # For example, calling custom image-decoding ops. 369 # Or just use "{}" as the format string to ignore size limits. 370 if isinstance(arg, InflatableArg): 371 if arg.fmt_fn: 372 if arg.fmt not in ["{}", ""]: 373 raise Exception( # noqa: TRY002 374 f"Bundled input argument at position '{ref}' has " 375 f"both arg.fmt_fn => \n{arg.fmt_fn} " 376 f"\n and arg.fmt => {arg.fmt}. " 377 "Please choose `arg.fmt` if the deflater is straightforward or " 378 "`arg.fmt_fn` if you need a function." 379 ) 380 381 helper_definition = arg.fmt_fn.format(inflate_helper_fn_name) 382 expr = f"self.{inflate_helper_fn_name}({ref})" 383 384 return arg.value, expr, helper_definition 385 else: 386 return arg.value, arg.fmt.format(ref), None 387 388 if isinstance(arg, torch.Tensor): 389 # Small-storage tensors can just be saved directly. 390 if arg._typed_storage().size() <= MAX_RAW_TENSOR_SIZE or skip_size_check: 391 return arg, ref, None 392 # Small contiguous tensors can be cloned to have small storage. 393 # TODO: Should we do this even for non-contiguous tensors? 394 if arg.is_contiguous() and arg.numel() <= MAX_RAW_TENSOR_SIZE: 395 return arg.clone(), ref, None 396 # Example inputs commonly come from torch.zeros, torch.ones, or torch.full. 397 # These can be represented compactly. 398 for fmt in [torch.contiguous_format, torch.channels_last]: 399 if arg.is_contiguous(memory_format=fmt) and (arg == arg.flatten()[0]).all().item(): 400 return (arg.flatten()[0].clone().expand(*arg.size()), 401 f"{ref}.contiguous(memory_format={fmt})", None) 402 # Prevent big tensors from being bundled by default. 403 # TODO: Provide more useful diagnostics. 404 raise Exception( # noqa: TRY002 405 f"Bundled input argument at position '{ref}' is " 406 f"a tensor with storage size {arg._typed_storage().size()}. " 407 f"You probably don't want to bundle this as an input. " 408 ) 409 else: 410 return arg, ref, None 411 412def _get_bundled_inputs_attributes_and_methods(script_module: torch.jit.ScriptModule) -> Tuple[List[str], List[str]]: 413 methods: List[str] = [] 414 attributes: List[str] = [] 415 416 # Has bundled inputs for forward 417 if hasattr(script_module, 'get_all_bundled_inputs'): 418 methods.append('get_all_bundled_inputs') 419 methods.append('get_num_bundled_inputs') 420 methods.append('run_on_bundled_input') 421 422 if hasattr(script_module, 'get_bundled_inputs_functions_and_info'): 423 methods.append('get_bundled_inputs_functions_and_info') 424 all_info = script_module.get_bundled_inputs_functions_and_info() 425 for function_name in all_info: 426 methods.append("get_all_bundled_inputs_for_" + function_name) 427 methods.append("_generate_bundled_inputs_for_" + function_name) 428 attributes.append("_bundled_inputs_deflated_" + function_name) 429 430 bundled_inputs_fn = getattr( 431 script_module, 432 f"get_all_bundled_inputs_for_{function_name}" 433 ) 434 num_bundled_inputs: int = len(bundled_inputs_fn()) 435 436 # Check inflate helper functions for each function, argument and bundled input 437 func = getattr(script_module, function_name) 438 for arg_idx in range(len(func.schema.arguments) - 1): 439 for input_idx in range(num_bundled_inputs): 440 helper_fn_name = _get_inflate_helper_fn_name( 441 arg_idx=arg_idx, 442 input_idx=input_idx, 443 function_name=function_name 444 ) 445 # if the arg has an InflatableArg with fmt_fn, add the helper function name 446 if hasattr(script_module, helper_fn_name): 447 methods.append(helper_fn_name) 448 449 return (methods, attributes) 450 451 452def _get_inflate_helper_fn_name( 453 arg_idx: int, 454 input_idx: int, 455 function_name: str, 456) -> str: 457 return f"_inflate_helper_for_{function_name}_input_{input_idx}_arg_{arg_idx}" 458 459 460 461def bundle_randn(*size, dtype=None): 462 """Generate a tensor that will be inflated with torch.randn.""" 463 stub = torch.zeros(1, dtype=dtype).expand(*size) 464 return InflatableArg(value=stub, fmt="torch.randn_like({})") 465 466 467def bundle_large_tensor(t): 468 """Wrap a tensor to allow bundling regardless of size.""" 469 return InflatableArg(value=t, fmt="{}") 470