1from torchvision import models 2 3import torch 4from torch.utils.bundled_inputs import augment_model_with_bundled_inputs 5from torch.utils.mobile_optimizer import optimize_for_mobile 6 7 8class MobileNetV2Module: 9 def getModule(self): 10 model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1) 11 model.eval() 12 example = torch.zeros(1, 3, 224, 224) 13 traced_script_module = torch.jit.trace(model, example) 14 optimized_module = optimize_for_mobile(traced_script_module) 15 augment_model_with_bundled_inputs( 16 optimized_module, 17 [ 18 (example,), 19 ], 20 ) 21 optimized_module(example) 22 return optimized_module 23 24 25class MobileNetV2VulkanModule: 26 def getModule(self): 27 model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1) 28 model.eval() 29 example = torch.zeros(1, 3, 224, 224) 30 traced_script_module = torch.jit.trace(model, example) 31 optimized_module = optimize_for_mobile(traced_script_module, backend="vulkan") 32 augment_model_with_bundled_inputs( 33 optimized_module, 34 [ 35 (example,), 36 ], 37 ) 38 optimized_module(example) 39 return optimized_module 40 41 42class Resnet18Module: 43 def getModule(self): 44 model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) 45 model.eval() 46 example = torch.zeros(1, 3, 224, 224) 47 traced_script_module = torch.jit.trace(model, example) 48 optimized_module = optimize_for_mobile(traced_script_module) 49 augment_model_with_bundled_inputs( 50 optimized_module, 51 [ 52 (example,), 53 ], 54 ) 55 optimized_module(example) 56 return optimized_module 57