1*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/ivalue.h>
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Worker #include <c10/util/flat_hash_map.h>
6*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
7*da0073e9SAndroid Build Coastguard Worker #include <c10/util/tempfile.h>
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h>
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker #include <test/cpp/api/support.h>
12*da0073e9SAndroid Build Coastguard Worker
13*da0073e9SAndroid Build Coastguard Worker #include <cstdio>
14*da0073e9SAndroid Build Coastguard Worker #include <memory>
15*da0073e9SAndroid Build Coastguard Worker #include <sstream>
16*da0073e9SAndroid Build Coastguard Worker #include <string>
17*da0073e9SAndroid Build Coastguard Worker #include <vector>
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker using namespace torch::test;
20*da0073e9SAndroid Build Coastguard Worker using namespace torch::nn;
21*da0073e9SAndroid Build Coastguard Worker using namespace torch::optim;
22*da0073e9SAndroid Build Coastguard Worker
TEST(IValueTest,DeepcopyTensors)23*da0073e9SAndroid Build Coastguard Worker TEST(IValueTest, DeepcopyTensors) {
24*da0073e9SAndroid Build Coastguard Worker torch::Tensor t0 = torch::randn({2, 3});
25*da0073e9SAndroid Build Coastguard Worker torch::Tensor t1 = torch::randn({3, 4});
26*da0073e9SAndroid Build Coastguard Worker torch::Tensor t2 = t0.detach();
27*da0073e9SAndroid Build Coastguard Worker torch::Tensor t3 = t0;
28*da0073e9SAndroid Build Coastguard Worker torch::Tensor t4 = t1.as_strided({2, 3}, {3, 1}, 2);
29*da0073e9SAndroid Build Coastguard Worker std::vector<torch::Tensor> tensor_vector = {t0, t1, t2, t3, t4};
30*da0073e9SAndroid Build Coastguard Worker c10::List<torch::Tensor> tensor_list(tensor_vector);
31*da0073e9SAndroid Build Coastguard Worker torch::IValue tensor_list_ivalue(tensor_list);
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker c10::IValue::CompIdentityIValues ivalue_compare;
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker // Make sure our setup configuration is correct
36*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(ivalue_compare(tensor_list[0].get(), tensor_list[3].get()));
37*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(ivalue_compare(tensor_list[0].get(), tensor_list[1].get()));
38*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(ivalue_compare(tensor_list[0].get(), tensor_list[2].get()));
39*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(ivalue_compare(tensor_list[1].get(), tensor_list[4].get()));
40*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(tensor_list[0].get().isAliasOf(tensor_list[2].get()));
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker c10::IValue copied_ivalue = tensor_list_ivalue.deepcopy();
43*da0073e9SAndroid Build Coastguard Worker c10::List<torch::IValue> copied_list = copied_ivalue.toList();
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker // Make sure our setup configuration is correct
46*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(ivalue_compare(copied_list[0].get(), copied_list[3].get()));
47*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(ivalue_compare(copied_list[0].get(), copied_list[1].get()));
48*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(ivalue_compare(copied_list[0].get(), copied_list[2].get()));
49*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(ivalue_compare(copied_list[1].get(), copied_list[4].get()));
50*da0073e9SAndroid Build Coastguard Worker // NOTE: this is actually incorrect. Ideally, these _should_ be aliases.
51*da0073e9SAndroid Build Coastguard Worker ASSERT_FALSE(copied_list[0].get().isAliasOf(copied_list[2].get()));
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(copied_list[0].get().toTensor().allclose(
54*da0073e9SAndroid Build Coastguard Worker tensor_list[0].get().toTensor()));
55*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(copied_list[1].get().toTensor().allclose(
56*da0073e9SAndroid Build Coastguard Worker tensor_list[1].get().toTensor()));
57*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(copied_list[2].get().toTensor().allclose(
58*da0073e9SAndroid Build Coastguard Worker tensor_list[2].get().toTensor()));
59*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(copied_list[3].get().toTensor().allclose(
60*da0073e9SAndroid Build Coastguard Worker tensor_list[3].get().toTensor()));
61*da0073e9SAndroid Build Coastguard Worker ASSERT_TRUE(copied_list[4].get().toTensor().allclose(
62*da0073e9SAndroid Build Coastguard Worker tensor_list[4].get().toTensor()));
63*da0073e9SAndroid Build Coastguard Worker }
64