xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_argument_spec.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/ATen.h>
4 #include <torch/csrc/jit/api/function_impl.h>
5 #include <torch/csrc/jit/runtime/argument_spec.h>
6 #include <torch/jit.h>
7 
8 #include "test/cpp/jit/test_utils.h"
9 
10 namespace torch {
11 namespace jit {
12 
13 namespace {
14 
device(const autograd::Variable & v)15 at::Device device(const autograd::Variable& v) {
16   // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
17   return v.device();
18 }
19 
isEqual(at::IntArrayRef lhs,at::IntArrayRef rhs)20 bool isEqual(at::IntArrayRef lhs, at::IntArrayRef rhs) {
21   return lhs.size() == rhs.size() &&
22       std::equal(lhs.begin(), lhs.end(), rhs.begin());
23 }
24 
isEqual(const CompleteArgumentInfo & ti,const autograd::Variable & v)25 bool isEqual(const CompleteArgumentInfo& ti, const autograd::Variable& v) {
26   if (!ti.defined())
27     return ti.defined() == v.defined();
28   return ti.device() == device(v) && ti.requires_grad() == v.requires_grad() &&
29       ti.type() == v.scalar_type() && isEqual(ti.sizes(), v.sizes()) &&
30       isEqual(ti.strides(), v.strides());
31 }
32 
isEqual(const ArgumentInfo & ti,const autograd::Variable & v)33 bool isEqual(const ArgumentInfo& ti, const autograd::Variable& v) {
34   if (!ti.defined())
35     return ti.defined() == v.defined();
36   return ti.device() == device(v) && ti.requires_grad() == v.requires_grad() &&
37       ti.type() == v.scalar_type() && ti.dim() == v.dim();
38 }
39 
var(at::TensorOptions t,at::IntArrayRef sizes,bool requires_grad)40 autograd::Variable var(
41     at::TensorOptions t,
42     at::IntArrayRef sizes,
43     bool requires_grad) {
44   return autograd::make_variable(at::rand(sizes, t), requires_grad);
45 }
undef()46 autograd::Variable undef() {
47   return autograd::Variable();
48 }
49 } // namespace
50 
TEST(ArgumentSpecTest,CompleteArgumentSpec_CUDA)51 TEST(ArgumentSpecTest, CompleteArgumentSpec_CUDA) {
52   auto const CF = at::CPU(at::kFloat);
53   auto const CD = at::CPU(at::kDouble);
54   auto const GF = at::CUDA(at::kFloat);
55   auto const GD = at::CUDA(at::kDouble);
56 
57   auto list = createStack(
58       {var(CF, {1}, true),
59        var(CD, {1, 2}, false),
60        var(GF, {}, true),
61        var(GD, {4, 5, 6}, false),
62        undef()});
63 
64   // make sure we have some non-standard strides
65   list[1].toTensor().transpose_(0, 1);
66 
67   // same list but different backing values
68   auto list2 = createStack(
69       {var(CF, {1}, true),
70        var(CD, {1, 2}, false),
71        var(GF, {}, true),
72        var(GD, {4, 5, 6}, false),
73        undef()});
74   list2[1].toTensor().transpose_(0, 1);
75 
76   CompleteArgumentSpec a(true, list);
77   CompleteArgumentSpec b(true, list);
78   ASSERT_EQ(a.hashCode(), b.hashCode());
79 
80   ASSERT_EQ(a, b);
81   CompleteArgumentSpec d(true, list2);
82   ASSERT_EQ(d, a);
83   ASSERT_EQ(d.hashCode(), a.hashCode());
84 
85   for (size_t i = 0; i < list.size(); ++i) {
86     ASSERT_TRUE(isEqual(a.at(i), list[i].toTensor()));
87   }
88   CompleteArgumentSpec no_grad(/*with_grad=*/false, list);
89   ASSERT_TRUE(no_grad != a);
90 
91   std::unordered_set<CompleteArgumentSpec> spec;
92   spec.insert(a); // we use a below, so no move
93   ASSERT_TRUE(spec.count(b) > 0);
94   ASSERT_EQ(spec.count(no_grad), 0);
95   spec.insert(std::move(no_grad));
96   ASSERT_EQ(spec.count(CompleteArgumentSpec(true, list)), 1);
97 
98   list2[1].toTensor().transpose_(0, 1);
99   CompleteArgumentSpec c(true, list2); // same as list, except for one stride
100   ASSERT_FALSE(c == a);
101   ASSERT_EQ(spec.count(c), 0);
102 
103   Stack stack = {var(CF, {1, 2}, true), 3, var(CF, {1, 2}, true)};
104   CompleteArgumentSpec with_const(true, stack);
105   ASSERT_EQ(with_const.at(2).sizes().size(), 2);
106 }
107 
108 // TODO: this test was disabled for unknown reasons and doesn't run.
109 // static size_t hashCode(const TensorTypePtr& ptr) {
110 //   return std::hash<TensorType>()(*ptr.get());
111 // }
112 
113 // TEST(ArgumentSpecTest, VaryingShape) {
114 //   c10::VaryingShape<int64_t> vs(std::optional<size_t>{});
115 //   auto ptt_empty1 = TensorType::create({}, {}, vs, vs, false);
116 //   auto ptt_empty2 = TensorType::create({}, {}, vs, vs, false);
117 //   ASSERT_EQ(hashCode(ptt_empty1), hashCode(ptt_empty2));
118 
119 //   c10::VaryingShape<int64_t> vs22(std::vector<int64_t>{2, 2});
120 //   auto ptt_vs22_vs22_1 = TensorType::create({}, {}, vs22, vs22, false);
121 //   auto ptt_vs22_vs22_2 = TensorType::create({}, {}, vs22, vs22, false);
122 //   ASSERT_EQ(hashCode(ptt_vs22_vs22_1), hashCode(ptt_vs22_vs22_2));
123 
124 //   c10::VaryingShape<int64_t> vs23(std::vector<int64_t>{2, 3});
125 //   auto ptt_vs22_vs23_2 = TensorType::create({}, {}, vs22, vs23, false);
126 //   ASSERT_NE(hashCode(ptt_vs22_vs22_1), hashCode(ptt_vs22_vs23_2));
127 
128 //   auto ptt_vs22_vs22_1_true = TensorType::create({}, {}, vs22, vs22, true);
129 //   auto ptt_vs22_vs22_2_true = TensorType::create({}, {}, vs22, vs22, true);
130 //   ASSERT_EQ(hashCode(ptt_vs22_vs22_1_true), hashCode(ptt_vs22_vs22_2_true));
131 
132 //   auto ptt_vs22_vs22_1_false = TensorType::create({}, {}, vs22, vs22, false);
133 //   ASSERT_NE(hashCode(ptt_vs22_vs22_1_true), hashCode(ptt_vs22_vs22_1_false));
134 // }
135 
TEST(ArgumentSpecTest,Basic_CUDA)136 TEST(ArgumentSpecTest, Basic_CUDA) {
137   auto& CF = at::CPU(at::kFloat);
138   auto& CD = at::CPU(at::kDouble);
139   auto& GF = at::CUDA(at::kFloat);
140   auto& GD = at::CUDA(at::kDouble);
141 
142   auto graph = toGraphFunction(jit::compile(R"JIT(
143    def fn(a, b, c, d, e):
144       return a, b, c, d, e
145    )JIT")
146                                    ->get_function("fn"))
147                    .graph();
148 
149   ArgumentSpecCreator arg_spec_creator(*graph);
150 
151   auto list = createStack(
152       {var(CF, {1}, true),
153        var(CD, {1, 2}, false),
154        var(GF, {}, true),
155        var(GD, {4, 5, 6}, false),
156        undef()});
157 
158   // make sure we have some non-standard strides
159   list[1].toTensor().transpose_(0, 1);
160 
161   // same list but different backing values
162   auto list2 = createStack(
163       {var(CF, {1}, true),
164        var(CD, {1, 2}, false),
165        var(GF, {}, true),
166        var(GD, {4, 5, 6}, false),
167        undef()});
168   list2[1].toTensor().transpose_(0, 1);
169 
170   ArgumentSpec a = arg_spec_creator.create(true, list);
171   ArgumentSpec b = arg_spec_creator.create(true, list);
172   ASSERT_EQ(a.hashCode(), b.hashCode());
173 
174   ASSERT_EQ(a, b);
175   ArgumentSpec d = arg_spec_creator.create(true, list2);
176   ASSERT_EQ(d, a);
177   ASSERT_EQ(d.hashCode(), a.hashCode());
178 
179   for (size_t i = 0; i < list.size(); ++i) {
180     ASSERT_TRUE(isEqual(a.tensorAt(i), list[i].toTensor()));
181   }
182   ArgumentSpec no_grad = arg_spec_creator.create(/*with_grad=*/false, list);
183   ASSERT_TRUE(no_grad != a);
184 
185   std::unordered_set<ArgumentSpec> spec;
186   spec.insert(a); // we still need a for the test below
187   ASSERT_TRUE(spec.count(b) > 0);
188   ASSERT_EQ(spec.count(no_grad), 0);
189   spec.insert(std::move(no_grad));
190   ASSERT_EQ(spec.count(arg_spec_creator.create(true, list)), 1);
191 
192   list2[1].toTensor().transpose_(0, 1);
193   ArgumentSpec c = arg_spec_creator.create(
194       true, list2); // same as list, except for one stride, used to be
195                     // different, now the same
196   ASSERT_TRUE(c == a);
197   ASSERT_EQ(spec.count(c), 1);
198 }
199 
200 } // namespace jit
201 } // namespace torch
202