xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/tensor_testutil.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/tensor_testutil.h"
17 
18 #include <cmath>
19 
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/platform/types.h"
22 
23 namespace tensorflow {
24 namespace test {
25 
IsSameType(const Tensor & x,const Tensor & y)26 ::testing::AssertionResult IsSameType(const Tensor& x, const Tensor& y) {
27   if (x.dtype() != y.dtype()) {
28     return ::testing::AssertionFailure()
29            << "Tensors have different dtypes (" << x.dtype() << " vs "
30            << y.dtype() << ")";
31   }
32   return ::testing::AssertionSuccess();
33 }
34 
IsSameShape(const Tensor & x,const Tensor & y)35 ::testing::AssertionResult IsSameShape(const Tensor& x, const Tensor& y) {
36   if (!x.IsSameSize(y)) {
37     return ::testing::AssertionFailure()
38            << "Tensors have different shapes (" << x.shape().DebugString()
39            << " vs " << y.shape().DebugString() << ")";
40   }
41   return ::testing::AssertionSuccess();
42 }
43 
44 template <typename T>
EqualFailure(const T & x,const T & y)45 static ::testing::AssertionResult EqualFailure(const T& x, const T& y) {
46   return ::testing::AssertionFailure()
47          << std::setprecision(std::numeric_limits<T>::digits10 + 2) << x
48          << " not equal to " << y;
49 }
IsEqual(float x,float y,Tolerance t)50 static ::testing::AssertionResult IsEqual(float x, float y, Tolerance t) {
51   // We consider NaNs equal for testing.
52   if (Eigen::numext::isnan(x) && Eigen::numext::isnan(y))
53     return ::testing::AssertionSuccess();
54   if (t == Tolerance::kNone) {
55     if (x == y) return ::testing::AssertionSuccess();
56   } else {
57     if (::testing::internal::CmpHelperFloatingPointEQ<float>("", "", x, y))
58       return ::testing::AssertionSuccess();
59   }
60   return EqualFailure(x, y);
61 }
IsEqual(double x,double y,Tolerance t)62 static ::testing::AssertionResult IsEqual(double x, double y, Tolerance t) {
63   // We consider NaNs equal for testing.
64   if (Eigen::numext::isnan(x) && Eigen::numext::isnan(y))
65     return ::testing::AssertionSuccess();
66   if (t == Tolerance::kNone) {
67     if (x == y) return ::testing::AssertionSuccess();
68   } else {
69     if (::testing::internal::CmpHelperFloatingPointEQ<double>("", "", x, y))
70       return ::testing::AssertionSuccess();
71   }
72   return EqualFailure(x, y);
73 }
IsEqual(Eigen::half x,Eigen::half y,Tolerance t)74 static ::testing::AssertionResult IsEqual(Eigen::half x, Eigen::half y,
75                                           Tolerance t) {
76   // We consider NaNs equal for testing.
77   if (Eigen::numext::isnan(x) && Eigen::numext::isnan(y))
78     return ::testing::AssertionSuccess();
79 
80   // Below is a reimplementation of CmpHelperFloatingPointEQ<Eigen::half>, which
81   // we cannot use because Eigen::half is not default-constructible.
82 
83   if (Eigen::numext::isnan(x) || Eigen::numext::isnan(y))
84     return EqualFailure(x, y);
85 
86   auto sign_and_magnitude_to_biased = [](uint16_t sam) {
87     const uint16_t kSignBitMask = 0x8000;
88     if (kSignBitMask & sam) return ~sam + 1;  // negative number.
89     return kSignBitMask | sam;                // positive number.
90   };
91 
92   auto xb = sign_and_magnitude_to_biased(Eigen::numext::bit_cast<uint16_t>(x));
93   auto yb = sign_and_magnitude_to_biased(Eigen::numext::bit_cast<uint16_t>(y));
94   if (t == Tolerance::kNone) {
95     if (xb == yb) return ::testing::AssertionSuccess();
96   } else {
97     auto distance = xb >= yb ? xb - yb : yb - xb;
98     const uint16_t kMaxUlps = 4;
99     if (distance <= kMaxUlps) return ::testing::AssertionSuccess();
100   }
101   return EqualFailure(x, y);
102 }
103 template <typename T>
IsEqual(const T & x,const T & y,Tolerance t)104 static ::testing::AssertionResult IsEqual(const T& x, const T& y, Tolerance t) {
105   if (::testing::internal::CmpHelperEQ<T>("", "", x, y))
106     return ::testing::AssertionSuccess();
107   return EqualFailure(x, y);
108 }
109 template <typename T>
IsEqual(const std::complex<T> & x,const std::complex<T> & y,Tolerance t)110 static ::testing::AssertionResult IsEqual(const std::complex<T>& x,
111                                           const std::complex<T>& y,
112                                           Tolerance t) {
113   if (IsEqual(x.real(), y.real(), t) && IsEqual(x.imag(), y.imag(), t))
114     return ::testing::AssertionSuccess();
115   return EqualFailure(x, y);
116 }
117 
118 template <typename T>
ExpectEqual(const Tensor & x,const Tensor & y,Tolerance t=Tolerance::kDefault)119 static void ExpectEqual(const Tensor& x, const Tensor& y,
120                         Tolerance t = Tolerance::kDefault) {
121   const T* Tx = x.unaligned_flat<T>().data();
122   const T* Ty = y.unaligned_flat<T>().data();
123   auto size = x.NumElements();
124   int max_failures = 10;
125   int num_failures = 0;
126   for (decltype(size) i = 0; i < size; ++i) {
127     EXPECT_TRUE(IsEqual(Tx[i], Ty[i], t)) << "i = " << (++num_failures, i);
128     ASSERT_LT(num_failures, max_failures) << "Too many mismatches, giving up.";
129   }
130 }
131 
132 template <typename T>
IsClose(const T & x,const T & y,const T & atol,const T & rtol)133 static ::testing::AssertionResult IsClose(const T& x, const T& y, const T& atol,
134                                           const T& rtol) {
135   // We consider NaNs equal for testing.
136   if (Eigen::numext::isnan(x) && Eigen::numext::isnan(y))
137     return ::testing::AssertionSuccess();
138   if (x == y) return ::testing::AssertionSuccess();  // Handle infinity.
139   auto tolerance = atol + rtol * Eigen::numext::abs(x);
140   if (Eigen::numext::abs(x - y) <= tolerance)
141     return ::testing::AssertionSuccess();
142   return ::testing::AssertionFailure() << x << " not close to " << y;
143 }
144 
145 template <typename T>
IsClose(const std::complex<T> & x,const std::complex<T> & y,const T & atol,const T & rtol)146 static ::testing::AssertionResult IsClose(const std::complex<T>& x,
147                                           const std::complex<T>& y,
148                                           const T& atol, const T& rtol) {
149   if (IsClose(x.real(), y.real(), atol, rtol) &&
150       IsClose(x.imag(), y.imag(), atol, rtol))
151     return ::testing::AssertionSuccess();
152   return ::testing::AssertionFailure() << x << " not close to " << y;
153 }
154 
155 // Return type can be different from T, e.g. float for T=std::complex<float>.
156 template <typename T>
GetTolerance(double tolerance)157 static auto GetTolerance(double tolerance) {
158   using Real = typename Eigen::NumTraits<T>::Real;
159   auto default_tol = static_cast<Real>(5.0) * Eigen::NumTraits<T>::epsilon();
160   auto result = tolerance < 0.0 ? default_tol : static_cast<Real>(tolerance);
161   EXPECT_GE(result, static_cast<Real>(0));
162   return result;
163 }
164 
165 template <typename T>
ExpectClose(const Tensor & x,const Tensor & y,double atol,double rtol)166 static void ExpectClose(const Tensor& x, const Tensor& y, double atol,
167                         double rtol) {
168   auto typed_atol = GetTolerance<T>(atol);
169   auto typed_rtol = GetTolerance<T>(rtol);
170 
171   const T* Tx = x.unaligned_flat<T>().data();
172   const T* Ty = y.unaligned_flat<T>().data();
173   auto size = x.NumElements();
174   int max_failures = 10;
175   int num_failures = 0;
176   for (decltype(size) i = 0; i < size; ++i) {
177     EXPECT_TRUE(IsClose(Tx[i], Ty[i], typed_atol, typed_rtol))
178         << "i = " << (++num_failures, i) << " Tx[i] = " << Tx[i]
179         << " Ty[i] = " << Ty[i];
180     ASSERT_LT(num_failures, max_failures)
181         << "Too many mismatches (atol = " << atol << " rtol = " << rtol
182         << "), giving up.";
183   }
184   EXPECT_EQ(num_failures, 0)
185       << "Mismatches detected (atol = " << atol << " rtol = " << rtol << ").";
186 }
187 
ExpectEqual(const Tensor & x,const Tensor & y,Tolerance t)188 void ExpectEqual(const Tensor& x, const Tensor& y, Tolerance t) {
189   ASSERT_TRUE(IsSameType(x, y));
190   ASSERT_TRUE(IsSameShape(x, y));
191 
192   switch (x.dtype()) {
193     case DT_FLOAT:
194       return ExpectEqual<float>(x, y, t);
195     case DT_DOUBLE:
196       return ExpectEqual<double>(x, y, t);
197     case DT_INT32:
198       return ExpectEqual<int32>(x, y);
199     case DT_UINT32:
200       return ExpectEqual<uint32>(x, y);
201     case DT_UINT16:
202       return ExpectEqual<uint16>(x, y);
203     case DT_UINT8:
204       return ExpectEqual<uint8>(x, y);
205     case DT_INT16:
206       return ExpectEqual<int16>(x, y);
207     case DT_INT8:
208       return ExpectEqual<int8>(x, y);
209     case DT_STRING:
210       return ExpectEqual<tstring>(x, y);
211     case DT_COMPLEX64:
212       return ExpectEqual<complex64>(x, y, t);
213     case DT_COMPLEX128:
214       return ExpectEqual<complex128>(x, y, t);
215     case DT_INT64:
216       return ExpectEqual<int64_t>(x, y);
217     case DT_UINT64:
218       return ExpectEqual<uint64>(x, y);
219     case DT_BOOL:
220       return ExpectEqual<bool>(x, y);
221     case DT_QINT8:
222       return ExpectEqual<qint8>(x, y);
223     case DT_QUINT8:
224       return ExpectEqual<quint8>(x, y);
225     case DT_QINT16:
226       return ExpectEqual<qint16>(x, y);
227     case DT_QUINT16:
228       return ExpectEqual<quint16>(x, y);
229     case DT_QINT32:
230       return ExpectEqual<qint32>(x, y);
231     case DT_BFLOAT16:
232       return ExpectEqual<bfloat16>(x, y, t);
233     case DT_HALF:
234       return ExpectEqual<Eigen::half>(x, y, t);
235     default:
236       EXPECT_TRUE(false) << "Unsupported type : " << DataTypeString(x.dtype());
237   }
238 }
239 
ExpectClose(const Tensor & x,const Tensor & y,double atol,double rtol)240 void ExpectClose(const Tensor& x, const Tensor& y, double atol, double rtol) {
241   ASSERT_TRUE(IsSameType(x, y));
242   ASSERT_TRUE(IsSameShape(x, y));
243 
244   switch (x.dtype()) {
245     case DT_HALF:
246       return ExpectClose<Eigen::half>(x, y, atol, rtol);
247     case DT_BFLOAT16:
248       return ExpectClose<Eigen::bfloat16>(x, y, atol, rtol);
249     case DT_FLOAT:
250       return ExpectClose<float>(x, y, atol, rtol);
251     case DT_DOUBLE:
252       return ExpectClose<double>(x, y, atol, rtol);
253     case DT_COMPLEX64:
254       return ExpectClose<complex64>(x, y, atol, rtol);
255     case DT_COMPLEX128:
256       return ExpectClose<complex128>(x, y, atol, rtol);
257     default:
258       EXPECT_TRUE(false) << "Unsupported type : " << DataTypeString(x.dtype());
259   }
260 }
261 
IsClose(Eigen::half x,Eigen::half y,double atol,double rtol)262 ::testing::AssertionResult internal_test::IsClose(Eigen::half x, Eigen::half y,
263                                                   double atol, double rtol) {
264   return test::IsClose(x, y, GetTolerance<Eigen::half>(atol),
265                        GetTolerance<Eigen::half>(rtol));
266 }
IsClose(float x,float y,double atol,double rtol)267 ::testing::AssertionResult internal_test::IsClose(float x, float y, double atol,
268                                                   double rtol) {
269   return test::IsClose(x, y, GetTolerance<float>(atol),
270                        GetTolerance<float>(rtol));
271 }
IsClose(double x,double y,double atol,double rtol)272 ::testing::AssertionResult internal_test::IsClose(double x, double y,
273                                                   double atol, double rtol) {
274   return test::IsClose(x, y, GetTolerance<double>(atol),
275                        GetTolerance<double>(rtol));
276 }
277 
278 }  // end namespace test
279 }  // end namespace tensorflow
280