1 #pragma once
2
3 #if defined(USE_GTEST)
4 #include <gtest/gtest.h>
5 #include <test/cpp/common/support.h>
6 #else
7 #include <cmath>
8 #include "c10/util/Exception.h"
9 #include "test/cpp/tensorexpr/gtest_assert_float_eq.h"
10 #define ASSERT_EQ(x, y, ...) TORCH_INTERNAL_ASSERT((x) == (y), __VA_ARGS__)
11 #define ASSERT_FLOAT_EQ(x, y, ...) \
12 TORCH_INTERNAL_ASSERT(AlmostEquals((x), (y)), __VA_ARGS__)
13 #define ASSERT_NE(x, y, ...) TORCH_INTERNAL_ASSERT((x) != (y), __VA_ARGS__)
14 #define ASSERT_GT(x, y, ...) TORCH_INTERNAL_ASSERT((x) > (y), __VA_ARGS__)
15 #define ASSERT_GE(x, y, ...) TORCH_INTERNAL_ASSERT((x) >= (y), __VA_ARGS__)
16 #define ASSERT_LT(x, y, ...) TORCH_INTERNAL_ASSERT((x) < (y), __VA_ARGS__)
17 #define ASSERT_LE(x, y, ...) TORCH_INTERNAL_ASSERT((x) <= (y), __VA_ARGS__)
18
19 #define ASSERT_NEAR(x, y, a, ...) \
20 TORCH_INTERNAL_ASSERT(std::fabs((x) - (y)) < (a), __VA_ARGS__)
21
22 #define ASSERT_TRUE TORCH_INTERNAL_ASSERT
23 #define ASSERT_FALSE(x) ASSERT_TRUE(!(x))
24 #define ASSERT_THROWS_WITH(statement, substring) \
25 try { \
26 (void)statement; \
27 ASSERT_TRUE(false); \
28 } catch (const std::exception& e) { \
29 ASSERT_NE(std::string(e.what()).find(substring), std::string::npos); \
30 }
31 #define ASSERT_ANY_THROW(statement) \
32 { \
33 bool threw = false; \
34 try { \
35 (void)statement; \
36 } catch (const std::exception& e) { \
37 threw = true; \
38 } \
39 ASSERT_TRUE(threw); \
40 }
41
42 #endif // defined(USE_GTEST)
43 #include <string>
44 #include <vector>
45
46 namespace torch {
47 namespace jit {
48 namespace tensorexpr {
49
50 template <typename U, typename V>
51 void ExpectAllNear(
52 const std::vector<U>& v1,
53 const std::vector<U>& v2,
54 V threshold,
55 const std::string& name = "") {
56 ASSERT_EQ(v1.size(), v2.size());
57 for (size_t i = 0; i < v1.size(); i++) {
58 ASSERT_NEAR(v1[i], v2[i], threshold);
59 }
60 }
61
62 template <typename U, typename V>
63 void ExpectAllNear(
64 const std::vector<U>& vec,
65 const U& val,
66 V threshold,
67 const std::string& name = "") {
68 for (size_t i = 0; i < vec.size(); i++) {
69 ASSERT_NEAR(vec[i], val, threshold);
70 }
71 }
72
73 template <typename T>
assertAllEqual(const std::vector<T> & vec,const T & val)74 static void assertAllEqual(const std::vector<T>& vec, const T& val) {
75 for (auto const& elt : vec) {
76 ASSERT_EQ(elt, val);
77 }
78 }
79
80 template <typename T>
assertAllEqual(const std::vector<T> & v1,const std::vector<T> & v2)81 static void assertAllEqual(const std::vector<T>& v1, const std::vector<T>& v2) {
82 ASSERT_EQ(v1.size(), v2.size());
83 for (size_t i = 0; i < v1.size(); ++i) {
84 ASSERT_EQ(v1[i], v2[i]);
85 }
86 }
87 } // namespace tensorexpr
88 } // namespace jit
89 } // namespace torch
90