xref: /aosp_15_r20/external/pytorch/test/mobile/model_test/torchvision_models.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from torchvision import models
2
3import torch
4from torch.utils.bundled_inputs import augment_model_with_bundled_inputs
5from torch.utils.mobile_optimizer import optimize_for_mobile
6
7
8class MobileNetV2Module:
9    def getModule(self):
10        model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
11        model.eval()
12        example = torch.zeros(1, 3, 224, 224)
13        traced_script_module = torch.jit.trace(model, example)
14        optimized_module = optimize_for_mobile(traced_script_module)
15        augment_model_with_bundled_inputs(
16            optimized_module,
17            [
18                (example,),
19            ],
20        )
21        optimized_module(example)
22        return optimized_module
23
24
25class MobileNetV2VulkanModule:
26    def getModule(self):
27        model = models.mobilenet_v2(weights=models.MobileNet_V2_Weights.IMAGENET1K_V1)
28        model.eval()
29        example = torch.zeros(1, 3, 224, 224)
30        traced_script_module = torch.jit.trace(model, example)
31        optimized_module = optimize_for_mobile(traced_script_module, backend="vulkan")
32        augment_model_with_bundled_inputs(
33            optimized_module,
34            [
35                (example,),
36            ],
37        )
38        optimized_module(example)
39        return optimized_module
40
41
42class Resnet18Module:
43    def getModule(self):
44        model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
45        model.eval()
46        example = torch.zeros(1, 3, 224, 224)
47        traced_script_module = torch.jit.trace(model, example)
48        optimized_module = optimize_for_mobile(traced_script_module)
49        augment_model_with_bundled_inputs(
50            optimized_module,
51            [
52                (example,),
53            ],
54        )
55        optimized_module(example)
56        return optimized_module
57