xref: /aosp_15_r20/external/skia/tests/sksl/compute/MatrixMultiply.metal (revision c8dee2aa9b3f27cf6c858bd81872bdeb2c07ed17)
1*c8dee2aaSAndroid Build Coastguard Worker#include <metal_stdlib>
2*c8dee2aaSAndroid Build Coastguard Worker#include <simd/simd.h>
3*c8dee2aaSAndroid Build Coastguard Worker#ifdef __clang__
4*c8dee2aaSAndroid Build Coastguard Worker#pragma clang diagnostic ignored "-Wall"
5*c8dee2aaSAndroid Build Coastguard Worker#endif
6*c8dee2aaSAndroid Build Coastguard Workerusing namespace metal;
7*c8dee2aaSAndroid Build Coastguard Workerstruct Inputs {
8*c8dee2aaSAndroid Build Coastguard Worker    uint3 sk_GlobalInvocationID;
9*c8dee2aaSAndroid Build Coastguard Worker};
10*c8dee2aaSAndroid Build Coastguard Workerstruct sizeBuffer {
11*c8dee2aaSAndroid Build Coastguard Worker    int2 sizes[1];
12*c8dee2aaSAndroid Build Coastguard Worker};
13*c8dee2aaSAndroid Build Coastguard Workerstruct inputs1 {
14*c8dee2aaSAndroid Build Coastguard Worker    float data1[1];
15*c8dee2aaSAndroid Build Coastguard Worker};
16*c8dee2aaSAndroid Build Coastguard Workerstruct inputs2 {
17*c8dee2aaSAndroid Build Coastguard Worker    float data2[1];
18*c8dee2aaSAndroid Build Coastguard Worker};
19*c8dee2aaSAndroid Build Coastguard Workerstruct result {
20*c8dee2aaSAndroid Build Coastguard Worker    float resultData[1];
21*c8dee2aaSAndroid Build Coastguard Worker};
22*c8dee2aaSAndroid Build Coastguard Workerstruct Globals {
23*c8dee2aaSAndroid Build Coastguard Worker    device sizeBuffer* _anonInterface0;
24*c8dee2aaSAndroid Build Coastguard Worker    const device inputs1* _anonInterface1;
25*c8dee2aaSAndroid Build Coastguard Worker    const device inputs2* _anonInterface2;
26*c8dee2aaSAndroid Build Coastguard Worker    device result* _anonInterface3;
27*c8dee2aaSAndroid Build Coastguard Worker};
28*c8dee2aaSAndroid Build Coastguard Workerkernel void computeMain(uint3 sk_GlobalInvocationID [[thread_position_in_grid]], device sizeBuffer& _anonInterface0 [[buffer(0)]], const device inputs1& _anonInterface1 [[buffer(1)]], const device inputs2& _anonInterface2 [[buffer(2)]], device result& _anonInterface3 [[buffer(3)]]) {
29*c8dee2aaSAndroid Build Coastguard Worker    Globals _globals{&_anonInterface0, &_anonInterface1, &_anonInterface2, &_anonInterface3};
30*c8dee2aaSAndroid Build Coastguard Worker    (void)_globals;
31*c8dee2aaSAndroid Build Coastguard Worker    Inputs _in = { sk_GlobalInvocationID };
32*c8dee2aaSAndroid Build Coastguard Worker    _globals._anonInterface0->sizes[2] = int2(_globals._anonInterface0->sizes[0].x, _globals._anonInterface0->sizes[1].y);
33*c8dee2aaSAndroid Build Coastguard Worker    int2 resultCell = int2(int(_in.sk_GlobalInvocationID.x), int(_in.sk_GlobalInvocationID.y));
34*c8dee2aaSAndroid Build Coastguard Worker    float result = 0.0;
35*c8dee2aaSAndroid Build Coastguard Worker    for (int i = 0;i < _globals._anonInterface0->sizes[0].y; ++i) {
36*c8dee2aaSAndroid Build Coastguard Worker        int a = i + resultCell.x * _globals._anonInterface0->sizes[0].y;
37*c8dee2aaSAndroid Build Coastguard Worker        int b = resultCell.y + i * _globals._anonInterface0->sizes[1].y;
38*c8dee2aaSAndroid Build Coastguard Worker        result += _globals._anonInterface1->data1[a] * _globals._anonInterface2->data2[b];
39*c8dee2aaSAndroid Build Coastguard Worker    }
40*c8dee2aaSAndroid Build Coastguard Worker    int index = resultCell.y + resultCell.x * _globals._anonInterface0->sizes[1].y;
41*c8dee2aaSAndroid Build Coastguard Worker    _globals._anonInterface3->resultData[index] = result;
42*c8dee2aaSAndroid Build Coastguard Worker    return;
43*c8dee2aaSAndroid Build Coastguard Worker}
44