xref: /aosp_15_r20/external/mesa3d/src/compiler/glsl/astc_decoder.glsl (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1#version 320 es
2precision highp float;
3precision highp int;
4precision highp usamplerBuffer;
5precision highp usampler2D;
6precision highp image2D;
7precision highp uimage2D;
8
9/* Copyright (c) 2020-2022 Hans-Kristian Arntzen
10 * Copyright (c) 2022 Intel Corporation
11 *
12 * Permission is hereby granted, free of charge, to any person obtaining
13 * a copy of this software and associated documentation files (the
14 * "Software"), to deal in the Software without restriction, including
15 * without limitation the rights to use, copy, modify, merge, publish,
16 * distribute, sublicense, and/or sell copies of the Software, and to
17 * permit persons to whom the Software is furnished to do so, subject to
18 * the following conditions:
19 *
20 * The above copyright notice and this permission notice shall be
21 * included in all copies or substantial portions of the Software.
22 *
23 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
24 * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
25 * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
26 * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
27 * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
28 * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
29 * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
30 */
31
32#ifdef VULKAN
33
34precision highp utextureBuffer;
35precision highp utexture2DArray;
36precision highp uimage2DArray;
37precision highp uimage3D;
38precision highp utexture3D;
39
40#extension GL_EXT_samplerless_texture_functions : require
41layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 4) in;
42
43layout(set = 0, binding = 0) writeonly uniform uimage2DArray OutputImage2Darray;
44layout(set = 0, binding = 0) writeonly uniform uimage3D OutputImage3D;
45layout(set = 0, binding = 1) uniform utexture2DArray PayloadInput2Darray;
46layout(set = 0, binding = 1) uniform utexture3D PayloadInput3D;
47layout(set = 0, binding = 2) uniform utextureBuffer LUTRemainingBitsToEndpointQuantizer;
48layout(set = 0, binding = 3) uniform utextureBuffer LUTEndpointUnquantize;
49layout(set = 0, binding = 4) uniform utextureBuffer LUTWeightQuantizer;
50layout(set = 0, binding = 5) uniform utextureBuffer LUTWeightUnquantize;
51layout(set = 0, binding = 6) uniform utextureBuffer LUTTritQuintDecode;
52layout(set = 0, binding = 7) uniform utextureBuffer LUTPartitionTable;
53
54layout(constant_id = 2) const bool DECODE_8BIT = false;
55
56layout(push_constant, std430) uniform pc {
57   ivec2 texel_blk_start;
58   ivec2 texel_end;
59   bool is_3Dimage;
60};
61
62#else /* VULKAN */
63
64layout(local_size_x = %u, local_size_y = %u, local_size_z = 4) in;
65
66#define utextureBuffer usamplerBuffer
67#define utexture2D usampler2D
68
69layout(binding = 0) uniform utextureBuffer LUTRemainingBitsToEndpointQuantizer;
70layout(binding = 1) uniform utextureBuffer LUTEndpointUnquantize;
71layout(binding = 2) uniform utextureBuffer LUTWeightQuantizer;
72layout(binding = 3) uniform utextureBuffer LUTWeightUnquantize;
73layout(binding = 4) uniform utextureBuffer LUTTritQuintDecode;
74layout(binding = 5) uniform utexture2D LUTPartitionTable;
75layout(binding = 6) uniform utexture2D PayloadInput;
76
77layout(rgba8ui, binding = 7) writeonly uniform uimage2D OutputImage;
78const bool DECODE_8BIT = true;
79
80#endif /* VULKAN */
81
82const int MODE_LDR = 0;
83const int MODE_HDR = 1;
84const int MODE_HDR_LDR_ALPHA = 2;
85
86const uvec4 error_color = uvec4(255, 0, 255, 255);
87
88/* bitextract.h */
89int extract_bits(uvec4 payload, int offset, int bits)
90{
91        int last_offset = offset + bits - 1;
92        int result;
93
94        if (bits <= 0)
95                result = 0;
96        else if ((last_offset >> 5) == (offset >> 5))
97                result = int(bitfieldExtract(payload[offset >> 5], offset & 31, bits));
98        else
99        {
100                int first_bits = 32 - (offset & 31);
101                int result_first = int(bitfieldExtract(payload[offset >> 5], offset & 31, first_bits));
102                int result_second = int(bitfieldExtract(payload[(offset >> 5) + 1], 0, bits - first_bits));
103                result = result_first | (result_second << first_bits);
104        }
105        return result;
106}
107
108/* bitextract.h */
109int extract_bits_sign(uvec4 payload, int offset, int bits)
110{
111        int last_offset = offset + bits - 1;
112        int result;
113
114        if (bits <= 0)
115                result = 0;
116        else if ((last_offset >> 5) == (offset >> 5))
117                result = bitfieldExtract(int(payload[offset >> 5]), offset & 31, bits);
118        else
119        {
120                int first_bits = 32 - (offset & 31);
121                int result_first = int(bitfieldExtract(payload[offset >> 5], offset & 31, first_bits));
122                int result_second = bitfieldExtract(int(payload[(offset >> 5) + 1]), 0, bits - first_bits);
123                result = result_first | (result_second << first_bits);
124        }
125        return result;
126}
127
128/* bitextract.h */
129int extract_bits_reverse(uvec4 payload, int offset, int bits)
130{
131        int last_offset = offset + bits - 1;
132        int result;
133
134        if (bits <= 0)
135                result = 0;
136        else if ((last_offset >> 5) == (offset >> 5))
137                result = int(bitfieldReverse(bitfieldExtract(payload[offset >> 5], offset & 31, bits)) >> (32 - bits));
138        else
139        {
140                int first_bits = 32 - (offset & 31);
141                uint result_first = bitfieldExtract(payload[offset >> 5], offset & 31, first_bits);
142                uint result_second = bitfieldExtract(payload[(offset >> 5) + 1], 0, bits - first_bits);
143                result = int(bitfieldReverse(result_first | (result_second << first_bits)) >> (32 - bits));
144        }
145        return result;
146}
147
148void swap(inout int a, inout int b)
149{
150    int tmp = a;
151    a = b;
152    b = tmp;
153}
154
155ivec4 build_coord()
156{
157    ivec2 payload_coord = ivec2(gl_WorkGroupID.xy) * 2;
158    payload_coord.x += int(gl_LocalInvocationID.z) & 1;
159    payload_coord.y += (int(gl_LocalInvocationID.z) >> 1) & 1;
160#ifdef VULKAN
161    payload_coord += texel_blk_start;
162#endif /* VULKAN */
163    ivec2 coord = payload_coord * ivec2(gl_WorkGroupSize.xy);
164    coord += ivec2(gl_LocalInvocationID.xy);
165    return ivec4(coord, payload_coord);
166}
167
168ivec4 interpolate_endpoint(ivec4 ep0, ivec4 ep1, ivec4 weight, int decode_mode)
169{
170    if (decode_mode == MODE_HDR)
171    {
172        ep0 <<= 4;
173        ep1 <<= 4;
174    }
175    else if (decode_mode == MODE_HDR_LDR_ALPHA)
176    {
177        ep0.rgb <<= 4;
178        ep1.rgb <<= 4;
179        ep0.a *= 0x101;
180        ep1.a *= 0x101;
181    }
182    else if (DECODE_8BIT)
183    {
184        // This isn't quite right in all cases.
185        // In normal ASTC with sRGB, the alpha channel is supposed to
186        // be decoded as FP16,
187        // even when color components are SRGB 8-bit (?!?!?!?!).
188        // This is correct if decode_unorm8 mode is used though,
189        // for sanity, we're going to assume unorm8 decoding mode
190        // is implied when using sRGB.
191        ep0 = (ep0 << 8) | ivec4(0x80);
192        ep1 = (ep1 << 8) | ivec4(0x80);
193    }
194    else
195    {
196        ep0 *= 0x101;
197        ep1 *= 0x101;
198    }
199
200    ivec4 color = (ep0 * (64 - weight) + ep1 * weight + 32) >> 6;
201    return color;
202}
203
204bvec4 bvec_or(bvec4 a, bvec4 b)
205{
206    return bvec4(ivec4(a) | ivec4(b));
207}
208
209uint round_down_quantize_fp16(int color)
210{
211    // ASTC has a very peculiar way of converting the decoded result to FP16.
212    // 0xffff -> 1.0, and for everything else we get roundDownQuantizeFP16(vec4(c) / vec4(0x10000)).
213    int msb = findMSB(color);
214    int shamt = msb;
215    int m = ((color << 10) >> shamt) & 0x3ff;
216    int e = msb - 1;
217    uint decoded = color == 0xffff ? 0x3c00u : uint(e < 1 ? (color << 8) : (m | (e << 10)));
218    return decoded;
219}
220
221uvec4 round_down_quantize_fp16(ivec4 color)
222{
223    // ASTC has a very peculiar way of converting the decoded result to FP16.
224    // 0xffff -> 1.0, and for everything else we get roundDownQuantizeFP16(vec4(c) / vec4(0x10000)).
225    ivec4 msb = findMSB(color);
226    ivec4 shamt = msb;
227    ivec4 m = ((color << 10) >> shamt) & 0x3ff;
228    ivec4 e = msb - 1;
229    uvec4 decoded = uvec4(m | (e << 10));
230    uvec4 denorm_decode = uvec4(color << 8);
231    decoded = mix(decoded, uvec4(denorm_decode), lessThan(e, ivec4(1)));
232    decoded = mix(decoded, uvec4(0x3c00), equal(color, ivec4(0xffff)));
233    return decoded;
234}
235
236uvec4 decode_fp16(ivec4 color, int decode_mode)
237{
238    if (decode_mode != MODE_LDR)
239    {
240        // Interpret the value as FP16, but with some extra fixups along the way to make the interpolation more
241        // logarithmic (apparently). From spec:
242        ivec4 e = color >> 11;
243        ivec4 m = color & 0x7ff;
244        ivec4 mt = 4 * m - 512;
245        mt = mix(mt, ivec4(3 * m), lessThan(m, ivec4(512)));
246        mt = mix(mt, ivec4(5 * m - 2048), greaterThanEqual(m, ivec4(1536)));
247
248        ivec4 decoded = (e << 10) + (mt >> 3);
249        // +Inf or NaN are decoded to 0x7bff (max finite value).
250        decoded = mix(decoded, ivec4(0x7bff), bvec_or(greaterThan(decoded & 0x7fff, ivec4(0x7c00)), equal(decoded, ivec4(0x7c00))));
251
252        if (decode_mode == MODE_HDR_LDR_ALPHA)
253            decoded.a = int(round_down_quantize_fp16(color.a));
254
255        return uvec4(decoded);
256    }
257    else
258    {
259        return round_down_quantize_fp16(color);
260    }
261}
262
263struct BlockMode
264{
265    ivec2 weight_grid_size;
266    int weight_mode_index;
267    int num_partitions;
268    int seed;
269    int cem;
270    int config_bits;
271    int primary_config_bits;
272    bool dual_plane;
273    bool void_extent;
274};
275
276bool decode_error = false;
277
278BlockMode decode_block_mode(uvec4 payload)
279{
280    BlockMode mode;
281    mode.void_extent = (payload.x & 0x1ffu) == 0x1fcu;
282    if (mode.void_extent)
283        return mode;
284
285    mode.dual_plane = (payload.x & (1u << 10u)) != 0u;
286
287    uint higher = (payload.x >> 2u) & 3u;
288    uint lower = payload.x & 3u;
289
290    if (lower != 0u)
291    {
292        mode.weight_mode_index = int((payload.x >> 4u) & 1u);
293        mode.weight_mode_index |= int((payload.x << 1u) & 6u);
294        mode.weight_mode_index |= int((payload.x >> 6u) & 8u);
295
296        if (higher < 2u)
297        {
298            mode.weight_grid_size.x = int(bitfieldExtract(payload.x, 7, 2) + 4u + 4u * higher);
299            mode.weight_grid_size.y = int(bitfieldExtract(payload.x, 5, 2) + 2u);
300        }
301        else if (higher == 2u)
302        {
303            mode.weight_grid_size.x = int(bitfieldExtract(payload.x, 5, 2) + 2u);
304            mode.weight_grid_size.y = int(bitfieldExtract(payload.x, 7, 2) + 8u);
305        }
306        else
307        {
308            if ((payload.x & (1u << 8u)) != 0u)
309            {
310                mode.weight_grid_size.x = int(bitfieldExtract(payload.x, 7, 1) + 2u);
311                mode.weight_grid_size.y = int(bitfieldExtract(payload.x, 5, 2) + 2u);
312            }
313            else
314            {
315                mode.weight_grid_size.x = int(bitfieldExtract(payload.x, 5, 2) + 2u);
316                mode.weight_grid_size.y = int(bitfieldExtract(payload.x, 7, 1) + 6u);
317            }
318        }
319    }
320    else
321    {
322        int p3 = int(bitfieldExtract(payload.x, 9, 1));
323        int hi = int(bitfieldExtract(payload.x, 7, 2));
324        int lo = int(bitfieldExtract(payload.x, 5, 2));
325        if (hi == 0)
326        {
327            mode.weight_grid_size.x = 12;
328            mode.weight_grid_size.y = lo + 2;
329        }
330        else if (hi == 1)
331        {
332            mode.weight_grid_size.x = lo + 2;
333            mode.weight_grid_size.y = 12;
334        }
335        else if (hi == 2)
336        {
337            mode.dual_plane = false;
338            p3 = 0;
339            mode.weight_grid_size.x = lo + 6;
340            mode.weight_grid_size.y = int(bitfieldExtract(payload.x, 9, 2) + 6u);
341        }
342        else
343        {
344            if (lo == 0)
345                mode.weight_grid_size = ivec2(6, 10);
346            else if (lo == 1)
347                mode.weight_grid_size = ivec2(10, 6);
348            else
349                decode_error = true;
350        }
351
352        int p0 = int(bitfieldExtract(payload.x, 4, 1));
353        int p1 = int(bitfieldExtract(payload.x, 2, 1));
354        int p2 = int(bitfieldExtract(payload.x, 3, 1));
355        mode.weight_mode_index = p0 + (p1 << 1) + (p2 << 2) + (p3 << 3);
356    }
357
358    // 11 bits for block mode.
359    // 2 bits for partition select
360    // If partitions > 1:
361    //   4 bits CEM selector
362    //   If dual_plane:
363    //     2 bits of CCS
364    // else:
365    //   10 for partition seed
366    //   2 bits for CEM main selector
367    //   If CEM[1:0] = 00:
368    //     4 bits for CEM extra selector if all same type.
369    //   else:
370    //     (1 + 2) * num_partitions if different types.
371    //     First 4 bits are encoded next to CEM[1:0], otherwise, packed before weights.
372    //   If dual_plane:
373    //     2 bits of CCS before extra CEM bits.
374    const int CONFIG_BITS_BLOCK = 11;
375    const int CONFIG_BITS_PARTITION_MODE = 2;
376    const int CONFIG_BITS_SEED = 10;
377    const int CONFIG_BITS_PRIMARY_MULTI_CEM = 2;
378    const int CONFIG_BITS_CEM = 4;
379    const int CONFIG_BITS_EXTRA_CEM_PER_PARTITION = 3;
380    const int CONFIG_BITS_CCS = 2;
381
382    mode.num_partitions = int(bitfieldExtract(payload.x, CONFIG_BITS_BLOCK, CONFIG_BITS_PARTITION_MODE)) + 1;
383
384    if (mode.num_partitions > 1)
385    {
386        mode.seed = int(bitfieldExtract(payload.x, CONFIG_BITS_BLOCK + CONFIG_BITS_PARTITION_MODE, CONFIG_BITS_SEED));
387        mode.cem = int(bitfieldExtract(payload.x, CONFIG_BITS_BLOCK + CONFIG_BITS_PARTITION_MODE + CONFIG_BITS_SEED,
388                                       CONFIG_BITS_PRIMARY_MULTI_CEM + CONFIG_BITS_CEM));
389    }
390    else
391        mode.cem = int(bitfieldExtract(payload.x, CONFIG_BITS_BLOCK + CONFIG_BITS_PARTITION_MODE, CONFIG_BITS_CEM));
392
393    int config_bits;
394    if (mode.num_partitions > 1)
395    {
396        bool single_cem = (mode.cem & 3) == 0;
397        if (single_cem)
398        {
399            config_bits = CONFIG_BITS_BLOCK + CONFIG_BITS_PARTITION_MODE +
400                          CONFIG_BITS_SEED + CONFIG_BITS_PRIMARY_MULTI_CEM + CONFIG_BITS_CEM;
401        }
402        else
403        {
404            config_bits = CONFIG_BITS_BLOCK + CONFIG_BITS_PARTITION_MODE +
405                          CONFIG_BITS_SEED + CONFIG_BITS_PRIMARY_MULTI_CEM +
406                          CONFIG_BITS_EXTRA_CEM_PER_PARTITION * mode.num_partitions;
407        }
408    }
409    else
410    {
411        config_bits = CONFIG_BITS_BLOCK + CONFIG_BITS_PARTITION_MODE + CONFIG_BITS_CEM;
412    }
413
414    // Other config bits are packed before the weights.
415    int primary_config_bits;
416    if (mode.num_partitions > 1)
417    {
418        primary_config_bits = CONFIG_BITS_BLOCK + CONFIG_BITS_PARTITION_MODE + CONFIG_BITS_SEED +
419                              CONFIG_BITS_PRIMARY_MULTI_CEM + CONFIG_BITS_CEM;
420    }
421    else
422        primary_config_bits = config_bits;
423
424    if (mode.dual_plane)
425        config_bits += CONFIG_BITS_CCS;
426
427    // This is not allowed.
428    if (any(greaterThan(mode.weight_grid_size, ivec2(gl_WorkGroupSize.xy))))
429        decode_error = true;
430    if (mode.dual_plane && mode.num_partitions > 3)
431        decode_error = true;
432
433    mode.config_bits = config_bits;
434    mode.primary_config_bits = primary_config_bits;
435    return mode;
436}
437
438int idiv3_floor(int v)
439{
440    return (v * 0x5556) >> 16;
441}
442
443int idiv3_ceil(int v)
444{
445    return idiv3_floor(v + 2);
446}
447
448int idiv5_floor(int v)
449{
450    return (v * 0x3334) >> 16;
451}
452
453int idiv5_ceil(int v)
454{
455    return idiv5_floor(v + 4);
456}
457
458uvec4 build_bitmask(int bits)
459{
460    ivec4 num_bits = ivec4(bits, bits - 32, bits - 64, bits - 96);
461    uvec4 mask = uvec4(1) << clamp(num_bits, ivec4(0), ivec4(31));
462    mask--;
463    mask = mix(mask, uvec4(0xffffffffu), greaterThanEqual(uvec4(bits), uvec4(32, 64, 96, 128)));
464    return mask;
465}
466
467int decode_integer_sequence(uvec4 payload, int start_bit, int index, ivec3 quant)
468{
469    int ret;
470    if (quant.y != 0)
471    {
472        // Trit-decoding.
473        int block = idiv5_floor(index);
474        int offset = index - block * 5;
475        start_bit += block * (5 * quant.x + 8);
476
477        int t0_t1_offset = start_bit + (quant.x * 1 + 0);
478        int t2_t3_offset = start_bit + (quant.x * 2 + 2);
479        int t4_offset    = start_bit + (quant.x * 3 + 4);
480        int t5_t6_offset = start_bit + (quant.x * 4 + 5);
481        int t7_offset    = start_bit + (quant.x * 5 + 7);
482
483        int t = (extract_bits(payload, t0_t1_offset, 2) << 0) |
484                (extract_bits(payload, t2_t3_offset, 2) << 2) |
485                (extract_bits(payload, t4_offset, 1) << 4) |
486                (extract_bits(payload, t5_t6_offset, 2) << 5) |
487                (extract_bits(payload, t7_offset, 1) << 7);
488
489        t = int(texelFetch(LUTTritQuintDecode, t).x);
490        t = (t >> (3 * offset)) & 7;
491
492        int m_offset = offset * quant.x;
493        m_offset += idiv5_ceil(offset * 8);
494
495        if (quant.x != 0)
496        {
497            int m = extract_bits(payload, m_offset + start_bit, quant.x);
498            ret = (t << quant.x) | m;
499        }
500        else
501            ret = t;
502    }
503    else if (quant.z != 0)
504    {
505        // Quint-decoding
506        int block = idiv3_floor(index);
507        int offset = index - block * 3;
508        start_bit += block * (3 * quant.x + 7);
509
510        int q0_q1_q2_offset = start_bit + (quant.x * 1 + 0);
511        int q3_q4_offset    = start_bit + (quant.x * 2 + 3);
512        int q5_q6_offset    = start_bit + (quant.x * 3 + 5);
513
514        int q = (extract_bits(payload, q0_q1_q2_offset, 3) << 0) |
515                (extract_bits(payload, q3_q4_offset, 2) << 3) |
516                (extract_bits(payload, q5_q6_offset, 2) << 5);
517
518        q = int(texelFetch(LUTTritQuintDecode, 256 + q).x);
519        q = (q >> (3 * offset)) & 7;
520
521        int m_offset = offset * quant.x;
522        m_offset += idiv3_ceil(offset * 7);
523
524        if (quant.x != 0)
525        {
526            int m = extract_bits(payload, m_offset + start_bit, quant.x);
527            ret = (q << quant.x) | m;
528        }
529        else
530            ret = q;
531    }
532    else
533    {
534        int bit = index * quant.x;
535        ret = extract_bits(payload, start_bit + bit, quant.x);
536    }
537    return ret;
538}
539
540ivec2 normalize_coord(ivec2 pixel_coord)
541{
542    ivec2 D = ivec2((vec2((1024 + ivec2(gl_WorkGroupSize.xy >> 1u))) + 0.5) / vec2(gl_WorkGroupSize.xy - 1u));
543    ivec2 c = D * pixel_coord;
544    return c;
545}
546
547int decode_weight(uvec4 payload, int weight_index, ivec4 quant)
548{
549    int primary_weight = decode_integer_sequence(payload, 0, weight_index, quant.xyz);
550    primary_weight = int(texelFetch(LUTWeightUnquantize, primary_weight + quant.w).x);
551    return primary_weight;
552}
553
554int decode_weight_bilinear(uvec4 payload, ivec2 coord, int weight_resolution,
555                           int stride, int offset, ivec2 fractional, ivec4 quant)
556{
557    int index = coord.y * weight_resolution + coord.x;
558    int p00 = decode_weight(payload, stride * index + offset, quant);
559    int p10, p01, p11;
560
561    if (fractional.x != 0)
562        p10 = decode_weight(payload, stride * (index + 1) + offset, quant);
563    else
564        p10 = p00;
565
566    if (fractional.y != 0)
567    {
568        p01 = decode_weight(payload, stride * (index + weight_resolution) + offset, quant);
569        if (fractional.x != 0)
570            p11 = decode_weight(payload, stride * (index + weight_resolution + 1) + offset, quant);
571        else
572            p11 = p01;
573    }
574    else
575    {
576        p01 = p00;
577        p11 = p10;
578    }
579
580    int w11 = (fractional.x * fractional.y + 8) >> 4;
581    int w10 = fractional.x - w11;
582    int w01 = fractional.y - w11;
583    int w00 = 16 - fractional.x - fractional.y + w11;
584    return (p00 * w00 + p10 * w10 + p01 * w01 + p11 * w11 + 8) >> 4;
585}
586
587ivec4 decode_weights(uvec4 payload, BlockMode mode, ivec2 normalized_pixel, out int weight_cost_bits)
588{
589    ivec4 quant = ivec4(texelFetch(LUTWeightQuantizer, mode.weight_mode_index));
590    int num_weights = mode.weight_grid_size.x * mode.weight_grid_size.y;
591    num_weights <<= int(mode.dual_plane);
592    weight_cost_bits =
593        quant.x * num_weights +
594        idiv5_ceil(num_weights * 8 * quant.y) +
595        idiv3_ceil(num_weights * 7 * quant.z);
596
597    // Decoders must deal with error conditions and return the correct error color.
598    if (weight_cost_bits < 24 || weight_cost_bits > 96 || num_weights > 64)
599    {
600        decode_error = true;
601        return ivec4(0);
602    }
603
604    int ccs;
605    if (mode.dual_plane)
606    {
607        int extra_cem_bits = 0;
608        if ((mode.cem & 3) != 0)
609            extra_cem_bits = max(mode.num_partitions * 3 - 4, 0);
610        ccs = extract_bits(payload, 126 - weight_cost_bits - extra_cem_bits, 2);
611    }
612
613    payload = bitfieldReverse(payload);
614    payload = payload.wzyx;
615    payload &= build_bitmask(weight_cost_bits);
616
617    // Scale the normalized coordinate to weight grid.
618    ivec2 weight_pixel_fixed_point = (normalized_pixel * (mode.weight_grid_size - 1) + 32) >> 6;
619    ivec2 weight_pixel = weight_pixel_fixed_point >> 4;
620    ivec2 weight_pixel_fractional = weight_pixel_fixed_point & 0xf;
621
622    ivec4 ret;
623    int primary_weight = decode_weight_bilinear(payload, weight_pixel, mode.weight_grid_size.x,
624                                                1 << int(mode.dual_plane), 0,
625                                                weight_pixel_fractional, quant);
626    if (mode.dual_plane)
627    {
628        int secondary_weight = decode_weight_bilinear(payload, weight_pixel, mode.weight_grid_size.x,
629                                                      2, 1,
630                                                      weight_pixel_fractional, quant);
631        ret = mix(ivec4(primary_weight), ivec4(secondary_weight), equal(ivec4(ccs), ivec4(0, 1, 2, 3)));
632    }
633    else
634        ret = ivec4(primary_weight);
635
636    return ret;
637}
638
639void decode_endpoint_ldr_luma_direct(out ivec4 ep0, out ivec4 ep1,
640        int v0, int v1)
641{
642    ep0 = ivec4(ivec3(v0), 0xff);
643    ep1 = ivec4(ivec3(v1), 0xff);
644}
645
646void decode_endpoint_hdr_luma_direct(out ivec4 ep0, out ivec4 ep1,
647        int v0, int v1)
648{
649    int y0, y1;
650    if (v1 >= v0)
651    {
652        y0 = v0 << 4;
653        y1 = v1 << 4;
654    }
655    else
656    {
657        y0 = (v1 << 4) + 8;
658        y1 = (v0 << 4) - 8;
659    }
660
661    ep0 = ivec4(ivec3(y0), 0x780);
662    ep1 = ivec4(ivec3(y1), 0x780);
663}
664
665void decode_endpoint_hdr_luma_direct_small_range(out ivec4 ep0, out ivec4 ep1,
666        int v0, int v1)
667{
668    int y0, y1, d;
669
670    if ((v0 & 0x80) != 0)
671    {
672        y0 = ((v1 & 0xe0) << 4) | ((v0 & 0x7f) << 2);
673        d = (v1 & 0x1f) << 2;
674    }
675    else
676    {
677        y0 = ((v1 & 0xf0) << 4) | ((v0 & 0x7f) << 1);
678        d = (v1 & 0x0f)  << 1;
679    }
680
681    y1 = min(y0 + d, 0xfff);
682
683    ep0 = ivec4(ivec3(y0), 0x780);
684    ep1 = ivec4(ivec3(y1), 0x780);
685}
686
687void decode_endpoint_ldr_luma_base_offset(out ivec4 ep0, out ivec4 ep1,
688        int v0, int v1)
689{
690    int l0 = (v0 >> 2) | (v1 & 0xc0);
691    int l1 = l0 + (v1 & 0x3f);
692    l1 = min(l1, 0xff);
693    ep0 = ivec4(ivec3(l0), 0xff);
694    ep1 = ivec4(ivec3(l1), 0xff);
695}
696
697void decode_endpoint_ldr_luma_alpha_direct(out ivec4 ep0, out ivec4 ep1,
698    int v0, int v1, int v2, int v3)
699{
700    ep0 = ivec4(ivec3(v0), v2);
701    ep1 = ivec4(ivec3(v1), v3);
702}
703
704ivec4 blue_contract(int r, int g, int b, int a)
705{
706    ivec4 ret;
707    ret.r = (r + b) >> 1;
708    ret.g = (g + b) >> 1;
709    ret.b = b;
710    ret.a = a;
711    return ret;
712}
713
714void bit_transfer_signed(inout int a, inout int b)
715{
716    b >>= 1;
717    b |= a & 0x80;
718    a >>= 1;
719    a &= 0x3f;
720    a = bitfieldExtract(a, 0, 6);
721}
722
723void decode_endpoint_ldr_luma_alpha_base_offset(out ivec4 ep0, out ivec4 ep1,
724    int v0, int v1, int v2, int v3)
725{
726    bit_transfer_signed(v1, v0);
727    bit_transfer_signed(v3, v2);
728    int v0_v1 = clamp(v0 + v1, 0, 0xff);
729    int v2_v3 = clamp(v2 + v3, 0, 0xff);
730    v0 = clamp(v0, 0, 0xff);
731    v2 = clamp(v2, 0, 0xff);
732    ep0 = ivec4(ivec3(v0), v2);
733    ep1 = ivec4(ivec3(v0_v1), v2_v3);
734}
735
736void decode_endpoint_ldr_rgb_base_scale(out ivec4 ep0, out ivec4 ep1,
737        int v0, int v1, int v2, int v3)
738{
739    ep0 = ivec4((ivec3(v0, v1, v2) * v3) >> 8, 0xff);
740    ep1 = ivec4(v0, v1, v2, 0xff);
741}
742
743void decode_endpoint_ldr_rgb_base_scale_two_a(out ivec4 ep0, out ivec4 ep1,
744        int v0, int v1, int v2, int v3, int v4, int v5)
745{
746    ep0 = ivec4((ivec3(v0, v1, v2) * v3) >> 8, v4);
747    ep1 = ivec4(v0, v1, v2, v5);
748}
749
750void decode_endpoint_ldr_rgb_direct(out ivec4 ep0, out ivec4 ep1,
751        int v0, int v1, int v2, int v3, int v4, int v5)
752{
753    int s0 = v0 + v2 + v4;
754    int s1 = v1 + v3 + v5;
755    if (s1 >= s0)
756    {
757        ep0 = ivec4(v0, v2, v4, 0xff);
758        ep1 = ivec4(v1, v3, v5, 0xff);
759    }
760    else
761    {
762        ep0 = blue_contract(v1, v3, v5, 0xff);
763        ep1 = blue_contract(v0, v2, v4, 0xff);
764    }
765}
766
767void decode_endpoint_hdr_rgb_scale(out ivec4 ep0, out ivec4 ep1,
768    int v0, int v1, int v2, int v3)
769{
770    // Mind-numbing weird format, just copy from spec ...
771    int mode_value = ((v0 & 0xc0) >> 6) | ((v1 & 0x80) >> 5) | ((v2 & 0x80) >> 4);
772    int major_component;
773    int mode;
774
775    if ((mode_value & 0xc) != 0xc)
776    {
777        major_component = mode_value >> 2;
778        mode = mode_value & 3;
779    }
780    else if (mode_value != 0xf)
781    {
782        major_component = mode_value & 3;
783        mode = 4;
784    }
785    else
786    {
787        major_component = 0;
788        mode = 5;
789    }
790
791    int red = v0 & 0x3f;
792    int green = v1 & 0x1f;
793    int blue = v2 & 0x1f;
794    int scale = v3 & 0x1f;
795
796    int x0 = (v1 >> 6) & 1;
797    int x1 = (v1 >> 5) & 1;
798    int x2 = (v2 >> 6) & 1;
799    int x3 = (v2 >> 5) & 1;
800    int x4 = (v3 >> 7) & 1;
801    int x5 = (v3 >> 6) & 1;
802    int x6 = (v3 >> 5) & 1;
803
804    int ohm = 1 << mode;
805    if ((ohm & 0x30) != 0) green |= x0 << 6;
806    if ((ohm & 0x3a) != 0) green |= x1 << 5;
807    if ((ohm & 0x30) != 0) blue |= x2 << 6;
808    if ((ohm & 0x3a) != 0) blue |= x3 << 5;
809    if ((ohm & 0x3d) != 0) scale |= x6 << 5;
810    if ((ohm & 0x2d) != 0) scale |= x5 << 6;
811    if ((ohm & 0x04) != 0) scale |= x4 << 7;
812    if ((ohm & 0x3b) != 0) red |= x4 << 6;
813    if ((ohm & 0x04) != 0) red |= x3 << 6;
814    if ((ohm & 0x10) != 0) red |= x5 << 7;
815    if ((ohm & 0x0f) != 0) red |= x2 << 7;
816    if ((ohm & 0x05) != 0) red |= x1 << 8;
817    if ((ohm & 0x0a) != 0) red |= x0 << 8;
818    if ((ohm & 0x05) != 0) red |= x0 << 9;
819    if ((ohm & 0x02) != 0) red |= x6 << 9;
820    if ((ohm & 0x01) != 0) red |= x3 << 10;
821    if ((ohm & 0x02) != 0) red |= x5 << 10;
822
823    int shamt = max(mode, 1);
824    red <<= shamt;
825    green <<= shamt;
826    blue <<= shamt;
827    scale <<= shamt;
828
829    if (mode != 5)
830    {
831        green = red - green;
832        blue = red - blue;
833    }
834
835    if (major_component == 1)
836        swap(red, green);
837    else if (major_component == 2)
838        swap(red, blue);
839
840    ep1 = ivec4(clamp(ivec3(red, green, blue), ivec3(0), ivec3(0xfff)), 0x780);
841    ep0 = ivec4(clamp(ivec3(red, green, blue) - scale, ivec3(0), ivec3(0xfff)), 0x780);
842}
843
844void decode_endpoint_hdr_rgb_direct(out ivec4 ep0, out ivec4 ep1,
845        int v0, int v1, int v2, int v3, int v4, int v5)
846{
847    int major_component = ((v4 & 0x80) >> 7) | ((v5 & 0x80) >> 6);
848
849    if (major_component == 3)
850    {
851        ep0 = ivec4(v0 << 4, v2 << 4, (v4 & 0x7f) << 5, 0x780);
852        ep1 = ivec4(v1 << 4, v3 << 4, (v5 & 0x7f) << 5, 0x780);
853        return;
854    }
855
856    int mode = ((v1 & 0x80) >> 7) | ((v2 & 0x80) >> 6) | ((v3 & 0x80) >> 5);
857    int va = v0 | ((v1 & 0x40) << 2);
858    int vb0 = v2 & 0x3f;
859    int vb1 =  v3 & 0x3f;
860    int vc = v1 & 0x3f;
861    int vd0 = v4 & 0x7f;
862    int vd1 = v5 & 0x7f;
863
864    int d_bits = 7 - (mode & 1);
865    if ((mode & 5) == 4)
866        d_bits -= 2;
867
868    vd0 = bitfieldExtract(vd0, 0, d_bits);
869    vd1 = bitfieldExtract(vd1, 0, d_bits);
870
871    int x0 = (v2 >> 6) & 1;
872    int x1 = (v3 >> 6) & 1;
873    int x2 = (v4 >> 6) & 1;
874    int x3 = (v5 >> 6) & 1;
875    int x4 = (v4 >> 5) & 1;
876    int x5 = (v5 >> 5) & 1;
877
878    int ohm = 1 << mode;
879    if ((ohm & 0xa4) != 0) va |= x0 << 9;
880    if ((ohm & 0x08) != 0) va |= x2 << 9;
881    if ((ohm & 0x50) != 0) va |= x4 << 9;
882    if ((ohm & 0x50) != 0) va |= x5 << 10;
883    if ((ohm & 0xa0) != 0) va |= x1 << 10;
884    if ((ohm & 0xc0) != 0) va |= x2 << 11;
885
886    if ((ohm & 0x04) != 0) vc |= x1 << 6;
887    if ((ohm & 0xe8) != 0) vc |= x3 << 6;
888    if ((ohm & 0x20) != 0) vc |= x2 << 7;
889
890    if ((ohm & 0x5b) != 0) vb0 |= x0 << 6;
891    if ((ohm & 0x5b) != 0) vb1 |= x1 << 6;
892    if ((ohm & 0x12) != 0) vb0 |= x2 << 7;
893    if ((ohm & 0x12) != 0) vb1 |= x3 << 7;
894
895    int shamt = (mode >> 1) ^ 3;
896    va <<= shamt;
897    vb0 <<= shamt;
898    vb1 <<= shamt;
899    vc <<= shamt;
900    vd0 <<= shamt;
901    vd1 <<= shamt;
902
903    ep1 = ivec4(clamp(ivec3(va, va - vb0, va - vb1), ivec3(0), ivec3(0xfff)), 0x780);
904    ep0 = ivec4(clamp(ivec3(va - vc, va - vb0 - vc - vd0, va - vb1 - vc - vd1), ivec3(0), ivec3(0xfff)), 0x780);
905
906    if (major_component == 1)
907    {
908        swap(ep0.r, ep0.g);
909        swap(ep1.r, ep1.g);
910    }
911    else if (major_component == 2)
912    {
913        swap(ep0.r, ep0.b);
914        swap(ep1.r, ep1.b);
915    }
916}
917
918void decode_endpoint_ldr_rgb_base_offset(out ivec4 ep0, out ivec4 ep1,
919        int v0, int v1, int v2, int v3, int v4, int v5)
920{
921    bit_transfer_signed(v1, v0);
922    bit_transfer_signed(v3, v2);
923    bit_transfer_signed(v5, v4);
924    if (v1 + v3 + v5 >= 0)
925    {
926        ep0 = ivec4(v0, v2, v4, 0xff);
927        ep1 = ivec4(v0 + v1, v2 + v3, v4 + v5, 0xff);
928    }
929    else
930    {
931        ep0 = blue_contract(v0 + v1, v2 + v3, v4 + v5, 0xff);
932        ep1 = blue_contract(v0, v2, v4, 0xff);
933    }
934
935    ep0.rgb = clamp(ep0.rgb, ivec3(0), ivec3(0xff));
936    ep1.rgb = clamp(ep1.rgb, ivec3(0), ivec3(0xff));
937}
938
939void decode_endpoint_ldr_rgba_direct(out ivec4 ep0, out ivec4 ep1,
940        int v0, int v1, int v2, int v3,
941        int v4, int v5, int v6, int v7)
942{
943    int s0 = v0 + v2 + v4;
944    int s1 = v1 + v3 + v5;
945    if (s1 >= s0)
946    {
947        ep0 = ivec4(v0, v2, v4, v6);
948        ep1 = ivec4(v1, v3, v5, v7);
949    }
950    else
951    {
952        ep0 = blue_contract(v1, v3, v5, v7);
953        ep1 = blue_contract(v0, v2, v4, v6);
954    }
955}
956
957void decode_endpoint_ldr_rgba_base_offset(out ivec4 ep0, out ivec4 ep1,
958        int v0, int v1, int v2, int v3, int v4, int v5, int v6, int v7)
959{
960    bit_transfer_signed(v1, v0);
961    bit_transfer_signed(v3, v2);
962    bit_transfer_signed(v5, v4);
963    bit_transfer_signed(v7, v6);
964
965    if (v1 + v3 + v5 >= 0)
966    {
967        ep0 = ivec4(v0, v2, v4, v6);
968        ep1 = ivec4(v0 + v1, v2 + v3, v4 + v5, v6 + v7);
969    }
970    else
971    {
972        ep0 = blue_contract(v0 + v1, v2 + v3, v4 + v5, v6 + v7);
973        ep1 = blue_contract(v0, v2, v4, v6);
974    }
975
976    ep0 = clamp(ep0, ivec4(0), ivec4(0xff));
977    ep1 = clamp(ep1, ivec4(0), ivec4(0xff));
978}
979
980void decode_endpoint_hdr_alpha(out int ep0, out int ep1, int v6, int v7)
981{
982    int mode = ((v6 >> 7) & 1) | ((v7 >> 6) & 2);
983    v6 &= 0x7f;
984    v7 &= 0x7f;
985
986    if (mode == 3)
987    {
988        ep0 = v6 << 5;
989        ep1 = v7 << 5;
990    }
991    else
992    {
993        v6 |= (v7 << (mode + 1)) & 0x780;
994        v7 &= 0x3f >> mode;
995        v7 ^= 0x20 >> mode;
996        v7 -= 0x20 >> mode;
997        v6 <<= 4 - mode;
998        v7 <<= 4 - mode;
999        v7 += v6;
1000        v7 = clamp(v7, 0, 0xfff);
1001        ep0 = v6;
1002        ep1 = v7;
1003    }
1004}
1005
1006void decode_endpoint(out ivec4 ep0, out ivec4 ep1, out int decode_mode,
1007                     uvec4 payload, int bit_offset, ivec4 quant, int ep_mode,
1008                     int base_endpoint_index, int num_endpoint_bits)
1009{
1010    num_endpoint_bits += bit_offset;
1011    payload &= build_bitmask(num_endpoint_bits);
1012
1013    // Could of course use an array, but that doesn't lower nicely to indexed registers on all GPUs.
1014    int v0, v1, v2, v3, v4, v5, v6, v7;
1015    int num_values = 2 * ((ep_mode >> 2) + 1);
1016
1017#define DECODE_EP(i) \
1018    int(texelFetch(LUTEndpointUnquantize, quant.w + decode_integer_sequence(payload, bit_offset, i + base_endpoint_index, quant.xyz)).x)
1019
1020    int hi_bits = ep_mode >> 2;
1021    v0 = DECODE_EP(0);
1022    v1 = DECODE_EP(1);
1023
1024    if (hi_bits >= 1)
1025    {
1026        v2 = DECODE_EP(2);
1027        v3 = DECODE_EP(3);
1028    }
1029
1030    if (hi_bits >= 2)
1031    {
1032        v4 = DECODE_EP(4);
1033        v5 = DECODE_EP(5);
1034    }
1035
1036    if (hi_bits >= 3)
1037    {
1038        v6 = DECODE_EP(6);
1039        v7 = DECODE_EP(7);
1040    }
1041
1042    switch (ep_mode)
1043    {
1044    case 0:
1045        decode_endpoint_ldr_luma_direct(ep0, ep1,
1046            v0, v1);
1047        decode_mode = MODE_LDR;
1048        break;
1049
1050    case 1:
1051        decode_endpoint_ldr_luma_base_offset(ep0, ep1,
1052            v0, v1);
1053        decode_mode = MODE_LDR;
1054        break;
1055
1056    case 2:
1057        decode_endpoint_hdr_luma_direct(ep0, ep1,
1058            v0, v1);
1059        decode_mode = MODE_HDR;
1060        break;
1061
1062    case 3:
1063        decode_endpoint_hdr_luma_direct_small_range(ep0, ep1,
1064            v0, v1);
1065        decode_mode = MODE_HDR;
1066        break;
1067
1068    case 4:
1069        decode_endpoint_ldr_luma_alpha_direct(ep0, ep1,
1070            v0, v1, v2, v3);
1071        decode_mode = MODE_LDR;
1072        break;
1073
1074    case 5:
1075        decode_endpoint_ldr_luma_alpha_base_offset(ep0, ep1,
1076            v0, v1, v2, v3);
1077        decode_mode = MODE_LDR;
1078        break;
1079
1080    case 6:
1081        decode_endpoint_ldr_rgb_base_scale(ep0, ep1,
1082            v0, v1, v2, v3);
1083        decode_mode = MODE_LDR;
1084        break;
1085
1086    case 7:
1087        decode_endpoint_hdr_rgb_scale(ep0, ep1,
1088            v0, v1, v2, v3);
1089        decode_mode = MODE_HDR;
1090        break;
1091
1092    case 8:
1093        decode_endpoint_ldr_rgb_direct(ep0, ep1,
1094            v0, v1, v2, v3, v4, v5);
1095        decode_mode = MODE_LDR;
1096        break;
1097
1098    case 9:
1099        decode_endpoint_ldr_rgb_base_offset(ep0, ep1,
1100            v0, v1, v2, v3, v4, v5);
1101        decode_mode = MODE_LDR;
1102        break;
1103
1104    case 10:
1105        decode_endpoint_ldr_rgb_base_scale_two_a(ep0, ep1,
1106            v0, v1, v2, v3, v4, v5);
1107        decode_mode = MODE_LDR;
1108        break;
1109
1110    case 11:
1111    case 14:
1112    case 15:
1113        decode_endpoint_hdr_rgb_direct(ep0, ep1,
1114            v0, v1, v2, v3, v4, v5);
1115        if (ep_mode == 14)
1116        {
1117            ep0.a = v6;
1118            ep1.a = v7;
1119            decode_mode = MODE_HDR_LDR_ALPHA;
1120        }
1121        else if (ep_mode == 15)
1122        {
1123            decode_endpoint_hdr_alpha(ep0.a, ep1.a, v6, v7);
1124            decode_mode = MODE_HDR;
1125        }
1126        else
1127            decode_mode = MODE_HDR;
1128        break;
1129
1130    case 12:
1131        decode_endpoint_ldr_rgba_direct(ep0, ep1,
1132            v0, v1, v2, v3, v4, v5, v6, v7);
1133        decode_mode = MODE_LDR;
1134        break;
1135
1136    case 13:
1137        decode_endpoint_ldr_rgba_base_offset(ep0, ep1,
1138            v0, v1, v2, v3, v4, v5, v6, v7);
1139        decode_mode = MODE_LDR;
1140        break;
1141    }
1142
1143    if (DECODE_8BIT && decode_mode != MODE_LDR)
1144        decode_error = true;
1145}
1146
1147#define CHECK_DECODE_ERROR() do { \
1148    if (decode_error) \
1149    { \
1150        emit_decode_error(coord.xy); \
1151        return; \
1152    } \
1153} while(false)
1154
1155void emit_decode_error(ivec2 coord)
1156{
1157#ifdef VULKAN
1158    if (is_3Dimage)
1159        imageStore(OutputImage3D, ivec3(coord, gl_WorkGroupID.z), error_color);
1160    else
1161        imageStore(OutputImage2Darray, ivec3(coord, gl_WorkGroupID.z), error_color);
1162#else /* VULKAN */
1163    imageStore(OutputImage, coord, error_color);
1164#endif /* VULKAN */
1165}
1166
1167int compute_num_endpoint_pairs(int num_partitions, int cem)
1168{
1169    int ret;
1170    if (num_partitions > 1)
1171    {
1172        bool single_cem = (cem & 3) == 0;
1173        if (single_cem)
1174            ret = ((cem >> 4) + 1) * num_partitions;
1175        else
1176            ret = (cem & 3) * num_partitions + bitCount(bitfieldExtract(uint(cem), 2, num_partitions));
1177    }
1178    else
1179    {
1180        ret = (cem >> 2) + 1;
1181    }
1182    return ret;
1183}
1184
1185void decode_cem_base_endpoint(uvec4 payload, int weight_cost_bits, inout int cem, out int base_endpoint_index,
1186    int num_partitions, int partition_index)
1187{
1188    if (num_partitions > 1)
1189    {
1190        bool single_cem = (cem & 3) == 0;
1191        if (single_cem)
1192        {
1193            cem >>= 2;
1194            base_endpoint_index = ((cem >> 2) + 1) * partition_index;
1195        }
1196        else
1197        {
1198            if (partition_index != 0)
1199                base_endpoint_index = (cem & 3) * partition_index + bitCount(bitfieldExtract(uint(cem), 2, partition_index));
1200            else
1201                base_endpoint_index = 0;
1202
1203            int base_class = (cem & 3) - 1;
1204            int extra_cem_bits = num_partitions * 3 - 4;
1205            int extra_bits = extract_bits(payload, 128 - weight_cost_bits - extra_cem_bits, extra_cem_bits);
1206            cem = (extra_bits << 4) | (cem >> 2);
1207
1208            int class_offset_bit = (cem >> partition_index) & 1;
1209            int ep_bits = (cem >> (num_partitions + 2 * partition_index)) & 3;
1210
1211            cem = 4 * (base_class + class_offset_bit) + ep_bits;
1212        }
1213        base_endpoint_index *= 2;
1214    }
1215    else
1216    {
1217        base_endpoint_index = 0;
1218    }
1219}
1220
1221ivec4 void_extent_color(uvec4 payload, out int decode_mode)
1222{
1223    int min_s = extract_bits(payload, 12, 13);
1224    int max_s = extract_bits(payload, 12 + 13, 13);
1225    int min_t = extract_bits(payload, 12 + 2 * 13, 13);
1226    int max_t = extract_bits(payload, 12 + 3 * 13, 13);
1227
1228    int reserved = extract_bits(payload, 10, 2);
1229    if (reserved != 3)
1230    {
1231        decode_error = true;
1232        return ivec4(0);
1233    }
1234
1235    if (!all(equal(ivec4(min_s, max_s, min_t, max_t), ivec4((1 << 13) - 1))))
1236    {
1237        if (any(greaterThanEqual(ivec2(min_s, min_t), ivec2(max_s, max_t))))
1238        {
1239            decode_error = true;
1240            return ivec4(0);
1241        }
1242    }
1243
1244    decode_mode = (payload.x & (1u << 9)) != 0u ? MODE_HDR : MODE_LDR;
1245
1246    int r = extract_bits(payload, 64, 16);
1247    int g = extract_bits(payload, 64 + 16, 16);
1248    int b = extract_bits(payload, 64 + 32, 16);
1249    int a = extract_bits(payload, 64 + 48, 16);
1250
1251    return ivec4(r, g, b, a);
1252}
1253
1254void main()
1255{
1256    ivec4 coord = build_coord();
1257#ifdef VULKAN
1258    if (any(greaterThanEqual(coord.xy, texel_end.xy)))
1259        return;
1260#else /* VULKAN */
1261    if (any(greaterThanEqual(coord.xy, imageSize(OutputImage))))
1262        return;
1263#endif /* VULKAN */
1264
1265    ivec2 pixel_coord = ivec2(gl_LocalInvocationID.xy);
1266    int linear_pixel = int(gl_WorkGroupSize.x) * pixel_coord.y + pixel_coord.x;
1267    uvec4 payload;
1268#ifdef VULKAN
1269    if (is_3Dimage)
1270        payload = texelFetch(PayloadInput3D, ivec3(coord.zw, gl_WorkGroupID.z), 0);
1271    else
1272        payload = texelFetch(PayloadInput2Darray,ivec3(coord.zw, gl_WorkGroupID.z), 0);
1273#else /* VULKAN */
1274    payload = texelFetch(PayloadInput, coord.zw, 0);
1275#endif /* VULKAN */
1276
1277    BlockMode block_mode = decode_block_mode(payload);
1278    CHECK_DECODE_ERROR();
1279
1280    ivec4 final_color;
1281    int decode_mode;
1282    if (block_mode.void_extent)
1283    {
1284        final_color = void_extent_color(payload, decode_mode);
1285        CHECK_DECODE_ERROR();
1286    }
1287    else
1288    {
1289        int weight_cost_bits;
1290        ivec4 weights = decode_weights(payload, block_mode, normalize_coord(pixel_coord), weight_cost_bits);
1291
1292        int partition_index = 0;
1293        if (block_mode.num_partitions > 1)
1294        {
1295            int lut_x = pixel_coord.x + int(gl_WorkGroupSize.x) * (block_mode.seed & 31);
1296            int lut_y = pixel_coord.y + int(gl_WorkGroupSize.y) * (block_mode.seed >> 5);
1297#ifdef VULKAN
1298            int lut_width = int(gl_WorkGroupSize.x) * 32;
1299            partition_index = int(texelFetch(LUTPartitionTable, lut_y * lut_width + lut_x).x);
1300#else /* VULKAN */
1301            partition_index = int(texelFetch(LUTPartitionTable, ivec2(lut_x, lut_y), 0).x);
1302#endif /* VULKAN */
1303            partition_index = (partition_index >> (2 * block_mode.num_partitions - 4)) & 3;
1304        }
1305
1306        int available_endpoint_bits = max(128 - block_mode.config_bits - weight_cost_bits, 0);
1307
1308        // In multi-partition mode, the 6-bit CEM field is encoded as
1309        // First two bits tell if all CEM field are the same, if not we specify a class offset, and N bits
1310        // after that will offset the class by 1.
1311        int num_endpoint_pairs = compute_num_endpoint_pairs(block_mode.num_partitions, block_mode.cem);
1312
1313        // Error color must be emitted if we need more than 18 integer sequence encoded values of color.
1314        if (num_endpoint_pairs > 9)
1315        {
1316            decode_error = true;
1317            emit_decode_error(coord.xy);
1318            return;
1319        }
1320
1321        ivec4 endpoint_quant = ivec4(texelFetch(LUTRemainingBitsToEndpointQuantizer,
1322                128 * (num_endpoint_pairs - 1) + available_endpoint_bits));
1323
1324        // Only read the bits we need for endpoints.
1325        int num_endpoint_values = num_endpoint_pairs * 2;
1326        available_endpoint_bits =
1327            endpoint_quant.x * num_endpoint_values +
1328            idiv5_ceil(endpoint_quant.y * 8 * num_endpoint_values) +
1329            idiv3_ceil(endpoint_quant.z * 7 * num_endpoint_values);
1330
1331        // No space left for color endpoints.
1332        if (all(equal(endpoint_quant.xyz, ivec3(0))))
1333        {
1334            decode_error = true;
1335            emit_decode_error(coord.xy);
1336            return;
1337        }
1338
1339        int endpoint_bit_offset = block_mode.primary_config_bits;
1340        ivec4 ep0, ep1;
1341
1342        // Decode CEM for multi-partition schemes.
1343        int cem = block_mode.cem;
1344        int base_endpoint_index;
1345        decode_cem_base_endpoint(payload, weight_cost_bits, cem, base_endpoint_index,
1346                                 block_mode.num_partitions, partition_index);
1347
1348        decode_endpoint(ep0, ep1, decode_mode, payload, endpoint_bit_offset, endpoint_quant,
1349                        cem, base_endpoint_index, available_endpoint_bits);
1350        CHECK_DECODE_ERROR();
1351
1352        final_color = interpolate_endpoint(ep0, ep1, weights, decode_mode);
1353    }
1354
1355    if (DECODE_8BIT)
1356    {
1357#ifdef VULKAN
1358        if (is_3Dimage)
1359            imageStore(OutputImage3D, ivec3(coord.xy, gl_WorkGroupID.z), uvec4(final_color >> 8));
1360        else
1361            imageStore(OutputImage2Darray, ivec3(coord.xy, gl_WorkGroupID.z), uvec4(final_color >> 8));
1362#else /* VULKAN */
1363        imageStore(OutputImage, coord.xy, uvec4(final_color >> 8));
1364#endif /* VULKAN */
1365    }
1366    else
1367    {
1368        uvec4 encoded;
1369        if (block_mode.void_extent && decode_mode == MODE_HDR)
1370            encoded = uvec4(final_color);
1371        else
1372            encoded = decode_fp16(final_color, decode_mode);
1373#ifdef VULKAN
1374        if (is_3Dimage)
1375            imageStore(OutputImage3D, ivec3(coord.xy, gl_WorkGroupID.z), encoded);
1376        else
1377            imageStore(OutputImage2Darray, ivec3(coord.xy, gl_WorkGroupID.z), encoded);
1378#else /* VULKAN */
1379        imageStore(OutputImage, coord.xy, encoded);
1380#endif /* VULKAN */
1381    }
1382}
1383