1import io 2import sys 3 4import yaml 5from android_api_module import AndroidAPIModule 6from builtin_ops import TSBuiltinOpsModule, TSCollectionOpsModule 7from math_ops import ( 8 BlasLapackOpsModule, 9 ComparisonOpsModule, 10 OtherMathOpsModule, 11 PointwiseOpsModule, 12 ReductionOpsModule, 13 SpectralOpsModule, 14) 15from nn_ops import ( 16 NNActivationModule, 17 NNConvolutionModule, 18 NNDistanceModule, 19 NNDropoutModule, 20 NNLinearModule, 21 NNLossFunctionModule, 22 NNNormalizationModule, 23 NNPaddingModule, 24 NNPoolingModule, 25 NNRecurrentModule, 26 NNShuffleModule, 27 NNSparseModule, 28 NNTransformerModule, 29 NNUtilsModule, 30 NNVisionModule, 31) 32from quantization_ops import FusedQuantModule, GeneralQuantModule, StaticQuantModule 33from sampling_ops import SamplingOpsModule 34from tensor_ops import ( 35 TensorCreationOpsModule, 36 TensorIndexingOpsModule, 37 TensorOpsModule, 38 TensorTypingOpsModule, 39 TensorViewOpsModule, 40) 41from torchvision_models import ( 42 MobileNetV2Module, 43 MobileNetV2VulkanModule, 44 Resnet18Module, 45) 46 47import torch 48from torch.jit.mobile import _load_for_lite_interpreter 49 50 51test_path_ios = "ios/TestApp/models/" 52test_path_android = "android/pytorch_android/src/androidTest/assets/" 53 54production_ops_path = "test/mobile/model_test/model_ops.yaml" 55coverage_out_path = "test/mobile/model_test/coverage.yaml" 56 57all_modules = { 58 # math ops 59 "pointwise_ops": PointwiseOpsModule(), 60 "reduction_ops": ReductionOpsModule(), 61 "comparison_ops": ComparisonOpsModule(), 62 "spectral_ops": SpectralOpsModule(), 63 "other_math_ops": OtherMathOpsModule(), 64 "blas_lapack_ops": BlasLapackOpsModule(), 65 # sampling 66 "sampling_ops": SamplingOpsModule(), 67 # tensor ops 68 "tensor_general_ops": TensorOpsModule(), 69 "tensor_creation_ops": TensorCreationOpsModule(), 70 "tensor_indexing_ops": TensorIndexingOpsModule(), 71 "tensor_typing_ops": TensorTypingOpsModule(), 72 "tensor_view_ops": TensorViewOpsModule(), 73 # nn ops 74 "convolution_ops": NNConvolutionModule(), 75 "pooling_ops": NNPoolingModule(), 76 "padding_ops": NNPaddingModule(), 77 "activation_ops": NNActivationModule(), 78 "normalization_ops": NNNormalizationModule(), 79 "recurrent_ops": NNRecurrentModule(), 80 "transformer_ops": NNTransformerModule(), 81 "linear_ops": NNLinearModule(), 82 "dropout_ops": NNDropoutModule(), 83 "sparse_ops": NNSparseModule(), 84 "distance_function_ops": NNDistanceModule(), 85 "loss_function_ops": NNLossFunctionModule(), 86 "vision_function_ops": NNVisionModule(), 87 "shuffle_ops": NNShuffleModule(), 88 "nn_utils_ops": NNUtilsModule(), 89 # quantization ops 90 "general_quant_ops": GeneralQuantModule(), 91 # TODO([email protected]): fix and re-enable dynamic_quant_ops 92 # "dynamic_quant_ops": DynamicQuantModule(), 93 "static_quant_ops": StaticQuantModule(), 94 "fused_quant_ops": FusedQuantModule(), 95 # TorchScript buildin ops 96 "torchscript_builtin_ops": TSBuiltinOpsModule(), 97 "torchscript_collection_ops": TSCollectionOpsModule(), 98 # vision 99 "mobilenet_v2": MobileNetV2Module(), 100 "mobilenet_v2_vulkan": MobileNetV2VulkanModule(), 101 "resnet18": Resnet18Module(), 102 # android api module 103 "android_api_module": AndroidAPIModule(), 104} 105 106models_need_trace = [ 107 "static_quant_ops", 108] 109 110 111def calcOpsCoverage(ops): 112 with open(production_ops_path) as input_yaml_file: 113 production_ops_dict = yaml.safe_load(input_yaml_file) 114 115 production_ops = set(production_ops_dict["root_operators"].keys()) 116 all_generated_ops = set(ops) 117 covered_ops = production_ops.intersection(all_generated_ops) 118 uncovered_ops = production_ops - covered_ops 119 coverage = round(100 * len(covered_ops) / len(production_ops), 2) 120 121 # weighted coverage (take op occurances into account) 122 total_occurances = sum(production_ops_dict["root_operators"].values()) 123 covered_ops_dict = { 124 op: production_ops_dict["root_operators"][op] for op in covered_ops 125 } 126 uncovered_ops_dict = { 127 op: production_ops_dict["root_operators"][op] for op in uncovered_ops 128 } 129 covered_occurances = sum(covered_ops_dict.values()) 130 occurances_coverage = round(100 * covered_occurances / total_occurances, 2) 131 132 print(f"\n{len(uncovered_ops)} uncovered ops: {uncovered_ops}\n") 133 print(f"Generated {len(all_generated_ops)} ops") 134 print( 135 f"Covered {len(covered_ops)}/{len(production_ops)} ({coverage}%) production ops" 136 ) 137 print( 138 f"Covered {covered_occurances}/{total_occurances} ({occurances_coverage}%) occurances" 139 ) 140 print(f"pytorch ver {torch.__version__}\n") 141 142 with open(coverage_out_path, "w") as f: 143 yaml.safe_dump( 144 { 145 "_covered_ops": len(covered_ops), 146 "_production_ops": len(production_ops), 147 "_generated_ops": len(all_generated_ops), 148 "_uncovered_ops": len(uncovered_ops), 149 "_coverage": round(coverage, 2), 150 "uncovered_ops": uncovered_ops_dict, 151 "covered_ops": covered_ops_dict, 152 "all_generated_ops": sorted(all_generated_ops), 153 }, 154 f, 155 ) 156 157 158def getModuleFromName(model_name): 159 if model_name not in all_modules: 160 print("Cannot find test model for " + model_name) 161 return None, [] 162 163 module = all_modules[model_name] 164 if not isinstance(module, torch.nn.Module): 165 module = module.getModule() 166 167 has_bundled_inputs = False # module.find_method("get_all_bundled_inputs") 168 169 if model_name in models_need_trace: 170 module = torch.jit.trace(module, []) 171 else: 172 module = torch.jit.script(module) 173 174 ops = torch.jit.export_opnames(module) 175 print(ops) 176 177 # try to run the model 178 runModule(module) 179 180 return module, ops 181 182 183def runModule(module): 184 buffer = io.BytesIO(module._save_to_buffer_for_lite_interpreter()) 185 buffer.seek(0) 186 lite_module = _load_for_lite_interpreter(buffer) 187 if lite_module.find_method("get_all_bundled_inputs"): 188 # run with the first bundled input 189 input = lite_module.run_method("get_all_bundled_inputs")[0] 190 lite_module.forward(*input) 191 else: 192 # assuming model has no input 193 lite_module() 194 195 196# generate all models in the given folder. 197# If it's "on the fly" mode, add "_temp" suffix to the model file. 198def generateAllModels(folder, on_the_fly=False): 199 all_ops = [] 200 for name in all_modules: 201 module, ops = getModuleFromName(name) 202 all_ops = all_ops + ops 203 path = folder + name + ("_temp.ptl" if on_the_fly else ".ptl") 204 module._save_for_lite_interpreter(path) 205 print("model saved to " + path) 206 calcOpsCoverage(all_ops) 207 208 209# generate/update a given model for storage 210def generateModel(name): 211 module, ops = getModuleFromName(name) 212 if module is None: 213 return 214 path_ios = test_path_ios + name + ".ptl" 215 path_android = test_path_android + name + ".ptl" 216 module._save_for_lite_interpreter(path_ios) 217 module._save_for_lite_interpreter(path_android) 218 print("model saved to " + path_ios + " and " + path_android) 219 220 221def main(argv): 222 if argv is None or len(argv) != 1: 223 print( 224 """ 225This script generate models for mobile test. For each model we have a "storage" version 226and an "on-the-fly" version. The "on-the-fly" version will be generated during test,and 227should not be committed to the repo. 228The "storage" version is for back compatibility # test (a model generated today should 229run on master branch in the next 6 months). We can use this script to update a model that 230is no longer supported. 231- use 'python gen_test_model.py android-test' to generate on-the-fly models for android 232- use 'python gen_test_model.py ios-test' to generate on-the-fly models for ios 233- use 'python gen_test_model.py android' to generate checked-in models for android 234- use 'python gen_test_model.py ios' to generate on-the-fly models for ios 235- use 'python gen_test_model.py <model_name_no_suffix>' to update the given storage model 236""" 237 ) 238 return 239 240 if argv[0] == "android": 241 generateAllModels(test_path_android, on_the_fly=False) 242 elif argv[0] == "ios": 243 generateAllModels(test_path_ios, on_the_fly=False) 244 elif argv[0] == "android-test": 245 generateAllModels(test_path_android, on_the_fly=True) 246 elif argv[0] == "ios-test": 247 generateAllModels(test_path_ios, on_the_fly=True) 248 else: 249 generateModel(argv[0]) 250 251 252if __name__ == "__main__": 253 main(sys.argv[1:]) 254