1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9from dataclasses import dataclass 10from typing import get_args, List, Optional, Sequence, Union 11 12import torch 13 14from torch.utils._pytree import tree_flatten 15 16from typing_extensions import TypeAlias 17 18""" 19The data types currently supported for element to be bundled. It should be 20consistent with the types in bundled_program.schema.Value. 21""" 22ConfigValue: TypeAlias = Union[ 23 torch.Tensor, 24 int, 25 bool, 26 float, 27] 28 29""" 30The data type of the input for method single execution. 31""" 32MethodInputType: TypeAlias = Sequence[ConfigValue] 33 34""" 35The data type of the output for method single execution. 36""" 37MethodOutputType: TypeAlias = Sequence[torch.Tensor] 38 39""" 40All supported types for input/expected output of MethodTestCase. 41 42Namedtuple is also supported and listed implicitly since it is a subclass of tuple. 43""" 44 45# pyre-ignore 46DataContainer: TypeAlias = Union[list, tuple, dict] 47 48 49class MethodTestCase: 50 """Test case with inputs and expected outputs 51 The expected_outputs are optional and only required if the user wants to verify model outputs after execution. 52 """ 53 54 def __init__( 55 self, 56 inputs: MethodInputType, 57 expected_outputs: Optional[MethodOutputType] = None, 58 ) -> None: 59 """Single test case for verifying specific method 60 61 Args: 62 inputs: All inputs required by eager_model with specific inference method for one-time execution. 63 64 It is worth mentioning that, although both bundled program and ET runtime apis support setting input 65 other than `torch.tensor` type, only the input in `torch.tensor` type will be actually updated in 66 the method, and the rest of the inputs will just do a sanity check if they match the default value in method. 67 68 expected_outputs: Expected output of given input for verification. It can be None if user only wants to use the test case for profiling. 69 70 Returns: 71 self 72 """ 73 # TODO(gasoonjia): Update type check logic. 74 # pyre-ignore [6]: Misalign data type for between MethodTestCase attribute and sanity check. 75 self.inputs: List[ConfigValue] = self._flatten_and_sanity_check(inputs) 76 self.expected_outputs: List[ConfigValue] = [] 77 if expected_outputs is not None: 78 # pyre-ignore [6]: Misalign data type for between MethodTestCase attribute and sanity check. 79 self.expected_outputs = self._flatten_and_sanity_check(expected_outputs) 80 81 def _flatten_and_sanity_check( 82 self, unflatten_data: DataContainer 83 ) -> List[ConfigValue]: 84 """Flat the given data and check its legality 85 86 Args: 87 unflatten_data: Data needs to be flatten. 88 89 Returns: 90 flatten_data: Flatten data with legal type. 91 """ 92 93 flatten_data, _ = tree_flatten(unflatten_data) 94 95 for data in flatten_data: 96 assert isinstance( 97 data, 98 get_args(ConfigValue), 99 ), "The type of input {} with type {} is not supported.\n".format( 100 data, type(data) 101 ) 102 assert not isinstance( 103 data, 104 type(None), 105 ), "The input {} should not be in null type.\n".format(data) 106 107 return flatten_data 108 109 110@dataclass 111class MethodTestSuite: 112 """All test info related to verify method 113 114 Attributes: 115 method_name: Name of the method to be verified. 116 test_cases: All test cases for verifying the method. 117 """ 118 119 method_name: str 120 test_cases: Sequence[MethodTestCase] 121