xref: /aosp_15_r20/external/pytorch/test/cpp/api/ivalue.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/core/ivalue.h>
4 
5 #include <c10/util/flat_hash_map.h>
6 #include <c10/util/irange.h>
7 #include <c10/util/tempfile.h>
8 
9 #include <torch/torch.h>
10 
11 #include <test/cpp/api/support.h>
12 
13 #include <cstdio>
14 #include <memory>
15 #include <sstream>
16 #include <string>
17 #include <vector>
18 
19 using namespace torch::test;
20 using namespace torch::nn;
21 using namespace torch::optim;
22 
TEST(IValueTest,DeepcopyTensors)23 TEST(IValueTest, DeepcopyTensors) {
24   torch::Tensor t0 = torch::randn({2, 3});
25   torch::Tensor t1 = torch::randn({3, 4});
26   torch::Tensor t2 = t0.detach();
27   torch::Tensor t3 = t0;
28   torch::Tensor t4 = t1.as_strided({2, 3}, {3, 1}, 2);
29   std::vector<torch::Tensor> tensor_vector = {t0, t1, t2, t3, t4};
30   c10::List<torch::Tensor> tensor_list(tensor_vector);
31   torch::IValue tensor_list_ivalue(tensor_list);
32 
33   c10::IValue::CompIdentityIValues ivalue_compare;
34 
35   // Make sure our setup configuration is correct
36   ASSERT_TRUE(ivalue_compare(tensor_list[0].get(), tensor_list[3].get()));
37   ASSERT_FALSE(ivalue_compare(tensor_list[0].get(), tensor_list[1].get()));
38   ASSERT_FALSE(ivalue_compare(tensor_list[0].get(), tensor_list[2].get()));
39   ASSERT_FALSE(ivalue_compare(tensor_list[1].get(), tensor_list[4].get()));
40   ASSERT_TRUE(tensor_list[0].get().isAliasOf(tensor_list[2].get()));
41 
42   c10::IValue copied_ivalue = tensor_list_ivalue.deepcopy();
43   c10::List<torch::IValue> copied_list = copied_ivalue.toList();
44 
45   // Make sure our setup configuration is correct
46   ASSERT_TRUE(ivalue_compare(copied_list[0].get(), copied_list[3].get()));
47   ASSERT_FALSE(ivalue_compare(copied_list[0].get(), copied_list[1].get()));
48   ASSERT_FALSE(ivalue_compare(copied_list[0].get(), copied_list[2].get()));
49   ASSERT_FALSE(ivalue_compare(copied_list[1].get(), copied_list[4].get()));
50   // NOTE: this is actually incorrect. Ideally, these _should_ be aliases.
51   ASSERT_FALSE(copied_list[0].get().isAliasOf(copied_list[2].get()));
52 
53   ASSERT_TRUE(copied_list[0].get().toTensor().allclose(
54       tensor_list[0].get().toTensor()));
55   ASSERT_TRUE(copied_list[1].get().toTensor().allclose(
56       tensor_list[1].get().toTensor()));
57   ASSERT_TRUE(copied_list[2].get().toTensor().allclose(
58       tensor_list[2].get().toTensor()));
59   ASSERT_TRUE(copied_list[3].get().toTensor().allclose(
60       tensor_list[3].get().toTensor()));
61   ASSERT_TRUE(copied_list[4].get().toTensor().allclose(
62       tensor_list[4].get().toTensor()));
63 }
64