1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/tensor_testutil.h"
17
18 #include <cmath>
19
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/platform/types.h"
22
23 namespace tensorflow {
24 namespace test {
25
IsSameType(const Tensor & x,const Tensor & y)26 ::testing::AssertionResult IsSameType(const Tensor& x, const Tensor& y) {
27 if (x.dtype() != y.dtype()) {
28 return ::testing::AssertionFailure()
29 << "Tensors have different dtypes (" << x.dtype() << " vs "
30 << y.dtype() << ")";
31 }
32 return ::testing::AssertionSuccess();
33 }
34
IsSameShape(const Tensor & x,const Tensor & y)35 ::testing::AssertionResult IsSameShape(const Tensor& x, const Tensor& y) {
36 if (!x.IsSameSize(y)) {
37 return ::testing::AssertionFailure()
38 << "Tensors have different shapes (" << x.shape().DebugString()
39 << " vs " << y.shape().DebugString() << ")";
40 }
41 return ::testing::AssertionSuccess();
42 }
43
44 template <typename T>
EqualFailure(const T & x,const T & y)45 static ::testing::AssertionResult EqualFailure(const T& x, const T& y) {
46 return ::testing::AssertionFailure()
47 << std::setprecision(std::numeric_limits<T>::digits10 + 2) << x
48 << " not equal to " << y;
49 }
IsEqual(float x,float y,Tolerance t)50 static ::testing::AssertionResult IsEqual(float x, float y, Tolerance t) {
51 // We consider NaNs equal for testing.
52 if (Eigen::numext::isnan(x) && Eigen::numext::isnan(y))
53 return ::testing::AssertionSuccess();
54 if (t == Tolerance::kNone) {
55 if (x == y) return ::testing::AssertionSuccess();
56 } else {
57 if (::testing::internal::CmpHelperFloatingPointEQ<float>("", "", x, y))
58 return ::testing::AssertionSuccess();
59 }
60 return EqualFailure(x, y);
61 }
IsEqual(double x,double y,Tolerance t)62 static ::testing::AssertionResult IsEqual(double x, double y, Tolerance t) {
63 // We consider NaNs equal for testing.
64 if (Eigen::numext::isnan(x) && Eigen::numext::isnan(y))
65 return ::testing::AssertionSuccess();
66 if (t == Tolerance::kNone) {
67 if (x == y) return ::testing::AssertionSuccess();
68 } else {
69 if (::testing::internal::CmpHelperFloatingPointEQ<double>("", "", x, y))
70 return ::testing::AssertionSuccess();
71 }
72 return EqualFailure(x, y);
73 }
IsEqual(Eigen::half x,Eigen::half y,Tolerance t)74 static ::testing::AssertionResult IsEqual(Eigen::half x, Eigen::half y,
75 Tolerance t) {
76 // We consider NaNs equal for testing.
77 if (Eigen::numext::isnan(x) && Eigen::numext::isnan(y))
78 return ::testing::AssertionSuccess();
79
80 // Below is a reimplementation of CmpHelperFloatingPointEQ<Eigen::half>, which
81 // we cannot use because Eigen::half is not default-constructible.
82
83 if (Eigen::numext::isnan(x) || Eigen::numext::isnan(y))
84 return EqualFailure(x, y);
85
86 auto sign_and_magnitude_to_biased = [](uint16_t sam) {
87 const uint16_t kSignBitMask = 0x8000;
88 if (kSignBitMask & sam) return ~sam + 1; // negative number.
89 return kSignBitMask | sam; // positive number.
90 };
91
92 auto xb = sign_and_magnitude_to_biased(Eigen::numext::bit_cast<uint16_t>(x));
93 auto yb = sign_and_magnitude_to_biased(Eigen::numext::bit_cast<uint16_t>(y));
94 if (t == Tolerance::kNone) {
95 if (xb == yb) return ::testing::AssertionSuccess();
96 } else {
97 auto distance = xb >= yb ? xb - yb : yb - xb;
98 const uint16_t kMaxUlps = 4;
99 if (distance <= kMaxUlps) return ::testing::AssertionSuccess();
100 }
101 return EqualFailure(x, y);
102 }
103 template <typename T>
IsEqual(const T & x,const T & y,Tolerance t)104 static ::testing::AssertionResult IsEqual(const T& x, const T& y, Tolerance t) {
105 if (::testing::internal::CmpHelperEQ<T>("", "", x, y))
106 return ::testing::AssertionSuccess();
107 return EqualFailure(x, y);
108 }
109 template <typename T>
IsEqual(const std::complex<T> & x,const std::complex<T> & y,Tolerance t)110 static ::testing::AssertionResult IsEqual(const std::complex<T>& x,
111 const std::complex<T>& y,
112 Tolerance t) {
113 if (IsEqual(x.real(), y.real(), t) && IsEqual(x.imag(), y.imag(), t))
114 return ::testing::AssertionSuccess();
115 return EqualFailure(x, y);
116 }
117
118 template <typename T>
ExpectEqual(const Tensor & x,const Tensor & y,Tolerance t=Tolerance::kDefault)119 static void ExpectEqual(const Tensor& x, const Tensor& y,
120 Tolerance t = Tolerance::kDefault) {
121 const T* Tx = x.unaligned_flat<T>().data();
122 const T* Ty = y.unaligned_flat<T>().data();
123 auto size = x.NumElements();
124 int max_failures = 10;
125 int num_failures = 0;
126 for (decltype(size) i = 0; i < size; ++i) {
127 EXPECT_TRUE(IsEqual(Tx[i], Ty[i], t)) << "i = " << (++num_failures, i);
128 ASSERT_LT(num_failures, max_failures) << "Too many mismatches, giving up.";
129 }
130 }
131
132 template <typename T>
IsClose(const T & x,const T & y,const T & atol,const T & rtol)133 static ::testing::AssertionResult IsClose(const T& x, const T& y, const T& atol,
134 const T& rtol) {
135 // We consider NaNs equal for testing.
136 if (Eigen::numext::isnan(x) && Eigen::numext::isnan(y))
137 return ::testing::AssertionSuccess();
138 if (x == y) return ::testing::AssertionSuccess(); // Handle infinity.
139 auto tolerance = atol + rtol * Eigen::numext::abs(x);
140 if (Eigen::numext::abs(x - y) <= tolerance)
141 return ::testing::AssertionSuccess();
142 return ::testing::AssertionFailure() << x << " not close to " << y;
143 }
144
145 template <typename T>
IsClose(const std::complex<T> & x,const std::complex<T> & y,const T & atol,const T & rtol)146 static ::testing::AssertionResult IsClose(const std::complex<T>& x,
147 const std::complex<T>& y,
148 const T& atol, const T& rtol) {
149 if (IsClose(x.real(), y.real(), atol, rtol) &&
150 IsClose(x.imag(), y.imag(), atol, rtol))
151 return ::testing::AssertionSuccess();
152 return ::testing::AssertionFailure() << x << " not close to " << y;
153 }
154
155 // Return type can be different from T, e.g. float for T=std::complex<float>.
156 template <typename T>
GetTolerance(double tolerance)157 static auto GetTolerance(double tolerance) {
158 using Real = typename Eigen::NumTraits<T>::Real;
159 auto default_tol = static_cast<Real>(5.0) * Eigen::NumTraits<T>::epsilon();
160 auto result = tolerance < 0.0 ? default_tol : static_cast<Real>(tolerance);
161 EXPECT_GE(result, static_cast<Real>(0));
162 return result;
163 }
164
165 template <typename T>
ExpectClose(const Tensor & x,const Tensor & y,double atol,double rtol)166 static void ExpectClose(const Tensor& x, const Tensor& y, double atol,
167 double rtol) {
168 auto typed_atol = GetTolerance<T>(atol);
169 auto typed_rtol = GetTolerance<T>(rtol);
170
171 const T* Tx = x.unaligned_flat<T>().data();
172 const T* Ty = y.unaligned_flat<T>().data();
173 auto size = x.NumElements();
174 int max_failures = 10;
175 int num_failures = 0;
176 for (decltype(size) i = 0; i < size; ++i) {
177 EXPECT_TRUE(IsClose(Tx[i], Ty[i], typed_atol, typed_rtol))
178 << "i = " << (++num_failures, i) << " Tx[i] = " << Tx[i]
179 << " Ty[i] = " << Ty[i];
180 ASSERT_LT(num_failures, max_failures)
181 << "Too many mismatches (atol = " << atol << " rtol = " << rtol
182 << "), giving up.";
183 }
184 EXPECT_EQ(num_failures, 0)
185 << "Mismatches detected (atol = " << atol << " rtol = " << rtol << ").";
186 }
187
ExpectEqual(const Tensor & x,const Tensor & y,Tolerance t)188 void ExpectEqual(const Tensor& x, const Tensor& y, Tolerance t) {
189 ASSERT_TRUE(IsSameType(x, y));
190 ASSERT_TRUE(IsSameShape(x, y));
191
192 switch (x.dtype()) {
193 case DT_FLOAT:
194 return ExpectEqual<float>(x, y, t);
195 case DT_DOUBLE:
196 return ExpectEqual<double>(x, y, t);
197 case DT_INT32:
198 return ExpectEqual<int32>(x, y);
199 case DT_UINT32:
200 return ExpectEqual<uint32>(x, y);
201 case DT_UINT16:
202 return ExpectEqual<uint16>(x, y);
203 case DT_UINT8:
204 return ExpectEqual<uint8>(x, y);
205 case DT_INT16:
206 return ExpectEqual<int16>(x, y);
207 case DT_INT8:
208 return ExpectEqual<int8>(x, y);
209 case DT_STRING:
210 return ExpectEqual<tstring>(x, y);
211 case DT_COMPLEX64:
212 return ExpectEqual<complex64>(x, y, t);
213 case DT_COMPLEX128:
214 return ExpectEqual<complex128>(x, y, t);
215 case DT_INT64:
216 return ExpectEqual<int64_t>(x, y);
217 case DT_UINT64:
218 return ExpectEqual<uint64>(x, y);
219 case DT_BOOL:
220 return ExpectEqual<bool>(x, y);
221 case DT_QINT8:
222 return ExpectEqual<qint8>(x, y);
223 case DT_QUINT8:
224 return ExpectEqual<quint8>(x, y);
225 case DT_QINT16:
226 return ExpectEqual<qint16>(x, y);
227 case DT_QUINT16:
228 return ExpectEqual<quint16>(x, y);
229 case DT_QINT32:
230 return ExpectEqual<qint32>(x, y);
231 case DT_BFLOAT16:
232 return ExpectEqual<bfloat16>(x, y, t);
233 case DT_HALF:
234 return ExpectEqual<Eigen::half>(x, y, t);
235 default:
236 EXPECT_TRUE(false) << "Unsupported type : " << DataTypeString(x.dtype());
237 }
238 }
239
ExpectClose(const Tensor & x,const Tensor & y,double atol,double rtol)240 void ExpectClose(const Tensor& x, const Tensor& y, double atol, double rtol) {
241 ASSERT_TRUE(IsSameType(x, y));
242 ASSERT_TRUE(IsSameShape(x, y));
243
244 switch (x.dtype()) {
245 case DT_HALF:
246 return ExpectClose<Eigen::half>(x, y, atol, rtol);
247 case DT_BFLOAT16:
248 return ExpectClose<Eigen::bfloat16>(x, y, atol, rtol);
249 case DT_FLOAT:
250 return ExpectClose<float>(x, y, atol, rtol);
251 case DT_DOUBLE:
252 return ExpectClose<double>(x, y, atol, rtol);
253 case DT_COMPLEX64:
254 return ExpectClose<complex64>(x, y, atol, rtol);
255 case DT_COMPLEX128:
256 return ExpectClose<complex128>(x, y, atol, rtol);
257 default:
258 EXPECT_TRUE(false) << "Unsupported type : " << DataTypeString(x.dtype());
259 }
260 }
261
IsClose(Eigen::half x,Eigen::half y,double atol,double rtol)262 ::testing::AssertionResult internal_test::IsClose(Eigen::half x, Eigen::half y,
263 double atol, double rtol) {
264 return test::IsClose(x, y, GetTolerance<Eigen::half>(atol),
265 GetTolerance<Eigen::half>(rtol));
266 }
IsClose(float x,float y,double atol,double rtol)267 ::testing::AssertionResult internal_test::IsClose(float x, float y, double atol,
268 double rtol) {
269 return test::IsClose(x, y, GetTolerance<float>(atol),
270 GetTolerance<float>(rtol));
271 }
IsClose(double x,double y,double atol,double rtol)272 ::testing::AssertionResult internal_test::IsClose(double x, double y,
273 double atol, double rtol) {
274 return test::IsClose(x, y, GetTolerance<double>(atol),
275 GetTolerance<double>(rtol));
276 }
277
278 } // end namespace test
279 } // end namespace tensorflow
280