1// Copyright 2019 The Android Open Source Project
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// For implementation details, please refer to:
16// https://www.khronos.org/registry/OpenGL/extensions/KHR/KHR_texture_compression_astc_hdr.txt
17
18// Please refer to this document for operator precendence (slightly different from C):
19// https://www.khronos.org/registry/OpenGL/specs/gl/GLSLangSpec.4.60.html#operators
20
21#version 450
22#include "AstcUnquantMap.comp"
23#include "Common.comp"
24
25precision highp int;
26
27layout(local_size_x = 8, local_size_y = 8, local_size_z = 1) in;
28
29layout(push_constant) uniform ImageFormatBlock {
30    uvec2 blockSize;
31    uint baseLayer;
32    uint smallBlock;
33}
34u_pushConstant;
35
36layout(binding = 0, rgba32ui) readonly uniform WITH_TYPE(uimage) u_image0;
37layout(binding = 1, rgba8ui) writeonly uniform WITH_TYPE(uimage) u_image1;
38
39// HDR CEM: 2, 3, 7, 11, 14, 15
40
41const bool kHDRCEM[16] = {
42    false, false, true,  true, false, false, false, true,
43    false, false, false, true, false, false, true,  true,
44};
45
46// Encoding table for C.2.12
47
48const uint kTritEncodings[256][5] = {
49    {0, 0, 0, 0, 0}, {1, 0, 0, 0, 0}, {2, 0, 0, 0, 0}, {0, 0, 2, 0, 0}, {0, 1, 0, 0, 0},
50    {1, 1, 0, 0, 0}, {2, 1, 0, 0, 0}, {1, 0, 2, 0, 0}, {0, 2, 0, 0, 0}, {1, 2, 0, 0, 0},
51    {2, 2, 0, 0, 0}, {2, 0, 2, 0, 0}, {0, 2, 2, 0, 0}, {1, 2, 2, 0, 0}, {2, 2, 2, 0, 0},
52    {2, 0, 2, 0, 0}, {0, 0, 1, 0, 0}, {1, 0, 1, 0, 0}, {2, 0, 1, 0, 0}, {0, 1, 2, 0, 0},
53    {0, 1, 1, 0, 0}, {1, 1, 1, 0, 0}, {2, 1, 1, 0, 0}, {1, 1, 2, 0, 0}, {0, 2, 1, 0, 0},
54    {1, 2, 1, 0, 0}, {2, 2, 1, 0, 0}, {2, 1, 2, 0, 0}, {0, 0, 0, 2, 2}, {1, 0, 0, 2, 2},
55    {2, 0, 0, 2, 2}, {0, 0, 2, 2, 2}, {0, 0, 0, 1, 0}, {1, 0, 0, 1, 0}, {2, 0, 0, 1, 0},
56    {0, 0, 2, 1, 0}, {0, 1, 0, 1, 0}, {1, 1, 0, 1, 0}, {2, 1, 0, 1, 0}, {1, 0, 2, 1, 0},
57    {0, 2, 0, 1, 0}, {1, 2, 0, 1, 0}, {2, 2, 0, 1, 0}, {2, 0, 2, 1, 0}, {0, 2, 2, 1, 0},
58    {1, 2, 2, 1, 0}, {2, 2, 2, 1, 0}, {2, 0, 2, 1, 0}, {0, 0, 1, 1, 0}, {1, 0, 1, 1, 0},
59    {2, 0, 1, 1, 0}, {0, 1, 2, 1, 0}, {0, 1, 1, 1, 0}, {1, 1, 1, 1, 0}, {2, 1, 1, 1, 0},
60    {1, 1, 2, 1, 0}, {0, 2, 1, 1, 0}, {1, 2, 1, 1, 0}, {2, 2, 1, 1, 0}, {2, 1, 2, 1, 0},
61    {0, 1, 0, 2, 2}, {1, 1, 0, 2, 2}, {2, 1, 0, 2, 2}, {1, 0, 2, 2, 2}, {0, 0, 0, 2, 0},
62    {1, 0, 0, 2, 0}, {2, 0, 0, 2, 0}, {0, 0, 2, 2, 0}, {0, 1, 0, 2, 0}, {1, 1, 0, 2, 0},
63    {2, 1, 0, 2, 0}, {1, 0, 2, 2, 0}, {0, 2, 0, 2, 0}, {1, 2, 0, 2, 0}, {2, 2, 0, 2, 0},
64    {2, 0, 2, 2, 0}, {0, 2, 2, 2, 0}, {1, 2, 2, 2, 0}, {2, 2, 2, 2, 0}, {2, 0, 2, 2, 0},
65    {0, 0, 1, 2, 0}, {1, 0, 1, 2, 0}, {2, 0, 1, 2, 0}, {0, 1, 2, 2, 0}, {0, 1, 1, 2, 0},
66    {1, 1, 1, 2, 0}, {2, 1, 1, 2, 0}, {1, 1, 2, 2, 0}, {0, 2, 1, 2, 0}, {1, 2, 1, 2, 0},
67    {2, 2, 1, 2, 0}, {2, 1, 2, 2, 0}, {0, 2, 0, 2, 2}, {1, 2, 0, 2, 2}, {2, 2, 0, 2, 2},
68    {2, 0, 2, 2, 2}, {0, 0, 0, 0, 2}, {1, 0, 0, 0, 2}, {2, 0, 0, 0, 2}, {0, 0, 2, 0, 2},
69    {0, 1, 0, 0, 2}, {1, 1, 0, 0, 2}, {2, 1, 0, 0, 2}, {1, 0, 2, 0, 2}, {0, 2, 0, 0, 2},
70    {1, 2, 0, 0, 2}, {2, 2, 0, 0, 2}, {2, 0, 2, 0, 2}, {0, 2, 2, 0, 2}, {1, 2, 2, 0, 2},
71    {2, 2, 2, 0, 2}, {2, 0, 2, 0, 2}, {0, 0, 1, 0, 2}, {1, 0, 1, 0, 2}, {2, 0, 1, 0, 2},
72    {0, 1, 2, 0, 2}, {0, 1, 1, 0, 2}, {1, 1, 1, 0, 2}, {2, 1, 1, 0, 2}, {1, 1, 2, 0, 2},
73    {0, 2, 1, 0, 2}, {1, 2, 1, 0, 2}, {2, 2, 1, 0, 2}, {2, 1, 2, 0, 2}, {0, 2, 2, 2, 2},
74    {1, 2, 2, 2, 2}, {2, 2, 2, 2, 2}, {2, 0, 2, 2, 2}, {0, 0, 0, 0, 1}, {1, 0, 0, 0, 1},
75    {2, 0, 0, 0, 1}, {0, 0, 2, 0, 1}, {0, 1, 0, 0, 1}, {1, 1, 0, 0, 1}, {2, 1, 0, 0, 1},
76    {1, 0, 2, 0, 1}, {0, 2, 0, 0, 1}, {1, 2, 0, 0, 1}, {2, 2, 0, 0, 1}, {2, 0, 2, 0, 1},
77    {0, 2, 2, 0, 1}, {1, 2, 2, 0, 1}, {2, 2, 2, 0, 1}, {2, 0, 2, 0, 1}, {0, 0, 1, 0, 1},
78    {1, 0, 1, 0, 1}, {2, 0, 1, 0, 1}, {0, 1, 2, 0, 1}, {0, 1, 1, 0, 1}, {1, 1, 1, 0, 1},
79    {2, 1, 1, 0, 1}, {1, 1, 2, 0, 1}, {0, 2, 1, 0, 1}, {1, 2, 1, 0, 1}, {2, 2, 1, 0, 1},
80    {2, 1, 2, 0, 1}, {0, 0, 1, 2, 2}, {1, 0, 1, 2, 2}, {2, 0, 1, 2, 2}, {0, 1, 2, 2, 2},
81    {0, 0, 0, 1, 1}, {1, 0, 0, 1, 1}, {2, 0, 0, 1, 1}, {0, 0, 2, 1, 1}, {0, 1, 0, 1, 1},
82    {1, 1, 0, 1, 1}, {2, 1, 0, 1, 1}, {1, 0, 2, 1, 1}, {0, 2, 0, 1, 1}, {1, 2, 0, 1, 1},
83    {2, 2, 0, 1, 1}, {2, 0, 2, 1, 1}, {0, 2, 2, 1, 1}, {1, 2, 2, 1, 1}, {2, 2, 2, 1, 1},
84    {2, 0, 2, 1, 1}, {0, 0, 1, 1, 1}, {1, 0, 1, 1, 1}, {2, 0, 1, 1, 1}, {0, 1, 2, 1, 1},
85    {0, 1, 1, 1, 1}, {1, 1, 1, 1, 1}, {2, 1, 1, 1, 1}, {1, 1, 2, 1, 1}, {0, 2, 1, 1, 1},
86    {1, 2, 1, 1, 1}, {2, 2, 1, 1, 1}, {2, 1, 2, 1, 1}, {0, 1, 1, 2, 2}, {1, 1, 1, 2, 2},
87    {2, 1, 1, 2, 2}, {1, 1, 2, 2, 2}, {0, 0, 0, 2, 1}, {1, 0, 0, 2, 1}, {2, 0, 0, 2, 1},
88    {0, 0, 2, 2, 1}, {0, 1, 0, 2, 1}, {1, 1, 0, 2, 1}, {2, 1, 0, 2, 1}, {1, 0, 2, 2, 1},
89    {0, 2, 0, 2, 1}, {1, 2, 0, 2, 1}, {2, 2, 0, 2, 1}, {2, 0, 2, 2, 1}, {0, 2, 2, 2, 1},
90    {1, 2, 2, 2, 1}, {2, 2, 2, 2, 1}, {2, 0, 2, 2, 1}, {0, 0, 1, 2, 1}, {1, 0, 1, 2, 1},
91    {2, 0, 1, 2, 1}, {0, 1, 2, 2, 1}, {0, 1, 1, 2, 1}, {1, 1, 1, 2, 1}, {2, 1, 1, 2, 1},
92    {1, 1, 2, 2, 1}, {0, 2, 1, 2, 1}, {1, 2, 1, 2, 1}, {2, 2, 1, 2, 1}, {2, 1, 2, 2, 1},
93    {0, 2, 1, 2, 2}, {1, 2, 1, 2, 2}, {2, 2, 1, 2, 2}, {2, 1, 2, 2, 2}, {0, 0, 0, 1, 2},
94    {1, 0, 0, 1, 2}, {2, 0, 0, 1, 2}, {0, 0, 2, 1, 2}, {0, 1, 0, 1, 2}, {1, 1, 0, 1, 2},
95    {2, 1, 0, 1, 2}, {1, 0, 2, 1, 2}, {0, 2, 0, 1, 2}, {1, 2, 0, 1, 2}, {2, 2, 0, 1, 2},
96    {2, 0, 2, 1, 2}, {0, 2, 2, 1, 2}, {1, 2, 2, 1, 2}, {2, 2, 2, 1, 2}, {2, 0, 2, 1, 2},
97    {0, 0, 1, 1, 2}, {1, 0, 1, 1, 2}, {2, 0, 1, 1, 2}, {0, 1, 2, 1, 2}, {0, 1, 1, 1, 2},
98    {1, 1, 1, 1, 2}, {2, 1, 1, 1, 2}, {1, 1, 2, 1, 2}, {0, 2, 1, 1, 2}, {1, 2, 1, 1, 2},
99    {2, 2, 1, 1, 2}, {2, 1, 2, 1, 2}, {0, 2, 2, 2, 2}, {1, 2, 2, 2, 2}, {2, 2, 2, 2, 2},
100    {2, 1, 2, 2, 2},
101};
102
103const uint kQuintEncodings[128][3] = {
104    {0, 0, 0}, {1, 0, 0}, {2, 0, 0}, {3, 0, 0}, {4, 0, 0}, {0, 4, 0}, {4, 4, 0}, {4, 4, 4},
105    {0, 1, 0}, {1, 1, 0}, {2, 1, 0}, {3, 1, 0}, {4, 1, 0}, {1, 4, 0}, {4, 4, 1}, {4, 4, 4},
106    {0, 2, 0}, {1, 2, 0}, {2, 2, 0}, {3, 2, 0}, {4, 2, 0}, {2, 4, 0}, {4, 4, 2}, {4, 4, 4},
107    {0, 3, 0}, {1, 3, 0}, {2, 3, 0}, {3, 3, 0}, {4, 3, 0}, {3, 4, 0}, {4, 4, 3}, {4, 4, 4},
108    {0, 0, 1}, {1, 0, 1}, {2, 0, 1}, {3, 0, 1}, {4, 0, 1}, {0, 4, 1}, {4, 0, 4}, {0, 4, 4},
109    {0, 1, 1}, {1, 1, 1}, {2, 1, 1}, {3, 1, 1}, {4, 1, 1}, {1, 4, 1}, {4, 1, 4}, {1, 4, 4},
110    {0, 2, 1}, {1, 2, 1}, {2, 2, 1}, {3, 2, 1}, {4, 2, 1}, {2, 4, 1}, {4, 2, 4}, {2, 4, 4},
111    {0, 3, 1}, {1, 3, 1}, {2, 3, 1}, {3, 3, 1}, {4, 3, 1}, {3, 4, 1}, {4, 3, 4}, {3, 4, 4},
112    {0, 0, 2}, {1, 0, 2}, {2, 0, 2}, {3, 0, 2}, {4, 0, 2}, {0, 4, 2}, {2, 0, 4}, {3, 0, 4},
113    {0, 1, 2}, {1, 1, 2}, {2, 1, 2}, {3, 1, 2}, {4, 1, 2}, {1, 4, 2}, {2, 1, 4}, {3, 1, 4},
114    {0, 2, 2}, {1, 2, 2}, {2, 2, 2}, {3, 2, 2}, {4, 2, 2}, {2, 4, 2}, {2, 2, 4}, {3, 2, 4},
115    {0, 3, 2}, {1, 3, 2}, {2, 3, 2}, {3, 3, 2}, {4, 3, 2}, {3, 4, 2}, {2, 3, 4}, {3, 3, 4},
116    {0, 0, 3}, {1, 0, 3}, {2, 0, 3}, {3, 0, 3}, {4, 0, 3}, {0, 4, 3}, {0, 0, 4}, {1, 0, 4},
117    {0, 1, 3}, {1, 1, 3}, {2, 1, 3}, {3, 1, 3}, {4, 1, 3}, {1, 4, 3}, {0, 1, 4}, {1, 1, 4},
118    {0, 2, 3}, {1, 2, 3}, {2, 2, 3}, {3, 2, 3}, {4, 2, 3}, {2, 4, 3}, {0, 2, 4}, {1, 2, 4},
119    {0, 3, 3}, {1, 3, 3}, {2, 3, 3}, {3, 3, 3}, {4, 3, 3}, {3, 4, 3}, {0, 3, 4}, {1, 3, 4}};
120
121const int kRQuantParamTableLength = 19;
122// T, Q, B values in Table c.2.16, including binaries, in reversed order
123const uint kRQuantParamTable[kRQuantParamTableLength][3] = {
124    {0, 0, 8},  // 255
125    {1, 0, 6},  // 191
126    {0, 1, 5},  // 159
127    {0, 0, 7},  // 127
128    {1, 0, 5},  // 95
129    {0, 1, 4},  // 79
130    {0, 0, 6},  // 63
131    {1, 0, 4},  // 47
132    {0, 1, 3},  // 39
133    {0, 0, 5},  // 31
134    {1, 0, 3},  // 23
135    {0, 1, 2},  // 19
136    {0, 0, 4},  // 15
137    {1, 0, 2},  // 11
138    {0, 1, 1},  // 9
139    {0, 0, 3},  // 7
140    {1, 0, 1},  // 5
141    //{0, 1, 0}, // 4
142    {0, 0, 2},  // 3
143    //{1, 0, 0}, // 2
144    {0, 0, 1},  // 1
145};
146
147uint bit(uint u, int bit) { return (u >> bit) & 1; }
148
149uint bits128(uvec4 u, uint bitStart, uint bitCount) {
150    uint firstIdx = bitStart / 32;
151    uint firstOffset = bitStart % 32;
152    uint bitMask = (1 << bitCount) - 1;
153    if (firstIdx == ((bitStart + bitCount - 1) / 32)) {
154        return (u[3 - firstIdx] >> firstOffset) & bitMask;
155    } else {
156        uint firstCount = 32 - firstOffset;
157        uint ret = u[3 - firstIdx - 1] << firstCount;
158        ret |= ((u[3 - firstIdx] >> firstOffset) & ((1 << firstCount) - 1));
159        return ret & bitMask;
160    }
161}
162
163uint bits128fillZeros(uvec4 u, uint bitStart, uint bitEnd, uint bitCount) {
164    if (bitEnd <= bitStart) {
165        return 0;
166    }
167    return bits128(u, bitStart, min(bitEnd - bitStart, bitCount));
168}
169
170uint get_bit_count(uint num_vals, uint trits, uint quints, uint bits) {
171    // See section C.2.22 for the formula used here.
172    uint trit_bit_count = ((num_vals * 8 * trits) + 4) / 5;
173    uint quint_bit_count = ((num_vals * 7 * quints) + 2) / 3;
174    uint base_bit_count = num_vals * bits;
175    return trit_bit_count + quint_bit_count + base_bit_count;
176}
177
178void get_pack_size(uint trits, uint quints, uint bits, out uint pack, out uint packedSize) {
179    if (trits == 1) {
180        pack = 5;
181        packedSize = 8 + 5 * bits;
182    } else if (quints == 1) {
183        pack = 3;
184        packedSize = 7 + 3 * bits;
185    } else {
186        pack = 1;
187        packedSize = bits;
188    }
189}
190
191uint[5] decode_trit(uvec4 data, uint start, uint end, uint n) {
192    // We either have three quints or five trits
193    const int kNumVals = 5;
194    const int kInterleavedBits[5] = {2, 2, 1, 2, 1};
195
196    // Decode the block
197    uint m[kNumVals];
198    uint encoded = 0;
199    uint encoded_bits_read = 0;
200    for (int i = 0; i < kNumVals; ++i) {
201        m[i] = bits128fillZeros(data, start, end, n);
202        start += n;
203
204        uint encoded_bits = bits128fillZeros(data, start, end, kInterleavedBits[i]);
205        start += kInterleavedBits[i];
206        encoded |= encoded_bits << encoded_bits_read;
207        encoded_bits_read += kInterleavedBits[i];
208    }
209
210    uint[kNumVals] result;
211    for (int i = 0; i < kNumVals; ++i) {
212        result[i] = kTritEncodings[encoded][i] << n | m[i];
213    }
214    return result;
215}
216
217uint[3] decode_quint(uvec4 data, uint start, uint end, uint n) {
218    // We either have three quints or five trits
219    const int kNumVals = 3;
220    const int kInterleavedBits[3] = {3, 2, 2};
221
222    // Decode the block
223    uint m[kNumVals];
224    uint encoded = 0;
225    uint encoded_bits_read = 0;
226    uint bitMask = (1 << n) - 1;
227    for (int i = 0; i < kNumVals; ++i) {
228        m[i] = bits128fillZeros(data, start, end, n);
229        start += n;
230
231        uint encoded_bits = bits128fillZeros(data, start, end, kInterleavedBits[i]);
232        start += kInterleavedBits[i];
233        encoded |= encoded_bits << encoded_bits_read;
234        encoded_bits_read += kInterleavedBits[i];
235    }
236
237    uint[kNumVals] result;
238    for (int i = 0; i < kNumVals; ++i) {
239        result[i] = kQuintEncodings[encoded][i] << n | m[i];
240    }
241    return result;
242}
243
244uint get_v_count(uint cem) { return (cem / 4 + 1) * 2; }
245
246const uint kLDRLumaDirect = 0;
247const uint kLDRLumaBaseOffset = 1;
248const uint kHDRLumaLargeRange = 2;
249const uint kHDRLumaSmallRange = 3;
250const uint kLDRLumaAlphaDirect = 4;
251const uint kLDRLumaAlphaBaseOffset = 5;
252const uint kLDRRGBBaseScale = 6;
253const uint kHDRRGBBaseScale = 7;
254const uint kLDRRGBDirect = 8;
255const uint kLDRRGBBaseOffset = 9;
256const uint kLDRRGBBaseScaleTwoA = 10;
257const uint kHDRRGBDirect = 11;
258const uint kLDRRGBADirect = 12;
259const uint kLDRRGBABaseOffset = 13;
260const uint kHDRRGBDirectLDRAlpha = 14;
261const uint kHDRRGBDirectHDRAlpha = 15;
262
263void swap(inout ivec4 v1, inout ivec4 v2) {
264    ivec4 tmp = v1;
265    v1 = v2;
266    v2 = tmp;
267}
268
269void bit_transfer_signed(inout int a, inout int b) {
270    b >>= 1;
271    b |= (a & 0x80);
272    a >>= 1;
273    a &= 0x3F;
274    if ((a & 0x20) != 0) a -= 0x40;
275}
276
277void blue_contract(inout ivec4 val) {
278    val.r = (val.r + val.b) / 2;
279    val.g = (val.g + val.b) / 2;
280}
281
282void decode_ldr_for_mode(const uint[40] vals, uint start_idx, uint mode, out uvec4 c1,
283                         out uvec4 c2) {
284    int v0 = int(vals[start_idx + 0]);
285    int v1 = int(vals[start_idx + 1]);
286    int v2 = int(vals[start_idx + 2]);
287    int v3 = int(vals[start_idx + 3]);
288    int v4 = int(vals[start_idx + 4]);
289    int v5 = int(vals[start_idx + 5]);
290    int v6 = int(vals[start_idx + 6]);
291    int v7 = int(vals[start_idx + 7]);
292    ivec4 endpoint_low_rgba;
293    ivec4 endpoint_high_rgba;
294    switch (mode) {
295        case kLDRLumaDirect: {
296            endpoint_low_rgba = ivec4(v0, v0, v0, 255);
297            endpoint_high_rgba = ivec4(v1, v1, v1, 255);
298        } break;
299
300        case kLDRLumaBaseOffset: {
301            const int l0 = (v0 >> 2) | (v1 & 0xC0);
302            const int l1 = min(l0 + (v1 & 0x3F), 0xFF);
303
304            endpoint_low_rgba = ivec4(l0, l0, l0, 255);
305            endpoint_high_rgba = ivec4(l1, l1, l1, 255);
306        } break;
307
308        case kLDRLumaAlphaDirect: {
309            endpoint_low_rgba = ivec4(v0, v0, v0, v2);
310            endpoint_high_rgba = ivec4(v1, v1, v1, v3);
311        } break;
312
313        case kLDRLumaAlphaBaseOffset: {
314            bit_transfer_signed(v1, v0);
315            bit_transfer_signed(v3, v2);
316
317            endpoint_low_rgba = clamp(ivec4(v0, v0, v0, v2), 0, 255);
318            const int high_luma = v0 + v1;
319            endpoint_high_rgba = clamp(ivec4(high_luma, high_luma, high_luma, v2 + v3), 0, 255);
320        } break;
321
322        case kLDRRGBBaseScale: {
323            endpoint_high_rgba = ivec4(v0, v1, v2, 255);
324            for (int i = 0; i < 3; ++i) {
325                const int x = endpoint_high_rgba[i];
326                endpoint_low_rgba[i] = (x * v3) >> 8;
327            }
328            endpoint_low_rgba[3] = 255;
329        } break;
330
331        case kLDRRGBDirect: {
332            const int s0 = v0 + v2 + v4;
333            const int s1 = v1 + v3 + v5;
334
335            endpoint_low_rgba = ivec4(v0, v2, v4, 255);
336            endpoint_high_rgba = ivec4(v1, v3, v5, 255);
337
338            if (s1 < s0) {
339                swap(endpoint_low_rgba, endpoint_high_rgba);
340                blue_contract(endpoint_low_rgba);
341                blue_contract(endpoint_high_rgba);
342            }
343        } break;
344
345        case kLDRRGBBaseOffset: {
346            bit_transfer_signed(v1, v0);
347            bit_transfer_signed(v3, v2);
348            bit_transfer_signed(v5, v4);
349
350            endpoint_low_rgba = ivec4(v0, v2, v4, 255);
351            endpoint_high_rgba = ivec4(v0 + v1, v2 + v3, v4 + v5, 255);
352
353            if (v1 + v3 + v5 < 0) {
354                swap(endpoint_low_rgba, endpoint_high_rgba);
355                blue_contract(endpoint_low_rgba);
356                blue_contract(endpoint_high_rgba);
357            }
358
359            endpoint_low_rgba = clamp(endpoint_low_rgba, 0, 255);
360            endpoint_high_rgba = clamp(endpoint_high_rgba, 0, 255);
361        } break;
362
363        case kLDRRGBBaseScaleTwoA: {
364            // Base
365            endpoint_low_rgba = endpoint_high_rgba = ivec4(v0, v1, v2, 255);
366
367            // Scale
368            endpoint_low_rgba = (endpoint_low_rgba * v3) >> 8;
369
370            // Two A
371            endpoint_low_rgba[3] = v4;
372            endpoint_high_rgba[3] = v5;
373        } break;
374
375        case kLDRRGBADirect: {
376            const uint s0 = v0 + v2 + v4;
377            const uint s1 = v1 + v3 + v5;
378
379            endpoint_low_rgba = ivec4(v0, v2, v4, v6);
380            endpoint_high_rgba = ivec4(v1, v3, v5, v7);
381
382            if (s1 < s0) {
383                swap(endpoint_low_rgba, endpoint_high_rgba);
384                blue_contract(endpoint_low_rgba);
385                blue_contract(endpoint_high_rgba);
386            }
387        } break;
388
389        case kLDRRGBABaseOffset: {
390            bit_transfer_signed(v1, v0);
391            bit_transfer_signed(v3, v2);
392            bit_transfer_signed(v5, v4);
393            bit_transfer_signed(v7, v6);
394
395            endpoint_low_rgba = ivec4(v0, v2, v4, v6);
396            endpoint_high_rgba = ivec4(v0 + v1, v2 + v3, v4 + v5, v6 + v7);
397
398            if (v1 + v3 + v5 < 0) {
399                swap(endpoint_low_rgba, endpoint_high_rgba);
400                blue_contract(endpoint_low_rgba);
401                blue_contract(endpoint_high_rgba);
402            }
403
404            endpoint_low_rgba = clamp(endpoint_low_rgba, 0, 255);
405            endpoint_high_rgba = clamp(endpoint_high_rgba, 0, 255);
406        } break;
407
408        default:
409            // Unimplemented color encoding.
410            // TODO(google): Is this the correct error handling?
411            endpoint_high_rgba = endpoint_low_rgba = ivec4(0, 0, 0, 0);
412    }
413    c1 = uvec4(endpoint_low_rgba);
414    c2 = uvec4(endpoint_high_rgba);
415}
416
417uint hash52(uint p) {
418    p ^= p >> 15;
419    p -= p << 17;
420    p += p << 7;
421    p += p << 4;
422    p ^= p >> 5;
423    p += p << 16;
424    p ^= p >> 7;
425    p ^= p >> 3;
426    p ^= p << 6;
427    p ^= p >> 17;
428    return p;
429}
430
431uint select_partition(uint seed, uint x, uint y, uint partitioncount) {
432    if (partitioncount == 1) {
433        return 0;
434    }
435    uint z = 0;
436    if (u_pushConstant.smallBlock != 0) {
437        x <<= 1;
438        y <<= 1;
439    }
440    seed += (partitioncount - 1) * 1024;
441    uint rnum = hash52(seed);
442    uint seed1 = rnum & 0xF;
443    uint seed2 = (rnum >> 4) & 0xF;
444    uint seed3 = (rnum >> 8) & 0xF;
445    uint seed4 = (rnum >> 12) & 0xF;
446    uint seed5 = (rnum >> 16) & 0xF;
447    uint seed6 = (rnum >> 20) & 0xF;
448    uint seed7 = (rnum >> 24) & 0xF;
449    uint seed8 = (rnum >> 28) & 0xF;
450    uint seed9 = (rnum >> 18) & 0xF;
451    uint seed10 = (rnum >> 22) & 0xF;
452    uint seed11 = (rnum >> 26) & 0xF;
453    uint seed12 = ((rnum >> 30) | (rnum << 2)) & 0xF;
454
455    seed1 *= seed1;
456    seed2 *= seed2;
457    seed3 *= seed3;
458    seed4 *= seed4;
459    seed5 *= seed5;
460    seed6 *= seed6;
461    seed7 *= seed7;
462    seed8 *= seed8;
463    seed9 *= seed9;
464    seed10 *= seed10;
465    seed11 *= seed11;
466    seed12 *= seed12;
467
468    uint sh1, sh2, sh3;
469    if ((seed & 1) != 0) {
470        sh1 = ((seed & 2) != 0 ? 4 : 5);
471        sh2 = (partitioncount == 3 ? 6 : 5);
472    } else {
473        sh1 = (partitioncount == 3 ? 6 : 5);
474        sh2 = ((seed & 2) != 0 ? 4 : 5);
475    }
476    sh3 = ((seed & 0x10) != 0) ? sh1 : sh2;
477
478    seed1 >>= sh1;
479    seed2 >>= sh2;
480    seed3 >>= sh1;
481    seed4 >>= sh2;
482    seed5 >>= sh1;
483    seed6 >>= sh2;
484    seed7 >>= sh1;
485    seed8 >>= sh2;
486    seed9 >>= sh3;
487    seed10 >>= sh3;
488    seed11 >>= sh3;
489    seed12 >>= sh3;
490
491    uint a = seed1 * x + seed2 * y + seed11 * z + (rnum >> 14);
492    uint b = seed3 * x + seed4 * y + seed12 * z + (rnum >> 10);
493    uint c = seed5 * x + seed6 * y + seed9 * z + (rnum >> 6);
494    uint d = seed7 * x + seed8 * y + seed10 * z + (rnum >> 2);
495
496    a &= 0x3F;
497    b &= 0x3F;
498    c &= 0x3F;
499    d &= 0x3F;
500
501    if (partitioncount < 4) d = 0;
502    if (partitioncount < 3) c = 0;
503
504    if (a >= b && a >= c && a >= d)
505        return 0;
506    else if (b >= c && b >= d)
507        return 1;
508    else if (c >= d)
509        return 2;
510    else
511        return 3;
512}
513
514uvec4[144] single_color_block(uvec4 color) {
515    uvec4 ret[144];
516    for (int h = 0; h < u_pushConstant.blockSize.y; h++) {
517        for (int w = 0; w < u_pushConstant.blockSize.x; w++) {
518            ret[h * u_pushConstant.blockSize.x + w] = color;
519        }
520    }
521    return ret;
522}
523
524uvec4[144] error_color_block() { return single_color_block(uvec4(0xff, 0, 0xff, 0xff)); }
525
526uvec4[144] astc_decode_block(const uvec4 u) {
527    uint d;
528    uint hdr;
529    uint b;
530    uint a;
531    uint r;
532    uint width;
533    uint height;
534    uvec4 cem;
535    uint weightGrid[120];
536    const uint u3 = u[3];
537    const uint b87 = u3 >> 7 & 3;
538    const uint b65 = u3 >> 5 & 3;
539    const uint b32 = u3 >> 2 & 3;
540    a = b65;
541    b = b87;
542    d = bit(u3, 10);
543    hdr = bit(u3, 9);
544    if ((u3 & 3) == 0) {
545        r = b32 << 1 | bit(u3, 4);
546        if (b87 == 0) {
547            width = 12;
548            height = a + 2;
549        } else if (b87 == 1) {
550            width = a + 2;
551            height = 12;
552        } else if (b87 == 3) {
553            if (b65 == 0) {
554                width = 6;
555                height = 10;
556            } else if (b65 == 1) {
557                width = 10;
558                height = 6;
559            } else if ((u3 & 0xDFF) == 0xDFC) {
560                // Void-extent
561                // In void extend, the last 12 bits should be
562                // 1 1 D 1 1 1 1 1 1 1 0 0
563                // Where D is the HDR bit
564
565                uvec4 color =
566                    uvec4(u[1] >> 8 & 0xff, u[1] >> 24 & 0xff, u[0] >> 8 & 0xff, u[0] >> 24 & 0xff);
567                return single_color_block(color);
568            } else {  // reserved
569                return error_color_block();
570            }
571        } else {  // b87 == 2
572            b = u3 >> 9 & 3;
573            width = a + 6;
574            height = b + 6;
575            d = 0;
576            hdr = 0;
577        }
578    } else {
579        r = (u3 & 3) << 1 | bit(u3, 4);
580        if (b32 == 0) {
581            width = b + 4;
582            height = a + 2;
583        } else if (b32 == 1) {
584            width = b + 8;
585            height = a + 2;
586        } else if (b32 == 2) {
587            width = a + 2;
588            height = b + 8;
589        } else if (bit(u3, 8) == 0) {
590            width = a + 2;
591            height = (b & 1) + 6;
592        } else {
593            width = (b & 1) + 2;
594            height = a + 2;
595        }
596    }
597
598    if (width > u_pushConstant.blockSize.x || height > u_pushConstant.blockSize.y) {
599        return error_color_block();
600    }
601    // Decode weight
602    uint trits = 0;
603    uint quints = 0;
604    uint bits = 0;
605    const uint weightCounts = height * width * (d + 1);
606    const int kMaxNumWeights = 64;
607    if (kMaxNumWeights < weightCounts) {
608        return error_color_block();
609    }
610    {
611        if (hdr == 0) {
612            switch (r) {
613                case 2:
614                    bits = 1;
615                    break;
616                case 3:
617                    trits = 1;
618                    break;
619                case 4:
620                    bits = 2;
621                    break;
622                case 5:
623                    quints = 1;
624                    break;
625                case 6:
626                    trits = 1;
627                    bits = 1;
628                    break;
629                case 7:
630                    bits = 3;
631                    break;
632                default:
633                    return error_color_block();
634            }
635        } else {
636            switch (r) {
637                case 2:
638                    bits = 1;
639                    quints = 1;
640                    break;
641                case 3:
642                    trits = 1;
643                    bits = 2;
644                    break;
645                case 4:
646                    bits = 4;
647                    break;
648                case 5:
649                    quints = 1;
650                    bits = 2;
651                    break;
652                case 6:
653                    trits = 1;
654                    bits = 3;
655                    break;
656                case 7:
657                    bits = 5;
658                    break;
659                default:
660                    return error_color_block();
661            }
662        }
663        uint packedSize = 0;
664        uint pack = 0;
665        get_pack_size(trits, quints, bits, pack, packedSize);
666        uint srcIdx = 0;
667        uint dstIdx = 0;
668        uvec4 uReversed = bitfieldReverse(u);
669        const uint weightBitCount = get_bit_count(weightCounts, trits, quints, bits);
670        const int kWeightGridMinBitLength = 24;
671        const int kWeightGridMaxBitLength = 96;
672        if (weightBitCount < kWeightGridMinBitLength || weightBitCount > kWeightGridMaxBitLength) {
673            return error_color_block();
674        }
675        uReversed = uvec4(uReversed[3], uReversed[2], uReversed[1], uReversed[0]);
676        const uint kUnquantBinMulTable[] = {0x3f, 0x15, 0x9, 0x4, 0x2, 0x1};
677        const uint kUnquantBinMovTable[] = {0x8, 0x8, 0x8, 0x2, 0x4, 0x8};
678        while (dstIdx < weightCounts) {
679            if (trits == 1) {
680                uint decoded[5] = decode_trit(uReversed, srcIdx, weightBitCount, bits);
681                // uint decoded[5] = {0, 0, 0, 0, 0};
682                for (int i = 0; i < 5; i++) {
683                    weightGrid[dstIdx] =
684                        kUnquantTritWeightMap[kUnquantTritWeightMapBitIdx[bits] + decoded[i]];
685                    if (weightGrid[dstIdx] > 32) {
686                        weightGrid[dstIdx] += 1;
687                    }
688                    dstIdx++;
689                    if (dstIdx >= weightCounts) {
690                        break;
691                    }
692                }
693            } else if (quints == 1) {
694                uint decoded[3] = decode_quint(uReversed, srcIdx, weightBitCount, bits);
695                for (int i = 0; i < 3; i++) {
696                    // TODO: handle overflow in the last
697                    weightGrid[dstIdx] =
698                        kUnquantQuintWeightMap[kUnquantQuintWeightMapBitIdx[bits] + decoded[i]];
699                    if (weightGrid[dstIdx] > 32) {
700                        weightGrid[dstIdx] += 1;
701                    }
702                    dstIdx++;
703                    if (dstIdx >= weightCounts) {
704                        break;
705                    }
706                }
707            } else {
708                uint decodedRaw = bits128(uReversed, srcIdx, packedSize);
709                uint decoded = decodedRaw * kUnquantBinMulTable[bits - 1] |
710                               decodedRaw >> kUnquantBinMovTable[bits - 1];
711                weightGrid[dstIdx] = decoded;
712                if (weightGrid[dstIdx] > 32) {
713                    weightGrid[dstIdx] += 1;
714                }
715                dstIdx++;
716            }
717            srcIdx += packedSize;
718        }
719    }
720    uint partitionCount = (u3 >> 11 & 3) + 1;
721    if (d == 1 && partitionCount == 4) {
722        return error_color_block();
723    }
724    const uint weightStart = 128 - get_bit_count(weightCounts, trits, quints, bits);
725    uint dualPlaneStart = 0;
726    // Decode cem mode
727    if (partitionCount == 1) {
728        // Single-partition mode
729        cem[0] = u3 >> 13 & 0xf;
730        dualPlaneStart = weightStart - d * 2;
731    } else {
732        // Multi-partition mode
733        // Calculate CEM for all 4 partitions, even when partitionCount < 4
734        uint partMode = u3 >> 23 & 3;
735        const uint kExtraMBitsTable[4] = {0, 2, 5, 8};
736        const uint extraMBitCount = (partMode == 0) ? 0 : kExtraMBitsTable[partitionCount - 1];
737        const uint extraMStart = weightStart - extraMBitCount;
738        dualPlaneStart = extraMStart - d * 2;
739
740        if (partMode == 0) {
741            uint cem_all = u3 >> 25 & 0xf;
742            cem = uvec4(cem_all, cem_all, cem_all, cem_all);
743        } else {
744            uint cemBase = partMode - 1;
745            uvec4 cemHigh = cemBase + uvec4(bit(u3, 25), bit(u3, 26), bit(u3, 27), bit(u3, 28));
746            const uint extraM = bits128(u, extraMStart, extraMBitCount);
747            const uint kMainMBitsTable[4] = {0, 2, 1, 0};
748            const uint mainMBitCount = kMainMBitsTable[partitionCount - 1];
749            const uint m = extraM << mainMBitCount | ((u3 >> 27 & 3) >> (2 - mainMBitCount));
750            cem = cemHigh << 2 | uvec4(m & 3, m >> 2 & 3, m >> 4 & 3, m >> 6 & 3);
751        }
752    }
753    // Decode end points
754    uvec4 endPoints[4][2];
755    {
756        uint totalV = 0;
757        for (uint part = 0; part < partitionCount; part++) {
758            totalV += get_v_count(cem[part]);
759        }
760        const uint epStart = (partitionCount == 1) ? 17 : 29;
761        const uint totalAvailBits = dualPlaneStart - epStart;
762        if (totalAvailBits >= 128) {
763            // overflowed
764            return error_color_block();
765        }
766        uint epQuints = 0;
767        uint epTrits = 0;
768        uint epBits = 0;
769        uint i;
770        for (i = 0; i < kRQuantParamTableLength; i++) {
771            epTrits = kRQuantParamTable[i][0];
772            epQuints = kRQuantParamTable[i][1];
773            epBits = kRQuantParamTable[i][2];
774            if (get_bit_count(totalV, epTrits, epQuints, epBits) <= totalAvailBits) {
775                break;
776            }
777        }
778        if (i >= kRQuantParamTableLength) {
779            return error_color_block();
780        }
781
782        const uint epBitCount = get_bit_count(totalV, epTrits, epQuints, epBits);
783        const uint epEnd = epStart + epBitCount;
784        uint packedSize = 0;
785        uint pack = 0;
786        get_pack_size(epTrits, epQuints, epBits, pack, packedSize);
787
788        // Decode end point parameters into buffer
789        uint vBuffer[40];
790        uint srcIdx = epStart;
791        uint dstIdx = 0;
792        const uint kUnquantBinMulTable[8] = {0xff, 0x55, 0x24, 0x11, 0x8, 0x4, 0x2, 0x1};
793        const uint kUnquantBinMovTable[8] = {8, 8, 1, 8, 2, 4, 6, 8};
794        while (dstIdx < totalV) {
795            if (epTrits == 1) {
796                uint decoded[5] = decode_trit(u, srcIdx, epEnd, epBits);
797                for (int i = 0; i < 5; i++) {
798                    vBuffer[dstIdx] =
799                        kUnquantTritColorMap[kUnquantTritColorMapBitIdx[epBits] + decoded[i]];
800                    dstIdx++;
801                    if (dstIdx >= totalV) {
802                        break;
803                    }
804                }
805            } else if (epQuints == 1) {
806                uint decoded[3] = decode_quint(u, srcIdx, epEnd, epBits);
807                for (int i = 0; i < 3; i++) {
808                    vBuffer[dstIdx] =
809                        kUnquantQuintColorMap[kUnquantQuintColorMapBitIdx[epBits] + decoded[i]];
810                    dstIdx++;
811                    if (dstIdx >= totalV) {
812                        break;
813                    }
814                }
815            } else {
816                uint src = bits128(u, srcIdx, packedSize);
817                uint decoded =
818                    src * kUnquantBinMulTable[epBits - 1] | src >> kUnquantBinMovTable[epBits - 1];
819                vBuffer[dstIdx] = decoded;
820                dstIdx++;
821            }
822            srcIdx += packedSize;
823        }
824        uint bufferIdx = 0;
825        for (uint part = 0; part < partitionCount; part++) {
826            // TODO: HDR support
827            decode_ldr_for_mode(vBuffer, bufferIdx, cem[part], endPoints[part][0],
828                                endPoints[part][1]);
829            bufferIdx += get_v_count(cem[part]);
830        }
831    }
832    uvec4 ret[144];
833    {
834        uvec2 dst = (1024 + u_pushConstant.blockSize / 2) / (u_pushConstant.blockSize - 1);
835        uint dd = d + 1;
836        for (uint h = 0; h < u_pushConstant.blockSize.y; h++) {
837            for (uint w = 0; w < u_pushConstant.blockSize.x; w++) {
838                uint part = select_partition(u3 >> 13 & 1023, w, h, partitionCount);
839                if (kHDRCEM[cem[part]]) {
840                    // HDR not supported
841                    ret[h * u_pushConstant.blockSize.x + w] = uvec4(0xff, 0, 0xff, 0xff);
842                    continue;
843                }
844                // Calculate weight
845                uvec2 st = uvec2(w, h);
846                uvec2 cst = dst * st;
847                uvec2 gst = (cst * (uvec2(width, height) - 1) + 32) >> 6;
848                uvec2 jst = gst >> 4;
849                uvec2 fst = gst & 0xf;
850                uint v0 = jst.x + jst.y * width;
851                uvec2 p00 = uvec2(weightGrid[v0 * dd], weightGrid[v0 * dd + 1]);
852                uvec2 p01 = uvec2(weightGrid[(v0 + 1) * dd], weightGrid[(v0 + 1) * dd + 1]);
853                uvec2 p10 = uvec2(weightGrid[(v0 + width) * dd], weightGrid[(v0 + width) * dd + 1]);
854                uvec2 p11 =
855                    uvec2(weightGrid[(v0 + width + 1) * dd], weightGrid[(v0 + width + 1) * dd + 1]);
856                uint w11 = (fst.x * fst.y + 8) >> 4;
857                uint w10 = fst.y - w11;
858                uint w01 = fst.x - w11;
859                uint w00 = 16 - fst.x - fst.y + w11;
860                uvec2 i = (p00 * w00 + p01 * w01 + p10 * w10 + p11 * w11 + 8) >> 4;
861
862                uvec4 c0 = endPoints[part][0];
863                uvec4 c1 = endPoints[part][1];
864                uvec4 c = (c0 * (64 - i[0]) + c1 * i[0] + 32) / 64;
865                if (d == 1) {
866                    uint ccs = bits128(u, dualPlaneStart, 2);
867                    c[ccs] = (c0[ccs] * (64 - i[1]) + c1[ccs] * i[1] + 32) / 64;
868                }
869                ret[h * u_pushConstant.blockSize.x + w] = c;
870            }
871        }
872    }
873    return ret;
874}
875
876uint block_y_size_1DArray() { return 1; }
877
878uint block_y_size_2DArray() { return u_pushConstant.blockSize.y; }
879
880uint block_y_size_3D() { return u_pushConstant.blockSize.y; }
881
882uvec4 flip32(uvec4 a) {
883    return ((a & 0xff) << 24) | ((a & 0xff00) << 8) | ((a & 0xff0000) >> 8) |
884           ((a & 0xff000000) >> 24);
885}
886
887void main(void) {
888    ivec3 pos = ivec3(gl_GlobalInvocationID.xyz);
889    pos.z += int(u_pushConstant.baseLayer);
890    uvec4 srcBlock = uvec4(imageLoad(u_image0, WITH_TYPE(getPos)(pos)));
891    srcBlock = uvec4(srcBlock[3], srcBlock[2], srcBlock[1], srcBlock[0]);
892    uvec4[144] decompressed = astc_decode_block(srcBlock);
893
894    for (uint y = 0; y < WITH_TYPE(block_y_size_)(); y++) {
895        for (uint x = 0; x < u_pushConstant.blockSize.x; x++) {
896            imageStore(
897                u_image1,
898                WITH_TYPE(getPos)(ivec3(pos.xy * u_pushConstant.blockSize + ivec2(x, y), pos.z)),
899                decompressed[y * u_pushConstant.blockSize.x + x]);
900        }
901    }
902}
903