xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/undefined_tensor_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/ATen.h>
4 #include <c10/core/UndefinedTensorImpl.h>
5 #include <string>
6 
7 using namespace at;
8 
TEST(TestUndefined,UndefinedTest)9 TEST(TestUndefined, UndefinedTest) {
10   manual_seed(123);
11 
12   // mainly test ops on undefined tensors don't segfault and give a reasonable errror message.
13   Tensor und;
14   Tensor ft = ones({1}, CPU(kFloat));
15 
16   std::stringstream ss;
17   ss << und << std::endl;
18   ASSERT_FALSE(und.defined());
19   ASSERT_EQ(std::string("UndefinedType"), und.toString());
20 
21   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
22   ASSERT_ANY_THROW(und.strides());
23   ASSERT_EQ(und.dim(), 1);
24   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
25   ASSERT_ANY_THROW([]() { return Tensor(); }() = Scalar(5));
26   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
27   ASSERT_ANY_THROW(und.add(und));
28   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
29   ASSERT_ANY_THROW(und.add(ft));
30   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
31   ASSERT_ANY_THROW(ft.add(und));
32   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
33   ASSERT_ANY_THROW(und.add(5));
34   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
35   ASSERT_ANY_THROW(und.mm(und));
36 
37   // public variable API
38   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
39   ASSERT_ANY_THROW(und.variable_data());
40   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
41   ASSERT_ANY_THROW(und.tensor_data());
42   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
43   ASSERT_ANY_THROW(und.is_view());
44   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
45   ASSERT_ANY_THROW(und._base());
46   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
47   ASSERT_ANY_THROW(und.name());
48   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
49   ASSERT_ANY_THROW(und.grad_fn());
50   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
51   ASSERT_ANY_THROW(und.remove_hook(0));
52   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
53   ASSERT_ANY_THROW(und.register_hook([](const Tensor& x) -> Tensor { return x; }));
54 
55   // copy_
56   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
57   ASSERT_ANY_THROW(und.copy_(und));
58   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
59   ASSERT_ANY_THROW(und.copy_(ft));
60   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
61   ASSERT_ANY_THROW(ft.copy_(und));
62 
63   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
64   ASSERT_ANY_THROW(und.toBackend(Backend::CPU));
65   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
66   ASSERT_ANY_THROW(ft.toBackend(Backend::Undefined));
67 
68   Tensor to_move = ones({1}, CPU(kFloat));
69   Tensor m(std::move(to_move));
70   // NOLINTNEXTLINE(bugprone-use-after-move)
71   ASSERT_FALSE(to_move.defined());
72   ASSERT_EQ(to_move.unsafeGetTensorImpl(), UndefinedTensorImpl::singleton());
73 }
74