1 #pragma once 2 3 const char* UNARY_KERNEL_TEMPLATE = R"METAL( 4 #include <metal_stdlib> 5 using namespace metal; 6 7 constant float a[4] = {{0.886226899, -1.645349621, 0.914624893, -0.140543331}}; 8 constant float b[4] = {{-2.118377725, 1.442710462, -0.329097515, 0.012229801}}; 9 constant float c[4] = {{-1.970840454, -1.624906493, 3.429567803, 1.641345311}}; 10 constant float d[2] = {{3.543889200, 1.637067800}}; 11 12 kernel void erfinv_kernel( device {0} *output [[buffer(0)]], 13 device {1} *input [[buffer(1)]], 14 uint index [[thread_position_in_grid]]) {{ 15 16 float y = input[index]; 17 float x, z, num, dem; /*working variables */ 18 /* coefficients in rational expansion */ 19 20 float y_abs = abs(y); 21 if (y_abs >= 1.0f) {{ 22 output[index] = {0}( y_abs > 1.0f ? NAN : copysign(INFINITY, y)); 23 return; 24 }} 25 if (y_abs <= 0.7f) {{ 26 z = y * y; 27 num = ((a[3] * z + a[2]) * z + a[1])*z + a[0]; 28 dem = (((b[3] * z + b[2]) * z + b[1]) * z +b[0]) * z + 1.0f; 29 x = y * num / dem; 30 }} else {{ 31 z = sqrt(-1.0f*log((1.0-y_abs)/2.0)); 32 num = ((c[3] * z + c[2]) * z + c[1]) * z + c[0]; 33 dem = (d[1] * z + d[0]) * z + 1.0f; 34 x = copysign(num, y) / dem; 35 }} 36 37 output[index] = {0}(x); 38 }} 39 40 kernel void exp_kernel( device {0} *output [[buffer(0)]], 41 device {1} *input [[ buffer(1)]], 42 uint index [[thread_position_in_grid]]) {{ 43 output[index] = {0}(precise::exp(input[index])); 44 }} 45 46 kernel void exp_complex_kernel( device {0}2 *output [[buffer(0)]], 47 device {0}2 *input [[ buffer(1)]], 48 uint index [[thread_position_in_grid]]) {{ 49 output[index].x = {0}(precise::exp(input[index].x)*precise::cos(input[index].y)); 50 output[index].y = {0}(precise::exp(input[index].x)*precise::sin(input[index].y)); 51 }} 52 53 kernel void tanh_kernel( device {0} *output [[buffer(0)]], 54 device {1} *input [[ buffer(1)]], 55 uint index [[thread_position_in_grid]]) {{ 56 output[index] = {0}(precise::tanh(input[index])); 57 }} 58 59 60 #if __METAL_VERSION__ >= 310 61 bfloat dot(bfloat2 a, bfloat2 b) {{ 62 return a.x * b.x + a.y * b.y; 63 }} 64 #endif 65 66 template<typename T> 67 T complex_div(T a, T b) {{ 68 auto denom = dot(b, b); 69 return T(dot(a, b), a.y * b.x - a.x * b.y)/denom; 70 }} 71 72 kernel void tanh_complex_kernel( device {0}2 *output [[buffer(0)]], 73 device {0}2 *input [[ buffer(1)]], 74 uint index [[thread_position_in_grid]]) {{ 75 //tanh(x+iy)=(tanh(x)+itan(y))/(1+itahnh(x)*tan(y)); 76 auto tanh_x = {0}(precise::tanh(input[index].x)); 77 auto tan_y = {0}(precise::tan(input[index].y)); 78 output[index] = complex_div({0}2(tanh_x, tan_y), {0}2({0}(1), tanh_x * tan_y)); 79 }} 80 )METAL"; 81