1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8 9import torch 10 11from torchvision.models import ( # @manual 12 resnet18, 13 ResNet18_Weights, 14 resnet50, 15 ResNet50_Weights, 16) 17 18from ..model_base import EagerModelBase 19 20 21class ResNet18Model(EagerModelBase): 22 def __init__(self): 23 pass 24 25 def get_eager_model(self) -> torch.nn.Module: 26 logging.info("Loading torchvision resnet18 model") 27 resnet18_model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1) 28 logging.info("Loaded torchvision resnet18 model") 29 return resnet18_model 30 31 def get_example_inputs(self): 32 input_shape = (1, 3, 224, 224) 33 return (torch.randn(input_shape),) 34 35 36class ResNet50Model(EagerModelBase): 37 def __init__(self): 38 pass 39 40 def get_eager_model(self) -> torch.nn.Module: 41 logging.info("Loading torchvision resnet50 model") 42 resnet50_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) 43 logging.info("Loaded torchvision resnet50 model") 44 return resnet50_model 45 46 def get_example_inputs(self): 47 input_shape = (1, 3, 224, 224) 48 return (torch.randn(input_shape),) 49