xref: /aosp_15_r20/external/executorch/examples/models/resnet/model.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport logging
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Workerimport torch
10*523fa7a6SAndroid Build Coastguard Worker
11*523fa7a6SAndroid Build Coastguard Workerfrom torchvision.models import (  # @manual
12*523fa7a6SAndroid Build Coastguard Worker    resnet18,
13*523fa7a6SAndroid Build Coastguard Worker    ResNet18_Weights,
14*523fa7a6SAndroid Build Coastguard Worker    resnet50,
15*523fa7a6SAndroid Build Coastguard Worker    ResNet50_Weights,
16*523fa7a6SAndroid Build Coastguard Worker)
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Workerfrom ..model_base import EagerModelBase
19*523fa7a6SAndroid Build Coastguard Worker
20*523fa7a6SAndroid Build Coastguard Worker
21*523fa7a6SAndroid Build Coastguard Workerclass ResNet18Model(EagerModelBase):
22*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
23*523fa7a6SAndroid Build Coastguard Worker        pass
24*523fa7a6SAndroid Build Coastguard Worker
25*523fa7a6SAndroid Build Coastguard Worker    def get_eager_model(self) -> torch.nn.Module:
26*523fa7a6SAndroid Build Coastguard Worker        logging.info("Loading torchvision resnet18 model")
27*523fa7a6SAndroid Build Coastguard Worker        resnet18_model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
28*523fa7a6SAndroid Build Coastguard Worker        logging.info("Loaded torchvision resnet18 model")
29*523fa7a6SAndroid Build Coastguard Worker        return resnet18_model
30*523fa7a6SAndroid Build Coastguard Worker
31*523fa7a6SAndroid Build Coastguard Worker    def get_example_inputs(self):
32*523fa7a6SAndroid Build Coastguard Worker        input_shape = (1, 3, 224, 224)
33*523fa7a6SAndroid Build Coastguard Worker        return (torch.randn(input_shape),)
34*523fa7a6SAndroid Build Coastguard Worker
35*523fa7a6SAndroid Build Coastguard Worker
36*523fa7a6SAndroid Build Coastguard Workerclass ResNet50Model(EagerModelBase):
37*523fa7a6SAndroid Build Coastguard Worker    def __init__(self):
38*523fa7a6SAndroid Build Coastguard Worker        pass
39*523fa7a6SAndroid Build Coastguard Worker
40*523fa7a6SAndroid Build Coastguard Worker    def get_eager_model(self) -> torch.nn.Module:
41*523fa7a6SAndroid Build Coastguard Worker        logging.info("Loading torchvision resnet50 model")
42*523fa7a6SAndroid Build Coastguard Worker        resnet50_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
43*523fa7a6SAndroid Build Coastguard Worker        logging.info("Loaded torchvision resnet50 model")
44*523fa7a6SAndroid Build Coastguard Worker        return resnet50_model
45*523fa7a6SAndroid Build Coastguard Worker
46*523fa7a6SAndroid Build Coastguard Worker    def get_example_inputs(self):
47*523fa7a6SAndroid Build Coastguard Worker        input_shape = (1, 3, 224, 224)
48*523fa7a6SAndroid Build Coastguard Worker        return (torch.randn(input_shape),)
49