1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates. 2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved. 3*523fa7a6SAndroid Build Coastguard Worker# 4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the 5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree. 6*523fa7a6SAndroid Build Coastguard Worker 7*523fa7a6SAndroid Build Coastguard Workerimport logging 8*523fa7a6SAndroid Build Coastguard Worker 9*523fa7a6SAndroid Build Coastguard Workerimport torch 10*523fa7a6SAndroid Build Coastguard Worker 11*523fa7a6SAndroid Build Coastguard Workerfrom torchvision.models import ( # @manual 12*523fa7a6SAndroid Build Coastguard Worker resnet18, 13*523fa7a6SAndroid Build Coastguard Worker ResNet18_Weights, 14*523fa7a6SAndroid Build Coastguard Worker resnet50, 15*523fa7a6SAndroid Build Coastguard Worker ResNet50_Weights, 16*523fa7a6SAndroid Build Coastguard Worker) 17*523fa7a6SAndroid Build Coastguard Worker 18*523fa7a6SAndroid Build Coastguard Workerfrom ..model_base import EagerModelBase 19*523fa7a6SAndroid Build Coastguard Worker 20*523fa7a6SAndroid Build Coastguard Worker 21*523fa7a6SAndroid Build Coastguard Workerclass ResNet18Model(EagerModelBase): 22*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 23*523fa7a6SAndroid Build Coastguard Worker pass 24*523fa7a6SAndroid Build Coastguard Worker 25*523fa7a6SAndroid Build Coastguard Worker def get_eager_model(self) -> torch.nn.Module: 26*523fa7a6SAndroid Build Coastguard Worker logging.info("Loading torchvision resnet18 model") 27*523fa7a6SAndroid Build Coastguard Worker resnet18_model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1) 28*523fa7a6SAndroid Build Coastguard Worker logging.info("Loaded torchvision resnet18 model") 29*523fa7a6SAndroid Build Coastguard Worker return resnet18_model 30*523fa7a6SAndroid Build Coastguard Worker 31*523fa7a6SAndroid Build Coastguard Worker def get_example_inputs(self): 32*523fa7a6SAndroid Build Coastguard Worker input_shape = (1, 3, 224, 224) 33*523fa7a6SAndroid Build Coastguard Worker return (torch.randn(input_shape),) 34*523fa7a6SAndroid Build Coastguard Worker 35*523fa7a6SAndroid Build Coastguard Worker 36*523fa7a6SAndroid Build Coastguard Workerclass ResNet50Model(EagerModelBase): 37*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 38*523fa7a6SAndroid Build Coastguard Worker pass 39*523fa7a6SAndroid Build Coastguard Worker 40*523fa7a6SAndroid Build Coastguard Worker def get_eager_model(self) -> torch.nn.Module: 41*523fa7a6SAndroid Build Coastguard Worker logging.info("Loading torchvision resnet50 model") 42*523fa7a6SAndroid Build Coastguard Worker resnet50_model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) 43*523fa7a6SAndroid Build Coastguard Worker logging.info("Loaded torchvision resnet50 model") 44*523fa7a6SAndroid Build Coastguard Worker return resnet50_model 45*523fa7a6SAndroid Build Coastguard Worker 46*523fa7a6SAndroid Build Coastguard Worker def get_example_inputs(self): 47*523fa7a6SAndroid Build Coastguard Worker input_shape = (1, 3, 224, 224) 48*523fa7a6SAndroid Build Coastguard Worker return (torch.randn(input_shape),) 49