xref: /aosp_15_r20/external/executorch/devtools/bundled_program/test/test_config.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9import unittest
10from typing import get_args, List, Union
11
12import torch
13from executorch.devtools.bundled_program.config import DataContainer
14
15from executorch.devtools.bundled_program.util.test_util import (
16    get_random_test_suites,
17    get_random_test_suites_with_eager_model,
18    SampleModel,
19)
20from executorch.extension.pytree import tree_flatten
21
22
23class TestConfig(unittest.TestCase):
24    def assertTensorEqual(self, t1: torch.Tensor, t2: torch.Tensor) -> None:
25        self.assertTrue((t1 == t2).all())
26
27    def assertIOListEqual(
28        self,
29        tl1: List[Union[bool, float, int, torch.Tensor]],
30        tl2: List[Union[bool, float, int, torch.Tensor]],
31    ) -> None:
32        self.assertEqual(len(tl1), len(tl2))
33        for t1, t2 in zip(tl1, tl2):
34            if isinstance(t1, torch.Tensor):
35                assert isinstance(t2, torch.Tensor)
36                self.assertTensorEqual(t1, t2)
37            else:
38                self.assertTrue(t1 == t2)
39
40    def test_create_test_suites(self) -> None:
41        n_sets_per_plan_test = 10
42        n_method_test_suites = 5
43
44        (
45            rand_method_names,
46            rand_inputs,
47            rand_expected_outpus,
48            method_test_suites,
49        ) = get_random_test_suites(
50            n_model_inputs=2,
51            model_input_sizes=[[2, 2], [2, 2]],
52            n_model_outputs=1,
53            model_output_sizes=[[2, 2]],
54            dtype=torch.int32,
55            n_sets_per_plan_test=n_sets_per_plan_test,
56            n_method_test_suites=n_method_test_suites,
57        )
58
59        self.assertEqual(len(method_test_suites), n_method_test_suites)
60
61        # Compare to see if bundled execution plan test match expectations.
62        for method_test_suite_idx in range(n_method_test_suites):
63            self.assertEqual(
64                method_test_suites[method_test_suite_idx].method_name,
65                rand_method_names[method_test_suite_idx],
66            )
67            for testset_idx in range(n_sets_per_plan_test):
68                self.assertIOListEqual(
69                    # pyre-ignore [6]: expected `List[Union[bool, float, int, Tensor]]` but got `Sequence[Union[bool, float, int, Tensor]]
70                    rand_inputs[method_test_suite_idx][testset_idx],
71                    method_test_suites[method_test_suite_idx]
72                    .test_cases[testset_idx]
73                    .inputs,
74                )
75                self.assertIOListEqual(
76                    # pyre-ignore [6]: expected `List[Union[bool, float, int, Tensor]]` but got `Sequence[Union[bool, float, int, Tensor]]
77                    rand_expected_outpus[method_test_suite_idx][testset_idx],
78                    method_test_suites[method_test_suite_idx]
79                    .test_cases[testset_idx]
80                    .expected_outputs,
81                )
82
83    def test_create_test_suites_from_eager_model(self) -> None:
84        n_sets_per_plan_test = 10
85        eager_model = SampleModel()
86        method_names: List[str] = eager_model.method_names
87
88        rand_inputs, method_test_suites = get_random_test_suites_with_eager_model(
89            eager_model=eager_model,
90            method_names=method_names,
91            n_model_inputs=2,
92            model_input_sizes=[[2, 2], [2, 2]],
93            dtype=torch.int32,
94            n_sets_per_plan_test=n_sets_per_plan_test,
95        )
96
97        self.assertEqual(len(method_test_suites), len(method_names))
98
99        # Compare to see if bundled testcases match expectations.
100        for method_test_suite_idx in range(len(method_names)):
101            self.assertEqual(
102                method_test_suites[method_test_suite_idx].method_name,
103                method_names[method_test_suite_idx],
104            )
105            for testset_idx in range(n_sets_per_plan_test):
106                ri = rand_inputs[method_test_suite_idx][testset_idx]
107                self.assertIOListEqual(
108                    # pyre-ignore [6]: expected `List[Union[bool, float, int, Tensor]]` but got `Sequence[Union[bool, float, int, Tensor]]
109                    ri,
110                    method_test_suites[method_test_suite_idx]
111                    .test_cases[testset_idx]
112                    .inputs,
113                )
114
115                model_outputs = getattr(
116                    eager_model, method_names[method_test_suite_idx]
117                )(*ri)
118                if isinstance(model_outputs, get_args(DataContainer)):
119                    # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
120                    flatten_eager_model_outputs = tree_flatten(model_outputs)
121                else:
122                    flatten_eager_model_outputs = [
123                        model_outputs,
124                    ]
125
126                self.assertIOListEqual(
127                    flatten_eager_model_outputs,
128                    method_test_suites[method_test_suite_idx]
129                    .test_cases[testset_idx]
130                    .expected_outputs,
131                )
132