1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9import unittest 10from typing import get_args, List, Union 11 12import torch 13from executorch.devtools.bundled_program.config import DataContainer 14 15from executorch.devtools.bundled_program.util.test_util import ( 16 get_random_test_suites, 17 get_random_test_suites_with_eager_model, 18 SampleModel, 19) 20from executorch.extension.pytree import tree_flatten 21 22 23class TestConfig(unittest.TestCase): 24 def assertTensorEqual(self, t1: torch.Tensor, t2: torch.Tensor) -> None: 25 self.assertTrue((t1 == t2).all()) 26 27 def assertIOListEqual( 28 self, 29 tl1: List[Union[bool, float, int, torch.Tensor]], 30 tl2: List[Union[bool, float, int, torch.Tensor]], 31 ) -> None: 32 self.assertEqual(len(tl1), len(tl2)) 33 for t1, t2 in zip(tl1, tl2): 34 if isinstance(t1, torch.Tensor): 35 assert isinstance(t2, torch.Tensor) 36 self.assertTensorEqual(t1, t2) 37 else: 38 self.assertTrue(t1 == t2) 39 40 def test_create_test_suites(self) -> None: 41 n_sets_per_plan_test = 10 42 n_method_test_suites = 5 43 44 ( 45 rand_method_names, 46 rand_inputs, 47 rand_expected_outpus, 48 method_test_suites, 49 ) = get_random_test_suites( 50 n_model_inputs=2, 51 model_input_sizes=[[2, 2], [2, 2]], 52 n_model_outputs=1, 53 model_output_sizes=[[2, 2]], 54 dtype=torch.int32, 55 n_sets_per_plan_test=n_sets_per_plan_test, 56 n_method_test_suites=n_method_test_suites, 57 ) 58 59 self.assertEqual(len(method_test_suites), n_method_test_suites) 60 61 # Compare to see if bundled execution plan test match expectations. 62 for method_test_suite_idx in range(n_method_test_suites): 63 self.assertEqual( 64 method_test_suites[method_test_suite_idx].method_name, 65 rand_method_names[method_test_suite_idx], 66 ) 67 for testset_idx in range(n_sets_per_plan_test): 68 self.assertIOListEqual( 69 # pyre-ignore [6]: expected `List[Union[bool, float, int, Tensor]]` but got `Sequence[Union[bool, float, int, Tensor]] 70 rand_inputs[method_test_suite_idx][testset_idx], 71 method_test_suites[method_test_suite_idx] 72 .test_cases[testset_idx] 73 .inputs, 74 ) 75 self.assertIOListEqual( 76 # pyre-ignore [6]: expected `List[Union[bool, float, int, Tensor]]` but got `Sequence[Union[bool, float, int, Tensor]] 77 rand_expected_outpus[method_test_suite_idx][testset_idx], 78 method_test_suites[method_test_suite_idx] 79 .test_cases[testset_idx] 80 .expected_outputs, 81 ) 82 83 def test_create_test_suites_from_eager_model(self) -> None: 84 n_sets_per_plan_test = 10 85 eager_model = SampleModel() 86 method_names: List[str] = eager_model.method_names 87 88 rand_inputs, method_test_suites = get_random_test_suites_with_eager_model( 89 eager_model=eager_model, 90 method_names=method_names, 91 n_model_inputs=2, 92 model_input_sizes=[[2, 2], [2, 2]], 93 dtype=torch.int32, 94 n_sets_per_plan_test=n_sets_per_plan_test, 95 ) 96 97 self.assertEqual(len(method_test_suites), len(method_names)) 98 99 # Compare to see if bundled testcases match expectations. 100 for method_test_suite_idx in range(len(method_names)): 101 self.assertEqual( 102 method_test_suites[method_test_suite_idx].method_name, 103 method_names[method_test_suite_idx], 104 ) 105 for testset_idx in range(n_sets_per_plan_test): 106 ri = rand_inputs[method_test_suite_idx][testset_idx] 107 self.assertIOListEqual( 108 # pyre-ignore [6]: expected `List[Union[bool, float, int, Tensor]]` but got `Sequence[Union[bool, float, int, Tensor]] 109 ri, 110 method_test_suites[method_test_suite_idx] 111 .test_cases[testset_idx] 112 .inputs, 113 ) 114 115 model_outputs = getattr( 116 eager_model, method_names[method_test_suite_idx] 117 )(*ri) 118 if isinstance(model_outputs, get_args(DataContainer)): 119 # pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`. 120 flatten_eager_model_outputs = tree_flatten(model_outputs) 121 else: 122 flatten_eager_model_outputs = [ 123 model_outputs, 124 ] 125 126 self.assertIOListEqual( 127 flatten_eager_model_outputs, 128 method_test_suites[method_test_suite_idx] 129 .test_cases[testset_idx] 130 .expected_outputs, 131 ) 132