1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9from dataclasses import dataclass 10from typing import Dict, List, Optional, Tuple 11 12import torch 13 14from executorch.extension.llm.custom_ops import op_tile_crop_aot # noqa 15from torch.export import Dim 16from torchtune.models.clip.inference._transform import _CLIPImageTransform 17 18from ...model_base import EagerModelBase 19 20 21@dataclass 22class PreprocessConfig: 23 image_mean: Optional[List[float]] = None 24 image_std: Optional[List[float]] = None 25 resample: str = "bilinear" 26 max_num_tiles: int = 4 27 tile_size: int = 224 28 antialias: bool = False 29 # Used for reference eager model from torchtune. 30 resize_to_max_canvas: bool = False 31 possible_resolutions: Optional[List[Tuple[int, int]]] = None 32 33 34class CLIPImageTransformModel(EagerModelBase): 35 def __init__( 36 self, 37 config: PreprocessConfig, 38 ): 39 super().__init__() 40 41 # Eager model. 42 self.model = _CLIPImageTransform( 43 image_mean=config.image_mean, 44 image_std=config.image_std, 45 resample=config.resample, 46 max_num_tiles=config.max_num_tiles, 47 tile_size=config.tile_size, 48 antialias=config.antialias, 49 ) 50 51 # Replace non-exportable ops with custom ops. 52 self.model.tile_crop = torch.ops.preprocess.tile_crop.default 53 54 def get_eager_model(self) -> torch.nn.Module: 55 return self.model 56 57 def get_example_inputs(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 58 image = torch.ones(3, 800, 600) 59 target_size = torch.tensor([448, 336]) 60 canvas_size = torch.tensor([448, 448]) 61 return (image, target_size, canvas_size) 62 63 def get_dynamic_shapes(self) -> Dict[str, Dict[int, Dim]]: 64 img_h = Dim("img_h", min=1, max=4000) 65 img_w = Dim("img_w", min=1, max=4000) 66 67 dynamic_shapes = { 68 "image": {1: img_h, 2: img_w}, 69 "target_size": None, 70 "canvas_size": None, 71 } 72 return dynamic_shapes 73