1 #pragma once
2
3 #include <test/cpp/common/support.h>
4
5 #include <gtest/gtest.h>
6
7 #include <ATen/TensorIndexing.h>
8 #include <c10/util/Exception.h>
9 #include <torch/nn/cloneable.h>
10 #include <torch/types.h>
11 #include <torch/utils.h>
12
13 #include <string>
14 #include <utility>
15
16 namespace torch {
17 namespace test {
18
19 // Lets you use a container without making a new class,
20 // for experimental implementations
21 class SimpleContainer : public nn::Cloneable<SimpleContainer> {
22 public:
reset()23 void reset() override {}
24
25 template <typename ModuleHolder>
26 ModuleHolder add(
27 ModuleHolder module_holder,
28 std::string name = std::string()) {
29 return Module::register_module(std::move(name), module_holder);
30 }
31 };
32
33 struct SeedingFixture : public ::testing::Test {
SeedingFixtureSeedingFixture34 SeedingFixture() {
35 torch::manual_seed(0);
36 }
37 };
38
39 struct WarningCapture : public WarningHandler {
WarningCaptureWarningCapture40 WarningCapture() : prev_(WarningUtils::get_warning_handler()) {
41 WarningUtils::set_warning_handler(this);
42 }
43
~WarningCaptureWarningCapture44 ~WarningCapture() override {
45 WarningUtils::set_warning_handler(prev_);
46 }
47
messagesWarningCapture48 const std::vector<std::string>& messages() {
49 return messages_;
50 }
51
strWarningCapture52 std::string str() {
53 return c10::Join("\n", messages_);
54 }
55
processWarningCapture56 void process(const c10::Warning& warning) override {
57 messages_.push_back(warning.msg());
58 }
59
60 private:
61 WarningHandler* prev_;
62 std::vector<std::string> messages_;
63 };
64
pointer_equal(at::Tensor first,at::Tensor second)65 inline bool pointer_equal(at::Tensor first, at::Tensor second) {
66 return first.data_ptr() == second.data_ptr();
67 }
68
69 // This mirrors the `isinstance(x, torch.Tensor) and isinstance(y,
70 // torch.Tensor)` branch in `TestCase.assertEqual` in
71 // torch/testing/_internal/common_utils.py
72 inline void assert_tensor_equal(
73 at::Tensor a,
74 at::Tensor b,
75 bool allow_inf = false) {
76 ASSERT_TRUE(a.sizes() == b.sizes());
77 if (a.numel() > 0) {
78 if (a.device().type() == torch::kCPU &&
79 (a.scalar_type() == torch::kFloat16 ||
80 a.scalar_type() == torch::kBFloat16)) {
81 // CPU half and bfloat16 tensors don't have the methods we need below
82 a = a.to(torch::kFloat32);
83 }
84 if (a.device().type() == torch::kCUDA &&
85 a.scalar_type() == torch::kBFloat16) {
86 // CUDA bfloat16 tensors don't have the methods we need below
87 a = a.to(torch::kFloat32);
88 }
89 b = b.to(a);
90
91 if ((a.scalar_type() == torch::kBool) !=
92 (b.scalar_type() == torch::kBool)) {
93 TORCH_CHECK(false, "Was expecting both tensors to be bool type.");
94 } else {
95 if (a.scalar_type() == torch::kBool && b.scalar_type() == torch::kBool) {
96 // we want to respect precision but as bool doesn't support subtraction,
97 // boolean tensor has to be converted to int
98 a = a.to(torch::kInt);
99 b = b.to(torch::kInt);
100 }
101
102 auto diff = a - b;
103 if (a.is_floating_point()) {
104 // check that NaNs are in the same locations
105 auto nan_mask = torch::isnan(a);
106 ASSERT_TRUE(torch::equal(nan_mask, torch::isnan(b)));
107 diff.index_put_({nan_mask}, 0);
108 // inf check if allow_inf=true
109 if (allow_inf) {
110 auto inf_mask = torch::isinf(a);
111 auto inf_sign = inf_mask.sign();
112 ASSERT_TRUE(torch::equal(inf_sign, torch::isinf(b).sign()));
113 diff.index_put_({inf_mask}, 0);
114 }
115 }
116 // TODO: implement abs on CharTensor (int8)
117 if (diff.is_signed() && diff.scalar_type() != torch::kInt8) {
118 diff = diff.abs();
119 }
120 auto max_err = diff.max().item<double>();
121 ASSERT_LE(max_err, 1e-5);
122 }
123 }
124 }
125
126 // This mirrors the `isinstance(x, torch.Tensor) and isinstance(y,
127 // torch.Tensor)` branch in `TestCase.assertNotEqual` in
128 // torch/testing/_internal/common_utils.py
assert_tensor_not_equal(at::Tensor x,at::Tensor y)129 inline void assert_tensor_not_equal(at::Tensor x, at::Tensor y) {
130 if (x.sizes() != y.sizes()) {
131 return;
132 }
133 ASSERT_GT(x.numel(), 0);
134 y = y.type_as(x);
135 y = x.is_cuda() ? y.to({torch::kCUDA, x.get_device()}) : y.cpu();
136 auto nan_mask = x != x;
137 if (torch::equal(nan_mask, y != y)) {
138 auto diff = x - y;
139 if (diff.is_signed()) {
140 diff = diff.abs();
141 }
142 diff.index_put_({nan_mask}, 0);
143 // Use `item()` to work around:
144 // https://github.com/pytorch/pytorch/issues/22301
145 auto max_err = diff.max().item<double>();
146 ASSERT_GE(max_err, 1e-5);
147 }
148 }
149
count_substr_occurrences(const std::string & str,const std::string & substr)150 inline int count_substr_occurrences(
151 const std::string& str,
152 const std::string& substr) {
153 int count = 0;
154 size_t pos = str.find(substr);
155
156 while (pos != std::string::npos) {
157 count++;
158 pos = str.find(substr, pos + substr.size());
159 }
160
161 return count;
162 }
163
164 // A RAII, thread local (!) guard that changes default dtype upon
165 // construction, and sets it back to the original dtype upon destruction.
166 //
167 // Usage of this guard is synchronized across threads, so that at any given
168 // time, only one guard can take effect.
169 struct AutoDefaultDtypeMode {
170 static std::mutex default_dtype_mutex;
171
AutoDefaultDtypeModeAutoDefaultDtypeMode172 AutoDefaultDtypeMode(c10::ScalarType default_dtype)
173 : prev_default_dtype(
174 torch::typeMetaToScalarType(torch::get_default_dtype())) {
175 default_dtype_mutex.lock();
176 torch::set_default_dtype(torch::scalarTypeToTypeMeta(default_dtype));
177 }
~AutoDefaultDtypeModeAutoDefaultDtypeMode178 ~AutoDefaultDtypeMode() {
179 default_dtype_mutex.unlock();
180 torch::set_default_dtype(torch::scalarTypeToTypeMeta(prev_default_dtype));
181 }
182 c10::ScalarType prev_default_dtype;
183 };
184
assert_tensor_creation_meta(torch::Tensor & x,torch::autograd::CreationMeta creation_meta)185 inline void assert_tensor_creation_meta(
186 torch::Tensor& x,
187 torch::autograd::CreationMeta creation_meta) {
188 auto autograd_meta = x.unsafeGetTensorImpl()->autograd_meta();
189 TORCH_CHECK(autograd_meta);
190 auto view_meta =
191 static_cast<torch::autograd::DifferentiableViewMeta*>(autograd_meta);
192 TORCH_CHECK(view_meta->has_bw_view());
193 ASSERT_EQ(view_meta->get_creation_meta(), creation_meta);
194 }
195 } // namespace test
196 } // namespace torch
197