xref: /aosp_15_r20/external/executorch/runtime/core/exec_aten/testing_util/tensor_util.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <cmath>
10 #include <cstring>
11 #include <ostream>
12 
13 #include <executorch/runtime/core/exec_aten/exec_aten.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
15 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
16 #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
17 #include <executorch/runtime/core/exec_aten/util/tensor_util.h>
18 
19 using exec_aten::BFloat16;
20 using exec_aten::Half;
21 using exec_aten::ScalarType;
22 using exec_aten::Tensor;
23 
24 namespace executorch {
25 namespace runtime {
26 namespace testing {
27 
28 namespace {
29 
30 /**
31  * Returns true if the two arrays are close according to the description on
32  * `tensors_are_close()`.
33  *
34  * T must be a floating point type. Non-floating point data should be compared
35  * directly.
36  */
37 template <typename T>
data_is_close(const T * a,const T * b,size_t numel,double rtol,double atol)38 bool data_is_close(
39     const T* a,
40     const T* b,
41     size_t numel,
42     double rtol,
43     double atol) {
44   ET_CHECK_MSG(
45       numel == 0 || (a != nullptr && b != nullptr),
46       "Pointers must not be null when numel > 0: numel %zu, a 0x%p, b 0x%p",
47       numel,
48       a,
49       b);
50   if (a == b) {
51     return true;
52   }
53   for (size_t i = 0; i < numel; i++) {
54     const auto ai = a[i];
55     const auto bi = b[i];
56 
57     if (std::isnan(ai) && std::isnan(bi)) {
58       // NaN == NaN
59     } else if (
60         !std::isfinite(ai) && !std::isfinite(bi) && ((ai > 0) == (bi > 0))) {
61       // -Inf == -Inf
62       // +Inf == +Inf
63     } else if (rtol == 0 && atol == 0) {
64       // Exact comparison; avoid unnecessary math.
65       if (ai != bi) {
66         return false;
67       }
68     } else {
69       auto allowed_error = atol + std::abs(rtol * bi);
70       auto actual_error = std::abs(ai - bi);
71       if (!std::isfinite(actual_error) || actual_error > allowed_error) {
72         return false;
73       }
74     }
75   }
76   return true;
77 }
78 
default_atol_for_type(ScalarType t)79 double default_atol_for_type(ScalarType t) {
80   if (t == ScalarType::Half) {
81     return internal::kDefaultHalfAtol;
82   }
83   return internal::kDefaultAtol;
84 }
85 } // namespace
86 
tensors_are_close(const Tensor & a,const Tensor & b,double rtol,std::optional<double> opt_atol)87 bool tensors_are_close(
88     const Tensor& a,
89     const Tensor& b,
90     double rtol,
91     std::optional<double> opt_atol) {
92   if (a.scalar_type() != b.scalar_type() || a.sizes() != b.sizes()) {
93     return false;
94   }
95 
96   // TODO(T132992348): support comparison between tensors of different strides
97   ET_CHECK_MSG(
98       a.strides() == b.strides(),
99       "The two inputs of `tensors_are_close` function shall have same strides");
100 
101   // Since the two tensors have same shape and strides, any two elements that
102   // share same index from underlying data perspective will also share same
103   // index from tensor perspective, whatever the size and strides really are.
104   // e.g. if a[i_1, i_2, ... i_n] = a.const_data_ptr()[m], we can assert
105   // b[i_1, i_2, ... i_n] = b.const_data_ptr()[m])
106   // So we can just compare the two underlying data sequentially to figure out
107   // if the two tensors are same.
108 
109   double atol = opt_atol.value_or(default_atol_for_type(a.scalar_type()));
110 
111   if (a.nbytes() == 0) {
112     // Note that this case is important. It's valid for a zero-size tensor to
113     // have a null data pointer, but in some environments it's invalid to pass a
114     // null pointer to memcmp() even when the size is zero.
115     return true;
116   } else if (a.scalar_type() == ScalarType::Float) {
117     return data_is_close<float>(
118         a.const_data_ptr<float>(),
119         b.const_data_ptr<float>(),
120         a.numel(),
121         rtol,
122         atol);
123   } else if (a.scalar_type() == ScalarType::Double) {
124     return data_is_close<double>(
125         a.const_data_ptr<double>(),
126         b.const_data_ptr<double>(),
127         a.numel(),
128         rtol,
129         atol);
130   } else if (a.scalar_type() == ScalarType::Half) {
131     return data_is_close<Half>(
132         a.const_data_ptr<Half>(),
133         b.const_data_ptr<Half>(),
134         a.numel(),
135         rtol,
136         atol);
137   } else if (a.scalar_type() == ScalarType::BFloat16) {
138     return data_is_close<BFloat16>(
139         a.const_data_ptr<BFloat16>(),
140         b.const_data_ptr<BFloat16>(),
141         a.numel(),
142         rtol,
143         atol);
144   } else {
145     // Non-floating-point types can be compared bitwise.
146     return memcmp(a.const_data_ptr(), b.const_data_ptr(), a.nbytes()) == 0;
147   }
148 }
149 
150 /**
151  * Asserts that the provided tensors have the same sequence of close
152  * underlying data elements and same numel. Note that this function is mainly
153  * about comparing underlying data between two tensors, not relevant with how
154  * tensor interpret the underlying data.
155  */
tensor_data_is_close(const Tensor & a,const Tensor & b,double rtol,std::optional<double> opt_atol)156 bool tensor_data_is_close(
157     const Tensor& a,
158     const Tensor& b,
159     double rtol,
160     std::optional<double> opt_atol) {
161   if (a.scalar_type() != b.scalar_type() || a.numel() != b.numel()) {
162     return false;
163   }
164 
165   double atol = opt_atol.value_or(default_atol_for_type(a.scalar_type()));
166   if (a.nbytes() == 0) {
167     // Note that this case is important. It's valid for a zero-size tensor to
168     // have a null data pointer, but in some environments it's invalid to pass a
169     // null pointer to memcmp() even when the size is zero.
170     return true;
171   } else if (a.scalar_type() == ScalarType::Float) {
172     return data_is_close<float>(
173         a.const_data_ptr<float>(),
174         b.const_data_ptr<float>(),
175         a.numel(),
176         rtol,
177         atol);
178   } else if (a.scalar_type() == ScalarType::Double) {
179     return data_is_close<double>(
180         a.const_data_ptr<double>(),
181         b.const_data_ptr<double>(),
182         a.numel(),
183         rtol,
184         atol);
185   } else {
186     // Non-floating-point types can be compared bitwise.
187     return memcmp(a.const_data_ptr(), b.const_data_ptr(), a.nbytes()) == 0;
188   }
189 }
190 
tensor_lists_are_close(const exec_aten::Tensor * tensors_a,size_t num_tensors_a,const exec_aten::Tensor * tensors_b,size_t num_tensors_b,double rtol,std::optional<double> opt_atol)191 bool tensor_lists_are_close(
192     const exec_aten::Tensor* tensors_a,
193     size_t num_tensors_a,
194     const exec_aten::Tensor* tensors_b,
195     size_t num_tensors_b,
196     double rtol,
197     std::optional<double> opt_atol) {
198   if (num_tensors_a != num_tensors_b) {
199     return false;
200   }
201   for (size_t i = 0; i < num_tensors_a; i++) {
202     if (!tensors_are_close(tensors_a[i], tensors_b[i], rtol, opt_atol)) {
203       return false;
204     }
205   }
206   return true;
207 }
208 
209 } // namespace testing
210 } // namespace runtime
211 } // namespace executorch
212 
213 // ATen already defines operator<<() for Tensor and ScalarType.
214 #ifndef USE_ATEN_LIB
215 
216 /*
217  * These functions must be declared in the original namespaces of their
218  * associated types so that C++ can find them.
219  */
220 namespace executorch {
221 namespace runtime {
222 namespace etensor {
223 
224 /**
225  * Prints the ScalarType to the stream as a human-readable string.
226  */
operator <<(std::ostream & os,const ScalarType & t)227 std::ostream& operator<<(std::ostream& os, const ScalarType& t) {
228   const char* s = torch::executor::toString(t);
229   if (std::strcmp(s, "UNKNOWN_SCALAR") == 0) {
230     return os << "Unknown(" << static_cast<int32_t>(t) << ")";
231   } else {
232     return os << s;
233   }
234 }
235 
236 namespace {
237 
238 /**
239  * Prints the elements of `data` to the stream as comma-separated strings.
240  */
241 template <typename T>
print_data(std::ostream & os,const T * data,size_t numel)242 std::ostream& print_data(std::ostream& os, const T* data, size_t numel) {
243   // TODO(dbort): Make this smarter: show dimensions, listen to strides,
244   // break up or truncate data when it's huge
245   for (auto i = 0; i < numel; i++) {
246     os << data[i];
247     if (i < numel - 1) {
248       os << ", ";
249     }
250   }
251   return os;
252 }
253 
254 /**
255  * Prints the elements of `data` to the stream as comma-separated strings.
256  *
257  * Specialization for byte tensors as c++ default prints them as chars where as
258  * debugging is typically easier with numbers here (tensors dont store string
259  * data)
260  */
261 template <>
print_data(std::ostream & os,const uint8_t * data,size_t numel)262 std::ostream& print_data(std::ostream& os, const uint8_t* data, size_t numel) {
263   // TODO(dbort): Make this smarter: show dimensions, listen to strides,
264   // break up or truncate data when it's huge
265   for (auto i = 0; i < numel; i++) {
266     os << (uint64_t)data[i];
267     if (i < numel - 1) {
268       os << ", ";
269     }
270   }
271   return os;
272 }
273 
274 } // namespace
275 
276 /**
277  * Prints the Tensor to the stream as a human-readable string.
278  */
operator <<(std::ostream & os,const Tensor & t)279 std::ostream& operator<<(std::ostream& os, const Tensor& t) {
280   os << "ETensor(sizes={";
281   for (auto dim = 0; dim < t.dim(); dim++) {
282     os << t.size(dim);
283     if (dim < t.dim() - 1) {
284       os << ", ";
285     }
286   }
287   os << "}, dtype=" << t.scalar_type() << ", data={";
288 
289   // Map from the ScalarType to the C type.
290 #define PRINT_CASE(ctype, stype)                          \
291   case ScalarType::stype:                                 \
292     print_data(os, t.const_data_ptr<ctype>(), t.numel()); \
293     break;
294 
295   switch (t.scalar_type()) {
296     ET_FORALL_REAL_TYPES_AND3(Half, Bool, BFloat16, PRINT_CASE)
297     default:
298       ET_CHECK_MSG(
299           false,
300           "Unhandled dtype %s",
301           torch::executor::toString(t.scalar_type()));
302   }
303 
304 #undef PRINT_CASE
305 
306   os << "})";
307 
308   return os;
309 }
310 
311 } // namespace etensor
312 } // namespace runtime
313 } // namespace executorch
314 
315 #endif // !USE_ATEN_LIB
316