1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3 2*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["oncall: mobile"] 3*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport io 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport cv2 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport torch 10*da0073e9SAndroid Build Coastguard Workerimport torch.utils.bundled_inputs 11*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TestCase 12*da0073e9SAndroid Build Coastguard Worker 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Workertorch.ops.load_library("//caffe2/torch/fb/operators:decode_bundled_image") 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Workerdef model_size(sm): 18*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 19*da0073e9SAndroid Build Coastguard Worker torch.jit.save(sm, buffer) 20*da0073e9SAndroid Build Coastguard Worker return len(buffer.getvalue()) 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Workerdef save_and_load(sm): 24*da0073e9SAndroid Build Coastguard Worker buffer = io.BytesIO() 25*da0073e9SAndroid Build Coastguard Worker torch.jit.save(sm, buffer) 26*da0073e9SAndroid Build Coastguard Worker buffer.seek(0) 27*da0073e9SAndroid Build Coastguard Worker return torch.jit.load(buffer) 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker"""Return an InflatableArg that contains a tensor of the compressed image and the way to decode it 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker keyword arguments: 33*da0073e9SAndroid Build Coastguard Worker img_tensor -- the raw image tensor in HWC or NCHW with pixel value of type unsigned int 34*da0073e9SAndroid Build Coastguard Worker if in NCHW format, N should be 1 35*da0073e9SAndroid Build Coastguard Worker quality -- the quality needed to compress the image 36*da0073e9SAndroid Build Coastguard Worker""" 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Workerdef bundle_jpeg_image(img_tensor, quality): 40*da0073e9SAndroid Build Coastguard Worker # turn NCHW to HWC 41*da0073e9SAndroid Build Coastguard Worker if img_tensor.dim() == 4: 42*da0073e9SAndroid Build Coastguard Worker assert img_tensor.size(0) == 1 43*da0073e9SAndroid Build Coastguard Worker img_tensor = img_tensor[0].permute(1, 2, 0) 44*da0073e9SAndroid Build Coastguard Worker pixels = img_tensor.numpy() 45*da0073e9SAndroid Build Coastguard Worker encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality] 46*da0073e9SAndroid Build Coastguard Worker _, enc_img = cv2.imencode(".JPEG", pixels, encode_param) 47*da0073e9SAndroid Build Coastguard Worker enc_img_tensor = torch.from_numpy(enc_img) 48*da0073e9SAndroid Build Coastguard Worker enc_img_tensor = torch.flatten(enc_img_tensor).byte() 49*da0073e9SAndroid Build Coastguard Worker obj = torch.utils.bundled_inputs.InflatableArg( 50*da0073e9SAndroid Build Coastguard Worker enc_img_tensor, "torch.ops.fb.decode_bundled_image({})" 51*da0073e9SAndroid Build Coastguard Worker ) 52*da0073e9SAndroid Build Coastguard Worker return obj 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Workerdef get_tensor_from_raw_BGR(im) -> torch.Tensor: 56*da0073e9SAndroid Build Coastguard Worker raw_data = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) 57*da0073e9SAndroid Build Coastguard Worker raw_data = torch.from_numpy(raw_data).float() 58*da0073e9SAndroid Build Coastguard Worker raw_data = raw_data.permute(2, 0, 1) 59*da0073e9SAndroid Build Coastguard Worker raw_data = torch.div(raw_data, 255).unsqueeze(0) 60*da0073e9SAndroid Build Coastguard Worker return raw_data 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Workerclass TestBundledImages(TestCase): 64*da0073e9SAndroid Build Coastguard Worker def test_single_tensors(self): 65*da0073e9SAndroid Build Coastguard Worker class SingleTensorModel(torch.nn.Module): 66*da0073e9SAndroid Build Coastguard Worker def forward(self, arg): 67*da0073e9SAndroid Build Coastguard Worker return arg 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker im = cv2.imread("caffe2/test/test_img/p1.jpg") 70*da0073e9SAndroid Build Coastguard Worker tensor = torch.from_numpy(im) 71*da0073e9SAndroid Build Coastguard Worker inflatable_arg = bundle_jpeg_image(tensor, 90) 72*da0073e9SAndroid Build Coastguard Worker input = [(inflatable_arg,)] 73*da0073e9SAndroid Build Coastguard Worker sm = torch.jit.script(SingleTensorModel()) 74*da0073e9SAndroid Build Coastguard Worker torch.utils.bundled_inputs.augment_model_with_bundled_inputs(sm, input) 75*da0073e9SAndroid Build Coastguard Worker loaded = save_and_load(sm) 76*da0073e9SAndroid Build Coastguard Worker inflated = loaded.get_all_bundled_inputs() 77*da0073e9SAndroid Build Coastguard Worker decoded_data = inflated[0][0] 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker # raw image 80*da0073e9SAndroid Build Coastguard Worker raw_data = get_tensor_from_raw_BGR(im) 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(inflated), 1) 83*da0073e9SAndroid Build Coastguard Worker self.assertEqual(len(inflated[0]), 1) 84*da0073e9SAndroid Build Coastguard Worker self.assertEqual(raw_data.shape, decoded_data.shape) 85*da0073e9SAndroid Build Coastguard Worker self.assertEqual(raw_data, decoded_data, atol=0.1, rtol=1e-01) 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker # Check if fb::image_decode_to_NCHW works as expected 88*da0073e9SAndroid Build Coastguard Worker with open("caffe2/test/test_img/p1.jpg", "rb") as fp: 89*da0073e9SAndroid Build Coastguard Worker weight = torch.full((3,), 1.0 / 255.0).diag() 90*da0073e9SAndroid Build Coastguard Worker bias = torch.zeros(3) 91*da0073e9SAndroid Build Coastguard Worker byte_tensor = torch.tensor(list(fp.read())).byte() 92*da0073e9SAndroid Build Coastguard Worker im2_tensor = torch.ops.fb.image_decode_to_NCHW(byte_tensor, weight, bias) 93*da0073e9SAndroid Build Coastguard Worker self.assertEqual(raw_data.shape, im2_tensor.shape) 94*da0073e9SAndroid Build Coastguard Worker self.assertEqual(raw_data, im2_tensor, atol=0.1, rtol=1e-01) 95