1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# Utility functions for image processing. Run it with your image: 8 9# python image_util.py --image-path <path_to_image> 10 11import logging 12from argparse import ArgumentParser 13 14import torch 15import torchvision 16from PIL import Image 17from torch import nn 18 19 20FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" 21logging.basicConfig(level=logging.INFO, format=FORMAT) 22 23 24def prepare_image(image: Image, target_h: int, target_w: int) -> torch.Tensor: 25 """Read image into a tensor and resize the image so that it fits in 26 a target_h x target_w canvas. 27 28 Args: 29 image (Image): An Image object. 30 target_h (int): Target height. 31 target_w (int): Target width. 32 33 Returns: 34 torch.Tensor: resized image tensor. 35 """ 36 img = torchvision.transforms.functional.pil_to_tensor(image) 37 # height ratio 38 ratio_h = img.shape[1] / target_h 39 # width ratio 40 ratio_w = img.shape[2] / target_w 41 # resize the image so that it fits in a target_h x target_w canvas 42 ratio = max(ratio_h, ratio_w) 43 output_size = (int(img.shape[1] / ratio), int(img.shape[2] / ratio)) 44 img = torchvision.transforms.Resize(size=output_size)(img) 45 return img 46 47 48def serialize_image(image: torch.Tensor, path: str) -> None: 49 copy = torch.tensor(image) 50 m = nn.Module() 51 par = nn.Parameter(copy, requires_grad=False) 52 m.register_parameter("0", par) 53 tensors = torch.jit.script(m) 54 tensors.save(path) 55 56 logging.info(f"Saved image tensor to {path}") 57 58 59def main(): 60 parser = ArgumentParser() 61 parser.add_argument( 62 "--image-path", 63 required=True, 64 help="Path to the image.", 65 ) 66 parser.add_argument( 67 "--output-path", 68 default="image.pt", 69 ) 70 args = parser.parse_args() 71 72 image = Image.open(args.image_path) 73 image_tensor = prepare_image(image, target_h=336, target_w=336) 74 serialize_image(image_tensor, args.output_path) 75 76 77if __name__ == "__main__": 78 main() 79