xref: /aosp_15_r20/external/executorch/examples/models/model_base.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from abc import ABC, abstractmethod
8
9import torch
10
11
12class EagerModelBase(ABC):
13    """
14    Abstract base class for eager mode models.
15
16    This abstract class defines the interface that eager mode model classes should adhere to.
17    Eager mode models inherit from this class to ensure consistent behavior and structure.
18    """
19
20    @abstractmethod
21    def __init__(self):
22        """
23        Constructor for EagerModelBase.
24
25        This initializer may be overridden in derived classes to provide additional setup if needed.
26        """
27        pass
28
29    @abstractmethod
30    def get_eager_model(self) -> torch.nn.Module:
31        """
32        Abstract method to return an eager PyTorch model instance.
33
34        Returns:
35            nn.Module: An instance of a PyTorch model, suitable for eager execution.
36        """
37        raise NotImplementedError("get_eager_model")
38
39    @abstractmethod
40    def get_example_inputs(self):
41        """
42        Abstract method to provide example inputs for the model.
43
44        Returns:
45            Any: Example inputs that can be used for testing and tracing.
46        """
47        raise NotImplementedError("get_example_inputs")
48