1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Worker# pyre-strict 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerfrom dataclasses import dataclass 10*523fa7a6SAndroid Build Coastguard Workerfrom typing import get_args, List, Optional, Sequence, Union 11*523fa7a6SAndroid Build Coastguard Worker 12*523fa7a6SAndroid Build Coastguard Workerimport torch 13*523fa7a6SAndroid Build Coastguard Worker 14*523fa7a6SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_flatten 15*523fa7a6SAndroid Build Coastguard Worker 16*523fa7a6SAndroid Build Coastguard Workerfrom typing_extensions import TypeAlias 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard Worker""" 19*523fa7a6SAndroid Build Coastguard WorkerThe data types currently supported for element to be bundled. It should be 20*523fa7a6SAndroid Build Coastguard Workerconsistent with the types in bundled_program.schema.Value. 21*523fa7a6SAndroid Build Coastguard Worker""" 22*523fa7a6SAndroid Build Coastguard WorkerConfigValue: TypeAlias = Union[ 23*523fa7a6SAndroid Build Coastguard Worker torch.Tensor, 24*523fa7a6SAndroid Build Coastguard Worker int, 25*523fa7a6SAndroid Build Coastguard Worker bool, 26*523fa7a6SAndroid Build Coastguard Worker float, 27*523fa7a6SAndroid Build Coastguard Worker] 28*523fa7a6SAndroid Build Coastguard Worker 29*523fa7a6SAndroid Build Coastguard Worker""" 30*523fa7a6SAndroid Build Coastguard WorkerThe data type of the input for method single execution. 31*523fa7a6SAndroid Build Coastguard Worker""" 32*523fa7a6SAndroid Build Coastguard WorkerMethodInputType: TypeAlias = Sequence[ConfigValue] 33*523fa7a6SAndroid Build Coastguard Worker 34*523fa7a6SAndroid Build Coastguard Worker""" 35*523fa7a6SAndroid Build Coastguard WorkerThe data type of the output for method single execution. 36*523fa7a6SAndroid Build Coastguard Worker""" 37*523fa7a6SAndroid Build Coastguard WorkerMethodOutputType: TypeAlias = Sequence[torch.Tensor] 38*523fa7a6SAndroid Build Coastguard Worker 39*523fa7a6SAndroid Build Coastguard Worker""" 40*523fa7a6SAndroid Build Coastguard WorkerAll supported types for input/expected output of MethodTestCase. 41*523fa7a6SAndroid Build Coastguard Worker 42*523fa7a6SAndroid Build Coastguard WorkerNamedtuple is also supported and listed implicitly since it is a subclass of tuple. 43*523fa7a6SAndroid Build Coastguard Worker""" 44*523fa7a6SAndroid Build Coastguard Worker 45*523fa7a6SAndroid Build Coastguard Worker# pyre-ignore 46*523fa7a6SAndroid Build Coastguard WorkerDataContainer: TypeAlias = Union[list, tuple, dict] 47*523fa7a6SAndroid Build Coastguard Worker 48*523fa7a6SAndroid Build Coastguard Worker 49*523fa7a6SAndroid Build Coastguard Workerclass MethodTestCase: 50*523fa7a6SAndroid Build Coastguard Worker """Test case with inputs and expected outputs 51*523fa7a6SAndroid Build Coastguard Worker The expected_outputs are optional and only required if the user wants to verify model outputs after execution. 52*523fa7a6SAndroid Build Coastguard Worker """ 53*523fa7a6SAndroid Build Coastguard Worker 54*523fa7a6SAndroid Build Coastguard Worker def __init__( 55*523fa7a6SAndroid Build Coastguard Worker self, 56*523fa7a6SAndroid Build Coastguard Worker inputs: MethodInputType, 57*523fa7a6SAndroid Build Coastguard Worker expected_outputs: Optional[MethodOutputType] = None, 58*523fa7a6SAndroid Build Coastguard Worker ) -> None: 59*523fa7a6SAndroid Build Coastguard Worker """Single test case for verifying specific method 60*523fa7a6SAndroid Build Coastguard Worker 61*523fa7a6SAndroid Build Coastguard Worker Args: 62*523fa7a6SAndroid Build Coastguard Worker inputs: All inputs required by eager_model with specific inference method for one-time execution. 63*523fa7a6SAndroid Build Coastguard Worker 64*523fa7a6SAndroid Build Coastguard Worker It is worth mentioning that, although both bundled program and ET runtime apis support setting input 65*523fa7a6SAndroid Build Coastguard Worker other than `torch.tensor` type, only the input in `torch.tensor` type will be actually updated in 66*523fa7a6SAndroid Build Coastguard Worker the method, and the rest of the inputs will just do a sanity check if they match the default value in method. 67*523fa7a6SAndroid Build Coastguard Worker 68*523fa7a6SAndroid Build Coastguard Worker expected_outputs: Expected output of given input for verification. It can be None if user only wants to use the test case for profiling. 69*523fa7a6SAndroid Build Coastguard Worker 70*523fa7a6SAndroid Build Coastguard Worker Returns: 71*523fa7a6SAndroid Build Coastguard Worker self 72*523fa7a6SAndroid Build Coastguard Worker """ 73*523fa7a6SAndroid Build Coastguard Worker # TODO(gasoonjia): Update type check logic. 74*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore [6]: Misalign data type for between MethodTestCase attribute and sanity check. 75*523fa7a6SAndroid Build Coastguard Worker self.inputs: List[ConfigValue] = self._flatten_and_sanity_check(inputs) 76*523fa7a6SAndroid Build Coastguard Worker self.expected_outputs: List[ConfigValue] = [] 77*523fa7a6SAndroid Build Coastguard Worker if expected_outputs is not None: 78*523fa7a6SAndroid Build Coastguard Worker # pyre-ignore [6]: Misalign data type for between MethodTestCase attribute and sanity check. 79*523fa7a6SAndroid Build Coastguard Worker self.expected_outputs = self._flatten_and_sanity_check(expected_outputs) 80*523fa7a6SAndroid Build Coastguard Worker 81*523fa7a6SAndroid Build Coastguard Worker def _flatten_and_sanity_check( 82*523fa7a6SAndroid Build Coastguard Worker self, unflatten_data: DataContainer 83*523fa7a6SAndroid Build Coastguard Worker ) -> List[ConfigValue]: 84*523fa7a6SAndroid Build Coastguard Worker """Flat the given data and check its legality 85*523fa7a6SAndroid Build Coastguard Worker 86*523fa7a6SAndroid Build Coastguard Worker Args: 87*523fa7a6SAndroid Build Coastguard Worker unflatten_data: Data needs to be flatten. 88*523fa7a6SAndroid Build Coastguard Worker 89*523fa7a6SAndroid Build Coastguard Worker Returns: 90*523fa7a6SAndroid Build Coastguard Worker flatten_data: Flatten data with legal type. 91*523fa7a6SAndroid Build Coastguard Worker """ 92*523fa7a6SAndroid Build Coastguard Worker 93*523fa7a6SAndroid Build Coastguard Worker flatten_data, _ = tree_flatten(unflatten_data) 94*523fa7a6SAndroid Build Coastguard Worker 95*523fa7a6SAndroid Build Coastguard Worker for data in flatten_data: 96*523fa7a6SAndroid Build Coastguard Worker assert isinstance( 97*523fa7a6SAndroid Build Coastguard Worker data, 98*523fa7a6SAndroid Build Coastguard Worker get_args(ConfigValue), 99*523fa7a6SAndroid Build Coastguard Worker ), "The type of input {} with type {} is not supported.\n".format( 100*523fa7a6SAndroid Build Coastguard Worker data, type(data) 101*523fa7a6SAndroid Build Coastguard Worker ) 102*523fa7a6SAndroid Build Coastguard Worker assert not isinstance( 103*523fa7a6SAndroid Build Coastguard Worker data, 104*523fa7a6SAndroid Build Coastguard Worker type(None), 105*523fa7a6SAndroid Build Coastguard Worker ), "The input {} should not be in null type.\n".format(data) 106*523fa7a6SAndroid Build Coastguard Worker 107*523fa7a6SAndroid Build Coastguard Worker return flatten_data 108*523fa7a6SAndroid Build Coastguard Worker 109*523fa7a6SAndroid Build Coastguard Worker 110*523fa7a6SAndroid Build Coastguard Worker@dataclass 111*523fa7a6SAndroid Build Coastguard Workerclass MethodTestSuite: 112*523fa7a6SAndroid Build Coastguard Worker """All test info related to verify method 113*523fa7a6SAndroid Build Coastguard Worker 114*523fa7a6SAndroid Build Coastguard Worker Attributes: 115*523fa7a6SAndroid Build Coastguard Worker method_name: Name of the method to be verified. 116*523fa7a6SAndroid Build Coastguard Worker test_cases: All test cases for verifying the method. 117*523fa7a6SAndroid Build Coastguard Worker """ 118*523fa7a6SAndroid Build Coastguard Worker 119*523fa7a6SAndroid Build Coastguard Worker method_name: str 120*523fa7a6SAndroid Build Coastguard Worker test_cases: Sequence[MethodTestCase] 121