xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/metal/MetalShaders.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #ifndef MPSCNNShaders_h
2 #define MPSCNNShaders_h
3 
4 static const char* PT_METAL_SHADERS = R"PT_METAL_SHADERS(
5 #include <metal_stdlib>
6 using namespace metal;
7 
8 constant ushort ushort_arg_0[[function_constant(0)]];
9 constant ushort ushort_arg_1[[function_constant(1)]];
10 constant ushort ushort_arg_2[[function_constant(2)]];
11 constant ushort ushort_arg_3[[function_constant(3)]];
12 constant ushort ushort_arg_4[[function_constant(4)]];
13 constant ushort ushort_arg_5[[function_constant(5)]];
14 constant ushort ushort_arg_6[[function_constant(6)]];
15 constant ushort ushort_arg_7[[function_constant(7)]];
16 constant ushort ushort_arg_8[[function_constant(8)]];
17 constant ushort ushort_arg_9[[function_constant(9)]];
18 constant ushort ushort_arg_10[[function_constant(10)]];
19 constant ushort ushort_arg_11[[function_constant(11)]];
20 constant float float_arg_0 [[function_constant(12)]];
21 constant float float_arg_1 [[function_constant(13)]];
22 
23 inline constexpr ushort divRoundUp(ushort x, ushort y) { return (x + (y - 1)) / y; }
24 
25 enum broadcastOp {
26     Add,
27     Sub,
28     Mul,
29     Div,
30 };
31 
32 void elementwise_broadcast_nonarray(texture2d<half, access::read> in0,
33                                    texture2d<half, access::read> in1,
34                                    texture2d<half, access::write> out,
35                                    ushort2 gid,
36                                    broadcastOp op) {
37     if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
38         return;
39     }
40     ushort2 in0_stride = ushort2(in0.get_width() > 1, in0.get_height() > 1);
41     ushort2 in1_stride = ushort2(in1.get_width() > 1, in1.get_height() > 1);
42 
43     ushort2 gid0 = gid.xy * in0_stride;
44     ushort2 gid1 = gid.xy * in1_stride;
45 
46     if(op == Add) {
47         out.write(in0.read(gid0) + in1.read(gid1), gid);
48     } else if(op == Sub) {
49         out.write(in0.read(gid0) - in1.read(gid1), gid);
50     } else if(op == Mul) {
51         out.write(in0.read(gid0) * in1.read(gid1), gid);
52     } else if(op == Div) {
53         out.write(in0.read(gid0) / in1.read(gid1), gid);
54     }
55 }
56 
57 void elementwise_broadcast(texture2d_array<half, access::read> in0,
58                            texture2d_array<half, access::read> in1,
59                            texture2d_array<half, access::write> out,
60                            ushort3 gid,
61                            broadcastOp op) {
62     if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
63         return;
64     }
65 
66     ushort2 in0_stride = ushort2(in0.get_width() > 1, in0.get_height() > 1);
67     ushort2 in1_stride = ushort2(in1.get_width() > 1, in1.get_height() > 1);
68 
69     ushort2 gid0 = gid.xy * in0_stride;
70     ushort2 gid1 = gid.xy * in1_stride;
71 
72     if(op == Add) {
73         out.write(in0.read(gid0, gid.z) + in1.read(gid1, gid.z), gid.xy, gid.z);
74     } else if(op == Sub) {
75         out.write(in0.read(gid0, gid.z) - in1.read(gid1, gid.z), gid.xy, gid.z);
76     } else if(op == Mul) {
77         out.write(in0.read(gid0, gid.z) * in1.read(gid1, gid.z), gid.xy, gid.z);
78     } else if(op == Div) {
79         out.write(in0.read(gid0, gid.z) / in1.read(gid1, gid.z), gid.xy, gid.z);
80     }
81 }
82 
83 kernel void elementwise_add_nonarray(texture2d<half, access::read> in0[[texture(0)]],
84                                      texture2d<half, access::read> in1[[texture(1)]],
85                                      texture2d<half, access::write> out[[texture(2)]],
86                                      ushort2 gid[[thread_position_in_grid]]) {
87     elementwise_broadcast_nonarray(in0, in1, out, gid, Add);
88 }
89 
90 kernel void elementwise_add(texture2d_array<half, access::read> in0[[texture(0)]],
91                             texture2d_array<half, access::read> in1[[texture(1)]],
92                             texture2d_array<half, access::write> out[[texture(2)]],
93                             ushort3 gid[[thread_position_in_grid]]) {
94     elementwise_broadcast(in0, in1, out, gid, Add);
95 }
96 
97 kernel void elementwise_sub_nonarray(texture2d<half, access::read> in0[[texture(0)]],
98                                      texture2d<half, access::read> in1[[texture(1)]],
99                                      texture2d<half, access::write> out[[texture(2)]],
100                                      ushort2 gid[[thread_position_in_grid]]) {
101     elementwise_broadcast_nonarray(in0, in1, out, gid, Sub);
102 }
103 
104 kernel void elementwise_sub(texture2d_array<half, access::read> in0[[texture(0)]],
105                             texture2d_array<half, access::read> in1[[texture(1)]],
106                             texture2d_array<half, access::write> out[[texture(2)]],
107                             ushort3 gid[[thread_position_in_grid]]) {
108     elementwise_broadcast(in0, in1, out, gid, Sub);
109 }
110 
111 kernel void elementwise_mul_nonarray(texture2d<half, access::read> in0[[texture(0)]],
112                                      texture2d<half, access::read> in1[[texture(1)]],
113                                      texture2d<half, access::write> out[[texture(2)]],
114                                      ushort2 gid[[thread_position_in_grid]]) {
115     elementwise_broadcast_nonarray(in0, in1, out, gid, Mul);
116 }
117 
118 kernel void elementwise_mul(texture2d_array<half, access::read> in0[[texture(0)]],
119                             texture2d_array<half, access::read> in1[[texture(1)]],
120                             texture2d_array<half, access::write> out[[texture(2)]],
121                             ushort3 gid[[thread_position_in_grid]]) {
122     elementwise_broadcast(in0, in1, out, gid, Mul);
123 }
124 
125 kernel void elementwise_div_nonarray(texture2d<half, access::read> in0[[texture(0)]],
126                                      texture2d<half, access::read> in1[[texture(1)]],
127                                      texture2d<half, access::write> out[[texture(2)]],
128                                      ushort2 gid[[thread_position_in_grid]]) {
129     elementwise_broadcast_nonarray(in0, in1, out, gid, Div);
130 }
131 
132 kernel void elementwise_div(texture2d_array<half, access::read> in0[[texture(0)]],
133                             texture2d_array<half, access::read> in1[[texture(1)]],
134                             texture2d_array<half, access::write> out[[texture(2)]],
135                             ushort3 gid[[thread_position_in_grid]]) {
136     elementwise_broadcast(in0, in1, out, gid, Div);
137 }
138 
139 kernel void copy_nchw_to_metal(constant float* in[[buffer(0)]],
140                                texture2d_array<half, access::write> out[[texture(0)]],
141                                ushort3 gid[[thread_position_in_grid]]) {
142     const ushort C = ushort_arg_0;
143     const ushort H = ushort_arg_1;
144     const ushort W = ushort_arg_2;
145     if (gid.x >= W || gid.y >= H) {
146         return;
147     }
148     const ushort n = gid.z / divRoundUp(C, 4);
149     const ushort c = gid.z - n * divRoundUp(C, 4);
150 #define CHW_TO_CHWP4(idx, n, c_, h, w)                                     \
151 if ((c_) < C) {                                                          \
152 trns[idx] = in[n * H * W * C + int(c_) * H * W + int(h) * W + int(w)]; \
153 } else {                                                                 \
154 trns[idx] = 0.0h;                                                      \
155 }
156     half4 trns;
157     CHW_TO_CHWP4(0, n, c * 4 + 0, gid.y, gid.x);
158     CHW_TO_CHWP4(1, n, c * 4 + 1, gid.y, gid.x);
159     CHW_TO_CHWP4(2, n, c * 4 + 2, gid.y, gid.x);
160     CHW_TO_CHWP4(3, n, c * 4 + 3, gid.y, gid.x);
161 #undef CHW_TO_CHWP4
162     out.write(trns, gid.xy, gid.z);
163 }
164 
165 kernel void copy_nchw_to_metal_nonarray(constant float* in[[buffer(0)]],
166                                         texture2d<half, access::write> out[[texture(0)]],
167                                         ushort2 gid[[thread_position_in_grid]]) {
168     const ushort C = ushort_arg_0;
169     const ushort H = ushort_arg_1;
170     const ushort W = ushort_arg_2;
171     if (gid.x >= W || gid.y >= H) {
172         return;
173     }
174     half4 trns;
175 #define CHW_TO_CHWP4(idx, c, h, w)                        \
176 if ((c) < C) {                                          \
177 trns[idx] = in[int(c) * H * W + int(h) * W + int(w)]; \
178 } else {                                                \
179 trns[idx] = 0.0h;                                     \
180 }
181     CHW_TO_CHWP4(0, 0, gid.y, gid.x);
182     CHW_TO_CHWP4(1, 1, gid.y, gid.x);
183     CHW_TO_CHWP4(2, 2, gid.y, gid.x);
184     CHW_TO_CHWP4(3, 3, gid.y, gid.x);
185 #undef CHW_TO_CHWP4
186     out.write(trns, gid.xy);
187 }
188 
189 kernel void copy_metal_to_nchw(texture2d_array<half, access::read> in[[texture(0)]],
190                                device float* out[[buffer(0)]],
191                                ushort3 gid[[thread_position_in_grid]]) {
192     const ushort C = ushort_arg_0;
193     const ushort H = ushort_arg_1;
194     const ushort W = ushort_arg_2;
195     if (gid.x >= W || gid.y >= H) {
196         return;
197     }
198     const ushort n = gid.z / divRoundUp(C, 4);
199     const ushort c = gid.z - n * divRoundUp(C, 4);
200     half4 cs = in.read(gid.xy, gid.z);
201 #define CHWP4_TO_CHW(idx, n, c_, h, w)                                    \
202 if ((c_) < C) {                                                         \
203 out[n * H * W * C + int(c_) * H * W + int(h) * W + int(w)] = cs[idx]; \
204 }
205     CHWP4_TO_CHW(0, n, c * 4 + 0, gid.y, gid.x);
206     CHWP4_TO_CHW(1, n, c * 4 + 1, gid.y, gid.x);
207     CHWP4_TO_CHW(2, n, c * 4 + 2, gid.y, gid.x);
208     CHWP4_TO_CHW(3, n, c * 4 + 3, gid.y, gid.x);
209 #undef CHWP4_TO_CHW
210 }
211 
212 kernel void copy_metal_to_nchw_nonarray(texture2d<half, access::read> in[[texture(0)]],
213                                         device float* out[[buffer(0)]],
214                                         ushort2 gid[[thread_position_in_grid]]) {
215     const ushort C = ushort_arg_0;
216     const ushort H = ushort_arg_1;
217     const ushort W = ushort_arg_2;
218     if (gid.x >= W || gid.y >= H) {
219         return;
220     }
221     half4 cs = in.read(gid.xy);
222 #define CHWP4_TO_CHW(idx, c, h, w)                       \
223 if ((c) < C) {                                         \
224 out[int(c) * H * W + int(h) * W + int(w)] = cs[idx]; \
225 }
226     CHWP4_TO_CHW(0, 0, gid.y, gid.x);
227     CHWP4_TO_CHW(1, 1, gid.y, gid.x);
228     CHWP4_TO_CHW(2, 2, gid.y, gid.x);
229     CHWP4_TO_CHW(3, 3, gid.y, gid.x);
230 #undef CHWP4_TO_CHW
231 }
232 
233 kernel void copy(texture2d_array<half, access::read> in[[texture(0)]],
234                  texture2d_array<half, access::write> out[[texture(1)]],
235                  ushort3 gid[[thread_position_in_grid]]) {
236     if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
237         return;
238     }
239     ushort2 gid_ = gid.xy;
240     out.write(in.read(gid_, gid.z), gid_, gid.z);
241 }
242 
243 kernel void copy_nonarray(texture2d<half, access::read> in[[texture(0)]],
244                           texture2d<half, access::write> out[[texture(1)]],
245                           ushort2 gid[[thread_position_in_grid]]) {
246     if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
247         return;
248     }
249     out.write(in.read(gid), gid);
250 }
251 
252 kernel void copy_offset(texture2d_array<half, access::read> in[[texture(0)]],
253                         texture2d_array<half, access::write> out[[texture(1)]],
254                         constant ushort* offset_buf[[buffer(0)]],
255                         ushort3 gid[[thread_position_in_grid]]) {
256     if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
257         return;
258     }
259     ushort2 gid_ = gid.xy;
260     out.write(in.read(gid_, gid.z), gid_, gid.z + offset_buf[0]);
261 }
262 
263 kernel void copy_offset_nonarray(texture2d<half, access::read> in[[texture(0)]],
264                                  texture2d_array<half, access::write> out[[texture(1)]],
265                                  constant ushort* offset_buf[[buffer(0)]],
266                                  ushort3 gid[[thread_position_in_grid]]) {
267     if (gid.x >= out.get_width() || gid.y >= out.get_height()) {
268         return;
269     }
270     ushort2 gid_ = gid.xy;
271     out.write(in.read(gid_), gid_, gid.z + offset_buf[0]);
272 }
273 
274 constant bool store_features_out_is_arr = (ushort_arg_3 > 1 || ushort_arg_2 > 4);
275 constant bool store_features_out_is_tex = !store_features_out_is_arr;
276 kernel void store_features(texture2d_array<half, access::read> in[[texture(0)]],
277                            texture2d<half, access::write> out_tex[[texture(1), function_constant(store_features_out_is_tex)]],
278                            texture2d_array<half, access::write> out_arr[[texture(1), function_constant(store_features_out_is_arr)]],
279                            constant ushort* offset_buf[[buffer(0)]],
280                            ushort3 gid[[thread_position_in_grid]]) {
281     ushort2 gid_ = gid.xy;
282     if (store_features_out_is_arr)
283       out_arr.write(in.read(gid_, gid.z * offset_buf[1] + offset_buf[0]), gid_, gid.z);
284     else
285       out_tex.write(in.read(gid_, gid.z * offset_buf[1] + offset_buf[0]), gid_);
286 }
287 
288 constant bool append_features_in_is_arr = (ushort_arg_7 > 1 || ushort_arg_6 > 4);
289 constant bool append_features_in_is_tex = !append_features_in_is_arr;
290 kernel void append_features(texture2d<half, access::read> in_tex[[texture(0), function_constant(append_features_in_is_tex)]],
291                             texture2d_array<half, access::read> in_arr[[texture(0), function_constant(append_features_in_is_arr)]],
292                             texture2d_array<half, access::write> out[[texture(1)]],
293                             constant ushort* offset_buf[[buffer(0)]],
294                             ushort3 gid[[thread_position_in_grid]]) {
295     ushort2 gid_ = gid.xy;
296 
297     ushort batch = gid.z / offset_buf[0];
298     ushort feature = gid.z % offset_buf[0];
299     ushort outz = batch * offset_buf[1] + offset_buf[2] + feature;
300     ushort inz = batch * offset_buf[3] + feature;
301 
302     half4 intex;
303     if (append_features_in_is_arr) {
304       intex = in_arr.read(gid_, inz);
305     }
306     else {
307       intex = in_tex.read(gid_);
308     }
309     out.write(intex, gid_, outz);
310 }
311 
312 constant bool prev_is_arr = (ushort_arg_3 > 1 || ushort_arg_2 > 4);
313 constant bool prev_is_tex = !prev_is_arr;
314 constant bool append_features_off_in_is_arr = (ushort_arg_7 > 1 || ushort_arg_6 > 4);
315 constant bool append_features_off_in_is_tex = !append_features_off_in_is_arr;
316 kernel void append_features_off(texture2d<half, access::read> in_tex[[texture(0), function_constant(append_features_off_in_is_tex)]],
317                                 texture2d_array<half, access::read> in_arr[[texture(0), function_constant(append_features_off_in_is_arr)]],
318                                 texture2d<half, access::read> prev_tex[[texture(1), function_constant(prev_is_tex)]],
319                                 texture2d_array<half, access::read> prev_arr[[texture(1), function_constant(prev_is_arr)]],
320                                 texture2d_array<half, access::write> out[[texture(2)]],
321                                 constant ushort* offset_buf[[buffer(0)]],
322                                 ushort3 gid[[thread_position_in_grid]]) {
323     ushort2 gid_ = gid.xy;
324 
325     ushort batch = gid.z / offset_buf[0];
326     ushort feature = gid.z % offset_buf[0];
327     ushort outz = batch * offset_buf[1] + offset_buf[2] + feature;
328     ushort inz = batch * offset_buf[3] + feature;
329     half4 outtex;
330     if (prev_is_arr)
331       outtex = prev_arr.read(gid_, batch);
332     else
333       outtex = prev_tex.read(gid_);
334     half4 intex1;
335     if (append_features_in_is_arr)
336       intex1 = in_arr.read(gid_, inz);
337     else
338       intex1 = in_tex.read(gid_);
339     if (feature == 0) {
340       if (offset_buf[5] == 1)
341         outtex.yzw = intex1.xyz;
342       else if (offset_buf[5] == 2)
343         outtex.zw = intex1.xy;
344       else
345         outtex.w = intex1.x;
346       out.write(outtex, gid_, outz);
347       return;
348     }
349     half4 intex0;
350     if (append_features_in_is_arr)
351       intex0 = in_arr.read(gid_, inz-1);
352     else
353       intex0 = intex1;
354     if (offset_buf[5] == 1) {
355       outtex.x = intex0.w;
356       outtex.yzw = intex1.xyz;
357     }
358     else if (offset_buf[5] == 2) {
359       outtex.xy = intex0.zw;
360       outtex.zw = intex1.xy;
361     }
362     else {
363       outtex.xyz = intex0.yzw;
364       outtex.w = intex1.x;
365     }
366 
367     out.write(outtex, gid_, outz);
368 }
369 
370 constant bool clamp_is_arr = (ushort_arg_1 > 1 || ushort_arg_0 > 4);
371 constant bool clamp_is_tex = !clamp_is_arr;
372 kernel void clamp(texture2d_array<half, access::read> in_arr[[texture(0), function_constant(clamp_is_arr)]],
373                   texture2d<half, access::read> in_tex[[texture(0), function_constant(clamp_is_tex)]],
374                   texture2d_array<half, access::write> out_arr[[texture(1), function_constant(clamp_is_arr)]],
375                   texture2d<half, access::write> out_tex[[texture(1), function_constant(clamp_is_tex)]],
376                  ushort3 gid[[thread_position_in_grid]]) {
377     const ushort w = clamp_is_arr? out_arr.get_width() : out_tex.get_width();
378     const ushort h = clamp_is_arr? out_arr.get_height() : out_tex.get_height();
379     if (gid.x >= w || gid.y >= h) {
380         return;
381     }
382     const float4 min_(float_arg_0, float_arg_0, float_arg_0, float_arg_0);
383     const float4 max_(float_arg_1, float_arg_1, float_arg_1, float_arg_1);
384     ushort2 gid_ = gid.xy;
385     if(clamp_is_arr){
386         float4 value = (float4)in_arr.read(gid_, gid.z);
387         half4 clamped = (half4)clamp(value, min_, max_);
388         out_arr.write(clamped, gid_, gid.z);
389     } else {
390         float4 value = (float4)in_tex.read(gid_);
391         half4 clamped = (half4)clamp(value, min_, max_);
392         out_tex.write(clamped, gid_);
393     }
394 }
395 
396 constant bool hardswish_is_arr = (ushort_arg_0 > 1 || ushort_arg_1 > 4);
397 constant bool hardswish_is_tex = !hardswish_is_arr;
398 kernel void hardswish(texture2d_array<half, access::read> in_arr[[texture(0), function_constant(hardswish_is_arr)]],
399                       texture2d<half, access::read> in_tex[[texture(0), function_constant(hardswish_is_tex)]],
400                       texture2d_array<half, access::write> out_arr[[texture(1), function_constant(hardswish_is_arr)]],
401                       texture2d<half, access::write> out_tex[[texture(1), function_constant(hardswish_is_tex)]],
402                       ushort3 gid[[thread_position_in_grid]]) {
403     const ushort oH = ushort_arg_2;
404     const ushort oW = ushort_arg_3;
405     if (gid.x >= oW || gid.y >= oH) {
406         return;
407     }
408     ushort2 gid_ = gid.xy;
409     if (hardswish_is_arr) {
410       half4 value = in_arr.read(gid_, gid.z);
411       half4 mask1 = half4(value < 3.0);
412       half4 mask2 = half4(value > -3.0);
413       half4 outval = mask2*(mask1*(value*(value + 3.0)/6.0) + (1 - mask1)*value);
414       out_arr.write(outval, gid_, gid.z);
415     } else {
416       half4 value = in_tex.read(gid_);
417       half4 mask1 = half4(value < 3);
418       half4 mask2 = half4(value > -3.0);
419       half4 outval = mask2*(mask1*(value*(value + 3.0)/6.0) + (1 - mask1)*value);
420       out_tex.write(outval, gid_);
421     }
422 }
423 
424 constant bool hardshrink_is_arr = (ushort_arg_0 > 1 || ushort_arg_1 > 4);
425 constant bool hardshrink_is_tex = !hardshrink_is_arr;
426 kernel void hardshrink(texture2d_array<half, access::read> in_arr[[texture(0), function_constant(hardshrink_is_arr)]],
427                       texture2d<half, access::read> in_tex[[texture(0), function_constant(hardshrink_is_tex)]],
428                       texture2d_array<half, access::write> out_arr[[texture(1), function_constant(hardshrink_is_arr)]],
429                       texture2d<half, access::write> out_tex[[texture(1), function_constant(hardshrink_is_tex)]],
430                       ushort3 gid[[thread_position_in_grid]]) {
431     const ushort oH = ushort_arg_2;
432     const ushort oW = ushort_arg_3;
433     const half lambda = (half)float_arg_0;
434     if (gid.x >= oW || gid.y >= oH) {
435         return;
436     }
437     ushort2 gid_ = gid.xy;
438     if (hardshrink_is_arr) {
439       half4 value = in_arr.read(gid_, gid.z);
440       half4 mask1 = half4(value <= lambda);
441       half4 mask2 = half4(value >= -lambda);
442       half4 outval = (1 - mask1)*value + (1 - mask2)*value;
443       out_arr.write(outval, gid_, gid.z);
444     } else {
445       half4 value = in_tex.read(gid_);
446       half4 mask1 = half4(value <= lambda);
447       half4 mask2 = half4(value >= -lambda);
448       half4 outval = (1 - mask1)*value + (1 - mask2)*value;
449       out_tex.write(outval, gid_);
450     }
451 }
452 
453 constant bool leaky_relu_is_arr = (ushort_arg_0 > 1 || ushort_arg_1 > 4);
454 constant bool leaky_relu_is_tex = !leaky_relu_is_arr;
455 kernel void leaky_relu(texture2d_array<half, access::read> in_arr[[texture(0), function_constant(leaky_relu_is_arr)]],
456                       texture2d<half, access::read> in_tex[[texture(0), function_constant(leaky_relu_is_tex)]],
457                       texture2d_array<half, access::write> out_arr[[texture(1), function_constant(leaky_relu_is_arr)]],
458                       texture2d<half, access::write> out_tex[[texture(1), function_constant(leaky_relu_is_tex)]],
459                       ushort3 gid[[thread_position_in_grid]]) {
460     const ushort oH = ushort_arg_2;
461     const ushort oW = ushort_arg_3;
462     const half negative_slope = (half)float_arg_0;
463     if (gid.x >= oW || gid.y >= oH) {
464         return;
465     }
466     ushort2 gid_ = gid.xy;
467     if (leaky_relu_is_arr) {
468       half4 value = in_arr.read(gid_, gid.z);
469       half4 is_negative = half4(value < 0.0);
470       half4 outval = is_negative*value*negative_slope + (1-is_negative)*value;
471       out_arr.write(outval, gid_, gid.z);
472     } else {
473       half4 value = in_tex.read(gid_);
474       half4 is_negative = half4(value < 0.0);
475       half4 outval = is_negative*value*negative_slope + (1-is_negative)*value;
476       out_tex.write(outval, gid_);
477     }
478 }
479 
480 constant bool out_is_arr = (ushort_arg_3 > 1 || ushort_arg_2 > 4);
481 constant bool out_is_tex = !out_is_arr;
482 constant bool in_is_arr = (ushort_arg_7 > 1 || ushort_arg_6 > 4);
483 constant bool in_is_tex = !in_is_arr;
484 kernel void reflection_pad2d(texture2d_array<half, access::read> in_arr[[texture(0), function_constant(in_is_arr)]],
485                              texture2d<half, access::read> in_tex[[texture(0),function_constant(in_is_tex)]],
486                              texture2d_array<half, access::write> out_arr[[texture(1), function_constant(out_is_arr)]],
487                              texture2d<half, access::write> out_tex[[texture(1), function_constant(out_is_tex)]],
488                              ushort3 gid[[thread_position_in_grid]]) {
489   const ushort H2 = ushort_arg_0;
490   const ushort W2 = ushort_arg_1;
491   if (gid.x >= W2 || gid.y >= H2) {
492       return;
493   }
494 
495   const ushort pad_left = ushort_arg_8;
496   const ushort pad_right = ushort_arg_9;
497   const ushort pad_top = ushort_arg_10;
498   const ushort pad_bottom = ushort_arg_11;
499 
500   const ushort2 out_size = ushort2(W2, H2);
501   const ushort xoff_pre  = 2*max(pad_left - gid.x, 0);
502   const ushort xoff_post = 2*max(gid.x - (out_size.x - 1 - pad_right), 0);
503   const ushort yoff_pre  = 2*max(pad_top - gid.y, 0);
504   const ushort yoff_post = 2*max(gid.y - (out_size.y - 1 - pad_bottom), 0);
505   ushort2 inpos = ushort2(
506       gid.x + xoff_pre - xoff_post - pad_left,
507       gid.y + yoff_pre - yoff_post - pad_top);
508 
509   half4 intex;
510   if (in_is_arr) {
511     intex = in_arr.read(inpos, gid.z);
512   } else {
513     intex = in_tex.read(inpos);
514   }
515 
516   if (out_is_arr) {
517       out_arr.write(intex, gid.xy, gid.z);
518   } else {
519       out_tex.write(intex, gid.xy);
520   }
521 }
522 
523 constant bool reshape_out_is_arr = (ushort_arg_3 > 1 || ushort_arg_2 > 4);
524 constant bool reshape_out_is_tex = !reshape_out_is_arr;
525 constant bool reshape_in_is_arr = (ushort_arg_7 > 1 || ushort_arg_6 > 4);
526 constant bool reshape_in_is_tex = !reshape_in_is_arr;
527 kernel void reshape(texture2d_array<half, access::read> in_arr[[texture(0), function_constant(reshape_in_is_arr)]],
528                     texture2d<half, access::read> in_tex[[texture(0),function_constant(reshape_in_is_tex)]],
529                     texture2d_array<half, access::write> out_arr[[texture(1), function_constant(reshape_out_is_arr)]],
530                     texture2d<half, access::write> out_tex[[texture(1),
531                         function_constant(reshape_out_is_tex)]],
532                     ushort3 gid[[thread_position_in_grid]]) {
533     const ushort H2 = ushort_arg_0;
534     const ushort W2 = ushort_arg_1;
535     const ushort C2 = ushort_arg_2;
536     if (gid.x >= W2 || gid.y >= H2) {
537         return;
538     }
539     const ushort H1 = ushort_arg_4;
540     const ushort W1 = ushort_arg_5;
541     const ushort C1 = ushort_arg_6;
542     const ushort N1 = ushort_arg_7;
543 
544     const size_t numel1 = H1 * W1 * C1 * N1;
545     const ushort slices2 = divRoundUp(C2, 4);
546     const ushort slices1 = divRoundUp(C1, 4);
547     const ushort n2 = gid.z / slices2; //image index
548     const ushort s2 = gid.z - n2 * slices2; // slice offest
549     half4 value;
550     for (int idx = 0; idx < 4; ++idx){
551         // we compute the "linear index" of the output element,
552         // and convert it to the equivalent "linear index" of the input element.
553         ushort offset = 4 * s2 + idx;
554         size_t linear_idx = n2 * C2 * H2 * W2 + offset * H2 * W2 + gid.y * W2 + gid.x;
555         if(linear_idx >= numel1){
556             value[idx] = 0;
557             continue;
558         }
559         auto x1 = linear_idx % W1;
560         auto y1 = ((int)(linear_idx/W1)) % H1;
561         auto s1 = ((int)(linear_idx/W1/H1) % C1);
562         auto n1 = ((int)(linear_idx/W1/H1/C1) % N1);
563         auto z1 = (int)s1 / 4 + n1 * slices1;
564         auto pos = s1 % 4;
565         if(reshape_in_is_arr) {
566             value[idx] = in_arr.read(ushort2(x1, y1), z1)[pos];
567         } else {
568             value[idx] = in_tex.read(ushort2(x1, y1))[pos];
569         }
570 
571     }
572     if(reshape_out_is_arr) {
573         out_arr.write(value, gid.xy, gid.z);
574     } else {
575         out_tex.write(value, gid.xy);
576     }
577 }
578 
579 constant bool transpose_in_is_arr = (ushort_arg_3 > 1 || ushort_arg_4 > 4);
580 constant bool transpose_in_is_tex = !transpose_in_is_arr;
581 constant bool transpose_out_is_arr = (ushort_arg_5 > 1 || ushort_arg_6 > 4);
582 constant bool transpose_out_is_tex = !transpose_out_is_arr;
583 kernel void transpose(texture2d_array<half, access::read>in_arr[[texture(0),function_constant(transpose_in_is_arr)]],
584                       texture2d<half, access::read> in_tex[[texture(0), function_constant(transpose_in_is_tex)]],
585                       texture2d_array<half, access::write>out_arr[[texture(1),function_constant(transpose_out_is_arr)]],
586                       texture2d<half, access::write> out_tex[[texture(1), function_constant(transpose_out_is_tex)]],
587                       constant ushort* inSizeBuffer [[buffer(0)]],
588                       constant ushort* outSizeBuffer [[buffer(1)]],
589                       ushort3 gid[[thread_position_in_grid]]) {
590 
591     const ushort dim0 = ushort_arg_0;
592     const ushort dim1 = ushort_arg_1;
593     const ushort dim = ushort_arg_2;
594     const ushort N1 = ushort_arg_3;
595     const ushort C1 = ushort_arg_4;
596     const ushort N2 = ushort_arg_5;
597     const ushort C2 = ushort_arg_6;
598     ushort W1,W2,H1,H2;
599     if(transpose_in_is_arr) {
600         W1 = in_arr.get_width();
601         H1 = in_arr.get_height();
602     } else {
603         W1 = in_tex.get_width();
604         H1 = in_tex.get_height();
605     }
606     if(transpose_out_is_arr) {
607         W2 = out_arr.get_width();
608         H2 = out_arr.get_height();
609     } else {
610         W2 = out_tex.get_width();
611         H2 = out_tex.get_height();
612     }
613     if (gid.x >= W2 || gid.y >= H2) {
614         return;
615     }
616     const size_t numel = H2 * W2 * C2 * N2;
617     const ushort slices2 = divRoundUp(C2, 4);
618     const ushort slices1 = divRoundUp(C1, 4);
619     const ushort n2 = gid.z / slices2;
620     const ushort s2 = gid.z - n2 * slices2;
621     half4 value;
622     ushort4 threadIndexBufferLower{1, 1, 1, 1};
623     ushort4 threadIndexBufferUpper{1, 1, 1 ,1};
624     for (int idx = 0; idx < 4; ++idx){
625         ushort offset = 4 * s2 + idx;
626         size_t linear_idx2 = n2 * C2 * H2 * W2 + offset * H2 * W2 + gid.y * W2 + gid.x;
627         if(linear_idx2 >= numel) {
628             value[idx] = 0;
629             continue;
630         }
631 
632         ushort d2 = 0;
633         for(int j = dim-1; j>=0; --j){
634             d2  = outSizeBuffer[j];
635             if(j > 3) {
636                 threadIndexBufferUpper[j-3] = linear_idx2 % d2;
637             } else {
638                 threadIndexBufferLower[j] = linear_idx2 % d2;
639             }
640             linear_idx2 /= d2;
641         }
642 
643         // swap dims
644         ushort tmp;
645         if(dim0 > 3) {
646             tmp = threadIndexBufferUpper[dim0-3];
647         } else {
648             tmp = threadIndexBufferLower[dim0];
649         }
650         if(dim0 > 3 && dim1 > 3) {
651             threadIndexBufferUpper[dim0-3] = threadIndexBufferUpper[dim1-3];
652         } else if (dim0 > 3 && dim1 < 3) {
653             threadIndexBufferUpper[dim0-3] = threadIndexBufferLower[dim1];
654         } else if (dim0 < 3 && dim1 > 3) {
655             threadIndexBufferLower[dim0] = threadIndexBufferUpper[dim1-3];
656         } else {
657             threadIndexBufferLower[dim0] = threadIndexBufferLower[dim1];
658         }
659         if(dim1 > 3) {
660             threadIndexBufferUpper[dim1-3] = tmp;
661         } else {
662             threadIndexBufferLower[dim1] = tmp;
663         }
664 
665         size_t linear_idx1 = 0;
666         ushort m = 1;
667         ushort d1 = 0;
668         for(int k = dim-1; k>=0; --k) {
669             if(k > 3) {
670                 d1 = threadIndexBufferUpper[k-3];
671             } else {
672                 d1 = threadIndexBufferLower[k];
673             }
674             linear_idx1 += d1 * m;
675             m *= inSizeBuffer[k];
676         }
677 
678         auto x1 = linear_idx1 % W1;
679         auto y1 = ((int)(linear_idx1/W1)) % H1;
680         auto c1 = ((int)(linear_idx1/W1/H1) % C1);
681         auto n1 = ((int)(linear_idx1/W1/H1/C1) % N1);
682         auto z1 = (int)c1 / 4 + n1 * slices1;
683         auto pos = c1 % 4;
684         if(transpose_in_is_arr) {
685             value[idx] = in_arr.read(ushort2(x1, y1), z1)[pos];
686         } else {
687             value[idx] = in_tex.read(ushort2(x1, y1))[pos];
688         }
689     }
690     if(transpose_out_is_arr) {
691         out_arr.write(value, gid.xy, gid.z);
692     } else {
693         out_tex.write(value, gid.xy);
694     }
695 }
696 
697 constant bool split_channels_in_is_arr = (ushort_arg_0 > 4);
698 constant bool split_channels_in_is_tex = !split_channels_in_is_arr;
699 constant bool split_channels_out1_is_arr = (ushort_arg_1 > 4);
700 constant bool split_channels_out1_is_tex = !split_channels_out1_is_arr;
701 constant bool split_channels_out2_is_arr = (ushort_arg_2 > 4);
702 constant bool split_channels_out2_is_tex = !(split_channels_out2_is_arr);
703 // A naive implementation to split the input texture into two on channel dimension
704 kernel void split_channels(texture2d_array<half, access::read> in_arr[[texture(0), function_constant(split_channels_in_is_arr)]],
705                            texture2d<half, access::read> in_tex[[texture(0), function_constant(split_channels_in_is_tex)]],
706                            texture2d_array<half, access::write> out1_arr[[texture(1),function_constant(split_channels_out1_is_arr)]],
707                            texture2d<half, access::write> out1_tex[[texture(1),function_constant(split_channels_out1_is_tex)]],
708                            texture2d_array<half, access::write> out2_arr[[texture(2), function_constant(split_channels_out2_is_arr)]],
709                            texture2d<half, access::write> out2_tex[[texture(2),function_constant(split_channels_out2_is_tex)]],
710                            ushort3 gid[[thread_position_in_grid]]) {
711     ushort W,H;
712     if(split_channels_in_is_arr) {
713         W = in_arr.get_width();
714         H = in_arr.get_height();
715     } else {
716         W = in_tex.get_width();
717         H = in_tex.get_height();
718     }
719     if(gid.x >= W || gid.y >= H){
720         return;
721     }
722     const ushort C1 = ushort_arg_1;
723     const ushort s1 = divRoundUp(C1, 4);
724     const ushort c_offset = C1 % 4;
725     half4 tmp1(0.0, 0.0, 0.0, 0.0);
726     half4 tmp2(0.0, 0.0, 0.0, 0.0);
727     half4 in41 = split_channels_in_is_arr ? in_arr.read(gid.xy, gid.z) : in_tex.read(gid.xy);
728     half4 in42 = split_channels_in_is_arr ? in_arr.read(gid.xy, gid.z+1) : half4(0,0,0,0);
729     if(gid.z < s1 - 1) {
730         if(split_channels_out1_is_arr) {
731             out1_arr.write(in41, gid.xy, gid.z);
732         }
733     }
734     else if(gid.z == s1 - 1) {
735         if(c_offset == 0){
736             if(split_channels_out1_is_arr) {
737                 out1_arr.write(in41, gid.xy, gid.z);
738             } else {
739                 out1_tex.write(in41, gid.xy);
740             }
741             return;
742         } else if(c_offset == 1) {
743             tmp1.x = in41.x;
744             tmp2.xyz = in41.yzw;
745             tmp2.w = in42.x;
746         } else if (c_offset == 2) {
747             tmp1.xy = in41.xy;
748             tmp2.xy = in41.zw;
749             tmp2.zw = in42.xy;
750         } else {
751             tmp1.xyz = in41.xyz;
752             tmp2.x = in41.w;
753             tmp2.yzw = in42.xyz;
754         }
755         if(split_channels_out1_is_arr) {
756             out1_arr.write(tmp1, gid.xy, gid.z);
757         } else {
758             out1_tex.write(tmp1, gid.xy);
759         }
760         if(split_channels_out2_is_arr) {
761             out2_arr.write(tmp2, gid.xy, 0);
762         } else {
763             out2_tex.write(tmp2, gid.xy);
764         }
765     }
766     else {
767         if (c_offset == 0) {
768             if(split_channels_out2_is_arr) {
769                 out2_arr.write(in41, gid.xy, gid.z - s1);
770             } else {
771                 out2_tex.write(in41, gid.xy);
772             }
773             return;
774         }
775         else if (c_offset == 1 ){
776             tmp2.xyz = in41.yzw;
777             tmp2.w = in42.x;
778         } else if (c_offset == 2){
779             tmp2.xy = in41.zw;
780             tmp2.zw = in42.xy;
781         } else {
782             tmp2.x = in41.w;
783             tmp2.yzw = in42.xyz;
784         }
785         if(split_channels_out2_is_arr) {
786             out2_arr.write(tmp2, gid.xy, gid.z - s1 + 1);
787         } else {
788             out2_tex.write(tmp2, gid.xy);
789         }
790     }
791 }
792 
793 constant bool ra_has_in_arr = (ushort_arg_3 > 1 ||  ushort_arg_2 > 4);
794 constant bool ra_has_out_arr = (ushort_arg_4 > 1 || ushort_arg_2 > 4);
795 constant bool ra_has_in_tex = (!ra_has_in_arr);
796 constant bool ra_has_out_tex = (!ra_has_out_arr);
797 kernel void roi_align(texture2d_array<half, access::sample> ina[[texture(0), function_constant(ra_has_in_arr)]],
798                       texture2d<half, access::sample> in[[texture(0), function_constant(ra_has_in_tex)]],
799                       texture2d_array<half, access::write> outa[[texture(1), function_constant(ra_has_out_arr)]],
800                       texture2d<half, access::write> out[[texture(1), function_constant(ra_has_out_tex)]],
801                       constant half4* rois[[buffer(0)]],
802                       ushort3 gid[[thread_position_in_grid]]) {
803 
804     ushort out_width, out_height;
805     if (ra_has_out_arr) {
806         out_width = outa.get_width();
807         out_height = outa.get_height();
808     } else {
809         out_width = out.get_width();
810         out_height = out.get_height();
811     }
812     if (gid.x >= out_width || gid.y >= out_height) {
813         return;
814     }
815     const half spatial_scale = half(ushort_arg_0) / 10000;
816     const ushort sampling_ratio = ushort_arg_1;
817     const ushort C = ushort_arg_2;
818     const ushort pw = gid.x;
819     const ushort ph = gid.y;
820     const ushort n = gid.z / divRoundUp(C, 4);
821     const ushort c = gid.z % divRoundUp(C, 4);
822 
823     const half4 roi_scaled = rois[n] * spatial_scale;
824     const half roi_start_w = roi_scaled[0];
825     const half roi_start_h = roi_scaled[1];
826     const half roi_end_w = roi_scaled[2];
827     const half roi_end_h = roi_scaled[3];
828 
829     // Force malformed ROIs to be 1x1
830     const half roi_width = max(roi_end_w - roi_start_w, (half)1.);
831     const half roi_height = max(roi_end_h - roi_start_h, (half)1.);
832 
833     const half bin_size_h = static_cast<half>(roi_height) / static_cast<half>(out_height);
834     const half bin_size_w = static_cast<half>(roi_width) / static_cast<half>(out_width);
835 
836     const ushort roi_bin_grid_h = sampling_ratio > 0 ? sampling_ratio : ceil(roi_height / static_cast<half>(out_height));
837     const ushort roi_bin_grid_w = sampling_ratio > 0 ? sampling_ratio : ceil(roi_width / static_cast<half>(out_width));
838 
839     const half count = roi_bin_grid_h * roi_bin_grid_w;
840     half4 output_val = 0.0;
841 
842     constexpr sampler s2(coord::pixel, address::clamp_to_edge, filter::linear);
843 
844     for (int iy = 0; iy < roi_bin_grid_h; iy++) {
845         for (int ix = 0; ix < roi_bin_grid_w; ix++) {
846             // Shift the pixel by 0.5. This is critical to achieve high accuracy.
847             const half y =
848             roi_start_h + ph * bin_size_h + (iy+0.5) * bin_size_h / static_cast<half>(roi_bin_grid_h);
849             const half x =
850             roi_start_w + pw * bin_size_w + (ix+0.5) * bin_size_w / static_cast<half>(roi_bin_grid_w);
851             if (ra_has_in_arr) {
852                 output_val += ina.sample(s2, float2(x, y), c);
853             } else {
854                 output_val += in.sample(s2, float2(x, y));
855             }
856         }
857     }
858     output_val /= count;
859     if (ra_has_out_arr) {
860         outa.write(static_cast<half4>(output_val), gid.xy, gid.z);
861     } else {
862         out.write(static_cast<half4>(output_val), gid.xy);
863     }
864 }
865 
866 )PT_METAL_SHADERS";
867 
868 #endif /* MPSCNNShaders_h */
869