xref: /aosp_15_r20/external/executorch/examples/models/llava/image_util.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# Utility functions for image processing. Run it with your image:
8
9# python image_util.py --image-path <path_to_image>
10
11import logging
12from argparse import ArgumentParser
13
14import torch
15import torchvision
16from PIL import Image
17from torch import nn
18
19
20FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
21logging.basicConfig(level=logging.INFO, format=FORMAT)
22
23
24def prepare_image(image: Image, target_h: int, target_w: int) -> torch.Tensor:
25    """Read image into a tensor and resize the image so that it fits in
26    a target_h x target_w canvas.
27
28    Args:
29        image (Image): An Image object.
30        target_h (int): Target height.
31        target_w (int): Target width.
32
33    Returns:
34        torch.Tensor: resized image tensor.
35    """
36    img = torchvision.transforms.functional.pil_to_tensor(image)
37    # height ratio
38    ratio_h = img.shape[1] / target_h
39    # width ratio
40    ratio_w = img.shape[2] / target_w
41    # resize the image so that it fits in a target_h x target_w canvas
42    ratio = max(ratio_h, ratio_w)
43    output_size = (int(img.shape[1] / ratio), int(img.shape[2] / ratio))
44    img = torchvision.transforms.Resize(size=output_size)(img)
45    return img
46
47
48def serialize_image(image: torch.Tensor, path: str) -> None:
49    copy = torch.tensor(image)
50    m = nn.Module()
51    par = nn.Parameter(copy, requires_grad=False)
52    m.register_parameter("0", par)
53    tensors = torch.jit.script(m)
54    tensors.save(path)
55
56    logging.info(f"Saved image tensor to {path}")
57
58
59def main():
60    parser = ArgumentParser()
61    parser.add_argument(
62        "--image-path",
63        required=True,
64        help="Path to the image.",
65    )
66    parser.add_argument(
67        "--output-path",
68        default="image.pt",
69    )
70    args = parser.parse_args()
71
72    image = Image.open(args.image_path)
73    image_tensor = prepare_image(image, target_h=336, target_w=336)
74    serialize_image(image_tensor, args.output_path)
75
76
77if __name__ == "__main__":
78    main()
79