xref: /aosp_15_r20/external/skia/tests/sksl/compute/Workgroup.metal (revision c8dee2aa9b3f27cf6c858bd81872bdeb2c07ed17)
1#include <metal_stdlib>
2#include <simd/simd.h>
3#ifdef __clang__
4#pragma clang diagnostic ignored "-Wall"
5#endif
6using namespace metal;
7struct Inputs {
8    uint3 sk_GlobalInvocationID;
9};
10struct inputs {
11    float in_data[1];
12};
13struct outputs {
14    float out_data[1];
15};
16struct Globals {
17    const device inputs* _anonInterface0;
18    device outputs* _anonInterface1;
19};
20struct Threadgroups {
21    array<float, 512> shared_data;
22};
23void store_vIf(threadgroup Threadgroups& _threadgroups, uint i, float value) {
24    _threadgroups.shared_data[i] = value;
25}
26kernel void computeMain(uint3 sk_GlobalInvocationID [[thread_position_in_grid]], const device inputs& _anonInterface0 [[buffer(0)]], device outputs& _anonInterface1 [[buffer(1)]]) {
27    Globals _globals{&_anonInterface0, &_anonInterface1};
28    (void)_globals;
29    threadgroup Threadgroups _threadgroups{{}};
30    (void)_threadgroups;
31    Inputs _in = { sk_GlobalInvocationID };
32    uint id = _in.sk_GlobalInvocationID.x;
33    uint rd_id;
34    uint wr_id;
35    uint mask;
36    _threadgroups.shared_data[id * 2u] = _globals._anonInterface0->in_data[id * 2u];
37    _threadgroups.shared_data[id * 2u + 1u] = _globals._anonInterface0->in_data[id * 2u + 1u];
38    threadgroup_barrier(mem_flags::mem_threadgroup);
39    const uint steps = 9u;
40    for (uint _0_step = 0u;_0_step < steps; _0_step++) {
41        mask = (1u << _0_step) - 1u;
42        rd_id = ((id >> _0_step) << _0_step + 1u) + mask;
43        wr_id = (rd_id + 1u) + (id & mask);
44        store_vIf(_threadgroups, wr_id, _threadgroups.shared_data[wr_id] + _threadgroups.shared_data[rd_id]);
45        threadgroup_barrier(mem_flags::mem_threadgroup);
46    }
47    _globals._anonInterface1->out_data[id * 2u] = _threadgroups.shared_data[id * 2u];
48    _globals._anonInterface1->out_data[id * 2u + 1u] = _threadgroups.shared_data[id * 2u + 1u];
49    return;
50}
51