xref: /aosp_15_r20/external/executorch/examples/models/llama3_2_vision/preprocess/model.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9from dataclasses import dataclass
10from typing import Dict, List, Optional, Tuple
11
12import torch
13
14from executorch.extension.llm.custom_ops import op_tile_crop_aot  # noqa
15from torch.export import Dim
16from torchtune.models.clip.inference._transform import _CLIPImageTransform
17
18from ...model_base import EagerModelBase
19
20
21@dataclass
22class PreprocessConfig:
23    image_mean: Optional[List[float]] = None
24    image_std: Optional[List[float]] = None
25    resample: str = "bilinear"
26    max_num_tiles: int = 4
27    tile_size: int = 224
28    antialias: bool = False
29    # Used for reference eager model from torchtune.
30    resize_to_max_canvas: bool = False
31    possible_resolutions: Optional[List[Tuple[int, int]]] = None
32
33
34class CLIPImageTransformModel(EagerModelBase):
35    def __init__(
36        self,
37        config: PreprocessConfig,
38    ):
39        super().__init__()
40
41        # Eager model.
42        self.model = _CLIPImageTransform(
43            image_mean=config.image_mean,
44            image_std=config.image_std,
45            resample=config.resample,
46            max_num_tiles=config.max_num_tiles,
47            tile_size=config.tile_size,
48            antialias=config.antialias,
49        )
50
51        # Replace non-exportable ops with custom ops.
52        self.model.tile_crop = torch.ops.preprocess.tile_crop.default
53
54    def get_eager_model(self) -> torch.nn.Module:
55        return self.model
56
57    def get_example_inputs(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
58        image = torch.ones(3, 800, 600)
59        target_size = torch.tensor([448, 336])
60        canvas_size = torch.tensor([448, 448])
61        return (image, target_size, canvas_size)
62
63    def get_dynamic_shapes(self) -> Dict[str, Dict[int, Dim]]:
64        img_h = Dim("img_h", min=1, max=4000)
65        img_w = Dim("img_w", min=1, max=4000)
66
67        dynamic_shapes = {
68            "image": {1: img_h, 2: img_w},
69            "target_size": None,
70            "canvas_size": None,
71        }
72        return dynamic_shapes
73