1 /* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9 #pragma once 10 11 #include <executorch/runtime/core/exec_aten/exec_aten.h> 12 #include <gmock/gmock.h> // For MATCHER_P 13 14 #include <optional> 15 16 namespace executorch { 17 namespace runtime { 18 namespace testing { 19 20 namespace internal { 21 constexpr double kDefaultRtol = 1e-5; 22 constexpr double kDefaultAtol = 1e-8; 23 // Per 24 // https://en.wikipedia.org/wiki/Half-precision_floating-point_format, 25 // float16 has about 3.3 digits of precision. 26 constexpr double kDefaultHalfAtol = 1e-3; 27 28 // Following similar reasoning to float16, BFloat16 has 29 // math.log10(2**8) = 2.4 digits of precision. 30 constexpr double kDefaultBFloat16Atol = 1e-2; 31 } // namespace internal 32 33 /** 34 * Returns true if the tensors are of the same shape and dtype, and if all 35 * elements are close to each other. 36 * 37 * TODO(T132992348): This function will currently fail an ET_CHECK if the 38 * strides of the tensors are not identical. Add support for comparing 39 * tensors with different strides. 40 * 41 * Note that gtest users can write `EXPECT_THAT(tensor1, IsCloseTo(tensor2))` or 42 * `EXPECT_THAT(tensor1, Not(IsCloseTo(tensor2)))`, or use the helper macros 43 * `EXPECT_TENSOR_CLOSE()` and `EXPECT_TENSOR_NOT_CLOSE()`. 44 * 45 * For exact equality, use `EXPECT_THAT(tensor1, IsEqualTo(tensor2))` or 46 * `EXPECT_THAT(tensor1, Not(IsEqualTo(tensor2)))`, or the helper macros 47 * `EXPECT_TENSOR_EQ()` and `EXPECT_TENSOR_NE()`. 48 * 49 * An element A is close to B when one is true: 50 * 51 * (1) A is equal to B. 52 * (2) A and B are both NaN, are both -infinity, or are both +infinity. 53 * (3) The error abs(A - B) is finite and less than the max error 54 * (atol + abs(rtol * B)). 55 * 56 * If both rtol/atol are zero, this function checks for exact equality. 57 * 58 * NOTE: rtol/atol are ignored and treated as zero for non-floating-point 59 * dtypes. 60 * 61 * @param[in] a The first tensor to compare. 62 * @param[in] b The second tensor to compare. 63 * @param[in] rtol Relative tolerance; see note above. 64 * @param[in] atol Absolute tolerance; see note above. 65 * @retval true All corresponding elements of the two tensors are within the 66 * specified tolerance of each other. 67 * @retval false One or more corresponding elements of the two tensors are 68 * outside of the specified tolerance of each other. 69 */ 70 bool tensors_are_close( 71 const executorch::aten::Tensor& a, 72 const executorch::aten::Tensor& b, 73 double rtol = internal::kDefaultRtol, 74 std::optional<double> opt_atol = std::nullopt); 75 76 /** 77 * Returns true if the tensors are of the same numel and dtype, and if all 78 * elements are close to each other. The tensor shapes do not need to be same. 79 * 80 * Note that gtest users can write `EXPECT_THAT(tensor1, 81 * IsDataCloseTo(tensor2))` or `EXPECT_THAT(tensor1, 82 * Not(IsDataCloseTo(tensor2)))`, or use the helper macros 83 * `EXPECT_TENSOR_DATA_CLOSE()` and `EXPECT_TENSOR_DATA_NOT_CLOSE()`. 84 * 85 * For exact equality, use `EXPECT_THAT(tensor1, IsDataEqualTo(tensor2))` or 86 * `EXPECT_THAT(tensor1, Not(IsDataEqualTo(tensor2)))`, or the helper macros 87 * `EXPECT_TENSOR_DATA_EQ()` and `EXPECT_TENSOR_DATA_NE()`. 88 * 89 * The defination of an element A is close to B is in the comment of the 90 * function `tensors_are_close` 91 * 92 * @param[in] a The first tensor to compare. 93 * @param[in] b The second tensor to compare. 94 * @param[in] rtol Relative tolerance; see note above. 95 * @param[in] atol Absolute tolerance; see note above. 96 * @retval true All corresponding elements of the two tensors are within the 97 * specified tolerance of each other. 98 * @retval false One or more corresponding elements of the two tensors are 99 * outside of the specified tolerance of each other. 100 */ 101 bool tensor_data_is_close( 102 const executorch::aten::Tensor& a, 103 const executorch::aten::Tensor& b, 104 double rtol = internal::kDefaultRtol, 105 std::optional<double> opt_atol = std::nullopt); 106 107 /** 108 * Returns true if the two lists are of the same length, and 109 * tensor_data_is_close(tensors_a[i], tensors_b[i], rtol, atol) is true for all 110 * i. 111 */ 112 bool tensor_lists_are_close( 113 const executorch::aten::Tensor* tensors_a, 114 size_t num_tensors_a, 115 const executorch::aten::Tensor* tensors_b, 116 size_t num_tensors_b, 117 double rtol = internal::kDefaultRtol, 118 std::optional<double> opt_atol = std::nullopt); 119 120 /** 121 * Lets gtest users write `EXPECT_THAT(tensor1, IsCloseTo(tensor2))` or 122 * `EXPECT_THAT(tensor1, Not(IsCloseTo(tensor2)))`. 123 * 124 * See also `EXPECT_TENSOR_CLOSE()` and `EXPECT_TENSOR_NOT_CLOSE()`. 125 */ 126 MATCHER_P(IsCloseTo, other, "") { 127 return tensors_are_close(arg, other); 128 } 129 130 /** 131 * Lets gtest users write 132 * `EXPECT_THAT(tensor1, IsCloseToWithTol(tensor2, rtol, atol))` 133 * or `EXPECT_THAT(tensor1, Not(IsCloseToWithTol(tensor2, rtol, atol)))`. 134 * 135 * See also `EXPECT_TENSOR_CLOSE_WITH_TOL()` and 136 * `EXPECT_TENSOR_NOT_CLOSE_WITH_TOL()`. 137 */ 138 MATCHER_P3(IsCloseToWithTol, other, rtol, atol, "") { 139 return tensors_are_close(arg, other, rtol, atol); 140 } 141 142 /** 143 * Lets gtest users write `EXPECT_THAT(tensor1, IsEqualTo(tensor2))` or 144 * `EXPECT_THAT(tensor1, Not(IsEqualTo(tensor2)))`. 145 * 146 * See also `EXPECT_TENSOR_EQ()` and `EXPECT_TENSOR_NE()`. 147 */ 148 MATCHER_P(IsEqualTo, other, "") { 149 return tensors_are_close(arg, other, /*rtol=*/0, /*atol=*/0); 150 } 151 152 /** 153 * Lets gtest users write `EXPECT_THAT(tensor1, IsDataCloseTo(tensor2))` or 154 * `EXPECT_THAT(tensor1, Not(IsDataCloseTo(tensor2)))`. 155 * 156 * See also `EXPECT_TENSOR_DATA_CLOSE()` and `EXPECT_TENSOR_DATA_NOT_CLOSE()`. 157 */ 158 MATCHER_P(IsDataCloseTo, other, "") { 159 return tensor_data_is_close(arg, other); 160 } 161 162 /** 163 * Lets gtest users write 164 * `EXPECT_THAT(tensor1, IsDataCloseToWithTol(tensor2, rtol, atol))` 165 * or `EXPECT_THAT(tensor1, Not(IsDataCloseToWithTol(tensor2, rtol, atol)))`. 166 * 167 * See also `EXPECT_TENSOR_CLOSE_WITH_TOL()` and 168 * `EXPECT_TENSOR_NOT_CLOSE_WITH_TOL()`. 169 */ 170 MATCHER_P3(IsDataCloseToWithTol, other, rtol, atol, "") { 171 return tensor_data_is_close(arg, other, rtol, atol); 172 } 173 174 /** 175 * Lets gtest users write `EXPECT_THAT(tensor1, IsDataEqualTo(tensor2))` or 176 * `EXPECT_THAT(tensor1, Not(IsDataEqualTo(tensor2)))`. 177 * 178 * See also `EXPECT_TENSOR_DATA_EQ()` and `EXPECT_TENSOR_DATA_NE()`. 179 */ 180 MATCHER_P(IsDataEqualTo, other, "") { 181 return tensor_data_is_close(arg, other, /*rtol=*/0, /*atol=*/0); 182 } 183 184 /** 185 * Lets gtest users write `EXPECT_THAT(tensor_list1, 186 * IsListCloseTo(tensor_list2))` or `EXPECT_THAT(tensor_list1, 187 * Not(IsListCloseTo(tensor_list2)))`. 188 * 189 * The lists can be any container of Tensor that supports ::data() and ::size(). 190 * 191 * See also `EXPECT_TENSOR_LISTS_CLOSE()` and `EXPECT_TENSOR_LISTS_NOT_CLOSE()`. 192 */ 193 MATCHER_P(IsListCloseTo, other, "") { 194 return tensor_lists_are_close( 195 arg.data(), arg.size(), other.data(), other.size()); 196 } 197 198 /** 199 * Lets gtest users write `EXPECT_THAT(tensor_list1, 200 * IsListEqualTo(tensor_list2))` or `EXPECT_THAT(tensor_list1, 201 * Not(IsListEqualTo(tensor_list2)))`. 202 * 203 * The lists can be any container of Tensor that supports ::data() and ::size(). 204 * 205 * See also `EXPECT_TENSOR_LISTS_EQ()` and `EXPECT_TENSOR_LISTS_NE()`. 206 */ 207 MATCHER_P(IsListEqualTo, other, "") { 208 return tensor_lists_are_close( 209 arg.data(), 210 arg.size(), 211 other.data(), 212 other.size(), 213 /*rtol=*/0, 214 /*atol=*/0); 215 } 216 217 /* 218 * NOTE: Although it would be nice to make `EXPECT_EQ(t1, t2)` and friends work, 219 * that would require implementing `bool operator==(Tensor, Tensor)`. 220 * 221 * at::Tensor implements `Tensor operator==(Tensor, Tensor)`, returning an 222 * element-by-element comparison. This causes an ambiguous conflict with the 223 * `bool`-returning operator. 224 */ 225 #define EXPECT_TENSOR_EQ(t1, t2) \ 226 EXPECT_THAT((t1), ::executorch::runtime::testing::IsEqualTo(t2)) 227 #define EXPECT_TENSOR_NE(t1, t2) \ 228 EXPECT_THAT((t1), ::testing::Not(executorch::runtime::testing::IsEqualTo(t2))) 229 #define ASSERT_TENSOR_EQ(t1, t2) \ 230 ASSERT_THAT((t1), ::executorch::runtime::testing::IsEqualTo(t2)) 231 #define ASSERT_TENSOR_NE(t1, t2) \ 232 ASSERT_THAT((t1), ::testing::Not(executorch::runtime::testing::IsEqualTo(t2))) 233 234 #define EXPECT_TENSOR_CLOSE(t1, t2) \ 235 EXPECT_THAT((t1), ::executorch::runtime::testing::IsCloseTo(t2)) 236 #define EXPECT_TENSOR_NOT_CLOSE(t1, t2) \ 237 EXPECT_THAT((t1), ::testing::Not(executorch::runtime::testing::IsCloseTo(t2))) 238 #define ASSERT_TENSOR_CLOSE(t1, t2) \ 239 ASSERT_THAT((t1), ::executorch::runtime::testing::IsCloseTo(t2)) 240 #define ASSERT_TENSOR_NOT_CLOSE(t1, t2) \ 241 ASSERT_THAT((t1), ::testing::Not(executorch::runtime::testing::IsCloseTo(t2))) 242 243 #define EXPECT_TENSOR_CLOSE_WITH_TOL(t1, t2, rtol, atol) \ 244 EXPECT_THAT( \ 245 (t1), ::executorch::runtime::testing::IsCloseToWithTol(t2, rtol, atol)) 246 #define EXPECT_TENSOR_NOT_CLOSE_WITH_TOL(t1, t2, rtol, atol) \ 247 EXPECT_THAT( \ 248 (t1), \ 249 ::testing::Not( \ 250 executorch::runtime::testing::IsCloseToWithTol(t2, rtol, atol))) 251 #define ASSERT_TENSOR_CLOSE_WITH_TOL(t1, t2, rtol, atol) \ 252 ASSERT_THAT( \ 253 (t1), ::executorch::runtime::testing::IsCloseToWithTol(t2, rtol, atol)) 254 #define ASSERT_TENSOR_NOT_CLOSE_WITH_TOL(t1, t2, rtol, atol) \ 255 ASSERT_THAT( \ 256 (t1), \ 257 ::testing::Not( \ 258 executorch::runtime::testing::IsCloseToWithTol(t2, rtol, atol))) 259 260 #define EXPECT_TENSOR_DATA_EQ(t1, t2) \ 261 EXPECT_THAT((t1), ::executorch::runtime::testing::IsDataEqualTo(t2)) 262 #define EXPECT_TENSOR_DATA_NE(t1, t2) \ 263 EXPECT_THAT( \ 264 (t1), ::testing::Not(executorch::runtime::testing::IsDataEqualTo(t2))) 265 #define ASSERT_TENSOR_DATA_EQ(t1, t2) \ 266 ASSERT_THAT((t1), ::executorch::runtime::testing::IsDataEqualTo(t2)) 267 #define ASSERT_TENSOR_DATA_NE(t1, t2) \ 268 ASSERT_THAT( \ 269 (t1), ::testing::Not(executorch::runtime::testing::IsDataEqualTo(t2))) 270 271 #define EXPECT_TENSOR_DATA_CLOSE(t1, t2) \ 272 EXPECT_THAT((t1), ::executorch::runtime::testing::IsDataCloseTo(t2)) 273 #define EXPECT_TENSOR_DATA_NOT_CLOSE(t1, t2) \ 274 EXPECT_THAT( \ 275 (t1), ::testing::Not(executorch::runtime::testing::IsDataCloseTo(t2))) 276 #define ASSERT_TENSOR_DATA_CLOSE(t1, t2) \ 277 ASSERT_THAT((t1), ::executorch::runtime::testing::IsDataCloseTo(t2)) 278 #define ASSERT_TENSOR_DATA_NOT_CLOSE(t1, t2) \ 279 ASSERT_THAT( \ 280 (t1), ::testing::Not(executorch::runtime::testing::IsDataCloseTo(t2))) 281 282 #define EXPECT_TENSOR_DATA_CLOSE_WITH_TOL(t1, t2, rtol, atol) \ 283 EXPECT_THAT( \ 284 (t1), \ 285 ::executorch::runtime::testing::IsDataCloseToWithTol(t2, rtol, atol)) 286 #define EXPECT_TENSOR_DATA_NOT_CLOSE_WITH_TOL(t1, t2, rtol, atol) \ 287 EXPECT_THAT( \ 288 (t1), \ 289 ::testing::Not( \ 290 executorch::runtime::testing::IsDataCloseToWithTol(t2, rtol, atol))) 291 #define ASSERT_TENSOR_DATA_CLOSE_WITH_TOL(t1, t2, rtol, atol) \ 292 ASSERT_THAT( \ 293 (t1), \ 294 ::executorch::runtime::testing::IsDataCloseToWithTol(t2, rtol, atol)) 295 #define ASSERT_TENSOR_DATA_NOT_CLOSE_WITH_TOL(t1, t2, rtol, atol) \ 296 ASSERT_THAT( \ 297 (t1), \ 298 ::testing::Not( \ 299 executorch::runtime::testing::IsDataCloseToWithTol(t2, rtol, atol))) 300 301 /* 302 * Helpers for comparing lists of Tensors. 303 */ 304 305 #define EXPECT_TENSOR_LISTS_EQ(t1, t2) \ 306 EXPECT_THAT((t1), ::executorch::runtime::testing::IsListEqualTo(t2)) 307 #define EXPECT_TENSOR_LISTS_NE(t1, t2) \ 308 EXPECT_THAT( \ 309 (t1), ::testing::Not(executorch::runtime::testing::IsListEqualTo(t2))) 310 #define ASSERT_TENSOR_LISTS_EQ(t1, t2) \ 311 ASSERT_THAT((t1), ::executorch::runtime::testing::IsListEqualTo(t2)) 312 #define ASSERT_TENSOR_LISTS_NE(t1, t2) \ 313 ASSERT_THAT( \ 314 (t1), ::testing::Not(executorch::runtime::testing::IsListEqualTo(t2))) 315 316 #define EXPECT_TENSOR_LISTS_CLOSE(t1, t2) \ 317 EXPECT_THAT((t1), ::executorch::runtime::testing::IsListCloseTo(t2)) 318 #define EXPECT_TENSOR_LISTS_NOT_CLOSE(t1, t2) \ 319 EXPECT_THAT( \ 320 (t1), ::testing::Not(executorch::runtime::testing::IsListCloseTo(t2))) 321 #define ASSERT_TENSOR_LISTS_CLOSE(t1, t2) \ 322 ASSERT_THAT((t1), ::executorch::runtime::testing::IsListCloseTo(t2)) 323 #define ASSERT_TENSOR_LISTS_NOT_CLOSE(t1, t2) \ 324 ASSERT_THAT( \ 325 (t1), ::testing::Not(executorch::runtime::testing::IsListCloseTo(t2))) 326 327 } // namespace testing 328 } // namespace runtime 329 } // namespace executorch 330 331 // ATen already defines operator<<() for Tensor and ScalarType. 332 #ifndef USE_ATEN_LIB 333 334 /* 335 * These functions must be declared in the original namespaces of their 336 * associated types so that C++ can find them. 337 */ 338 namespace executorch { 339 namespace runtime { 340 namespace etensor { 341 342 /** 343 * Prints the ScalarType to the stream as a human-readable string. 344 * 345 * See also executorch::runtime::toString(ScalarType t) in ScalarTypeUtil.h. 346 */ 347 std::ostream& operator<<(std::ostream& os, const ScalarType& t); 348 349 /** 350 * Prints the Tensor to the stream as a human-readable string. 351 */ 352 std::ostream& operator<<(std::ostream& os, const Tensor& t); 353 354 } // namespace etensor 355 } // namespace runtime 356 } // namespace executorch 357 358 #endif // !USE_ATEN_LIB 359 360 namespace torch { 361 namespace executor { 362 namespace testing { 363 // TODO(T197294990): Remove these deprecated aliases once all users have moved 364 // to the new `::executorch` namespaces. 365 using ::executorch::runtime::testing::IsCloseTo; 366 using ::executorch::runtime::testing::IsCloseToWithTol; 367 using ::executorch::runtime::testing::IsDataCloseTo; 368 using ::executorch::runtime::testing::IsDataCloseToWithTol; 369 using ::executorch::runtime::testing::IsDataEqualTo; 370 using ::executorch::runtime::testing::IsEqualTo; 371 using ::executorch::runtime::testing::IsListCloseTo; 372 using ::executorch::runtime::testing::IsListEqualTo; 373 using ::executorch::runtime::testing::tensor_data_is_close; 374 using ::executorch::runtime::testing::tensor_lists_are_close; 375 using ::executorch::runtime::testing::tensors_are_close; 376 } // namespace testing 377 } // namespace executor 378 } // namespace torch 379