xref: /aosp_15_r20/external/pytorch/torch/utils/bundled_inputs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# mypy: allow-untyped-defs
3from typing import Any, TypeVar, Optional, Tuple, List, NamedTuple, Union, Sequence, Dict, Callable
4import textwrap
5import torch
6from torch._C import TupleType, ListType
7from torch.jit._recursive import wrap_cpp_module
8
9
10T = TypeVar("T")
11
12MAX_RAW_TENSOR_SIZE = 16
13
14class InflatableArg(NamedTuple):
15    """Helper type for bundled inputs.
16
17    'value' is the compressed/deflated input that is stored in the model. Value
18    must be of the same type as the argument to the function that it is a deflated
19    input for.
20
21    'fmt' is a formatable code string that is executed to inflate the compressed data into
22    the appropriate input. It can use 'value' as an input to the format str. It must result
23    in a value of the same type as 'value'.
24
25    'fmt_fn' is a formatable function code string that is executed to inflate the compressed
26    data into the appropriate input. It must result in a value of the same type as 'value'.
27    The function name should be the formatable part of the string.
28
29    Note: Only top level InflatableArgs can be inflated. i.e. you cannot place
30    an inflatable arg inside of some other structure. You should instead create
31    an inflatable arg such that the fmt code string returns the full structure
32    of your input.
33    """
34
35    value: Any
36    fmt: str = "{}"
37    fmt_fn: str = ""
38
39
40def bundle_inputs(
41        model: torch.jit.ScriptModule,
42        inputs: Union[Optional[Sequence[Tuple[Any, ...]]], Dict[Callable, Optional[Sequence[Tuple[Any, ...]]]]],
43        info: Optional[Union[List[str], Dict[Callable, List[str]]]] = None,
44        *,
45        _receive_inflate_expr: Optional[List[str]] = None,
46) -> torch.jit.ScriptModule:
47    """Create and return a copy of the specified model with inputs attached.
48
49    The original model is not mutated or changed in any way.
50
51    Models with bundled inputs can be invoked in a uniform manner by
52    benchmarking and code coverage tools.
53
54    If inputs is passed in as a list then the inputs will be bundled for 'forward'.
55    If inputs is instead passed in as a map then all the methods specified in the map
56    will have their corresponding inputs bundled. Info should match watchever type is
57    chosen for the inputs.
58
59    The returned model will support the following methods:
60
61        `get_all_bundled_inputs_for_<function_name>() -> List[Tuple[Any, ...]]`
62            Returns a list of tuples suitable for passing to the model like
63            `for inp in model.get_all_bundled_inputs_for_foo(): model.foo(*inp)`
64
65        `get_bundled_inputs_functions_and_info() -> Dict[str, Dict[str: List[str]]]`
66            Returns a dictionary mapping function names to a metadata dictionary.
67            This nested dictionary maps preset strings like:
68                'get_inputs_function_name' -> the name of a function attribute in this model that can be
69                    run to get back a list of inputs corresponding to that function.
70                'info' -> the user provided extra information about the bundled inputs
71
72    If forward has bundled inputs then these following functions will also be defined on the returned module:
73
74        `get_all_bundled_inputs() -> List[Tuple[Any, ...]]`
75            Returns a list of tuples suitable for passing to the model like
76            `for inp in model.get_all_bundled_inputs(): model(*inp)`
77
78        `get_num_bundled_inputs() -> int`
79            Equivalent to `len(model.get_all_bundled_inputs())`,
80            but slightly easier to call from C++.
81
82    Inputs can be specified in one of two ways:
83
84      - The model can define `_generate_bundled_inputs_for_<function_name>`.
85        If the user chooses this method inputs[<function>] should map to None
86
87      - The `inputs` argument to this function can be a dictionary mapping functions to a
88        list of inputs, of the same form that will be returned by get_all_bundled_inputs_for_<function_name>.
89        Alternatively if only bundling inputs for forward the map can be omitted and a singular list of inputs
90        can be provided instead.
91
92        The type of the inputs is List[Tuple[Any, ...]]. The outer list corresponds with a
93        list of inputs, the inner tuple is the list of args that together make up one input.
94        For inputs of functions that take one arg, this will be a tuple of length one. The Any, ...
95        is the actual data that makes up the args, e.g. a tensor.
96
97    Info is an optional parameter that maps functions to a list of strings providing extra information about that
98    function's bundled inputs. Alternatively if only bundling inputs for forward the map can be omitted and
99    a singular list of information can be provided instead. This could be descriptions, expected outputs, etc.
100        - Ex: info={model.forward : ['man eating icecream', 'an airplane', 'a dog']}
101
102    This function will attempt to optimize arguments so that (e.g.)
103    arguments like `torch.zeros(1000)` will be represented compactly.
104    Only top-level arguments will be optimized.
105    Tensors in lists or tuples will not.
106    """
107    if not isinstance(model, torch.jit.ScriptModule):
108        raise Exception("Only ScriptModule is supported.")  # noqa: TRY002
109
110    ignored_methods, ignored_attrs = _get_bundled_inputs_attributes_and_methods(model)
111    clone = torch._C._hack_do_not_use_clone_module_with_class(  # type: ignore[attr-defined]
112        model._c,
113        ignored_methods,
114        ignored_attrs,
115    )
116
117    # The above cloning function returns a torch._C.scriptmodule and we need a torch.jit.scriptmodule.
118    # Fortunately theres a function in _recursive that does exactly that conversion.
119    cloned_module = wrap_cpp_module(clone)
120    if isinstance(inputs, dict):
121        assert isinstance(info, dict) or info is None
122        augment_many_model_functions_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info)
123    else:
124        assert isinstance(info, list) or info is None
125        augment_model_with_bundled_inputs(cloned_module, inputs, _receive_inflate_expr, info)
126    return cloned_module
127
128def augment_model_with_bundled_inputs(
129        model: torch.jit.ScriptModule,
130        inputs: Optional[Sequence[Tuple[Any, ...]]] = None,
131        _receive_inflate_expr: Optional[List[str]] = None,  # For debugging.
132        info: Optional[List[str]] = None,  # Optional argument to provide info about forward or its inputs
133        skip_size_check=False,
134) -> None:
135    """Add bundled sample inputs to a model for the forward function.
136
137    Models with bundled inputs can be invoked in a uniform manner by
138    benchmarking and code coverage tools.
139
140    Augmented models will support the following methods:
141
142        `get_all_bundled_inputs() -> List[Tuple[Any, ...]]`
143            Returns a list of tuples suitable for passing to the model like
144            `for inp in model.get_all_bundled_inputs(): model(*inp)`
145
146        `get_num_bundled_inputs() -> int`
147            Equivalent to `len(model.get_all_bundled_inputs())`,
148            but slightly easier to call from C++.
149
150        `get_bundled_inputs_functions_and_info() -> Dict[str, Dict[str: List[str]]]`
151            Returns a dictionary mapping function names to a metadata dictionary.
152            This nested dictionary maps preset strings like:
153                'get_inputs_function_name' -> the name of a function attribute in this model that can be
154                    run to get back a list of inputs corresponding to that function.
155                'info' -> the user provided extra information about the bundled inputs
156
157    Inputs can be specified in one of two ways:
158
159      - The model can define `_generate_bundled_inputs_for_forward`.
160        If the user chooses this method inputs should be None
161
162      - `inputs` is a list of inputs of form List[Tuple[Any, ...]]. A list of tuples where the elements
163        of each tuple are the args that make up one input.
164    """
165    if not isinstance(model, torch.jit.ScriptModule):
166        raise Exception("Only ScriptModule is supported.")  # noqa: TRY002
167
168    forward: Callable = model.forward
169
170    # Sometimes forward won't have a name attached so just in case
171    if not hasattr(forward, "__name__"):
172        forward.__name__ = 'forward'
173    augment_many_model_functions_with_bundled_inputs(
174        model,
175        inputs={forward : inputs},
176        _receive_inflate_expr=_receive_inflate_expr,
177        info={forward : info} if info else None,
178        skip_size_check=skip_size_check,
179    )
180
181
182def augment_many_model_functions_with_bundled_inputs(
183        model: torch.jit.ScriptModule,
184        inputs: Dict[Callable, Optional[Sequence[Tuple[Any, ...]]]],
185        _receive_inflate_expr: Optional[List[str]] = None,  # For debugging.
186        info: Optional[Dict[Callable, List[str]]] = None,  # Optional argument to provide info about the function or its inputs
187        skip_size_check=False,
188) -> None:
189    """Add bundled sample inputs to a model for an arbitrary list of public functions.
190
191    Models with bundled inputs can be invoked in a uniform manner by
192    benchmarking and code coverage tools.
193
194    Augmented models will support the following methods:
195
196        `get_all_bundled_inputs_for_<function_name>() -> List[Tuple[Any, ...]]`
197            Returns a list of tuples suitable for passing to the model like
198            `for inp in model.get_all_bundled_inputs_for_foo(): model.foo(*inp)`
199
200        `get_bundled_inputs_functions_and_info() -> Dict[str, Dict[str: List[str]]]`
201            Returns a dictionary mapping function names to a metadata dictionary.
202            This nested dictionary maps preset strings like:
203                'get_inputs_function_name' -> the name of a function attribute in this model that can be
204                    run to get back a list of inputs corresponding to that function.
205                'info' -> the user provided extra information about the bundled inputs
206
207    If forward has bundled inputs then these following functions are also defined:
208
209        `get_all_bundled_inputs() -> List[Tuple[Any, ...]]`
210            Returns a list of tuples suitable for passing to the model like
211            `for inp in model.get_all_bundled_inputs(): model(*inp)`
212
213        `get_num_bundled_inputs() -> int`
214            Equivalent to `len(model.get_all_bundled_inputs())`,
215            but slightly easier to call from C++.
216
217    Inputs can be specified in one of two ways:
218
219      - The model can define `_generate_bundled_inputs_for_<function_name>`.
220        If the user chooses this method inputs[<function>] should map to None
221
222      - The `inputs` argument to this function can be a dictionary mapping functions to a
223        list of inputs, of the same form that will be returned by get_all_bundled_inputs_for_<function_name>.
224        The type of the inputs is List[Tuple[Any, ...]]. The outer list corresponds with a
225        list of inputs, the inner tuple is the list of args that together make up one input.
226        For inputs of functions that take one arg, this will be a tuple of length one. The Any, ...
227        is the actual data that makes up the args, e.g. a tensor.
228
229    Info is an optional parameter that maps functions to a list of strings providing extra information about that
230    function's bundled inputs. This could be descriptions, expected outputs, etc.
231        - Ex: info={model.forward : ['man eating icecream', 'an airplane', 'a dog']}
232
233    This function will attempt to optimize arguments so that (e.g.)
234    arguments like `torch.zeros(1000)` will be represented compactly.
235    Only top-level arguments will be optimized.
236    Tensors in lists or tuples will not.
237    """
238    if not isinstance(model, torch.jit.ScriptModule):
239        raise Exception("Only ScriptModule is supported.")  # noqa: TRY002
240
241    if not inputs:
242        raise Exception("Please provide inputs for at least 1 function")  # noqa: TRY002
243
244    if hasattr(model, "get_all_bundled_inputs") or hasattr(model, "get_bundled_inputs_functions_and_info"):
245        raise Exception(  # noqa: TRY002
246            "Models can only be augmented with bundled inputs once. "
247            "This Model seems to have already been augmented with "
248            "bundled inputs. Please start afresh with one that "
249            "doesn't have bundled inputs.",
250        )
251
252    get_bundled_inputs_functions_and_info_template = ""
253
254    for function, input_list in inputs.items():
255        if hasattr(function, "__name__"):
256            function_name = function.__name__
257        else:
258            if hasattr(function, "name"):
259                function_name = function.name  # type: ignore[attr-defined]
260            else:
261                raise Exception(  # noqa: TRY002
262                    'At least one of your functions has no attribute name please ensure all have one. m.foo.name = "foo"')
263
264
265        if input_list is not None and not isinstance(input_list, Sequence):
266            raise TypeError(f"Error inputs for function {function_name} is not a Sequence")
267
268        function_arg_types = [arg.type for arg in function.schema.arguments[1:]]  # type: ignore[attr-defined]
269        deflated_inputs_type: ListType = ListType(TupleType(function_arg_types))
270        model._c._register_attribute(f"_bundled_inputs_deflated_{function_name}", deflated_inputs_type, [])
271
272        if hasattr(model, "_generate_bundled_inputs_for_" + function_name):
273            if input_list is not None:
274                raise Exception(  # noqa: TRY002
275                    f"inputs[{function_name}] is not None, but _generate_bundled_inputs_for_{function_name} is already defined"
276                )
277            # Model author already defined _generate_bundled_inputs_for_<function_name>.
278        elif input_list is None or len(input_list) == 0:
279            raise Exception(  # noqa: TRY002
280                f"inputs for {function_name} must be specified if "
281                f"_generate_bundled_inputs_for_{function_name} is not already defined"
282            )
283        else:
284            # Iterate over the inputs and args in each input.
285            # Accumulate `deflated_inputs` as (possibly) compressed values
286            # and `parts` to be joined into the expression that unpacks them.
287            deflated_inputs = []
288            parts = []
289            for inp_idx, args in enumerate(input_list):
290                if not isinstance(args, Tuple) and not isinstance(args, List):  # type: ignore[arg-type]
291                    raise TypeError(
292                        f"Error bundled input for function {function_name} idx: {inp_idx} is not a Tuple or a List"
293                    )
294                deflated_args = []
295                parts.append("(")
296                for arg_idx, arg in enumerate(args):
297                    inflate_helper_fn_name = _get_inflate_helper_fn_name(arg_idx, inp_idx, function_name)
298                    deflated, inflater, helper_definition = _inflate_expr(
299                        arg,
300                        f"deflated[{inp_idx}][{arg_idx}]",
301                        inflate_helper_fn_name,
302                        skip_size_check=skip_size_check,
303                    )
304                    deflated_args.append(deflated)
305                    parts.append(f"    {inflater},")
306                    if helper_definition:
307                        model.define(textwrap.dedent(helper_definition))
308                deflated_inputs.append(tuple(deflated_args))
309                parts.append("),")
310            parts.append("")
311            expr = "\n".join(parts)
312
313            # Back-channel return this expr for debugging.
314            if _receive_inflate_expr is not None:
315                _receive_inflate_expr.append(expr)
316            setattr(model, f"_bundled_inputs_deflated_{function_name}", deflated_inputs)
317            definition = textwrap.dedent("""
318                def _generate_bundled_inputs_for_{name}(self):
319                    deflated = self._bundled_inputs_deflated_{name}
320                    return [
321                {expr}
322                    ]
323                """).format(expr=expr, name=function_name)
324            model.define(definition)
325
326        # Define get_all_bundled_inputs_for_<function_name> that caches the generated inputs.
327        model.define(textwrap.dedent("""
328            def get_all_bundled_inputs_for_{name}(self):
329                all_inputs = self._generate_bundled_inputs_for_{name}()
330                assert all_inputs is not None
331                return all_inputs
332            """).format(name=function_name))
333
334        # Add to the high level helper methods
335        inputs_info = repr(info[function]) if info and function in info else '[]'
336        get_bundled_inputs_functions_and_info_template += f"""
337            temp_dict : Dict[str,List[str]] = {{}}
338            info: List[str] = {inputs_info}
339
340            temp_dict['info'] = info
341            temp_dict['get_inputs_function_name'] = ['get_all_bundled_inputs_for_{function_name}']
342            all_inputs['{function_name}'] = temp_dict
343            """
344
345        # To ensure backwards compatibility and a streamlined api for forward these wrappers are provided
346        if function_name == 'forward':
347            model.define(textwrap.dedent("""
348                def get_all_bundled_inputs(self):
349                    return self.get_all_bundled_inputs_for_forward()
350                """))
351            model.define(textwrap.dedent("""
352                def get_num_bundled_inputs(self):
353                    return len(self.get_all_bundled_inputs_for_forward())
354                """))
355
356    # Define some high level helper methods that act on all bundled inputs
357    model.define(textwrap.dedent(f"""
358        def get_bundled_inputs_functions_and_info(self):
359            all_inputs : Dict[str, Dict[str,List[str]]] = {{}}
360            {get_bundled_inputs_functions_and_info_template}
361            return all_inputs
362        """))
363
364def _inflate_expr(
365    arg: T, ref: str, inflate_helper_fn_name: str, skip_size_check: bool = False
366) -> Tuple[Union[T, torch.Tensor], str, Optional[str]]:
367    # Allow custom inflation expressions any object.
368    # For example, calling custom image-decoding ops.
369    # Or just use "{}" as the format string to ignore size limits.
370    if isinstance(arg, InflatableArg):
371        if arg.fmt_fn:
372            if arg.fmt not in ["{}", ""]:
373                raise Exception(  # noqa: TRY002
374                    f"Bundled input argument at position '{ref}' has "
375                    f"both arg.fmt_fn => \n{arg.fmt_fn} "
376                    f"\n and arg.fmt  => {arg.fmt}. "
377                    "Please choose `arg.fmt` if the deflater is straightforward or "
378                    "`arg.fmt_fn` if you need a function."
379                )
380
381            helper_definition = arg.fmt_fn.format(inflate_helper_fn_name)
382            expr = f"self.{inflate_helper_fn_name}({ref})"
383
384            return arg.value, expr, helper_definition
385        else:
386            return arg.value, arg.fmt.format(ref), None
387
388    if isinstance(arg, torch.Tensor):
389        # Small-storage tensors can just be saved directly.
390        if arg._typed_storage().size() <= MAX_RAW_TENSOR_SIZE or skip_size_check:
391            return arg, ref, None
392        # Small contiguous tensors can be cloned to have small storage.
393        # TODO: Should we do this even for non-contiguous tensors?
394        if arg.is_contiguous() and arg.numel() <= MAX_RAW_TENSOR_SIZE:
395            return arg.clone(), ref, None
396        # Example inputs commonly come from torch.zeros, torch.ones, or torch.full.
397        # These can be represented compactly.
398        for fmt in [torch.contiguous_format, torch.channels_last]:
399            if arg.is_contiguous(memory_format=fmt) and (arg == arg.flatten()[0]).all().item():
400                return (arg.flatten()[0].clone().expand(*arg.size()),
401                        f"{ref}.contiguous(memory_format={fmt})", None)
402        # Prevent big tensors from being bundled by default.
403        # TODO: Provide more useful diagnostics.
404        raise Exception(  # noqa: TRY002
405            f"Bundled input argument at position '{ref}' is "
406            f"a tensor with storage size {arg._typed_storage().size()}. "
407            f"You probably don't want to bundle this as an input. "
408        )
409    else:
410        return arg, ref, None
411
412def _get_bundled_inputs_attributes_and_methods(script_module: torch.jit.ScriptModule) -> Tuple[List[str], List[str]]:
413    methods: List[str] = []
414    attributes: List[str] = []
415
416    # Has bundled inputs for forward
417    if hasattr(script_module, 'get_all_bundled_inputs'):
418        methods.append('get_all_bundled_inputs')
419        methods.append('get_num_bundled_inputs')
420        methods.append('run_on_bundled_input')
421
422    if hasattr(script_module, 'get_bundled_inputs_functions_and_info'):
423        methods.append('get_bundled_inputs_functions_and_info')
424        all_info = script_module.get_bundled_inputs_functions_and_info()
425        for function_name in all_info:
426            methods.append("get_all_bundled_inputs_for_" + function_name)
427            methods.append("_generate_bundled_inputs_for_" + function_name)
428            attributes.append("_bundled_inputs_deflated_" + function_name)
429
430            bundled_inputs_fn = getattr(
431                script_module,
432                f"get_all_bundled_inputs_for_{function_name}"
433            )
434            num_bundled_inputs: int = len(bundled_inputs_fn())
435
436            # Check inflate helper functions for each function, argument and bundled input
437            func = getattr(script_module, function_name)
438            for arg_idx in range(len(func.schema.arguments) - 1):
439                for input_idx in range(num_bundled_inputs):
440                    helper_fn_name = _get_inflate_helper_fn_name(
441                        arg_idx=arg_idx,
442                        input_idx=input_idx,
443                        function_name=function_name
444                    )
445                    # if the arg has an InflatableArg with fmt_fn, add the helper function name
446                    if hasattr(script_module, helper_fn_name):
447                        methods.append(helper_fn_name)
448
449    return (methods, attributes)
450
451
452def _get_inflate_helper_fn_name(
453    arg_idx: int,
454    input_idx: int,
455    function_name: str,
456) -> str:
457    return f"_inflate_helper_for_{function_name}_input_{input_idx}_arg_{arg_idx}"
458
459
460
461def bundle_randn(*size, dtype=None):
462    """Generate a tensor that will be inflated with torch.randn."""
463    stub = torch.zeros(1, dtype=dtype).expand(*size)
464    return InflatableArg(value=stub, fmt="torch.randn_like({})")
465
466
467def bundle_large_tensor(t):
468    """Wrap a tensor to allow bundling regardless of size."""
469    return InflatableArg(value=t, fmt="{}")
470