1""" 2This is a script for PyTorch Android custom selective build test. It prepares 3MobileNetV2 TorchScript model, and dumps root ops used by the model for custom 4build script to create a tailored build which only contains these used ops. 5""" 6 7import yaml 8from torchvision import models 9 10import torch 11 12 13# Download and trace the model. 14model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1) 15model.eval() 16example = torch.rand(1, 3, 224, 224) 17# TODO: create script model with `torch.jit.script` 18traced_script_module = torch.jit.trace(model, example) 19 20# Save traced TorchScript model. 21traced_script_module.save("MobileNetV2.pt") 22 23# Dump root ops used by the model (for custom build optimization). 24ops = torch.jit.export_opnames(traced_script_module) 25 26with open("MobileNetV2.yaml", "w") as output: 27 yaml.dump(ops, output) 28