1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8 9import torch 10from torchvision.models.segmentation import ( 11 deeplabv3, 12 deeplabv3_resnet101, 13 deeplabv3_resnet50, 14) 15 16from ..model_base import EagerModelBase 17 18 19class DeepLabV3ResNet50Model(EagerModelBase): 20 def __init__(self): 21 pass 22 23 def get_eager_model(self) -> torch.nn.Module: 24 logging.info("loading deeplabv3_resnet50 model") 25 deeplabv3_model = deeplabv3_resnet50( 26 weights=deeplabv3.DeepLabV3_ResNet50_Weights.DEFAULT 27 ) 28 logging.info("loaded deeplabv3_resnet50 model") 29 return deeplabv3_model 30 31 def get_example_inputs(self): 32 input_shape = (1, 3, 224, 224) 33 return (torch.randn(input_shape),) 34 35 36class DeepLabV3ResNet101Model(EagerModelBase): 37 def __init__(self): 38 pass 39 40 def get_eager_model(self) -> torch.nn.Module: 41 logging.info("loading deeplabv3_resnet101 model") 42 deeplabv3_model = deeplabv3_resnet101( 43 weights=deeplabv3.DeepLabV3_ResNet101_Weights.DEFAULT 44 ) 45 logging.info("loaded deeplabv3_resnet101 model") 46 return deeplabv3_model 47 48 def get_example_inputs(self): 49 input_shape = (1, 3, 224, 224) 50 return (torch.randn(input_shape),) 51