xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mps/UnaryConstants.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 const char* UNARY_KERNEL_TEMPLATE = R"METAL(
4 #include <metal_stdlib>
5 using namespace metal;
6 
7 constant float a[4] = {{0.886226899, -1.645349621, 0.914624893, -0.140543331}};
8 constant float b[4] = {{-2.118377725, 1.442710462, -0.329097515, 0.012229801}};
9 constant float c[4] = {{-1.970840454, -1.624906493, 3.429567803, 1.641345311}};
10 constant float d[2] = {{3.543889200, 1.637067800}};
11 
12 kernel void erfinv_kernel( device {0} *output [[buffer(0)]],
13                            device {1} *input [[buffer(1)]],
14                            uint index [[thread_position_in_grid]]) {{
15 
16   float y = input[index];
17   float x, z, num, dem; /*working variables */
18   /* coefficients in rational expansion */
19 
20   float y_abs = abs(y);
21   if (y_abs >= 1.0f) {{
22     output[index] = {0}( y_abs > 1.0f ? NAN : copysign(INFINITY, y));
23     return;
24   }}
25   if (y_abs <= 0.7f) {{
26     z = y * y;
27     num = ((a[3] * z + a[2]) * z + a[1])*z + a[0];
28     dem = (((b[3] * z + b[2]) * z + b[1]) * z +b[0]) * z + 1.0f;
29     x = y * num / dem;
30   }} else {{
31     z = sqrt(-1.0f*log((1.0-y_abs)/2.0));
32     num = ((c[3] * z + c[2]) * z + c[1]) * z + c[0];
33     dem = (d[1] * z + d[0]) * z + 1.0f;
34     x = copysign(num, y) / dem;
35   }}
36 
37   output[index] = {0}(x);
38 }}
39 
40 kernel void exp_kernel( device {0} *output [[buffer(0)]],
41                         device {1} *input [[ buffer(1)]],
42                         uint index [[thread_position_in_grid]]) {{
43   output[index] = {0}(precise::exp(input[index]));
44 }}
45 
46 kernel void exp_complex_kernel( device {0}2 *output [[buffer(0)]],
47                                 device {0}2 *input [[ buffer(1)]],
48                                 uint index [[thread_position_in_grid]]) {{
49   output[index].x = {0}(precise::exp(input[index].x)*precise::cos(input[index].y));
50   output[index].y = {0}(precise::exp(input[index].x)*precise::sin(input[index].y));
51 }}
52 
53 kernel void tanh_kernel( device {0} *output [[buffer(0)]],
54                         device {1} *input [[ buffer(1)]],
55                         uint index [[thread_position_in_grid]]) {{
56   output[index] = {0}(precise::tanh(input[index]));
57 }}
58 
59 
60 #if __METAL_VERSION__ >= 310
61 bfloat dot(bfloat2 a, bfloat2 b) {{
62   return a.x * b.x + a.y * b.y;
63 }}
64 #endif
65 
66 template<typename T>
67 T complex_div(T a, T b) {{
68   auto denom = dot(b, b);
69   return T(dot(a, b), a.y * b.x - a.x * b.y)/denom;
70 }}
71 
72 kernel void tanh_complex_kernel( device {0}2 *output [[buffer(0)]],
73                                  device {0}2 *input [[ buffer(1)]],
74                                  uint index [[thread_position_in_grid]]) {{
75   //tanh(x+iy)=(tanh(x)+itan(y))/(1+itahnh(x)*tan(y));
76   auto tanh_x = {0}(precise::tanh(input[index].x));
77   auto tan_y = {0}(precise::tan(input[index].y));
78   output[index] = complex_div({0}2(tanh_x, tan_y), {0}2({0}(1), tanh_x * tan_y));
79 }}
80 )METAL";
81