xref: /aosp_15_r20/external/pytorch/test/cpp/api/tensor_flatten.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 #include <test/cpp/api/support.h>
3 
4 #include <torch/csrc/autograd/variable.h>
5 #include <torch/csrc/utils/tensor_flatten.h>
6 #include <torch/torch.h>
7 
8 using namespace torch::test;
9 
TEST(UnflattenDenseTensorTest,TestEmptyTensor)10 TEST(UnflattenDenseTensorTest, TestEmptyTensor) {
11   auto emptyTensor1 = at::tensor(std::vector<int>());
12   auto emptyTensor2 = at::tensor(std::vector<int>());
13   auto tensor1 = at::tensor({1, 2, 3});
14   auto tensor2 = at::tensor({4, 5});
15   auto tensorList =
16       std::vector<at::Tensor>({tensor1, emptyTensor1, emptyTensor2, tensor2});
17   auto flatTensor = at::tensor({1, 2, 3, 4, 5});
18   auto unflatten_results =
19       torch::utils::unflatten_dense_tensors(flatTensor, tensorList);
20   ASSERT_EQ(unflatten_results.size(), 4);
21   ASSERT_EQ(unflatten_results.at(0).numel(), 3);
22   ASSERT_EQ(unflatten_results.at(1).numel(), 0);
23   ASSERT_EQ(unflatten_results.at(2).numel(), 0);
24   ASSERT_EQ(unflatten_results.at(3).numel(), 2);
25 
26   // empty tensor address is 0 as memory is not allocated yet
27   ASSERT_EQ(unflatten_results.at(1).data_ptr(), nullptr);
28   ASSERT_EQ(unflatten_results.at(2).data_ptr(), nullptr);
29   // without fix in unflatten_dense_tensors() for empty tensors,
30   // unflattend empty tensor unflatten_results.at(1) will share the same storage
31   // as other non-empty tensor like unflatten_results.at(3).
32   // after fix, the empty tensor and non-empty tensor do not share the same
33   // storage.
34   ASSERT_NE(
35       unflatten_results.at(1).data_ptr(), unflatten_results.at(3).data_ptr());
36   unflatten_results.at(1).resize_(1);
37   unflatten_results.at(2).resize_(1);
38   // after resizing the two empty tensors, the resized tensors do not share
39   // the same storage. without fix in unflatten_dense_tensors() for empty
40   // tensors, the resized tensors will share the same storage.
41   ASSERT_NE(
42       unflatten_results.at(1).data_ptr(), unflatten_results.at(2).data_ptr());
43 }
44