xref: /aosp_15_r20/external/executorch/examples/models/resnet/model.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8
9import torch
10
11from torchvision.models import (  # @manual
12    resnet18,
13    ResNet18_Weights,
14    resnet50,
15    ResNet50_Weights,
16)
17
18from ..model_base import EagerModelBase
19
20
21class ResNet18Model(EagerModelBase):
22    def __init__(self):
23        pass
24
25    def get_eager_model(self) -> torch.nn.Module:
26        logging.info("Loading torchvision resnet18 model")
27        resnet18_model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
28        logging.info("Loaded torchvision resnet18 model")
29        return resnet18_model
30
31    def get_example_inputs(self):
32        input_shape = (1, 3, 224, 224)
33        return (torch.randn(input_shape),)
34
35
36class ResNet50Model(EagerModelBase):
37    def __init__(self):
38        pass
39
40    def get_eager_model(self) -> torch.nn.Module:
41        logging.info("Loading torchvision resnet50 model")
42        resnet50_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
43        logging.info("Loaded torchvision resnet50 model")
44        return resnet50_model
45
46    def get_example_inputs(self):
47        input_shape = (1, 3, 224, 224)
48        return (torch.randn(input_shape),)
49