1 #pragma once
2
3 #include <ATen/Config.h>
4 #include <ATen/Parallel.h>
5 #include <ATen/OpMathType.h>
6 #include <ATen/cpu/vec/functional.h>
7 #include <ATen/cpu/vec/vec.h>
8 #include <c10/util/complex.h>
9
10 // This header implements various unary operations using a MKL VML style
11 // interface.
12
13 // It implements various functions with a simple interface
14 // For example it enables the user to call vsin(float* out, const float* in,
15 // size) This functions takes a pointer to a continuous output array of floats and
16 // a constant input array. It will then apply sin to each value in the input
17 // array and write the result into the output array. out and in may point to the
18 // same memory, i.e. this fully supports in-place operations. These functions
19 // also implement their own parallelization, so take precautions when calling
20 // these from threaded functions.
21
22 // When MKL is available it will call into MKL's VML library similar to NumPy
23 // If MKL is not available it will use SLEEF.
24
25 // This file might be compiled under AVX or AVX2 when called from e.g.
26 // UnaryOpsKernel.cpp
27
28 #include <algorithm>
29 #include <cstddef>
30 #include <cstdint>
31 #include <cstring>
32 #include <type_traits>
33
34 #if AT_MKL_ENABLED() && !defined(__APPLE__)
35 #include <mkl.h>
36 #endif
37
38
39 namespace at::vml {
40 inline namespace CPU_CAPABILITY {
41
42 using namespace vec;
43
44 template <typename scalar_t>
vrsqrt(scalar_t * out,scalar_t * in,int64_t size)45 inline void vrsqrt(scalar_t* out, scalar_t* in, int64_t size) {
46 parallel_for(0, size, 2048, [out, in](int64_t begin, int64_t end) {
47 map(
48 [](const Vectorized<scalar_t>& x) {
49 return Vectorized<scalar_t>((scalar_t)(1)) / x.sqrt();
50 },
51 out + begin,
52 in + begin,
53 end - begin);
54 });
55 }
56
57 // NB: We ignore numerical errors by convention and leave them to the user
58
59 #define IMPLEMENT_VML(op) \
60 template <typename scalar_t> \
61 inline void v##op(scalar_t* out, const scalar_t* in, int64_t size) { \
62 using vec_t = Vectorized<vec_scalar_t<scalar_t>>; \
63 vec::map([](vec_t x) { return x.op(); }, out, in, size); \
64 } \
65
66 IMPLEMENT_VML(abs)
67 IMPLEMENT_VML(acos)
68 IMPLEMENT_VML(asin)
69 IMPLEMENT_VML(atan)
70 IMPLEMENT_VML(atanh)
71 IMPLEMENT_VML(ceil)
72 IMPLEMENT_VML(cos)
73 // IMPLEMENT_VML(cosh)
74 IMPLEMENT_VML(erf)
75 IMPLEMENT_VML(erfc)
76 IMPLEMENT_VML(erfinv)
77 IMPLEMENT_VML(exp)
78 IMPLEMENT_VML(expm1)
79 IMPLEMENT_VML(floor)
80 IMPLEMENT_VML(i0)
81 IMPLEMENT_VML(i0e)
82 IMPLEMENT_VML(digamma)
83 IMPLEMENT_VML(reciprocal)
84 IMPLEMENT_VML(log)
85 IMPLEMENT_VML(log10)
86 IMPLEMENT_VML(log1p)
87 IMPLEMENT_VML(log2)
88 IMPLEMENT_VML(neg)
89 IMPLEMENT_VML(sin)
90 // IMPLEMENT_VML(sinh)
91 IMPLEMENT_VML(sqrt)
92 IMPLEMENT_VML(round)
93 IMPLEMENT_VML(rsqrt)
94 IMPLEMENT_VML(tan)
95 IMPLEMENT_VML(tanh)
96 IMPLEMENT_VML(trunc)
97 IMPLEMENT_VML(lgamma)
98
99
100 #if AT_MKL_ENABLED() && !defined(__APPLE__)
101
102 // NB: LP64 MKL is the most commonly used and thus we assume it here. That means
103 // we need to expect MKL_INT to be of type int, which implies int32_t or int64_t in most
104 // cases.
105 static_assert(
106 std::is_same_v<MKL_INT, int32_t> || std::is_same_v<MKL_INT, int64_t>,
107 "MKL_INT is assumed to be int32_t or int64_t");
108 #define IMPLEMENT_VML_MKL_STUB(op, mklop, type, mkltype) \
109 template <> \
110 inline void v##op(type * out, const type * in, int64_t size) { \
111 int64_t max_mkl_ind = std::numeric_limits<MKL_INT>::max(); \
112 if (size <= static_cast<int64_t>(max_mkl_ind)) { \
113 vm##mkltype##mklop( \
114 size, in, out, VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
115 } else { \
116 MKL_INT ind = 0; \
117 int64_t chunks = size / max_mkl_ind; \
118 int64_t rest = size % max_mkl_ind; \
119 for (; ind < chunks; ind++) { \
120 vm##mkltype##mklop( \
121 max_mkl_ind, \
122 in + ind * max_mkl_ind, \
123 out + ind * max_mkl_ind, \
124 VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
125 } \
126 vm##mkltype##mklop( \
127 rest, \
128 in + ind * max_mkl_ind, \
129 out + ind * max_mkl_ind, \
130 VML_HA | VML_FTZDAZ_OFF | VML_ERRMODE_IGNORE); \
131 } \
132 }
133
134 #define IMPLEMENT_VML_MKL(op, mklop) \
135 IMPLEMENT_VML_MKL_STUB(op, mklop, float, s) \
136 IMPLEMENT_VML_MKL_STUB(op, mklop, double, d)
137
138 // NB: abs, cosh and sinh were temporarily disabled due to issues with Apple
139 // NB: expm1 is disabled because on some configs it produces expm1(nan)=-1
140 IMPLEMENT_VML_MKL(acos, Acos)
141 IMPLEMENT_VML_MKL(asin, Asin)
142 IMPLEMENT_VML_MKL(atan, Atan)
143 IMPLEMENT_VML_MKL(cos, Cos)
144 // IMPLEMENT_VML_MKL(cosh, Cosh)
145 IMPLEMENT_VML_MKL(erf, Erf)
146 IMPLEMENT_VML_MKL(erfc, Erfc)
147 IMPLEMENT_VML_MKL(erfinv, ErfInv)
148 IMPLEMENT_VML_MKL(exp, Exp)
149 // IMPLEMENT_VML_MKL(expm1, Expm1)
150 IMPLEMENT_VML_MKL(log, Ln)
151 IMPLEMENT_VML_MKL(log10, Log10)
152 IMPLEMENT_VML_MKL(sin, Sin)
153 // IMPLEMENT_VML_MKL(sinh, Sinh)
154 IMPLEMENT_VML_MKL(sqrt, Sqrt)
155 IMPLEMENT_VML_MKL(tan, Tan)
156 IMPLEMENT_VML_MKL(tanh, Tanh)
157 IMPLEMENT_VML_MKL(trunc, Trunc)
158
159 // Not vectorized in MKL version tested
160 // IMPLEMENT_VML_MKL(abs, Abs)
161 // IMPLEMENT_VML_MKL(log1p, Log1p)
162
163 #if INTEL_MKL_VERSION >= 20180406
164 IMPLEMENT_VML_MKL(log2, Log2)
165 #endif
166
167 #endif
168
169 } // namespace
170 } // namespace at::vml
171