xref: /aosp_15_r20/external/pytorch/test/mobile/model_test/gen_test_model.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import io
2import sys
3
4import yaml
5from android_api_module import AndroidAPIModule
6from builtin_ops import TSBuiltinOpsModule, TSCollectionOpsModule
7from math_ops import (
8    BlasLapackOpsModule,
9    ComparisonOpsModule,
10    OtherMathOpsModule,
11    PointwiseOpsModule,
12    ReductionOpsModule,
13    SpectralOpsModule,
14)
15from nn_ops import (
16    NNActivationModule,
17    NNConvolutionModule,
18    NNDistanceModule,
19    NNDropoutModule,
20    NNLinearModule,
21    NNLossFunctionModule,
22    NNNormalizationModule,
23    NNPaddingModule,
24    NNPoolingModule,
25    NNRecurrentModule,
26    NNShuffleModule,
27    NNSparseModule,
28    NNTransformerModule,
29    NNUtilsModule,
30    NNVisionModule,
31)
32from quantization_ops import FusedQuantModule, GeneralQuantModule, StaticQuantModule
33from sampling_ops import SamplingOpsModule
34from tensor_ops import (
35    TensorCreationOpsModule,
36    TensorIndexingOpsModule,
37    TensorOpsModule,
38    TensorTypingOpsModule,
39    TensorViewOpsModule,
40)
41from torchvision_models import (
42    MobileNetV2Module,
43    MobileNetV2VulkanModule,
44    Resnet18Module,
45)
46
47import torch
48from torch.jit.mobile import _load_for_lite_interpreter
49
50
51test_path_ios = "ios/TestApp/models/"
52test_path_android = "android/pytorch_android/src/androidTest/assets/"
53
54production_ops_path = "test/mobile/model_test/model_ops.yaml"
55coverage_out_path = "test/mobile/model_test/coverage.yaml"
56
57all_modules = {
58    # math ops
59    "pointwise_ops": PointwiseOpsModule(),
60    "reduction_ops": ReductionOpsModule(),
61    "comparison_ops": ComparisonOpsModule(),
62    "spectral_ops": SpectralOpsModule(),
63    "other_math_ops": OtherMathOpsModule(),
64    "blas_lapack_ops": BlasLapackOpsModule(),
65    # sampling
66    "sampling_ops": SamplingOpsModule(),
67    # tensor ops
68    "tensor_general_ops": TensorOpsModule(),
69    "tensor_creation_ops": TensorCreationOpsModule(),
70    "tensor_indexing_ops": TensorIndexingOpsModule(),
71    "tensor_typing_ops": TensorTypingOpsModule(),
72    "tensor_view_ops": TensorViewOpsModule(),
73    # nn ops
74    "convolution_ops": NNConvolutionModule(),
75    "pooling_ops": NNPoolingModule(),
76    "padding_ops": NNPaddingModule(),
77    "activation_ops": NNActivationModule(),
78    "normalization_ops": NNNormalizationModule(),
79    "recurrent_ops": NNRecurrentModule(),
80    "transformer_ops": NNTransformerModule(),
81    "linear_ops": NNLinearModule(),
82    "dropout_ops": NNDropoutModule(),
83    "sparse_ops": NNSparseModule(),
84    "distance_function_ops": NNDistanceModule(),
85    "loss_function_ops": NNLossFunctionModule(),
86    "vision_function_ops": NNVisionModule(),
87    "shuffle_ops": NNShuffleModule(),
88    "nn_utils_ops": NNUtilsModule(),
89    # quantization ops
90    "general_quant_ops": GeneralQuantModule(),
91    # TODO([email protected]): fix and re-enable dynamic_quant_ops
92    # "dynamic_quant_ops": DynamicQuantModule(),
93    "static_quant_ops": StaticQuantModule(),
94    "fused_quant_ops": FusedQuantModule(),
95    # TorchScript buildin ops
96    "torchscript_builtin_ops": TSBuiltinOpsModule(),
97    "torchscript_collection_ops": TSCollectionOpsModule(),
98    # vision
99    "mobilenet_v2": MobileNetV2Module(),
100    "mobilenet_v2_vulkan": MobileNetV2VulkanModule(),
101    "resnet18": Resnet18Module(),
102    # android api module
103    "android_api_module": AndroidAPIModule(),
104}
105
106models_need_trace = [
107    "static_quant_ops",
108]
109
110
111def calcOpsCoverage(ops):
112    with open(production_ops_path) as input_yaml_file:
113        production_ops_dict = yaml.safe_load(input_yaml_file)
114
115    production_ops = set(production_ops_dict["root_operators"].keys())
116    all_generated_ops = set(ops)
117    covered_ops = production_ops.intersection(all_generated_ops)
118    uncovered_ops = production_ops - covered_ops
119    coverage = round(100 * len(covered_ops) / len(production_ops), 2)
120
121    # weighted coverage (take op occurances into account)
122    total_occurances = sum(production_ops_dict["root_operators"].values())
123    covered_ops_dict = {
124        op: production_ops_dict["root_operators"][op] for op in covered_ops
125    }
126    uncovered_ops_dict = {
127        op: production_ops_dict["root_operators"][op] for op in uncovered_ops
128    }
129    covered_occurances = sum(covered_ops_dict.values())
130    occurances_coverage = round(100 * covered_occurances / total_occurances, 2)
131
132    print(f"\n{len(uncovered_ops)} uncovered ops: {uncovered_ops}\n")
133    print(f"Generated {len(all_generated_ops)} ops")
134    print(
135        f"Covered {len(covered_ops)}/{len(production_ops)} ({coverage}%) production ops"
136    )
137    print(
138        f"Covered {covered_occurances}/{total_occurances} ({occurances_coverage}%) occurances"
139    )
140    print(f"pytorch ver {torch.__version__}\n")
141
142    with open(coverage_out_path, "w") as f:
143        yaml.safe_dump(
144            {
145                "_covered_ops": len(covered_ops),
146                "_production_ops": len(production_ops),
147                "_generated_ops": len(all_generated_ops),
148                "_uncovered_ops": len(uncovered_ops),
149                "_coverage": round(coverage, 2),
150                "uncovered_ops": uncovered_ops_dict,
151                "covered_ops": covered_ops_dict,
152                "all_generated_ops": sorted(all_generated_ops),
153            },
154            f,
155        )
156
157
158def getModuleFromName(model_name):
159    if model_name not in all_modules:
160        print("Cannot find test model for " + model_name)
161        return None, []
162
163    module = all_modules[model_name]
164    if not isinstance(module, torch.nn.Module):
165        module = module.getModule()
166
167    has_bundled_inputs = False  # module.find_method("get_all_bundled_inputs")
168
169    if model_name in models_need_trace:
170        module = torch.jit.trace(module, [])
171    else:
172        module = torch.jit.script(module)
173
174    ops = torch.jit.export_opnames(module)
175    print(ops)
176
177    # try to run the model
178    runModule(module)
179
180    return module, ops
181
182
183def runModule(module):
184    buffer = io.BytesIO(module._save_to_buffer_for_lite_interpreter())
185    buffer.seek(0)
186    lite_module = _load_for_lite_interpreter(buffer)
187    if lite_module.find_method("get_all_bundled_inputs"):
188        # run with the first bundled input
189        input = lite_module.run_method("get_all_bundled_inputs")[0]
190        lite_module.forward(*input)
191    else:
192        # assuming model has no input
193        lite_module()
194
195
196# generate all models in the given folder.
197# If it's "on the fly" mode, add "_temp" suffix to the model file.
198def generateAllModels(folder, on_the_fly=False):
199    all_ops = []
200    for name in all_modules:
201        module, ops = getModuleFromName(name)
202        all_ops = all_ops + ops
203        path = folder + name + ("_temp.ptl" if on_the_fly else ".ptl")
204        module._save_for_lite_interpreter(path)
205        print("model saved to " + path)
206    calcOpsCoverage(all_ops)
207
208
209# generate/update a given model for storage
210def generateModel(name):
211    module, ops = getModuleFromName(name)
212    if module is None:
213        return
214    path_ios = test_path_ios + name + ".ptl"
215    path_android = test_path_android + name + ".ptl"
216    module._save_for_lite_interpreter(path_ios)
217    module._save_for_lite_interpreter(path_android)
218    print("model saved to " + path_ios + " and " + path_android)
219
220
221def main(argv):
222    if argv is None or len(argv) != 1:
223        print(
224            """
225This script generate models for mobile test. For each model we have a "storage" version
226and an "on-the-fly" version. The "on-the-fly" version will be generated during test,and
227should not be committed to the repo.
228The "storage" version is for back compatibility # test (a model generated today should
229run on master branch in the next 6 months). We can use this script to update a model that
230is no longer supported.
231- use 'python gen_test_model.py android-test' to generate on-the-fly models for android
232- use 'python gen_test_model.py ios-test' to generate on-the-fly models for ios
233- use 'python gen_test_model.py android' to generate checked-in models for android
234- use 'python gen_test_model.py ios' to generate on-the-fly models for ios
235- use 'python gen_test_model.py <model_name_no_suffix>' to update the given storage model
236"""
237        )
238        return
239
240    if argv[0] == "android":
241        generateAllModels(test_path_android, on_the_fly=False)
242    elif argv[0] == "ios":
243        generateAllModels(test_path_ios, on_the_fly=False)
244    elif argv[0] == "android-test":
245        generateAllModels(test_path_android, on_the_fly=True)
246    elif argv[0] == "ios-test":
247        generateAllModels(test_path_ios, on_the_fly=True)
248    else:
249        generateModel(argv[0])
250
251
252if __name__ == "__main__":
253    main(sys.argv[1:])
254