1 #ifndef MPSCNNShaders_h 2 #define MPSCNNShaders_h 3 4 static const char* PT_METAL_SHADERS = R"PT_METAL_SHADERS( 5 #include <metal_stdlib> 6 using namespace metal; 7 8 constant ushort ushort_arg_0[[function_constant(0)]]; 9 constant ushort ushort_arg_1[[function_constant(1)]]; 10 constant ushort ushort_arg_2[[function_constant(2)]]; 11 constant ushort ushort_arg_3[[function_constant(3)]]; 12 constant ushort ushort_arg_4[[function_constant(4)]]; 13 constant ushort ushort_arg_5[[function_constant(5)]]; 14 constant ushort ushort_arg_6[[function_constant(6)]]; 15 constant ushort ushort_arg_7[[function_constant(7)]]; 16 constant ushort ushort_arg_8[[function_constant(8)]]; 17 constant ushort ushort_arg_9[[function_constant(9)]]; 18 constant ushort ushort_arg_10[[function_constant(10)]]; 19 constant ushort ushort_arg_11[[function_constant(11)]]; 20 constant float float_arg_0 [[function_constant(12)]]; 21 constant float float_arg_1 [[function_constant(13)]]; 22 23 inline constexpr ushort divRoundUp(ushort x, ushort y) { return (x + (y - 1)) / y; } 24 25 enum broadcastOp { 26 Add, 27 Sub, 28 Mul, 29 Div, 30 }; 31 32 void elementwise_broadcast_nonarray(texture2d<half, access::read> in0, 33 texture2d<half, access::read> in1, 34 texture2d<half, access::write> out, 35 ushort2 gid, 36 broadcastOp op) { 37 if (gid.x >= out.get_width() || gid.y >= out.get_height()) { 38 return; 39 } 40 ushort2 in0_stride = ushort2(in0.get_width() > 1, in0.get_height() > 1); 41 ushort2 in1_stride = ushort2(in1.get_width() > 1, in1.get_height() > 1); 42 43 ushort2 gid0 = gid.xy * in0_stride; 44 ushort2 gid1 = gid.xy * in1_stride; 45 46 if(op == Add) { 47 out.write(in0.read(gid0) + in1.read(gid1), gid); 48 } else if(op == Sub) { 49 out.write(in0.read(gid0) - in1.read(gid1), gid); 50 } else if(op == Mul) { 51 out.write(in0.read(gid0) * in1.read(gid1), gid); 52 } else if(op == Div) { 53 out.write(in0.read(gid0) / in1.read(gid1), gid); 54 } 55 } 56 57 void elementwise_broadcast(texture2d_array<half, access::read> in0, 58 texture2d_array<half, access::read> in1, 59 texture2d_array<half, access::write> out, 60 ushort3 gid, 61 broadcastOp op) { 62 if (gid.x >= out.get_width() || gid.y >= out.get_height()) { 63 return; 64 } 65 66 ushort2 in0_stride = ushort2(in0.get_width() > 1, in0.get_height() > 1); 67 ushort2 in1_stride = ushort2(in1.get_width() > 1, in1.get_height() > 1); 68 69 ushort2 gid0 = gid.xy * in0_stride; 70 ushort2 gid1 = gid.xy * in1_stride; 71 72 if(op == Add) { 73 out.write(in0.read(gid0, gid.z) + in1.read(gid1, gid.z), gid.xy, gid.z); 74 } else if(op == Sub) { 75 out.write(in0.read(gid0, gid.z) - in1.read(gid1, gid.z), gid.xy, gid.z); 76 } else if(op == Mul) { 77 out.write(in0.read(gid0, gid.z) * in1.read(gid1, gid.z), gid.xy, gid.z); 78 } else if(op == Div) { 79 out.write(in0.read(gid0, gid.z) / in1.read(gid1, gid.z), gid.xy, gid.z); 80 } 81 } 82 83 kernel void elementwise_add_nonarray(texture2d<half, access::read> in0[[texture(0)]], 84 texture2d<half, access::read> in1[[texture(1)]], 85 texture2d<half, access::write> out[[texture(2)]], 86 ushort2 gid[[thread_position_in_grid]]) { 87 elementwise_broadcast_nonarray(in0, in1, out, gid, Add); 88 } 89 90 kernel void elementwise_add(texture2d_array<half, access::read> in0[[texture(0)]], 91 texture2d_array<half, access::read> in1[[texture(1)]], 92 texture2d_array<half, access::write> out[[texture(2)]], 93 ushort3 gid[[thread_position_in_grid]]) { 94 elementwise_broadcast(in0, in1, out, gid, Add); 95 } 96 97 kernel void elementwise_sub_nonarray(texture2d<half, access::read> in0[[texture(0)]], 98 texture2d<half, access::read> in1[[texture(1)]], 99 texture2d<half, access::write> out[[texture(2)]], 100 ushort2 gid[[thread_position_in_grid]]) { 101 elementwise_broadcast_nonarray(in0, in1, out, gid, Sub); 102 } 103 104 kernel void elementwise_sub(texture2d_array<half, access::read> in0[[texture(0)]], 105 texture2d_array<half, access::read> in1[[texture(1)]], 106 texture2d_array<half, access::write> out[[texture(2)]], 107 ushort3 gid[[thread_position_in_grid]]) { 108 elementwise_broadcast(in0, in1, out, gid, Sub); 109 } 110 111 kernel void elementwise_mul_nonarray(texture2d<half, access::read> in0[[texture(0)]], 112 texture2d<half, access::read> in1[[texture(1)]], 113 texture2d<half, access::write> out[[texture(2)]], 114 ushort2 gid[[thread_position_in_grid]]) { 115 elementwise_broadcast_nonarray(in0, in1, out, gid, Mul); 116 } 117 118 kernel void elementwise_mul(texture2d_array<half, access::read> in0[[texture(0)]], 119 texture2d_array<half, access::read> in1[[texture(1)]], 120 texture2d_array<half, access::write> out[[texture(2)]], 121 ushort3 gid[[thread_position_in_grid]]) { 122 elementwise_broadcast(in0, in1, out, gid, Mul); 123 } 124 125 kernel void elementwise_div_nonarray(texture2d<half, access::read> in0[[texture(0)]], 126 texture2d<half, access::read> in1[[texture(1)]], 127 texture2d<half, access::write> out[[texture(2)]], 128 ushort2 gid[[thread_position_in_grid]]) { 129 elementwise_broadcast_nonarray(in0, in1, out, gid, Div); 130 } 131 132 kernel void elementwise_div(texture2d_array<half, access::read> in0[[texture(0)]], 133 texture2d_array<half, access::read> in1[[texture(1)]], 134 texture2d_array<half, access::write> out[[texture(2)]], 135 ushort3 gid[[thread_position_in_grid]]) { 136 elementwise_broadcast(in0, in1, out, gid, Div); 137 } 138 139 kernel void copy_nchw_to_metal(constant float* in[[buffer(0)]], 140 texture2d_array<half, access::write> out[[texture(0)]], 141 ushort3 gid[[thread_position_in_grid]]) { 142 const ushort C = ushort_arg_0; 143 const ushort H = ushort_arg_1; 144 const ushort W = ushort_arg_2; 145 if (gid.x >= W || gid.y >= H) { 146 return; 147 } 148 const ushort n = gid.z / divRoundUp(C, 4); 149 const ushort c = gid.z - n * divRoundUp(C, 4); 150 #define CHW_TO_CHWP4(idx, n, c_, h, w) \ 151 if ((c_) < C) { \ 152 trns[idx] = in[n * H * W * C + int(c_) * H * W + int(h) * W + int(w)]; \ 153 } else { \ 154 trns[idx] = 0.0h; \ 155 } 156 half4 trns; 157 CHW_TO_CHWP4(0, n, c * 4 + 0, gid.y, gid.x); 158 CHW_TO_CHWP4(1, n, c * 4 + 1, gid.y, gid.x); 159 CHW_TO_CHWP4(2, n, c * 4 + 2, gid.y, gid.x); 160 CHW_TO_CHWP4(3, n, c * 4 + 3, gid.y, gid.x); 161 #undef CHW_TO_CHWP4 162 out.write(trns, gid.xy, gid.z); 163 } 164 165 kernel void copy_nchw_to_metal_nonarray(constant float* in[[buffer(0)]], 166 texture2d<half, access::write> out[[texture(0)]], 167 ushort2 gid[[thread_position_in_grid]]) { 168 const ushort C = ushort_arg_0; 169 const ushort H = ushort_arg_1; 170 const ushort W = ushort_arg_2; 171 if (gid.x >= W || gid.y >= H) { 172 return; 173 } 174 half4 trns; 175 #define CHW_TO_CHWP4(idx, c, h, w) \ 176 if ((c) < C) { \ 177 trns[idx] = in[int(c) * H * W + int(h) * W + int(w)]; \ 178 } else { \ 179 trns[idx] = 0.0h; \ 180 } 181 CHW_TO_CHWP4(0, 0, gid.y, gid.x); 182 CHW_TO_CHWP4(1, 1, gid.y, gid.x); 183 CHW_TO_CHWP4(2, 2, gid.y, gid.x); 184 CHW_TO_CHWP4(3, 3, gid.y, gid.x); 185 #undef CHW_TO_CHWP4 186 out.write(trns, gid.xy); 187 } 188 189 kernel void copy_metal_to_nchw(texture2d_array<half, access::read> in[[texture(0)]], 190 device float* out[[buffer(0)]], 191 ushort3 gid[[thread_position_in_grid]]) { 192 const ushort C = ushort_arg_0; 193 const ushort H = ushort_arg_1; 194 const ushort W = ushort_arg_2; 195 if (gid.x >= W || gid.y >= H) { 196 return; 197 } 198 const ushort n = gid.z / divRoundUp(C, 4); 199 const ushort c = gid.z - n * divRoundUp(C, 4); 200 half4 cs = in.read(gid.xy, gid.z); 201 #define CHWP4_TO_CHW(idx, n, c_, h, w) \ 202 if ((c_) < C) { \ 203 out[n * H * W * C + int(c_) * H * W + int(h) * W + int(w)] = cs[idx]; \ 204 } 205 CHWP4_TO_CHW(0, n, c * 4 + 0, gid.y, gid.x); 206 CHWP4_TO_CHW(1, n, c * 4 + 1, gid.y, gid.x); 207 CHWP4_TO_CHW(2, n, c * 4 + 2, gid.y, gid.x); 208 CHWP4_TO_CHW(3, n, c * 4 + 3, gid.y, gid.x); 209 #undef CHWP4_TO_CHW 210 } 211 212 kernel void copy_metal_to_nchw_nonarray(texture2d<half, access::read> in[[texture(0)]], 213 device float* out[[buffer(0)]], 214 ushort2 gid[[thread_position_in_grid]]) { 215 const ushort C = ushort_arg_0; 216 const ushort H = ushort_arg_1; 217 const ushort W = ushort_arg_2; 218 if (gid.x >= W || gid.y >= H) { 219 return; 220 } 221 half4 cs = in.read(gid.xy); 222 #define CHWP4_TO_CHW(idx, c, h, w) \ 223 if ((c) < C) { \ 224 out[int(c) * H * W + int(h) * W + int(w)] = cs[idx]; \ 225 } 226 CHWP4_TO_CHW(0, 0, gid.y, gid.x); 227 CHWP4_TO_CHW(1, 1, gid.y, gid.x); 228 CHWP4_TO_CHW(2, 2, gid.y, gid.x); 229 CHWP4_TO_CHW(3, 3, gid.y, gid.x); 230 #undef CHWP4_TO_CHW 231 } 232 233 kernel void copy(texture2d_array<half, access::read> in[[texture(0)]], 234 texture2d_array<half, access::write> out[[texture(1)]], 235 ushort3 gid[[thread_position_in_grid]]) { 236 if (gid.x >= out.get_width() || gid.y >= out.get_height()) { 237 return; 238 } 239 ushort2 gid_ = gid.xy; 240 out.write(in.read(gid_, gid.z), gid_, gid.z); 241 } 242 243 kernel void copy_nonarray(texture2d<half, access::read> in[[texture(0)]], 244 texture2d<half, access::write> out[[texture(1)]], 245 ushort2 gid[[thread_position_in_grid]]) { 246 if (gid.x >= out.get_width() || gid.y >= out.get_height()) { 247 return; 248 } 249 out.write(in.read(gid), gid); 250 } 251 252 kernel void copy_offset(texture2d_array<half, access::read> in[[texture(0)]], 253 texture2d_array<half, access::write> out[[texture(1)]], 254 constant ushort* offset_buf[[buffer(0)]], 255 ushort3 gid[[thread_position_in_grid]]) { 256 if (gid.x >= out.get_width() || gid.y >= out.get_height()) { 257 return; 258 } 259 ushort2 gid_ = gid.xy; 260 out.write(in.read(gid_, gid.z), gid_, gid.z + offset_buf[0]); 261 } 262 263 kernel void copy_offset_nonarray(texture2d<half, access::read> in[[texture(0)]], 264 texture2d_array<half, access::write> out[[texture(1)]], 265 constant ushort* offset_buf[[buffer(0)]], 266 ushort3 gid[[thread_position_in_grid]]) { 267 if (gid.x >= out.get_width() || gid.y >= out.get_height()) { 268 return; 269 } 270 ushort2 gid_ = gid.xy; 271 out.write(in.read(gid_), gid_, gid.z + offset_buf[0]); 272 } 273 274 constant bool store_features_out_is_arr = (ushort_arg_3 > 1 || ushort_arg_2 > 4); 275 constant bool store_features_out_is_tex = !store_features_out_is_arr; 276 kernel void store_features(texture2d_array<half, access::read> in[[texture(0)]], 277 texture2d<half, access::write> out_tex[[texture(1), function_constant(store_features_out_is_tex)]], 278 texture2d_array<half, access::write> out_arr[[texture(1), function_constant(store_features_out_is_arr)]], 279 constant ushort* offset_buf[[buffer(0)]], 280 ushort3 gid[[thread_position_in_grid]]) { 281 ushort2 gid_ = gid.xy; 282 if (store_features_out_is_arr) 283 out_arr.write(in.read(gid_, gid.z * offset_buf[1] + offset_buf[0]), gid_, gid.z); 284 else 285 out_tex.write(in.read(gid_, gid.z * offset_buf[1] + offset_buf[0]), gid_); 286 } 287 288 constant bool append_features_in_is_arr = (ushort_arg_7 > 1 || ushort_arg_6 > 4); 289 constant bool append_features_in_is_tex = !append_features_in_is_arr; 290 kernel void append_features(texture2d<half, access::read> in_tex[[texture(0), function_constant(append_features_in_is_tex)]], 291 texture2d_array<half, access::read> in_arr[[texture(0), function_constant(append_features_in_is_arr)]], 292 texture2d_array<half, access::write> out[[texture(1)]], 293 constant ushort* offset_buf[[buffer(0)]], 294 ushort3 gid[[thread_position_in_grid]]) { 295 ushort2 gid_ = gid.xy; 296 297 ushort batch = gid.z / offset_buf[0]; 298 ushort feature = gid.z % offset_buf[0]; 299 ushort outz = batch * offset_buf[1] + offset_buf[2] + feature; 300 ushort inz = batch * offset_buf[3] + feature; 301 302 half4 intex; 303 if (append_features_in_is_arr) { 304 intex = in_arr.read(gid_, inz); 305 } 306 else { 307 intex = in_tex.read(gid_); 308 } 309 out.write(intex, gid_, outz); 310 } 311 312 constant bool prev_is_arr = (ushort_arg_3 > 1 || ushort_arg_2 > 4); 313 constant bool prev_is_tex = !prev_is_arr; 314 constant bool append_features_off_in_is_arr = (ushort_arg_7 > 1 || ushort_arg_6 > 4); 315 constant bool append_features_off_in_is_tex = !append_features_off_in_is_arr; 316 kernel void append_features_off(texture2d<half, access::read> in_tex[[texture(0), function_constant(append_features_off_in_is_tex)]], 317 texture2d_array<half, access::read> in_arr[[texture(0), function_constant(append_features_off_in_is_arr)]], 318 texture2d<half, access::read> prev_tex[[texture(1), function_constant(prev_is_tex)]], 319 texture2d_array<half, access::read> prev_arr[[texture(1), function_constant(prev_is_arr)]], 320 texture2d_array<half, access::write> out[[texture(2)]], 321 constant ushort* offset_buf[[buffer(0)]], 322 ushort3 gid[[thread_position_in_grid]]) { 323 ushort2 gid_ = gid.xy; 324 325 ushort batch = gid.z / offset_buf[0]; 326 ushort feature = gid.z % offset_buf[0]; 327 ushort outz = batch * offset_buf[1] + offset_buf[2] + feature; 328 ushort inz = batch * offset_buf[3] + feature; 329 half4 outtex; 330 if (prev_is_arr) 331 outtex = prev_arr.read(gid_, batch); 332 else 333 outtex = prev_tex.read(gid_); 334 half4 intex1; 335 if (append_features_in_is_arr) 336 intex1 = in_arr.read(gid_, inz); 337 else 338 intex1 = in_tex.read(gid_); 339 if (feature == 0) { 340 if (offset_buf[5] == 1) 341 outtex.yzw = intex1.xyz; 342 else if (offset_buf[5] == 2) 343 outtex.zw = intex1.xy; 344 else 345 outtex.w = intex1.x; 346 out.write(outtex, gid_, outz); 347 return; 348 } 349 half4 intex0; 350 if (append_features_in_is_arr) 351 intex0 = in_arr.read(gid_, inz-1); 352 else 353 intex0 = intex1; 354 if (offset_buf[5] == 1) { 355 outtex.x = intex0.w; 356 outtex.yzw = intex1.xyz; 357 } 358 else if (offset_buf[5] == 2) { 359 outtex.xy = intex0.zw; 360 outtex.zw = intex1.xy; 361 } 362 else { 363 outtex.xyz = intex0.yzw; 364 outtex.w = intex1.x; 365 } 366 367 out.write(outtex, gid_, outz); 368 } 369 370 constant bool clamp_is_arr = (ushort_arg_1 > 1 || ushort_arg_0 > 4); 371 constant bool clamp_is_tex = !clamp_is_arr; 372 kernel void clamp(texture2d_array<half, access::read> in_arr[[texture(0), function_constant(clamp_is_arr)]], 373 texture2d<half, access::read> in_tex[[texture(0), function_constant(clamp_is_tex)]], 374 texture2d_array<half, access::write> out_arr[[texture(1), function_constant(clamp_is_arr)]], 375 texture2d<half, access::write> out_tex[[texture(1), function_constant(clamp_is_tex)]], 376 ushort3 gid[[thread_position_in_grid]]) { 377 const ushort w = clamp_is_arr? out_arr.get_width() : out_tex.get_width(); 378 const ushort h = clamp_is_arr? out_arr.get_height() : out_tex.get_height(); 379 if (gid.x >= w || gid.y >= h) { 380 return; 381 } 382 const float4 min_(float_arg_0, float_arg_0, float_arg_0, float_arg_0); 383 const float4 max_(float_arg_1, float_arg_1, float_arg_1, float_arg_1); 384 ushort2 gid_ = gid.xy; 385 if(clamp_is_arr){ 386 float4 value = (float4)in_arr.read(gid_, gid.z); 387 half4 clamped = (half4)clamp(value, min_, max_); 388 out_arr.write(clamped, gid_, gid.z); 389 } else { 390 float4 value = (float4)in_tex.read(gid_); 391 half4 clamped = (half4)clamp(value, min_, max_); 392 out_tex.write(clamped, gid_); 393 } 394 } 395 396 constant bool hardswish_is_arr = (ushort_arg_0 > 1 || ushort_arg_1 > 4); 397 constant bool hardswish_is_tex = !hardswish_is_arr; 398 kernel void hardswish(texture2d_array<half, access::read> in_arr[[texture(0), function_constant(hardswish_is_arr)]], 399 texture2d<half, access::read> in_tex[[texture(0), function_constant(hardswish_is_tex)]], 400 texture2d_array<half, access::write> out_arr[[texture(1), function_constant(hardswish_is_arr)]], 401 texture2d<half, access::write> out_tex[[texture(1), function_constant(hardswish_is_tex)]], 402 ushort3 gid[[thread_position_in_grid]]) { 403 const ushort oH = ushort_arg_2; 404 const ushort oW = ushort_arg_3; 405 if (gid.x >= oW || gid.y >= oH) { 406 return; 407 } 408 ushort2 gid_ = gid.xy; 409 if (hardswish_is_arr) { 410 half4 value = in_arr.read(gid_, gid.z); 411 half4 mask1 = half4(value < 3.0); 412 half4 mask2 = half4(value > -3.0); 413 half4 outval = mask2*(mask1*(value*(value + 3.0)/6.0) + (1 - mask1)*value); 414 out_arr.write(outval, gid_, gid.z); 415 } else { 416 half4 value = in_tex.read(gid_); 417 half4 mask1 = half4(value < 3); 418 half4 mask2 = half4(value > -3.0); 419 half4 outval = mask2*(mask1*(value*(value + 3.0)/6.0) + (1 - mask1)*value); 420 out_tex.write(outval, gid_); 421 } 422 } 423 424 constant bool hardshrink_is_arr = (ushort_arg_0 > 1 || ushort_arg_1 > 4); 425 constant bool hardshrink_is_tex = !hardshrink_is_arr; 426 kernel void hardshrink(texture2d_array<half, access::read> in_arr[[texture(0), function_constant(hardshrink_is_arr)]], 427 texture2d<half, access::read> in_tex[[texture(0), function_constant(hardshrink_is_tex)]], 428 texture2d_array<half, access::write> out_arr[[texture(1), function_constant(hardshrink_is_arr)]], 429 texture2d<half, access::write> out_tex[[texture(1), function_constant(hardshrink_is_tex)]], 430 ushort3 gid[[thread_position_in_grid]]) { 431 const ushort oH = ushort_arg_2; 432 const ushort oW = ushort_arg_3; 433 const half lambda = (half)float_arg_0; 434 if (gid.x >= oW || gid.y >= oH) { 435 return; 436 } 437 ushort2 gid_ = gid.xy; 438 if (hardshrink_is_arr) { 439 half4 value = in_arr.read(gid_, gid.z); 440 half4 mask1 = half4(value <= lambda); 441 half4 mask2 = half4(value >= -lambda); 442 half4 outval = (1 - mask1)*value + (1 - mask2)*value; 443 out_arr.write(outval, gid_, gid.z); 444 } else { 445 half4 value = in_tex.read(gid_); 446 half4 mask1 = half4(value <= lambda); 447 half4 mask2 = half4(value >= -lambda); 448 half4 outval = (1 - mask1)*value + (1 - mask2)*value; 449 out_tex.write(outval, gid_); 450 } 451 } 452 453 constant bool leaky_relu_is_arr = (ushort_arg_0 > 1 || ushort_arg_1 > 4); 454 constant bool leaky_relu_is_tex = !leaky_relu_is_arr; 455 kernel void leaky_relu(texture2d_array<half, access::read> in_arr[[texture(0), function_constant(leaky_relu_is_arr)]], 456 texture2d<half, access::read> in_tex[[texture(0), function_constant(leaky_relu_is_tex)]], 457 texture2d_array<half, access::write> out_arr[[texture(1), function_constant(leaky_relu_is_arr)]], 458 texture2d<half, access::write> out_tex[[texture(1), function_constant(leaky_relu_is_tex)]], 459 ushort3 gid[[thread_position_in_grid]]) { 460 const ushort oH = ushort_arg_2; 461 const ushort oW = ushort_arg_3; 462 const half negative_slope = (half)float_arg_0; 463 if (gid.x >= oW || gid.y >= oH) { 464 return; 465 } 466 ushort2 gid_ = gid.xy; 467 if (leaky_relu_is_arr) { 468 half4 value = in_arr.read(gid_, gid.z); 469 half4 is_negative = half4(value < 0.0); 470 half4 outval = is_negative*value*negative_slope + (1-is_negative)*value; 471 out_arr.write(outval, gid_, gid.z); 472 } else { 473 half4 value = in_tex.read(gid_); 474 half4 is_negative = half4(value < 0.0); 475 half4 outval = is_negative*value*negative_slope + (1-is_negative)*value; 476 out_tex.write(outval, gid_); 477 } 478 } 479 480 constant bool out_is_arr = (ushort_arg_3 > 1 || ushort_arg_2 > 4); 481 constant bool out_is_tex = !out_is_arr; 482 constant bool in_is_arr = (ushort_arg_7 > 1 || ushort_arg_6 > 4); 483 constant bool in_is_tex = !in_is_arr; 484 kernel void reflection_pad2d(texture2d_array<half, access::read> in_arr[[texture(0), function_constant(in_is_arr)]], 485 texture2d<half, access::read> in_tex[[texture(0),function_constant(in_is_tex)]], 486 texture2d_array<half, access::write> out_arr[[texture(1), function_constant(out_is_arr)]], 487 texture2d<half, access::write> out_tex[[texture(1), function_constant(out_is_tex)]], 488 ushort3 gid[[thread_position_in_grid]]) { 489 const ushort H2 = ushort_arg_0; 490 const ushort W2 = ushort_arg_1; 491 if (gid.x >= W2 || gid.y >= H2) { 492 return; 493 } 494 495 const ushort pad_left = ushort_arg_8; 496 const ushort pad_right = ushort_arg_9; 497 const ushort pad_top = ushort_arg_10; 498 const ushort pad_bottom = ushort_arg_11; 499 500 const ushort2 out_size = ushort2(W2, H2); 501 const ushort xoff_pre = 2*max(pad_left - gid.x, 0); 502 const ushort xoff_post = 2*max(gid.x - (out_size.x - 1 - pad_right), 0); 503 const ushort yoff_pre = 2*max(pad_top - gid.y, 0); 504 const ushort yoff_post = 2*max(gid.y - (out_size.y - 1 - pad_bottom), 0); 505 ushort2 inpos = ushort2( 506 gid.x + xoff_pre - xoff_post - pad_left, 507 gid.y + yoff_pre - yoff_post - pad_top); 508 509 half4 intex; 510 if (in_is_arr) { 511 intex = in_arr.read(inpos, gid.z); 512 } else { 513 intex = in_tex.read(inpos); 514 } 515 516 if (out_is_arr) { 517 out_arr.write(intex, gid.xy, gid.z); 518 } else { 519 out_tex.write(intex, gid.xy); 520 } 521 } 522 523 constant bool reshape_out_is_arr = (ushort_arg_3 > 1 || ushort_arg_2 > 4); 524 constant bool reshape_out_is_tex = !reshape_out_is_arr; 525 constant bool reshape_in_is_arr = (ushort_arg_7 > 1 || ushort_arg_6 > 4); 526 constant bool reshape_in_is_tex = !reshape_in_is_arr; 527 kernel void reshape(texture2d_array<half, access::read> in_arr[[texture(0), function_constant(reshape_in_is_arr)]], 528 texture2d<half, access::read> in_tex[[texture(0),function_constant(reshape_in_is_tex)]], 529 texture2d_array<half, access::write> out_arr[[texture(1), function_constant(reshape_out_is_arr)]], 530 texture2d<half, access::write> out_tex[[texture(1), 531 function_constant(reshape_out_is_tex)]], 532 ushort3 gid[[thread_position_in_grid]]) { 533 const ushort H2 = ushort_arg_0; 534 const ushort W2 = ushort_arg_1; 535 const ushort C2 = ushort_arg_2; 536 if (gid.x >= W2 || gid.y >= H2) { 537 return; 538 } 539 const ushort H1 = ushort_arg_4; 540 const ushort W1 = ushort_arg_5; 541 const ushort C1 = ushort_arg_6; 542 const ushort N1 = ushort_arg_7; 543 544 const size_t numel1 = H1 * W1 * C1 * N1; 545 const ushort slices2 = divRoundUp(C2, 4); 546 const ushort slices1 = divRoundUp(C1, 4); 547 const ushort n2 = gid.z / slices2; //image index 548 const ushort s2 = gid.z - n2 * slices2; // slice offest 549 half4 value; 550 for (int idx = 0; idx < 4; ++idx){ 551 // we compute the "linear index" of the output element, 552 // and convert it to the equivalent "linear index" of the input element. 553 ushort offset = 4 * s2 + idx; 554 size_t linear_idx = n2 * C2 * H2 * W2 + offset * H2 * W2 + gid.y * W2 + gid.x; 555 if(linear_idx >= numel1){ 556 value[idx] = 0; 557 continue; 558 } 559 auto x1 = linear_idx % W1; 560 auto y1 = ((int)(linear_idx/W1)) % H1; 561 auto s1 = ((int)(linear_idx/W1/H1) % C1); 562 auto n1 = ((int)(linear_idx/W1/H1/C1) % N1); 563 auto z1 = (int)s1 / 4 + n1 * slices1; 564 auto pos = s1 % 4; 565 if(reshape_in_is_arr) { 566 value[idx] = in_arr.read(ushort2(x1, y1), z1)[pos]; 567 } else { 568 value[idx] = in_tex.read(ushort2(x1, y1))[pos]; 569 } 570 571 } 572 if(reshape_out_is_arr) { 573 out_arr.write(value, gid.xy, gid.z); 574 } else { 575 out_tex.write(value, gid.xy); 576 } 577 } 578 579 constant bool transpose_in_is_arr = (ushort_arg_3 > 1 || ushort_arg_4 > 4); 580 constant bool transpose_in_is_tex = !transpose_in_is_arr; 581 constant bool transpose_out_is_arr = (ushort_arg_5 > 1 || ushort_arg_6 > 4); 582 constant bool transpose_out_is_tex = !transpose_out_is_arr; 583 kernel void transpose(texture2d_array<half, access::read>in_arr[[texture(0),function_constant(transpose_in_is_arr)]], 584 texture2d<half, access::read> in_tex[[texture(0), function_constant(transpose_in_is_tex)]], 585 texture2d_array<half, access::write>out_arr[[texture(1),function_constant(transpose_out_is_arr)]], 586 texture2d<half, access::write> out_tex[[texture(1), function_constant(transpose_out_is_tex)]], 587 constant ushort* inSizeBuffer [[buffer(0)]], 588 constant ushort* outSizeBuffer [[buffer(1)]], 589 ushort3 gid[[thread_position_in_grid]]) { 590 591 const ushort dim0 = ushort_arg_0; 592 const ushort dim1 = ushort_arg_1; 593 const ushort dim = ushort_arg_2; 594 const ushort N1 = ushort_arg_3; 595 const ushort C1 = ushort_arg_4; 596 const ushort N2 = ushort_arg_5; 597 const ushort C2 = ushort_arg_6; 598 ushort W1,W2,H1,H2; 599 if(transpose_in_is_arr) { 600 W1 = in_arr.get_width(); 601 H1 = in_arr.get_height(); 602 } else { 603 W1 = in_tex.get_width(); 604 H1 = in_tex.get_height(); 605 } 606 if(transpose_out_is_arr) { 607 W2 = out_arr.get_width(); 608 H2 = out_arr.get_height(); 609 } else { 610 W2 = out_tex.get_width(); 611 H2 = out_tex.get_height(); 612 } 613 if (gid.x >= W2 || gid.y >= H2) { 614 return; 615 } 616 const size_t numel = H2 * W2 * C2 * N2; 617 const ushort slices2 = divRoundUp(C2, 4); 618 const ushort slices1 = divRoundUp(C1, 4); 619 const ushort n2 = gid.z / slices2; 620 const ushort s2 = gid.z - n2 * slices2; 621 half4 value; 622 ushort4 threadIndexBufferLower{1, 1, 1, 1}; 623 ushort4 threadIndexBufferUpper{1, 1, 1 ,1}; 624 for (int idx = 0; idx < 4; ++idx){ 625 ushort offset = 4 * s2 + idx; 626 size_t linear_idx2 = n2 * C2 * H2 * W2 + offset * H2 * W2 + gid.y * W2 + gid.x; 627 if(linear_idx2 >= numel) { 628 value[idx] = 0; 629 continue; 630 } 631 632 ushort d2 = 0; 633 for(int j = dim-1; j>=0; --j){ 634 d2 = outSizeBuffer[j]; 635 if(j > 3) { 636 threadIndexBufferUpper[j-3] = linear_idx2 % d2; 637 } else { 638 threadIndexBufferLower[j] = linear_idx2 % d2; 639 } 640 linear_idx2 /= d2; 641 } 642 643 // swap dims 644 ushort tmp; 645 if(dim0 > 3) { 646 tmp = threadIndexBufferUpper[dim0-3]; 647 } else { 648 tmp = threadIndexBufferLower[dim0]; 649 } 650 if(dim0 > 3 && dim1 > 3) { 651 threadIndexBufferUpper[dim0-3] = threadIndexBufferUpper[dim1-3]; 652 } else if (dim0 > 3 && dim1 < 3) { 653 threadIndexBufferUpper[dim0-3] = threadIndexBufferLower[dim1]; 654 } else if (dim0 < 3 && dim1 > 3) { 655 threadIndexBufferLower[dim0] = threadIndexBufferUpper[dim1-3]; 656 } else { 657 threadIndexBufferLower[dim0] = threadIndexBufferLower[dim1]; 658 } 659 if(dim1 > 3) { 660 threadIndexBufferUpper[dim1-3] = tmp; 661 } else { 662 threadIndexBufferLower[dim1] = tmp; 663 } 664 665 size_t linear_idx1 = 0; 666 ushort m = 1; 667 ushort d1 = 0; 668 for(int k = dim-1; k>=0; --k) { 669 if(k > 3) { 670 d1 = threadIndexBufferUpper[k-3]; 671 } else { 672 d1 = threadIndexBufferLower[k]; 673 } 674 linear_idx1 += d1 * m; 675 m *= inSizeBuffer[k]; 676 } 677 678 auto x1 = linear_idx1 % W1; 679 auto y1 = ((int)(linear_idx1/W1)) % H1; 680 auto c1 = ((int)(linear_idx1/W1/H1) % C1); 681 auto n1 = ((int)(linear_idx1/W1/H1/C1) % N1); 682 auto z1 = (int)c1 / 4 + n1 * slices1; 683 auto pos = c1 % 4; 684 if(transpose_in_is_arr) { 685 value[idx] = in_arr.read(ushort2(x1, y1), z1)[pos]; 686 } else { 687 value[idx] = in_tex.read(ushort2(x1, y1))[pos]; 688 } 689 } 690 if(transpose_out_is_arr) { 691 out_arr.write(value, gid.xy, gid.z); 692 } else { 693 out_tex.write(value, gid.xy); 694 } 695 } 696 697 constant bool split_channels_in_is_arr = (ushort_arg_0 > 4); 698 constant bool split_channels_in_is_tex = !split_channels_in_is_arr; 699 constant bool split_channels_out1_is_arr = (ushort_arg_1 > 4); 700 constant bool split_channels_out1_is_tex = !split_channels_out1_is_arr; 701 constant bool split_channels_out2_is_arr = (ushort_arg_2 > 4); 702 constant bool split_channels_out2_is_tex = !(split_channels_out2_is_arr); 703 // A naive implementation to split the input texture into two on channel dimension 704 kernel void split_channels(texture2d_array<half, access::read> in_arr[[texture(0), function_constant(split_channels_in_is_arr)]], 705 texture2d<half, access::read> in_tex[[texture(0), function_constant(split_channels_in_is_tex)]], 706 texture2d_array<half, access::write> out1_arr[[texture(1),function_constant(split_channels_out1_is_arr)]], 707 texture2d<half, access::write> out1_tex[[texture(1),function_constant(split_channels_out1_is_tex)]], 708 texture2d_array<half, access::write> out2_arr[[texture(2), function_constant(split_channels_out2_is_arr)]], 709 texture2d<half, access::write> out2_tex[[texture(2),function_constant(split_channels_out2_is_tex)]], 710 ushort3 gid[[thread_position_in_grid]]) { 711 ushort W,H; 712 if(split_channels_in_is_arr) { 713 W = in_arr.get_width(); 714 H = in_arr.get_height(); 715 } else { 716 W = in_tex.get_width(); 717 H = in_tex.get_height(); 718 } 719 if(gid.x >= W || gid.y >= H){ 720 return; 721 } 722 const ushort C1 = ushort_arg_1; 723 const ushort s1 = divRoundUp(C1, 4); 724 const ushort c_offset = C1 % 4; 725 half4 tmp1(0.0, 0.0, 0.0, 0.0); 726 half4 tmp2(0.0, 0.0, 0.0, 0.0); 727 half4 in41 = split_channels_in_is_arr ? in_arr.read(gid.xy, gid.z) : in_tex.read(gid.xy); 728 half4 in42 = split_channels_in_is_arr ? in_arr.read(gid.xy, gid.z+1) : half4(0,0,0,0); 729 if(gid.z < s1 - 1) { 730 if(split_channels_out1_is_arr) { 731 out1_arr.write(in41, gid.xy, gid.z); 732 } 733 } 734 else if(gid.z == s1 - 1) { 735 if(c_offset == 0){ 736 if(split_channels_out1_is_arr) { 737 out1_arr.write(in41, gid.xy, gid.z); 738 } else { 739 out1_tex.write(in41, gid.xy); 740 } 741 return; 742 } else if(c_offset == 1) { 743 tmp1.x = in41.x; 744 tmp2.xyz = in41.yzw; 745 tmp2.w = in42.x; 746 } else if (c_offset == 2) { 747 tmp1.xy = in41.xy; 748 tmp2.xy = in41.zw; 749 tmp2.zw = in42.xy; 750 } else { 751 tmp1.xyz = in41.xyz; 752 tmp2.x = in41.w; 753 tmp2.yzw = in42.xyz; 754 } 755 if(split_channels_out1_is_arr) { 756 out1_arr.write(tmp1, gid.xy, gid.z); 757 } else { 758 out1_tex.write(tmp1, gid.xy); 759 } 760 if(split_channels_out2_is_arr) { 761 out2_arr.write(tmp2, gid.xy, 0); 762 } else { 763 out2_tex.write(tmp2, gid.xy); 764 } 765 } 766 else { 767 if (c_offset == 0) { 768 if(split_channels_out2_is_arr) { 769 out2_arr.write(in41, gid.xy, gid.z - s1); 770 } else { 771 out2_tex.write(in41, gid.xy); 772 } 773 return; 774 } 775 else if (c_offset == 1 ){ 776 tmp2.xyz = in41.yzw; 777 tmp2.w = in42.x; 778 } else if (c_offset == 2){ 779 tmp2.xy = in41.zw; 780 tmp2.zw = in42.xy; 781 } else { 782 tmp2.x = in41.w; 783 tmp2.yzw = in42.xyz; 784 } 785 if(split_channels_out2_is_arr) { 786 out2_arr.write(tmp2, gid.xy, gid.z - s1 + 1); 787 } else { 788 out2_tex.write(tmp2, gid.xy); 789 } 790 } 791 } 792 793 constant bool ra_has_in_arr = (ushort_arg_3 > 1 || ushort_arg_2 > 4); 794 constant bool ra_has_out_arr = (ushort_arg_4 > 1 || ushort_arg_2 > 4); 795 constant bool ra_has_in_tex = (!ra_has_in_arr); 796 constant bool ra_has_out_tex = (!ra_has_out_arr); 797 kernel void roi_align(texture2d_array<half, access::sample> ina[[texture(0), function_constant(ra_has_in_arr)]], 798 texture2d<half, access::sample> in[[texture(0), function_constant(ra_has_in_tex)]], 799 texture2d_array<half, access::write> outa[[texture(1), function_constant(ra_has_out_arr)]], 800 texture2d<half, access::write> out[[texture(1), function_constant(ra_has_out_tex)]], 801 constant half4* rois[[buffer(0)]], 802 ushort3 gid[[thread_position_in_grid]]) { 803 804 ushort out_width, out_height; 805 if (ra_has_out_arr) { 806 out_width = outa.get_width(); 807 out_height = outa.get_height(); 808 } else { 809 out_width = out.get_width(); 810 out_height = out.get_height(); 811 } 812 if (gid.x >= out_width || gid.y >= out_height) { 813 return; 814 } 815 const half spatial_scale = half(ushort_arg_0) / 10000; 816 const ushort sampling_ratio = ushort_arg_1; 817 const ushort C = ushort_arg_2; 818 const ushort pw = gid.x; 819 const ushort ph = gid.y; 820 const ushort n = gid.z / divRoundUp(C, 4); 821 const ushort c = gid.z % divRoundUp(C, 4); 822 823 const half4 roi_scaled = rois[n] * spatial_scale; 824 const half roi_start_w = roi_scaled[0]; 825 const half roi_start_h = roi_scaled[1]; 826 const half roi_end_w = roi_scaled[2]; 827 const half roi_end_h = roi_scaled[3]; 828 829 // Force malformed ROIs to be 1x1 830 const half roi_width = max(roi_end_w - roi_start_w, (half)1.); 831 const half roi_height = max(roi_end_h - roi_start_h, (half)1.); 832 833 const half bin_size_h = static_cast<half>(roi_height) / static_cast<half>(out_height); 834 const half bin_size_w = static_cast<half>(roi_width) / static_cast<half>(out_width); 835 836 const ushort roi_bin_grid_h = sampling_ratio > 0 ? sampling_ratio : ceil(roi_height / static_cast<half>(out_height)); 837 const ushort roi_bin_grid_w = sampling_ratio > 0 ? sampling_ratio : ceil(roi_width / static_cast<half>(out_width)); 838 839 const half count = roi_bin_grid_h * roi_bin_grid_w; 840 half4 output_val = 0.0; 841 842 constexpr sampler s2(coord::pixel, address::clamp_to_edge, filter::linear); 843 844 for (int iy = 0; iy < roi_bin_grid_h; iy++) { 845 for (int ix = 0; ix < roi_bin_grid_w; ix++) { 846 // Shift the pixel by 0.5. This is critical to achieve high accuracy. 847 const half y = 848 roi_start_h + ph * bin_size_h + (iy+0.5) * bin_size_h / static_cast<half>(roi_bin_grid_h); 849 const half x = 850 roi_start_w + pw * bin_size_w + (ix+0.5) * bin_size_w / static_cast<half>(roi_bin_grid_w); 851 if (ra_has_in_arr) { 852 output_val += ina.sample(s2, float2(x, y), c); 853 } else { 854 output_val += in.sample(s2, float2(x, y)); 855 } 856 } 857 } 858 output_val /= count; 859 if (ra_has_out_arr) { 860 outa.write(static_cast<half4>(output_val), gid.xy, gid.z); 861 } else { 862 out.write(static_cast<half4>(output_val), gid.xy); 863 } 864 } 865 866 )PT_METAL_SHADERS"; 867 868 #endif /* MPSCNNShaders_h */ 869