xref: /aosp_15_r20/external/executorch/extension/llm/custom_ops/test_preprocess_custom_ops.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import unittest
10from typing import Tuple
11
12import torch
13
14from .preprocess_custom_ops import preprocess_op_lib  # noqa
15
16
17class PreprocessTest(unittest.TestCase):
18
19    def setUp(self):
20        # tile_crop
21        self.tile_size = 224
22
23    def _test_tile_crop(self, image: torch.Tensor, expected_shape: Tuple[int]) -> None:
24        output = torch.ops.preprocess.tile_crop.default(image, self.tile_size)
25        self.assertTrue(output.shape == expected_shape)
26
27    def test_op_tile_crop_2x2(self):
28        self._test_tile_crop(torch.ones(3, 448, 448), (4, 3, 224, 224))
29
30    def test_op_tile_crop_1x3(self):
31        self._test_tile_crop(torch.ones(3, 224, 672), (3, 3, 224, 224))
32
33    def test_op_tile_crop_4x2(self):
34        self._test_tile_crop(torch.ones(3, 896, 448), (8, 3, 224, 224))
35