1 #include <gtest/gtest.h>
2
3 #include <ATen/ATen.h>
4
5 using namespace at;
6
7 std::vector<std::vector<int64_t>> sizes = {{4, 4, 4, 4}, {4, 4, 1, 1}, {4, 1, 4, 4}, {4, 1, 4, 1}, {4, 1, 1, 4}, {1, 4, 1, 4}, {1, 4, 4, 1}};
8
TEST(MemoryFormatTest,SetMemoryFormat)9 TEST(MemoryFormatTest, SetMemoryFormat) {
10 // NOLINTNEXTLINE(performance-for-range-copy)
11 for (auto size : sizes) {
12 Tensor t = at::rand(size);
13 for (auto memory_format : {at::MemoryFormat::ChannelsLast, at::MemoryFormat::Contiguous}) {
14 t.resize_(size, memory_format);
15 EXPECT_TRUE(t.suggest_memory_format() == memory_format);
16 }
17 }
18
19 Tensor t = at::rand({4, 1, 1, 1});
20 EXPECT_TRUE(t.suggest_memory_format() == at::MemoryFormat::Contiguous);
21 t.resize_({4, 1, 1, 1}, at::MemoryFormat::ChannelsLast);
22 // TODO: Should be able to handle this after accumulated permutation is implemented;
23 // Ambiguous case where we fallback to Contiguous;
24 // This should be `EXPECT_TRUE(t.suggest_memory_format() == at::MemoryFormat::ChannelsLast);`
25 EXPECT_TRUE(t.suggest_memory_format() == at::MemoryFormat::Contiguous);
26 }
27
TEST(MemoryFormatTest,TransposeMemoryFormat)28 TEST(MemoryFormatTest, TransposeMemoryFormat) {
29 Tensor t = at::rand({2, 3, 4, 5});
30 EXPECT_TRUE(t.suggest_memory_format() == at::MemoryFormat::Contiguous);
31 t.transpose_(1, 3);
32 EXPECT_TRUE(t.suggest_memory_format() != at::MemoryFormat::ChannelsLast);
33 t.transpose_(2, 3);
34 EXPECT_TRUE(t.suggest_memory_format() == at::MemoryFormat::ChannelsLast);
35 t = at::rand({2, 3, 4, 5});
36 t.transpose_(1, 2);
37 EXPECT_TRUE(t.suggest_memory_format() != at::MemoryFormat::ChannelsLast);
38 t = at::rand({2, 3, 4, 5});
39 t.transpose_(2, 3);
40 EXPECT_TRUE(t.suggest_memory_format() != at::MemoryFormat::ChannelsLast);
41
42 // corner cases:
43 t = at::rand({1, 4, 1, 4});
44 t.transpose_(1, 3);
45 EXPECT_TRUE(t.suggest_memory_format() != at::MemoryFormat::ChannelsLast);
46 t = at::rand({1, 4, 1, 4});
47 t.transpose_(1, 2);
48 EXPECT_TRUE(t.suggest_memory_format() != at::MemoryFormat::ChannelsLast);
49 t = at::rand({1, 4, 1, 4});
50 t.transpose_(2, 3);
51 EXPECT_TRUE(t.suggest_memory_format() != at::MemoryFormat::ChannelsLast);
52 t = at::rand({1, 4, 1, 4});
53 t.transpose_(2, 3);
54 t.transpose_(1, 2);
55 EXPECT_TRUE(t.suggest_memory_format() == at::MemoryFormat::ChannelsLast);
56
57 t = at::rand({1, 4, 4, 1});
58 t.transpose_(1, 3);
59 EXPECT_TRUE(t.suggest_memory_format() != at::MemoryFormat::ChannelsLast);
60 t = at::rand({1, 4, 4, 1});
61 t.transpose_(1, 2);
62 EXPECT_TRUE(t.suggest_memory_format() != at::MemoryFormat::ChannelsLast);
63 t = at::rand({1, 4, 4, 1});
64 t.transpose_(2, 3);
65 EXPECT_TRUE(t.suggest_memory_format() != at::MemoryFormat::ChannelsLast);
66 t = at::rand({1, 4, 4, 1});
67 t.transpose_(2, 3);
68 t.transpose_(1, 2);
69 EXPECT_TRUE(t.suggest_memory_format() == at::MemoryFormat::ChannelsLast);
70 }
71
sliceStepTwo(Tensor & t,int dim,at::MemoryFormat format)72 inline void sliceStepTwo(Tensor& t, int dim, at::MemoryFormat format) {
73 t = t.slice(dim, 0, 3, 2);
74 EXPECT_TRUE(t.suggest_memory_format() == format);
75 t = t.slice(dim, 0, 3, 2);
76 EXPECT_TRUE(t.suggest_memory_format() == format);
77 }
78
TEST(MemoryFormatTest,SliceStepTwoMemoryFormat)79 TEST(MemoryFormatTest, SliceStepTwoMemoryFormat) {
80 Tensor t = at::rand({4, 4, 4, 4});
81 sliceStepTwo(t, 1, MemoryFormat::Contiguous);
82 sliceStepTwo(t, 2, MemoryFormat::Contiguous);
83 sliceStepTwo(t, 3, MemoryFormat::Contiguous);
84
85 t = at::rand({4, 4, 4, 4});
86 sliceStepTwo(t, 2, MemoryFormat::Contiguous);
87 sliceStepTwo(t, 3, MemoryFormat::Contiguous);
88 sliceStepTwo(t, 1, MemoryFormat::Contiguous);
89
90 t = at::rand({4, 4, 4, 4});
91 t.resize_({4, 4, 4, 4}, at::MemoryFormat::ChannelsLast);
92 sliceStepTwo(t, 1, MemoryFormat::ChannelsLast);
93 sliceStepTwo(t, 2, MemoryFormat::ChannelsLast);
94 sliceStepTwo(t, 3, MemoryFormat::ChannelsLast);
95
96 t = at::rand({4, 4, 4, 4});
97 t.resize_({4, 4, 4, 4}, at::MemoryFormat::ChannelsLast);
98 sliceStepTwo(t, 2, MemoryFormat::ChannelsLast);
99 sliceStepTwo(t, 3, MemoryFormat::ChannelsLast);
100 sliceStepTwo(t, 1, MemoryFormat::ChannelsLast);
101
102 t = at::rand({4, 4, 1, 1});
103 sliceStepTwo(t, 1, MemoryFormat::Contiguous);
104 t = at::rand({4, 4, 1, 1});
105 t.resize_({4, 4, 1, 1}, at::MemoryFormat::ChannelsLast);
106 t = t.slice(1, 0, 3, 2);
107 EXPECT_TRUE(t.suggest_memory_format() == MemoryFormat::ChannelsLast);
108 t = t.slice(1, 0, 3, 2);
109 // TODO: Should be able to handle this after accumulated permutation is implemented;
110 // won't be able to tell how we ended up here
111 // [4, 1, 1, 4]@[4, 4, 4, 1] slice twice at dim3
112 // [4, 4, 1, 1]@[4, 1, 4, 4] slice twice at dim1
113 // EXPECT_TRUE(t.suggest_memory_format() == MemoryFormat::ChannelsLast);
114 EXPECT_TRUE(t.suggest_memory_format() == MemoryFormat::Contiguous);
115
116 t = at::rand({4, 1, 4, 4});
117 sliceStepTwo(t, 2, MemoryFormat::Contiguous);
118 sliceStepTwo(t, 3, MemoryFormat::Contiguous);
119 t = at::rand({4, 1, 4, 4});
120 t.resize_({4, 1, 4, 4}, at::MemoryFormat::ChannelsLast);
121 sliceStepTwo(t, 2, MemoryFormat::ChannelsLast);
122 sliceStepTwo(t, 3, MemoryFormat::ChannelsLast);
123
124 t = at::rand({4, 1, 1, 4});
125 sliceStepTwo(t, 3, MemoryFormat::Contiguous);
126 t = at::rand({4, 1, 1, 4});
127 t.resize_({4, 1, 1, 4}, at::MemoryFormat::ChannelsLast);
128 sliceStepTwo(t, 3, MemoryFormat::ChannelsLast);
129
130 t = at::rand({4, 1, 4, 1});
131 sliceStepTwo(t, 2, MemoryFormat::Contiguous);
132 t = at::rand({4, 1, 4, 1});
133 t.resize_({4, 1, 4, 1}, at::MemoryFormat::ChannelsLast);
134 sliceStepTwo(t, 2, MemoryFormat::ChannelsLast);
135 }
136
sliceFirst(Tensor & t,int dim,at::MemoryFormat format)137 inline void sliceFirst(Tensor& t, int dim, at::MemoryFormat format) {
138 t = t.slice(dim, 0, 1, 1);
139 EXPECT_TRUE(t.suggest_memory_format() == format);
140 }
141
TEST(MemoryFormatTest,SliceFirstMemoryFormat)142 TEST(MemoryFormatTest, SliceFirstMemoryFormat) {
143 Tensor t = at::rand({4, 4, 4, 4});
144 sliceFirst(t, 1, MemoryFormat::Contiguous);
145 sliceFirst(t, 2, MemoryFormat::Contiguous);
146 sliceFirst(t, 3, MemoryFormat::Contiguous);
147
148 t = at::rand({4, 4, 4, 4});
149 sliceFirst(t, 2, MemoryFormat::Contiguous);
150 sliceFirst(t, 3, MemoryFormat::Contiguous);
151 sliceFirst(t, 1, MemoryFormat::Contiguous);
152
153 t = at::rand({4, 4, 4, 4});
154 t.resize_({4, 4, 4, 4}, at::MemoryFormat::ChannelsLast);
155 sliceFirst(t, 1, MemoryFormat::ChannelsLast);
156 sliceFirst(t, 2, MemoryFormat::ChannelsLast);
157 sliceFirst(t, 3, MemoryFormat::ChannelsLast);
158
159 t = at::rand({4, 4, 4, 4});
160 t.resize_({4, 4, 4, 4}, at::MemoryFormat::ChannelsLast);
161 sliceFirst(t, 2, MemoryFormat::ChannelsLast);
162 sliceFirst(t, 3, MemoryFormat::ChannelsLast);
163 sliceFirst(t, 1, MemoryFormat::ChannelsLast);
164
165 t = at::rand({4, 4, 1, 1});
166 sliceFirst(t, 1, MemoryFormat::Contiguous);
167 t = at::rand({4, 4, 1, 1});
168 t.resize_({4, 4, 1, 1}, at::MemoryFormat::ChannelsLast);
169 sliceFirst(t, 1, MemoryFormat::ChannelsLast);
170
171 t = at::rand({4, 1, 4, 4});
172 sliceFirst(t, 2, MemoryFormat::Contiguous);
173 sliceFirst(t, 3, MemoryFormat::Contiguous);
174 t = at::rand({4, 1, 4, 4});
175 t.resize_({4, 1, 4, 4}, at::MemoryFormat::ChannelsLast);
176 sliceFirst(t, 2, MemoryFormat::ChannelsLast);
177 sliceFirst(t, 3, MemoryFormat::ChannelsLast);
178
179 t = at::rand({4, 1, 1, 4});
180 sliceFirst(t, 3, MemoryFormat::Contiguous);
181 t = at::rand({4, 1, 1, 4});
182 t.resize_({4, 1, 1, 4}, at::MemoryFormat::ChannelsLast);
183 sliceFirst(t, 3, MemoryFormat::ChannelsLast);
184
185 t = at::rand({4, 1, 4, 1});
186 sliceFirst(t, 2, MemoryFormat::Contiguous);
187 t = at::rand({4, 1, 4, 1});
188 t.resize_({4, 1, 4, 1}, at::MemoryFormat::ChannelsLast);
189 // TODO: Should be able to handle this after accumulated permutation is implemented;
190 // [4, 1, 4, 1]@[4, 1, 1, 1] after slice becomes [4, 1, 1, 1]@[4, 1, 1, 1]
191 // sliceFirst(t, 2, MemoryFormat::ChannelsLast);
192 sliceFirst(t, 2, MemoryFormat::Contiguous);
193 }
194