xref: /aosp_15_r20/external/executorch/runtime/core/exec_aten/testing_util/tensor_util.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <executorch/runtime/core/exec_aten/exec_aten.h>
12 #include <gmock/gmock.h> // For MATCHER_P
13 
14 #include <optional>
15 
16 namespace executorch {
17 namespace runtime {
18 namespace testing {
19 
20 namespace internal {
21 constexpr double kDefaultRtol = 1e-5;
22 constexpr double kDefaultAtol = 1e-8;
23 // Per
24 // https://en.wikipedia.org/wiki/Half-precision_floating-point_format,
25 // float16 has about 3.3 digits of precision.
26 constexpr double kDefaultHalfAtol = 1e-3;
27 
28 // Following similar reasoning to float16, BFloat16 has
29 // math.log10(2**8) = 2.4 digits of precision.
30 constexpr double kDefaultBFloat16Atol = 1e-2;
31 } // namespace internal
32 
33 /**
34  *  Returns true if the tensors are of the same shape and dtype, and if all
35  * elements are close to each other.
36  *
37  * TODO(T132992348): This function will currently fail an ET_CHECK if the
38  * strides of the tensors are not identical. Add support for comparing
39  * tensors with different strides.
40  *
41  * Note that gtest users can write `EXPECT_THAT(tensor1, IsCloseTo(tensor2))` or
42  * `EXPECT_THAT(tensor1, Not(IsCloseTo(tensor2)))`, or use the helper macros
43  * `EXPECT_TENSOR_CLOSE()` and `EXPECT_TENSOR_NOT_CLOSE()`.
44  *
45  * For exact equality, use `EXPECT_THAT(tensor1, IsEqualTo(tensor2))` or
46  * `EXPECT_THAT(tensor1, Not(IsEqualTo(tensor2)))`, or the helper macros
47  * `EXPECT_TENSOR_EQ()` and `EXPECT_TENSOR_NE()`.
48  *
49  * An element A is close to B when one is true:
50  *
51  * (1) A is equal to B.
52  * (2) A and B are both NaN, are both -infinity, or are both +infinity.
53  * (3) The error abs(A - B) is finite and less than the max error
54  *     (atol + abs(rtol * B)).
55  *
56  * If both rtol/atol are zero, this function checks for exact equality.
57  *
58  * NOTE: rtol/atol are ignored and treated as zero for non-floating-point
59  * dtypes.
60  *
61  * @param[in] a The first tensor to compare.
62  * @param[in] b The second tensor to compare.
63  * @param[in] rtol Relative tolerance; see note above.
64  * @param[in] atol Absolute tolerance; see note above.
65  * @retval true All corresponding elements of the two tensors are within the
66  *     specified tolerance of each other.
67  * @retval false One or more corresponding elements of the two tensors are
68  *     outside of the specified tolerance of each other.
69  */
70 bool tensors_are_close(
71     const executorch::aten::Tensor& a,
72     const executorch::aten::Tensor& b,
73     double rtol = internal::kDefaultRtol,
74     std::optional<double> opt_atol = std::nullopt);
75 
76 /**
77  * Returns true if the tensors are of the same numel and dtype, and if all
78  * elements are close to each other. The tensor shapes do not need to be same.
79  *
80  * Note that gtest users can write `EXPECT_THAT(tensor1,
81  * IsDataCloseTo(tensor2))` or `EXPECT_THAT(tensor1,
82  * Not(IsDataCloseTo(tensor2)))`, or use the helper macros
83  * `EXPECT_TENSOR_DATA_CLOSE()` and `EXPECT_TENSOR_DATA_NOT_CLOSE()`.
84  *
85  * For exact equality, use `EXPECT_THAT(tensor1, IsDataEqualTo(tensor2))` or
86  * `EXPECT_THAT(tensor1, Not(IsDataEqualTo(tensor2)))`, or the helper macros
87  * `EXPECT_TENSOR_DATA_EQ()` and `EXPECT_TENSOR_DATA_NE()`.
88  *
89  * The defination of an element A is close to B is in the comment of the
90  * function `tensors_are_close`
91  *
92  * @param[in] a The first tensor to compare.
93  * @param[in] b The second tensor to compare.
94  * @param[in] rtol Relative tolerance; see note above.
95  * @param[in] atol Absolute tolerance; see note above.
96  * @retval true All corresponding elements of the two tensors are within the
97  *     specified tolerance of each other.
98  * @retval false One or more corresponding elements of the two tensors are
99  *     outside of the specified tolerance of each other.
100  */
101 bool tensor_data_is_close(
102     const executorch::aten::Tensor& a,
103     const executorch::aten::Tensor& b,
104     double rtol = internal::kDefaultRtol,
105     std::optional<double> opt_atol = std::nullopt);
106 
107 /**
108  * Returns true if the two lists are of the same length, and
109  * tensor_data_is_close(tensors_a[i], tensors_b[i], rtol, atol) is true for all
110  * i.
111  */
112 bool tensor_lists_are_close(
113     const executorch::aten::Tensor* tensors_a,
114     size_t num_tensors_a,
115     const executorch::aten::Tensor* tensors_b,
116     size_t num_tensors_b,
117     double rtol = internal::kDefaultRtol,
118     std::optional<double> opt_atol = std::nullopt);
119 
120 /**
121  * Lets gtest users write `EXPECT_THAT(tensor1, IsCloseTo(tensor2))` or
122  * `EXPECT_THAT(tensor1, Not(IsCloseTo(tensor2)))`.
123  *
124  * See also `EXPECT_TENSOR_CLOSE()` and `EXPECT_TENSOR_NOT_CLOSE()`.
125  */
126 MATCHER_P(IsCloseTo, other, "") {
127   return tensors_are_close(arg, other);
128 }
129 
130 /**
131  * Lets gtest users write
132  * `EXPECT_THAT(tensor1, IsCloseToWithTol(tensor2, rtol, atol))`
133  * or `EXPECT_THAT(tensor1, Not(IsCloseToWithTol(tensor2, rtol, atol)))`.
134  *
135  * See also `EXPECT_TENSOR_CLOSE_WITH_TOL()` and
136  * `EXPECT_TENSOR_NOT_CLOSE_WITH_TOL()`.
137  */
138 MATCHER_P3(IsCloseToWithTol, other, rtol, atol, "") {
139   return tensors_are_close(arg, other, rtol, atol);
140 }
141 
142 /**
143  * Lets gtest users write `EXPECT_THAT(tensor1, IsEqualTo(tensor2))` or
144  * `EXPECT_THAT(tensor1, Not(IsEqualTo(tensor2)))`.
145  *
146  * See also `EXPECT_TENSOR_EQ()` and `EXPECT_TENSOR_NE()`.
147  */
148 MATCHER_P(IsEqualTo, other, "") {
149   return tensors_are_close(arg, other, /*rtol=*/0, /*atol=*/0);
150 }
151 
152 /**
153  * Lets gtest users write `EXPECT_THAT(tensor1, IsDataCloseTo(tensor2))` or
154  * `EXPECT_THAT(tensor1, Not(IsDataCloseTo(tensor2)))`.
155  *
156  * See also `EXPECT_TENSOR_DATA_CLOSE()` and `EXPECT_TENSOR_DATA_NOT_CLOSE()`.
157  */
158 MATCHER_P(IsDataCloseTo, other, "") {
159   return tensor_data_is_close(arg, other);
160 }
161 
162 /**
163  * Lets gtest users write
164  * `EXPECT_THAT(tensor1, IsDataCloseToWithTol(tensor2, rtol, atol))`
165  * or `EXPECT_THAT(tensor1, Not(IsDataCloseToWithTol(tensor2, rtol, atol)))`.
166  *
167  * See also `EXPECT_TENSOR_CLOSE_WITH_TOL()` and
168  * `EXPECT_TENSOR_NOT_CLOSE_WITH_TOL()`.
169  */
170 MATCHER_P3(IsDataCloseToWithTol, other, rtol, atol, "") {
171   return tensor_data_is_close(arg, other, rtol, atol);
172 }
173 
174 /**
175  * Lets gtest users write `EXPECT_THAT(tensor1, IsDataEqualTo(tensor2))` or
176  * `EXPECT_THAT(tensor1, Not(IsDataEqualTo(tensor2)))`.
177  *
178  * See also `EXPECT_TENSOR_DATA_EQ()` and `EXPECT_TENSOR_DATA_NE()`.
179  */
180 MATCHER_P(IsDataEqualTo, other, "") {
181   return tensor_data_is_close(arg, other, /*rtol=*/0, /*atol=*/0);
182 }
183 
184 /**
185  * Lets gtest users write `EXPECT_THAT(tensor_list1,
186  * IsListCloseTo(tensor_list2))` or `EXPECT_THAT(tensor_list1,
187  * Not(IsListCloseTo(tensor_list2)))`.
188  *
189  * The lists can be any container of Tensor that supports ::data() and ::size().
190  *
191  * See also `EXPECT_TENSOR_LISTS_CLOSE()` and `EXPECT_TENSOR_LISTS_NOT_CLOSE()`.
192  */
193 MATCHER_P(IsListCloseTo, other, "") {
194   return tensor_lists_are_close(
195       arg.data(), arg.size(), other.data(), other.size());
196 }
197 
198 /**
199  * Lets gtest users write `EXPECT_THAT(tensor_list1,
200  * IsListEqualTo(tensor_list2))` or `EXPECT_THAT(tensor_list1,
201  * Not(IsListEqualTo(tensor_list2)))`.
202  *
203  * The lists can be any container of Tensor that supports ::data() and ::size().
204  *
205  * See also `EXPECT_TENSOR_LISTS_EQ()` and `EXPECT_TENSOR_LISTS_NE()`.
206  */
207 MATCHER_P(IsListEqualTo, other, "") {
208   return tensor_lists_are_close(
209       arg.data(),
210       arg.size(),
211       other.data(),
212       other.size(),
213       /*rtol=*/0,
214       /*atol=*/0);
215 }
216 
217 /*
218  * NOTE: Although it would be nice to make `EXPECT_EQ(t1, t2)` and friends work,
219  * that would require implementing `bool operator==(Tensor, Tensor)`.
220  *
221  * at::Tensor implements `Tensor operator==(Tensor, Tensor)`, returning an
222  * element-by-element comparison. This causes an ambiguous conflict with the
223  * `bool`-returning operator.
224  */
225 #define EXPECT_TENSOR_EQ(t1, t2) \
226   EXPECT_THAT((t1), ::executorch::runtime::testing::IsEqualTo(t2))
227 #define EXPECT_TENSOR_NE(t1, t2) \
228   EXPECT_THAT((t1), ::testing::Not(executorch::runtime::testing::IsEqualTo(t2)))
229 #define ASSERT_TENSOR_EQ(t1, t2) \
230   ASSERT_THAT((t1), ::executorch::runtime::testing::IsEqualTo(t2))
231 #define ASSERT_TENSOR_NE(t1, t2) \
232   ASSERT_THAT((t1), ::testing::Not(executorch::runtime::testing::IsEqualTo(t2)))
233 
234 #define EXPECT_TENSOR_CLOSE(t1, t2) \
235   EXPECT_THAT((t1), ::executorch::runtime::testing::IsCloseTo(t2))
236 #define EXPECT_TENSOR_NOT_CLOSE(t1, t2) \
237   EXPECT_THAT((t1), ::testing::Not(executorch::runtime::testing::IsCloseTo(t2)))
238 #define ASSERT_TENSOR_CLOSE(t1, t2) \
239   ASSERT_THAT((t1), ::executorch::runtime::testing::IsCloseTo(t2))
240 #define ASSERT_TENSOR_NOT_CLOSE(t1, t2) \
241   ASSERT_THAT((t1), ::testing::Not(executorch::runtime::testing::IsCloseTo(t2)))
242 
243 #define EXPECT_TENSOR_CLOSE_WITH_TOL(t1, t2, rtol, atol) \
244   EXPECT_THAT(                                           \
245       (t1), ::executorch::runtime::testing::IsCloseToWithTol(t2, rtol, atol))
246 #define EXPECT_TENSOR_NOT_CLOSE_WITH_TOL(t1, t2, rtol, atol) \
247   EXPECT_THAT(                                               \
248       (t1),                                                  \
249       ::testing::Not(                                        \
250           executorch::runtime::testing::IsCloseToWithTol(t2, rtol, atol)))
251 #define ASSERT_TENSOR_CLOSE_WITH_TOL(t1, t2, rtol, atol) \
252   ASSERT_THAT(                                           \
253       (t1), ::executorch::runtime::testing::IsCloseToWithTol(t2, rtol, atol))
254 #define ASSERT_TENSOR_NOT_CLOSE_WITH_TOL(t1, t2, rtol, atol) \
255   ASSERT_THAT(                                               \
256       (t1),                                                  \
257       ::testing::Not(                                        \
258           executorch::runtime::testing::IsCloseToWithTol(t2, rtol, atol)))
259 
260 #define EXPECT_TENSOR_DATA_EQ(t1, t2) \
261   EXPECT_THAT((t1), ::executorch::runtime::testing::IsDataEqualTo(t2))
262 #define EXPECT_TENSOR_DATA_NE(t1, t2) \
263   EXPECT_THAT(                        \
264       (t1), ::testing::Not(executorch::runtime::testing::IsDataEqualTo(t2)))
265 #define ASSERT_TENSOR_DATA_EQ(t1, t2) \
266   ASSERT_THAT((t1), ::executorch::runtime::testing::IsDataEqualTo(t2))
267 #define ASSERT_TENSOR_DATA_NE(t1, t2) \
268   ASSERT_THAT(                        \
269       (t1), ::testing::Not(executorch::runtime::testing::IsDataEqualTo(t2)))
270 
271 #define EXPECT_TENSOR_DATA_CLOSE(t1, t2) \
272   EXPECT_THAT((t1), ::executorch::runtime::testing::IsDataCloseTo(t2))
273 #define EXPECT_TENSOR_DATA_NOT_CLOSE(t1, t2) \
274   EXPECT_THAT(                               \
275       (t1), ::testing::Not(executorch::runtime::testing::IsDataCloseTo(t2)))
276 #define ASSERT_TENSOR_DATA_CLOSE(t1, t2) \
277   ASSERT_THAT((t1), ::executorch::runtime::testing::IsDataCloseTo(t2))
278 #define ASSERT_TENSOR_DATA_NOT_CLOSE(t1, t2) \
279   ASSERT_THAT(                               \
280       (t1), ::testing::Not(executorch::runtime::testing::IsDataCloseTo(t2)))
281 
282 #define EXPECT_TENSOR_DATA_CLOSE_WITH_TOL(t1, t2, rtol, atol) \
283   EXPECT_THAT(                                                \
284       (t1),                                                   \
285       ::executorch::runtime::testing::IsDataCloseToWithTol(t2, rtol, atol))
286 #define EXPECT_TENSOR_DATA_NOT_CLOSE_WITH_TOL(t1, t2, rtol, atol) \
287   EXPECT_THAT(                                                    \
288       (t1),                                                       \
289       ::testing::Not(                                             \
290           executorch::runtime::testing::IsDataCloseToWithTol(t2, rtol, atol)))
291 #define ASSERT_TENSOR_DATA_CLOSE_WITH_TOL(t1, t2, rtol, atol) \
292   ASSERT_THAT(                                                \
293       (t1),                                                   \
294       ::executorch::runtime::testing::IsDataCloseToWithTol(t2, rtol, atol))
295 #define ASSERT_TENSOR_DATA_NOT_CLOSE_WITH_TOL(t1, t2, rtol, atol) \
296   ASSERT_THAT(                                                    \
297       (t1),                                                       \
298       ::testing::Not(                                             \
299           executorch::runtime::testing::IsDataCloseToWithTol(t2, rtol, atol)))
300 
301 /*
302  * Helpers for comparing lists of Tensors.
303  */
304 
305 #define EXPECT_TENSOR_LISTS_EQ(t1, t2) \
306   EXPECT_THAT((t1), ::executorch::runtime::testing::IsListEqualTo(t2))
307 #define EXPECT_TENSOR_LISTS_NE(t1, t2) \
308   EXPECT_THAT(                         \
309       (t1), ::testing::Not(executorch::runtime::testing::IsListEqualTo(t2)))
310 #define ASSERT_TENSOR_LISTS_EQ(t1, t2) \
311   ASSERT_THAT((t1), ::executorch::runtime::testing::IsListEqualTo(t2))
312 #define ASSERT_TENSOR_LISTS_NE(t1, t2) \
313   ASSERT_THAT(                         \
314       (t1), ::testing::Not(executorch::runtime::testing::IsListEqualTo(t2)))
315 
316 #define EXPECT_TENSOR_LISTS_CLOSE(t1, t2) \
317   EXPECT_THAT((t1), ::executorch::runtime::testing::IsListCloseTo(t2))
318 #define EXPECT_TENSOR_LISTS_NOT_CLOSE(t1, t2) \
319   EXPECT_THAT(                                \
320       (t1), ::testing::Not(executorch::runtime::testing::IsListCloseTo(t2)))
321 #define ASSERT_TENSOR_LISTS_CLOSE(t1, t2) \
322   ASSERT_THAT((t1), ::executorch::runtime::testing::IsListCloseTo(t2))
323 #define ASSERT_TENSOR_LISTS_NOT_CLOSE(t1, t2) \
324   ASSERT_THAT(                                \
325       (t1), ::testing::Not(executorch::runtime::testing::IsListCloseTo(t2)))
326 
327 } // namespace testing
328 } // namespace runtime
329 } // namespace executorch
330 
331 // ATen already defines operator<<() for Tensor and ScalarType.
332 #ifndef USE_ATEN_LIB
333 
334 /*
335  * These functions must be declared in the original namespaces of their
336  * associated types so that C++ can find them.
337  */
338 namespace executorch {
339 namespace runtime {
340 namespace etensor {
341 
342 /**
343  * Prints the ScalarType to the stream as a human-readable string.
344  *
345  * See also executorch::runtime::toString(ScalarType t) in ScalarTypeUtil.h.
346  */
347 std::ostream& operator<<(std::ostream& os, const ScalarType& t);
348 
349 /**
350  * Prints the Tensor to the stream as a human-readable string.
351  */
352 std::ostream& operator<<(std::ostream& os, const Tensor& t);
353 
354 } // namespace etensor
355 } // namespace runtime
356 } // namespace executorch
357 
358 #endif // !USE_ATEN_LIB
359 
360 namespace torch {
361 namespace executor {
362 namespace testing {
363 // TODO(T197294990): Remove these deprecated aliases once all users have moved
364 // to the new `::executorch` namespaces.
365 using ::executorch::runtime::testing::IsCloseTo;
366 using ::executorch::runtime::testing::IsCloseToWithTol;
367 using ::executorch::runtime::testing::IsDataCloseTo;
368 using ::executorch::runtime::testing::IsDataCloseToWithTol;
369 using ::executorch::runtime::testing::IsDataEqualTo;
370 using ::executorch::runtime::testing::IsEqualTo;
371 using ::executorch::runtime::testing::IsListCloseTo;
372 using ::executorch::runtime::testing::IsListEqualTo;
373 using ::executorch::runtime::testing::tensor_data_is_close;
374 using ::executorch::runtime::testing::tensor_lists_are_close;
375 using ::executorch::runtime::testing::tensors_are_close;
376 } // namespace testing
377 } // namespace executor
378 } // namespace torch
379