1*4930cef6SMatthias Ringwald# 2*4930cef6SMatthias Ringwald# Copyright 2022 Google LLC 3*4930cef6SMatthias Ringwald# 4*4930cef6SMatthias Ringwald# Licensed under the Apache License, Version 2.0 (the "License"); 5*4930cef6SMatthias Ringwald# you may not use this file except in compliance with the License. 6*4930cef6SMatthias Ringwald# You may obtain a copy of the License at 7*4930cef6SMatthias Ringwald# 8*4930cef6SMatthias Ringwald# http://www.apache.org/licenses/LICENSE-2.0 9*4930cef6SMatthias Ringwald# 10*4930cef6SMatthias Ringwald# Unless required by applicable law or agreed to in writing, software 11*4930cef6SMatthias Ringwald# distributed under the License is distributed on an "AS IS" BASIS, 12*4930cef6SMatthias Ringwald# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13*4930cef6SMatthias Ringwald# See the License for the specific language governing permissions and 14*4930cef6SMatthias Ringwald# limitations under the License. 15*4930cef6SMatthias Ringwald# 16*4930cef6SMatthias Ringwald 17*4930cef6SMatthias Ringwaldimport math 18*4930cef6SMatthias Ringwald 19*4930cef6SMatthias Ringwaldclass Bitstream: 20*4930cef6SMatthias Ringwald 21*4930cef6SMatthias Ringwald def __init__(self, data): 22*4930cef6SMatthias Ringwald 23*4930cef6SMatthias Ringwald self.bytes = data 24*4930cef6SMatthias Ringwald 25*4930cef6SMatthias Ringwald self.bp_bw = len(data) - 1 26*4930cef6SMatthias Ringwald self.mask_bw = 1 27*4930cef6SMatthias Ringwald 28*4930cef6SMatthias Ringwald self.bp = 0 29*4930cef6SMatthias Ringwald self.low = 0 30*4930cef6SMatthias Ringwald self.range = 0xffffff 31*4930cef6SMatthias Ringwald 32*4930cef6SMatthias Ringwald def dump(self): 33*4930cef6SMatthias Ringwald 34*4930cef6SMatthias Ringwald b = self.bytes 35*4930cef6SMatthias Ringwald 36*4930cef6SMatthias Ringwald for i in range(0, len(b), 20): 37*4930cef6SMatthias Ringwald print(''.join('{:02x} '.format(x) 38*4930cef6SMatthias Ringwald for x in b[i:min(i+20, len(b))] )) 39*4930cef6SMatthias Ringwald 40*4930cef6SMatthias Ringwaldclass BitstreamReader(Bitstream): 41*4930cef6SMatthias Ringwald 42*4930cef6SMatthias Ringwald def __init__(self, data): 43*4930cef6SMatthias Ringwald 44*4930cef6SMatthias Ringwald super().__init__(data) 45*4930cef6SMatthias Ringwald 46*4930cef6SMatthias Ringwald self.low = ( (self.bytes[0] << 16) | 47*4930cef6SMatthias Ringwald (self.bytes[1] << 8) | 48*4930cef6SMatthias Ringwald (self.bytes[2] ) ) 49*4930cef6SMatthias Ringwald self.bp = 3 50*4930cef6SMatthias Ringwald 51*4930cef6SMatthias Ringwald def read_bit(self): 52*4930cef6SMatthias Ringwald 53*4930cef6SMatthias Ringwald bit = bool(self.bytes[self.bp_bw] & self.mask_bw) 54*4930cef6SMatthias Ringwald 55*4930cef6SMatthias Ringwald self.mask_bw <<= 1 56*4930cef6SMatthias Ringwald if self.mask_bw == 0x100: 57*4930cef6SMatthias Ringwald self.mask_bw = 1 58*4930cef6SMatthias Ringwald self.bp_bw -= 1 59*4930cef6SMatthias Ringwald 60*4930cef6SMatthias Ringwald return bit 61*4930cef6SMatthias Ringwald 62*4930cef6SMatthias Ringwald def read_uint(self, nbits): 63*4930cef6SMatthias Ringwald 64*4930cef6SMatthias Ringwald val = 0 65*4930cef6SMatthias Ringwald for k in range(nbits): 66*4930cef6SMatthias Ringwald val |= self.read_bit() << k 67*4930cef6SMatthias Ringwald 68*4930cef6SMatthias Ringwald return val 69*4930cef6SMatthias Ringwald 70*4930cef6SMatthias Ringwald def ac_decode(self, cum_freqs, sym_freqs): 71*4930cef6SMatthias Ringwald 72*4930cef6SMatthias Ringwald r = self.range >> 10 73*4930cef6SMatthias Ringwald if self.low >= r << 10: 74*4930cef6SMatthias Ringwald raise ValueError('Invalid ac bitstream') 75*4930cef6SMatthias Ringwald 76*4930cef6SMatthias Ringwald val = len(cum_freqs) - 1 77*4930cef6SMatthias Ringwald while self.low < r * cum_freqs[val]: 78*4930cef6SMatthias Ringwald val -= 1 79*4930cef6SMatthias Ringwald 80*4930cef6SMatthias Ringwald self.low -= r * cum_freqs[val] 81*4930cef6SMatthias Ringwald self.range = r * sym_freqs[val] 82*4930cef6SMatthias Ringwald while self.range < 0x10000: 83*4930cef6SMatthias Ringwald self.range <<= 8 84*4930cef6SMatthias Ringwald 85*4930cef6SMatthias Ringwald self.low <<= 8 86*4930cef6SMatthias Ringwald self.low &= 0xffffff 87*4930cef6SMatthias Ringwald self.low += self.bytes[self.bp] 88*4930cef6SMatthias Ringwald self.bp += 1 89*4930cef6SMatthias Ringwald 90*4930cef6SMatthias Ringwald return val 91*4930cef6SMatthias Ringwald 92*4930cef6SMatthias Ringwald def get_bits_left(self): 93*4930cef6SMatthias Ringwald 94*4930cef6SMatthias Ringwald nbits = 8 * len(self.bytes) 95*4930cef6SMatthias Ringwald 96*4930cef6SMatthias Ringwald nbits_bw = nbits - \ 97*4930cef6SMatthias Ringwald (8*self.bp_bw + 8 - int(math.log2(self.mask_bw))) 98*4930cef6SMatthias Ringwald 99*4930cef6SMatthias Ringwald nbits_ac = 8 * (self.bp - 3) + \ 100*4930cef6SMatthias Ringwald (25 - int(math.floor(math.log2(self.range)))) 101*4930cef6SMatthias Ringwald 102*4930cef6SMatthias Ringwald return nbits - (nbits_bw + nbits_ac) 103*4930cef6SMatthias Ringwald 104*4930cef6SMatthias Ringwaldclass BitstreamWriter(Bitstream): 105*4930cef6SMatthias Ringwald 106*4930cef6SMatthias Ringwald def __init__(self, nbytes): 107*4930cef6SMatthias Ringwald 108*4930cef6SMatthias Ringwald super().__init__(bytearray(nbytes)) 109*4930cef6SMatthias Ringwald 110*4930cef6SMatthias Ringwald self.cache = -1 111*4930cef6SMatthias Ringwald self.carry = 0 112*4930cef6SMatthias Ringwald self.carry_count = 0 113*4930cef6SMatthias Ringwald 114*4930cef6SMatthias Ringwald def write_bit(self, bit): 115*4930cef6SMatthias Ringwald 116*4930cef6SMatthias Ringwald mask = self.mask_bw 117*4930cef6SMatthias Ringwald bp = self.bp_bw 118*4930cef6SMatthias Ringwald 119*4930cef6SMatthias Ringwald if bit == 0: 120*4930cef6SMatthias Ringwald self.bytes[bp] &= ~mask 121*4930cef6SMatthias Ringwald else: 122*4930cef6SMatthias Ringwald self.bytes[bp] |= mask 123*4930cef6SMatthias Ringwald 124*4930cef6SMatthias Ringwald self.mask_bw <<= 1 125*4930cef6SMatthias Ringwald if self.mask_bw == 0x100: 126*4930cef6SMatthias Ringwald self.mask_bw = 1 127*4930cef6SMatthias Ringwald self.bp_bw -= 1 128*4930cef6SMatthias Ringwald 129*4930cef6SMatthias Ringwald def write_uint(self, val, nbits): 130*4930cef6SMatthias Ringwald 131*4930cef6SMatthias Ringwald for k in range(nbits): 132*4930cef6SMatthias Ringwald self.write_bit(val & 1) 133*4930cef6SMatthias Ringwald val >>= 1 134*4930cef6SMatthias Ringwald 135*4930cef6SMatthias Ringwald def ac_shift(self): 136*4930cef6SMatthias Ringwald 137*4930cef6SMatthias Ringwald if self.low < 0xff0000 or self.carry == 1: 138*4930cef6SMatthias Ringwald 139*4930cef6SMatthias Ringwald if self.cache >= 0: 140*4930cef6SMatthias Ringwald self.bytes[self.bp] = self.cache + self.carry 141*4930cef6SMatthias Ringwald self.bp += 1 142*4930cef6SMatthias Ringwald 143*4930cef6SMatthias Ringwald while self.carry_count > 0: 144*4930cef6SMatthias Ringwald self.bytes[self.bp] = (self.carry + 0xff) & 0xff 145*4930cef6SMatthias Ringwald self.bp += 1 146*4930cef6SMatthias Ringwald self.carry_count -= 1 147*4930cef6SMatthias Ringwald 148*4930cef6SMatthias Ringwald self.cache = self.low >> 16 149*4930cef6SMatthias Ringwald self.carry = 0 150*4930cef6SMatthias Ringwald 151*4930cef6SMatthias Ringwald else: 152*4930cef6SMatthias Ringwald self.carry_count += 1 153*4930cef6SMatthias Ringwald 154*4930cef6SMatthias Ringwald self.low <<= 8 155*4930cef6SMatthias Ringwald self.low &= 0xffffff 156*4930cef6SMatthias Ringwald 157*4930cef6SMatthias Ringwald def ac_encode(self, cum_freq, sym_freq): 158*4930cef6SMatthias Ringwald 159*4930cef6SMatthias Ringwald r = self.range >> 10 160*4930cef6SMatthias Ringwald self.low += r * cum_freq 161*4930cef6SMatthias Ringwald if (self.low >> 24) != 0: 162*4930cef6SMatthias Ringwald self.carry = 1 163*4930cef6SMatthias Ringwald 164*4930cef6SMatthias Ringwald self.low &= 0xffffff 165*4930cef6SMatthias Ringwald self.range = r * sym_freq 166*4930cef6SMatthias Ringwald while self.range < 0x10000: 167*4930cef6SMatthias Ringwald self.range <<= 8; 168*4930cef6SMatthias Ringwald self.ac_shift() 169*4930cef6SMatthias Ringwald 170*4930cef6SMatthias Ringwald def get_bits_left(self): 171*4930cef6SMatthias Ringwald 172*4930cef6SMatthias Ringwald nbits = 8 * len(self.bytes) 173*4930cef6SMatthias Ringwald 174*4930cef6SMatthias Ringwald nbits_bw = nbits - \ 175*4930cef6SMatthias Ringwald (8*self.bp_bw + 8 - int(math.log2(self.mask_bw))) 176*4930cef6SMatthias Ringwald 177*4930cef6SMatthias Ringwald nbits_ac = 8 * self.bp + (25 - int(math.floor(math.log2(self.range)))) 178*4930cef6SMatthias Ringwald if self.cache >= 0: 179*4930cef6SMatthias Ringwald nbits_ac += 8 180*4930cef6SMatthias Ringwald if self.carry_count > 0: 181*4930cef6SMatthias Ringwald nbits_ac += 8 * self.carry_count 182*4930cef6SMatthias Ringwald 183*4930cef6SMatthias Ringwald return nbits - (nbits_bw + nbits_ac) 184*4930cef6SMatthias Ringwald 185*4930cef6SMatthias Ringwald def terminate(self): 186*4930cef6SMatthias Ringwald 187*4930cef6SMatthias Ringwald bits = 1 188*4930cef6SMatthias Ringwald while self.range >> (24 - bits) == 0: 189*4930cef6SMatthias Ringwald bits += 1 190*4930cef6SMatthias Ringwald 191*4930cef6SMatthias Ringwald mask = 0xffffff >> bits; 192*4930cef6SMatthias Ringwald val = self.low + mask; 193*4930cef6SMatthias Ringwald 194*4930cef6SMatthias Ringwald over1 = val >> 24 195*4930cef6SMatthias Ringwald val &= 0x00ffffff 196*4930cef6SMatthias Ringwald high = self.low + self.range 197*4930cef6SMatthias Ringwald over2 = high >> 24 198*4930cef6SMatthias Ringwald high &= 0x00ffffff 199*4930cef6SMatthias Ringwald val = val & ~mask 200*4930cef6SMatthias Ringwald 201*4930cef6SMatthias Ringwald if over1 == over2: 202*4930cef6SMatthias Ringwald 203*4930cef6SMatthias Ringwald if val + mask >= high: 204*4930cef6SMatthias Ringwald bits += 1 205*4930cef6SMatthias Ringwald mask >>= 1 206*4930cef6SMatthias Ringwald val = ((self.low + mask) & 0x00ffffff) & ~mask 207*4930cef6SMatthias Ringwald 208*4930cef6SMatthias Ringwald if val < self.low: 209*4930cef6SMatthias Ringwald self.carry = 1 210*4930cef6SMatthias Ringwald 211*4930cef6SMatthias Ringwald self.low = val 212*4930cef6SMatthias Ringwald while bits > 0: 213*4930cef6SMatthias Ringwald self.ac_shift() 214*4930cef6SMatthias Ringwald bits -= 8 215*4930cef6SMatthias Ringwald bits += 8; 216*4930cef6SMatthias Ringwald 217*4930cef6SMatthias Ringwald val = self.cache 218*4930cef6SMatthias Ringwald 219*4930cef6SMatthias Ringwald if self.carry_count > 0: 220*4930cef6SMatthias Ringwald self.bytes[self.bp] = self.cache 221*4930cef6SMatthias Ringwald self.bp += 1 222*4930cef6SMatthias Ringwald 223*4930cef6SMatthias Ringwald while self.carry_count > 1: 224*4930cef6SMatthias Ringwald self.bytes[self.bp] = 0xff 225*4930cef6SMatthias Ringwald self.bp += 1 226*4930cef6SMatthias Ringwald self.carry_count -= 1 227*4930cef6SMatthias Ringwald 228*4930cef6SMatthias Ringwald val = 0xff >> (8 - bits) 229*4930cef6SMatthias Ringwald 230*4930cef6SMatthias Ringwald mask = 0x80; 231*4930cef6SMatthias Ringwald for k in range(bits): 232*4930cef6SMatthias Ringwald 233*4930cef6SMatthias Ringwald if val & mask == 0: 234*4930cef6SMatthias Ringwald self.bytes[self.bp] &= ~mask 235*4930cef6SMatthias Ringwald else: 236*4930cef6SMatthias Ringwald self.bytes[self.bp] |= mask 237*4930cef6SMatthias Ringwald 238*4930cef6SMatthias Ringwald mask >>= 1 239*4930cef6SMatthias Ringwald 240*4930cef6SMatthias Ringwald return self.bytes 241