1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8 9import torch 10from timm.models import inception_v4 11 12from ..model_base import EagerModelBase 13 14 15class InceptionV4Model(EagerModelBase): 16 def __init__(self): 17 pass 18 19 def get_eager_model(self) -> torch.nn.Module: 20 logging.info("Loading inception_v4 model") 21 m = inception_v4(pretrained=True) 22 logging.info("Loaded inception_v4 model") 23 return m 24 25 def get_example_inputs(self): 26 return (torch.randn(3, 299, 299).unsqueeze(0),) 27