#!/usr/bin/env python3 # Owner(s): ["oncall: mobile"] # mypy: allow-untyped-defs import io import cv2 import torch import torch.utils.bundled_inputs from torch.testing._internal.common_utils import TestCase torch.ops.load_library("//caffe2/torch/fb/operators:decode_bundled_image") def model_size(sm): buffer = io.BytesIO() torch.jit.save(sm, buffer) return len(buffer.getvalue()) def save_and_load(sm): buffer = io.BytesIO() torch.jit.save(sm, buffer) buffer.seek(0) return torch.jit.load(buffer) """Return an InflatableArg that contains a tensor of the compressed image and the way to decode it keyword arguments: img_tensor -- the raw image tensor in HWC or NCHW with pixel value of type unsigned int if in NCHW format, N should be 1 quality -- the quality needed to compress the image """ def bundle_jpeg_image(img_tensor, quality): # turn NCHW to HWC if img_tensor.dim() == 4: assert img_tensor.size(0) == 1 img_tensor = img_tensor[0].permute(1, 2, 0) pixels = img_tensor.numpy() encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality] _, enc_img = cv2.imencode(".JPEG", pixels, encode_param) enc_img_tensor = torch.from_numpy(enc_img) enc_img_tensor = torch.flatten(enc_img_tensor).byte() obj = torch.utils.bundled_inputs.InflatableArg( enc_img_tensor, "torch.ops.fb.decode_bundled_image({})" ) return obj def get_tensor_from_raw_BGR(im) -> torch.Tensor: raw_data = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) raw_data = torch.from_numpy(raw_data).float() raw_data = raw_data.permute(2, 0, 1) raw_data = torch.div(raw_data, 255).unsqueeze(0) return raw_data class TestBundledImages(TestCase): def test_single_tensors(self): class SingleTensorModel(torch.nn.Module): def forward(self, arg): return arg im = cv2.imread("caffe2/test/test_img/p1.jpg") tensor = torch.from_numpy(im) inflatable_arg = bundle_jpeg_image(tensor, 90) input = [(inflatable_arg,)] sm = torch.jit.script(SingleTensorModel()) torch.utils.bundled_inputs.augment_model_with_bundled_inputs(sm, input) loaded = save_and_load(sm) inflated = loaded.get_all_bundled_inputs() decoded_data = inflated[0][0] # raw image raw_data = get_tensor_from_raw_BGR(im) self.assertEqual(len(inflated), 1) self.assertEqual(len(inflated[0]), 1) self.assertEqual(raw_data.shape, decoded_data.shape) self.assertEqual(raw_data, decoded_data, atol=0.1, rtol=1e-01) # Check if fb::image_decode_to_NCHW works as expected with open("caffe2/test/test_img/p1.jpg", "rb") as fp: weight = torch.full((3,), 1.0 / 255.0).diag() bias = torch.zeros(3) byte_tensor = torch.tensor(list(fp.read())).byte() im2_tensor = torch.ops.fb.image_decode_to_NCHW(byte_tensor, weight, bias) self.assertEqual(raw_data.shape, im2_tensor.shape) self.assertEqual(raw_data, im2_tensor, atol=0.1, rtol=1e-01)