xref: /aosp_15_r20/external/executorch/devtools/bundled_program/config.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerfrom dataclasses import dataclass
10*523fa7a6SAndroid Build Coastguard Workerfrom typing import get_args, List, Optional, Sequence, Union
11*523fa7a6SAndroid Build Coastguard Worker
12*523fa7a6SAndroid Build Coastguard Workerimport torch
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_flatten
15*523fa7a6SAndroid Build Coastguard Worker
16*523fa7a6SAndroid Build Coastguard Workerfrom typing_extensions import TypeAlias
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Worker"""
19*523fa7a6SAndroid Build Coastguard WorkerThe data types currently supported for element to be bundled. It should be
20*523fa7a6SAndroid Build Coastguard Workerconsistent with the types in bundled_program.schema.Value.
21*523fa7a6SAndroid Build Coastguard Worker"""
22*523fa7a6SAndroid Build Coastguard WorkerConfigValue: TypeAlias = Union[
23*523fa7a6SAndroid Build Coastguard Worker    torch.Tensor,
24*523fa7a6SAndroid Build Coastguard Worker    int,
25*523fa7a6SAndroid Build Coastguard Worker    bool,
26*523fa7a6SAndroid Build Coastguard Worker    float,
27*523fa7a6SAndroid Build Coastguard Worker]
28*523fa7a6SAndroid Build Coastguard Worker
29*523fa7a6SAndroid Build Coastguard Worker"""
30*523fa7a6SAndroid Build Coastguard WorkerThe data type of the input for method single execution.
31*523fa7a6SAndroid Build Coastguard Worker"""
32*523fa7a6SAndroid Build Coastguard WorkerMethodInputType: TypeAlias = Sequence[ConfigValue]
33*523fa7a6SAndroid Build Coastguard Worker
34*523fa7a6SAndroid Build Coastguard Worker"""
35*523fa7a6SAndroid Build Coastguard WorkerThe data type of the output for method single execution.
36*523fa7a6SAndroid Build Coastguard Worker"""
37*523fa7a6SAndroid Build Coastguard WorkerMethodOutputType: TypeAlias = Sequence[torch.Tensor]
38*523fa7a6SAndroid Build Coastguard Worker
39*523fa7a6SAndroid Build Coastguard Worker"""
40*523fa7a6SAndroid Build Coastguard WorkerAll supported types for input/expected output of MethodTestCase.
41*523fa7a6SAndroid Build Coastguard Worker
42*523fa7a6SAndroid Build Coastguard WorkerNamedtuple is also supported and listed implicitly since it is a subclass of tuple.
43*523fa7a6SAndroid Build Coastguard Worker"""
44*523fa7a6SAndroid Build Coastguard Worker
45*523fa7a6SAndroid Build Coastguard Worker# pyre-ignore
46*523fa7a6SAndroid Build Coastguard WorkerDataContainer: TypeAlias = Union[list, tuple, dict]
47*523fa7a6SAndroid Build Coastguard Worker
48*523fa7a6SAndroid Build Coastguard Worker
49*523fa7a6SAndroid Build Coastguard Workerclass MethodTestCase:
50*523fa7a6SAndroid Build Coastguard Worker    """Test case with inputs and expected outputs
51*523fa7a6SAndroid Build Coastguard Worker    The expected_outputs are optional and only required if the user wants to verify model outputs after execution.
52*523fa7a6SAndroid Build Coastguard Worker    """
53*523fa7a6SAndroid Build Coastguard Worker
54*523fa7a6SAndroid Build Coastguard Worker    def __init__(
55*523fa7a6SAndroid Build Coastguard Worker        self,
56*523fa7a6SAndroid Build Coastguard Worker        inputs: MethodInputType,
57*523fa7a6SAndroid Build Coastguard Worker        expected_outputs: Optional[MethodOutputType] = None,
58*523fa7a6SAndroid Build Coastguard Worker    ) -> None:
59*523fa7a6SAndroid Build Coastguard Worker        """Single test case for verifying specific method
60*523fa7a6SAndroid Build Coastguard Worker
61*523fa7a6SAndroid Build Coastguard Worker        Args:
62*523fa7a6SAndroid Build Coastguard Worker            inputs: All inputs required by eager_model with specific inference method for one-time execution.
63*523fa7a6SAndroid Build Coastguard Worker
64*523fa7a6SAndroid Build Coastguard Worker                    It is worth mentioning that, although both bundled program and ET runtime apis support setting input
65*523fa7a6SAndroid Build Coastguard Worker                    other than `torch.tensor` type, only the input in `torch.tensor` type will be actually updated in
66*523fa7a6SAndroid Build Coastguard Worker                    the method, and the rest of the inputs will just do a sanity check if they match the default value in method.
67*523fa7a6SAndroid Build Coastguard Worker
68*523fa7a6SAndroid Build Coastguard Worker            expected_outputs: Expected output of given input for verification. It can be None if user only wants to use the test case for profiling.
69*523fa7a6SAndroid Build Coastguard Worker
70*523fa7a6SAndroid Build Coastguard Worker        Returns:
71*523fa7a6SAndroid Build Coastguard Worker            self
72*523fa7a6SAndroid Build Coastguard Worker        """
73*523fa7a6SAndroid Build Coastguard Worker        # TODO(gasoonjia): Update type check logic.
74*523fa7a6SAndroid Build Coastguard Worker        # pyre-ignore [6]: Misalign data type for between MethodTestCase attribute and sanity check.
75*523fa7a6SAndroid Build Coastguard Worker        self.inputs: List[ConfigValue] = self._flatten_and_sanity_check(inputs)
76*523fa7a6SAndroid Build Coastguard Worker        self.expected_outputs: List[ConfigValue] = []
77*523fa7a6SAndroid Build Coastguard Worker        if expected_outputs is not None:
78*523fa7a6SAndroid Build Coastguard Worker            # pyre-ignore [6]: Misalign data type for between MethodTestCase attribute and sanity check.
79*523fa7a6SAndroid Build Coastguard Worker            self.expected_outputs = self._flatten_and_sanity_check(expected_outputs)
80*523fa7a6SAndroid Build Coastguard Worker
81*523fa7a6SAndroid Build Coastguard Worker    def _flatten_and_sanity_check(
82*523fa7a6SAndroid Build Coastguard Worker        self, unflatten_data: DataContainer
83*523fa7a6SAndroid Build Coastguard Worker    ) -> List[ConfigValue]:
84*523fa7a6SAndroid Build Coastguard Worker        """Flat the given data and check its legality
85*523fa7a6SAndroid Build Coastguard Worker
86*523fa7a6SAndroid Build Coastguard Worker        Args:
87*523fa7a6SAndroid Build Coastguard Worker            unflatten_data: Data needs to be flatten.
88*523fa7a6SAndroid Build Coastguard Worker
89*523fa7a6SAndroid Build Coastguard Worker        Returns:
90*523fa7a6SAndroid Build Coastguard Worker            flatten_data: Flatten data with legal type.
91*523fa7a6SAndroid Build Coastguard Worker        """
92*523fa7a6SAndroid Build Coastguard Worker
93*523fa7a6SAndroid Build Coastguard Worker        flatten_data, _ = tree_flatten(unflatten_data)
94*523fa7a6SAndroid Build Coastguard Worker
95*523fa7a6SAndroid Build Coastguard Worker        for data in flatten_data:
96*523fa7a6SAndroid Build Coastguard Worker            assert isinstance(
97*523fa7a6SAndroid Build Coastguard Worker                data,
98*523fa7a6SAndroid Build Coastguard Worker                get_args(ConfigValue),
99*523fa7a6SAndroid Build Coastguard Worker            ), "The type of input {} with type {} is not supported.\n".format(
100*523fa7a6SAndroid Build Coastguard Worker                data, type(data)
101*523fa7a6SAndroid Build Coastguard Worker            )
102*523fa7a6SAndroid Build Coastguard Worker            assert not isinstance(
103*523fa7a6SAndroid Build Coastguard Worker                data,
104*523fa7a6SAndroid Build Coastguard Worker                type(None),
105*523fa7a6SAndroid Build Coastguard Worker            ), "The input {} should not be in null type.\n".format(data)
106*523fa7a6SAndroid Build Coastguard Worker
107*523fa7a6SAndroid Build Coastguard Worker        return flatten_data
108*523fa7a6SAndroid Build Coastguard Worker
109*523fa7a6SAndroid Build Coastguard Worker
110*523fa7a6SAndroid Build Coastguard Worker@dataclass
111*523fa7a6SAndroid Build Coastguard Workerclass MethodTestSuite:
112*523fa7a6SAndroid Build Coastguard Worker    """All test info related to verify method
113*523fa7a6SAndroid Build Coastguard Worker
114*523fa7a6SAndroid Build Coastguard Worker    Attributes:
115*523fa7a6SAndroid Build Coastguard Worker        method_name: Name of the method to be verified.
116*523fa7a6SAndroid Build Coastguard Worker        test_cases: All test cases for verifying the method.
117*523fa7a6SAndroid Build Coastguard Worker    """
118*523fa7a6SAndroid Build Coastguard Worker
119*523fa7a6SAndroid Build Coastguard Worker    method_name: str
120*523fa7a6SAndroid Build Coastguard Worker    test_cases: Sequence[MethodTestCase]
121