xref: /aosp_15_r20/external/pytorch/test/cpp/tensorexpr/test_base.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #if defined(USE_GTEST)
4 #include <gtest/gtest.h>
5 #include <test/cpp/common/support.h>
6 #else
7 #include <cmath>
8 #include "c10/util/Exception.h"
9 #include "test/cpp/tensorexpr/gtest_assert_float_eq.h"
10 #define ASSERT_EQ(x, y, ...) TORCH_INTERNAL_ASSERT((x) == (y), __VA_ARGS__)
11 #define ASSERT_FLOAT_EQ(x, y, ...) \
12   TORCH_INTERNAL_ASSERT(AlmostEquals((x), (y)), __VA_ARGS__)
13 #define ASSERT_NE(x, y, ...) TORCH_INTERNAL_ASSERT((x) != (y), __VA_ARGS__)
14 #define ASSERT_GT(x, y, ...) TORCH_INTERNAL_ASSERT((x) > (y), __VA_ARGS__)
15 #define ASSERT_GE(x, y, ...) TORCH_INTERNAL_ASSERT((x) >= (y), __VA_ARGS__)
16 #define ASSERT_LT(x, y, ...) TORCH_INTERNAL_ASSERT((x) < (y), __VA_ARGS__)
17 #define ASSERT_LE(x, y, ...) TORCH_INTERNAL_ASSERT((x) <= (y), __VA_ARGS__)
18 
19 #define ASSERT_NEAR(x, y, a, ...) \
20   TORCH_INTERNAL_ASSERT(std::fabs((x) - (y)) < (a), __VA_ARGS__)
21 
22 #define ASSERT_TRUE TORCH_INTERNAL_ASSERT
23 #define ASSERT_FALSE(x) ASSERT_TRUE(!(x))
24 #define ASSERT_THROWS_WITH(statement, substring)                         \
25   try {                                                                  \
26     (void)statement;                                                     \
27     ASSERT_TRUE(false);                                                  \
28   } catch (const std::exception& e) {                                    \
29     ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \
30   }
31 #define ASSERT_ANY_THROW(statement)     \
32   {                                     \
33     bool threw = false;                 \
34     try {                               \
35       (void)statement;                  \
36     } catch (const std::exception& e) { \
37       threw = true;                     \
38     }                                   \
39     ASSERT_TRUE(threw);                 \
40   }
41 
42 #endif // defined(USE_GTEST)
43 #include <string>
44 #include <vector>
45 
46 namespace torch {
47 namespace jit {
48 namespace tensorexpr {
49 
50 template <typename U, typename V>
51 void ExpectAllNear(
52     const std::vector<U>& v1,
53     const std::vector<U>& v2,
54     V threshold,
55     const std::string& name = "") {
56   ASSERT_EQ(v1.size(), v2.size());
57   for (size_t i = 0; i < v1.size(); i++) {
58     ASSERT_NEAR(v1[i], v2[i], threshold);
59   }
60 }
61 
62 template <typename U, typename V>
63 void ExpectAllNear(
64     const std::vector<U>& vec,
65     const U& val,
66     V threshold,
67     const std::string& name = "") {
68   for (size_t i = 0; i < vec.size(); i++) {
69     ASSERT_NEAR(vec[i], val, threshold);
70   }
71 }
72 
73 template <typename T>
assertAllEqual(const std::vector<T> & vec,const T & val)74 static void assertAllEqual(const std::vector<T>& vec, const T& val) {
75   for (auto const& elt : vec) {
76     ASSERT_EQ(elt, val);
77   }
78 }
79 
80 template <typename T>
assertAllEqual(const std::vector<T> & v1,const std::vector<T> & v2)81 static void assertAllEqual(const std::vector<T>& v1, const std::vector<T>& v2) {
82   ASSERT_EQ(v1.size(), v2.size());
83   for (size_t i = 0; i < v1.size(); ++i) {
84     ASSERT_EQ(v1[i], v2[i]);
85   }
86 }
87 } // namespace tensorexpr
88 } // namespace jit
89 } // namespace torch
90