xref: /aosp_15_r20/external/executorch/devtools/bundled_program/config.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9from dataclasses import dataclass
10from typing import get_args, List, Optional, Sequence, Union
11
12import torch
13
14from torch.utils._pytree import tree_flatten
15
16from typing_extensions import TypeAlias
17
18"""
19The data types currently supported for element to be bundled. It should be
20consistent with the types in bundled_program.schema.Value.
21"""
22ConfigValue: TypeAlias = Union[
23    torch.Tensor,
24    int,
25    bool,
26    float,
27]
28
29"""
30The data type of the input for method single execution.
31"""
32MethodInputType: TypeAlias = Sequence[ConfigValue]
33
34"""
35The data type of the output for method single execution.
36"""
37MethodOutputType: TypeAlias = Sequence[torch.Tensor]
38
39"""
40All supported types for input/expected output of MethodTestCase.
41
42Namedtuple is also supported and listed implicitly since it is a subclass of tuple.
43"""
44
45# pyre-ignore
46DataContainer: TypeAlias = Union[list, tuple, dict]
47
48
49class MethodTestCase:
50    """Test case with inputs and expected outputs
51    The expected_outputs are optional and only required if the user wants to verify model outputs after execution.
52    """
53
54    def __init__(
55        self,
56        inputs: MethodInputType,
57        expected_outputs: Optional[MethodOutputType] = None,
58    ) -> None:
59        """Single test case for verifying specific method
60
61        Args:
62            inputs: All inputs required by eager_model with specific inference method for one-time execution.
63
64                    It is worth mentioning that, although both bundled program and ET runtime apis support setting input
65                    other than `torch.tensor` type, only the input in `torch.tensor` type will be actually updated in
66                    the method, and the rest of the inputs will just do a sanity check if they match the default value in method.
67
68            expected_outputs: Expected output of given input for verification. It can be None if user only wants to use the test case for profiling.
69
70        Returns:
71            self
72        """
73        # TODO(gasoonjia): Update type check logic.
74        # pyre-ignore [6]: Misalign data type for between MethodTestCase attribute and sanity check.
75        self.inputs: List[ConfigValue] = self._flatten_and_sanity_check(inputs)
76        self.expected_outputs: List[ConfigValue] = []
77        if expected_outputs is not None:
78            # pyre-ignore [6]: Misalign data type for between MethodTestCase attribute and sanity check.
79            self.expected_outputs = self._flatten_and_sanity_check(expected_outputs)
80
81    def _flatten_and_sanity_check(
82        self, unflatten_data: DataContainer
83    ) -> List[ConfigValue]:
84        """Flat the given data and check its legality
85
86        Args:
87            unflatten_data: Data needs to be flatten.
88
89        Returns:
90            flatten_data: Flatten data with legal type.
91        """
92
93        flatten_data, _ = tree_flatten(unflatten_data)
94
95        for data in flatten_data:
96            assert isinstance(
97                data,
98                get_args(ConfigValue),
99            ), "The type of input {} with type {} is not supported.\n".format(
100                data, type(data)
101            )
102            assert not isinstance(
103                data,
104                type(None),
105            ), "The input {} should not be in null type.\n".format(data)
106
107        return flatten_data
108
109
110@dataclass
111class MethodTestSuite:
112    """All test info related to verify method
113
114    Attributes:
115        method_name: Name of the method to be verified.
116        test_cases: All test cases for verifying the method.
117    """
118
119    method_name: str
120    test_cases: Sequence[MethodTestCase]
121