1 #include <gtest/gtest.h>
2
3 #include <ATen/ATen.h>
4 #include <c10/core/UndefinedTensorImpl.h>
5 #include <string>
6
7 using namespace at;
8
TEST(TestUndefined,UndefinedTest)9 TEST(TestUndefined, UndefinedTest) {
10 manual_seed(123);
11
12 // mainly test ops on undefined tensors don't segfault and give a reasonable errror message.
13 Tensor und;
14 Tensor ft = ones({1}, CPU(kFloat));
15
16 std::stringstream ss;
17 ss << und << std::endl;
18 ASSERT_FALSE(und.defined());
19 ASSERT_EQ(std::string("UndefinedType"), und.toString());
20
21 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
22 ASSERT_ANY_THROW(und.strides());
23 ASSERT_EQ(und.dim(), 1);
24 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
25 ASSERT_ANY_THROW([]() { return Tensor(); }() = Scalar(5));
26 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
27 ASSERT_ANY_THROW(und.add(und));
28 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
29 ASSERT_ANY_THROW(und.add(ft));
30 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
31 ASSERT_ANY_THROW(ft.add(und));
32 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
33 ASSERT_ANY_THROW(und.add(5));
34 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
35 ASSERT_ANY_THROW(und.mm(und));
36
37 // public variable API
38 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
39 ASSERT_ANY_THROW(und.variable_data());
40 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
41 ASSERT_ANY_THROW(und.tensor_data());
42 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
43 ASSERT_ANY_THROW(und.is_view());
44 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
45 ASSERT_ANY_THROW(und._base());
46 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
47 ASSERT_ANY_THROW(und.name());
48 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
49 ASSERT_ANY_THROW(und.grad_fn());
50 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
51 ASSERT_ANY_THROW(und.remove_hook(0));
52 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
53 ASSERT_ANY_THROW(und.register_hook([](const Tensor& x) -> Tensor { return x; }));
54
55 // copy_
56 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
57 ASSERT_ANY_THROW(und.copy_(und));
58 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
59 ASSERT_ANY_THROW(und.copy_(ft));
60 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
61 ASSERT_ANY_THROW(ft.copy_(und));
62
63 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
64 ASSERT_ANY_THROW(und.toBackend(Backend::CPU));
65 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
66 ASSERT_ANY_THROW(ft.toBackend(Backend::Undefined));
67
68 Tensor to_move = ones({1}, CPU(kFloat));
69 Tensor m(std::move(to_move));
70 // NOLINTNEXTLINE(bugprone-use-after-move)
71 ASSERT_FALSE(to_move.defined());
72 ASSERT_EQ(to_move.unsafeGetTensorImpl(), UndefinedTensorImpl::singleton());
73 }
74