xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/memory_format_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/ATen.h>
4 
5 using namespace at;
6 
7 std::vector<std::vector<int64_t>> sizes = {{4, 4, 4, 4}, {4, 4, 1, 1}, {4, 1, 4, 4}, {4, 1, 4, 1}, {4, 1, 1, 4}, {1, 4, 1, 4}, {1, 4, 4, 1}};
8 
TEST(MemoryFormatTest,SetMemoryFormat)9 TEST(MemoryFormatTest, SetMemoryFormat) {
10   // NOLINTNEXTLINE(performance-for-range-copy)
11   for (auto size : sizes) {
12     Tensor t = at::rand(size);
13     for (auto memory_format : {at::MemoryFormat::ChannelsLast, at::MemoryFormat::Contiguous}) {
14       t.resize_(size, memory_format);
15       EXPECT_TRUE(t.suggest_memory_format() == memory_format);
16     }
17   }
18 
19   Tensor t = at::rand({4, 1, 1, 1});
20   EXPECT_TRUE(t.suggest_memory_format() == at::MemoryFormat::Contiguous);
21   t.resize_({4, 1, 1, 1}, at::MemoryFormat::ChannelsLast);
22   // TODO: Should be able to handle this after accumulated permutation is implemented;
23   // Ambiguous case where we fallback to Contiguous;
24   // This should be `EXPECT_TRUE(t.suggest_memory_format() == at::MemoryFormat::ChannelsLast);`
25   EXPECT_TRUE(t.suggest_memory_format() == at::MemoryFormat::Contiguous);
26 }
27 
TEST(MemoryFormatTest,TransposeMemoryFormat)28 TEST(MemoryFormatTest, TransposeMemoryFormat) {
29   Tensor t = at::rand({2, 3, 4, 5});
30   EXPECT_TRUE(t.suggest_memory_format() == at::MemoryFormat::Contiguous);
31   t.transpose_(1, 3);
32   EXPECT_TRUE(t.suggest_memory_format() != at::MemoryFormat::ChannelsLast);
33   t.transpose_(2, 3);
34   EXPECT_TRUE(t.suggest_memory_format() == at::MemoryFormat::ChannelsLast);
35   t = at::rand({2, 3, 4, 5});
36   t.transpose_(1, 2);
37   EXPECT_TRUE(t.suggest_memory_format() != at::MemoryFormat::ChannelsLast);
38   t = at::rand({2, 3, 4, 5});
39   t.transpose_(2, 3);
40   EXPECT_TRUE(t.suggest_memory_format() != at::MemoryFormat::ChannelsLast);
41 
42   // corner cases:
43   t = at::rand({1, 4, 1, 4});
44   t.transpose_(1, 3);
45   EXPECT_TRUE(t.suggest_memory_format() != at::MemoryFormat::ChannelsLast);
46   t = at::rand({1, 4, 1, 4});
47   t.transpose_(1, 2);
48   EXPECT_TRUE(t.suggest_memory_format() != at::MemoryFormat::ChannelsLast);
49   t = at::rand({1, 4, 1, 4});
50   t.transpose_(2, 3);
51   EXPECT_TRUE(t.suggest_memory_format() != at::MemoryFormat::ChannelsLast);
52   t = at::rand({1, 4, 1, 4});
53   t.transpose_(2, 3);
54   t.transpose_(1, 2);
55   EXPECT_TRUE(t.suggest_memory_format() == at::MemoryFormat::ChannelsLast);
56 
57   t = at::rand({1, 4, 4, 1});
58   t.transpose_(1, 3);
59   EXPECT_TRUE(t.suggest_memory_format() != at::MemoryFormat::ChannelsLast);
60   t = at::rand({1, 4, 4, 1});
61   t.transpose_(1, 2);
62   EXPECT_TRUE(t.suggest_memory_format() != at::MemoryFormat::ChannelsLast);
63   t = at::rand({1, 4, 4, 1});
64   t.transpose_(2, 3);
65   EXPECT_TRUE(t.suggest_memory_format() != at::MemoryFormat::ChannelsLast);
66   t = at::rand({1, 4, 4, 1});
67   t.transpose_(2, 3);
68   t.transpose_(1, 2);
69   EXPECT_TRUE(t.suggest_memory_format() == at::MemoryFormat::ChannelsLast);
70 }
71 
sliceStepTwo(Tensor & t,int dim,at::MemoryFormat format)72 inline void sliceStepTwo(Tensor& t, int dim, at::MemoryFormat format) {
73   t = t.slice(dim, 0, 3, 2);
74   EXPECT_TRUE(t.suggest_memory_format() == format);
75   t = t.slice(dim, 0, 3, 2);
76   EXPECT_TRUE(t.suggest_memory_format() == format);
77 }
78 
TEST(MemoryFormatTest,SliceStepTwoMemoryFormat)79 TEST(MemoryFormatTest, SliceStepTwoMemoryFormat) {
80   Tensor t = at::rand({4, 4, 4, 4});
81   sliceStepTwo(t, 1, MemoryFormat::Contiguous);
82   sliceStepTwo(t, 2, MemoryFormat::Contiguous);
83   sliceStepTwo(t, 3, MemoryFormat::Contiguous);
84 
85   t = at::rand({4, 4, 4, 4});
86   sliceStepTwo(t, 2, MemoryFormat::Contiguous);
87   sliceStepTwo(t, 3, MemoryFormat::Contiguous);
88   sliceStepTwo(t, 1, MemoryFormat::Contiguous);
89 
90   t = at::rand({4, 4, 4, 4});
91   t.resize_({4, 4, 4, 4}, at::MemoryFormat::ChannelsLast);
92   sliceStepTwo(t, 1, MemoryFormat::ChannelsLast);
93   sliceStepTwo(t, 2, MemoryFormat::ChannelsLast);
94   sliceStepTwo(t, 3, MemoryFormat::ChannelsLast);
95 
96   t = at::rand({4, 4, 4, 4});
97   t.resize_({4, 4, 4, 4}, at::MemoryFormat::ChannelsLast);
98   sliceStepTwo(t, 2, MemoryFormat::ChannelsLast);
99   sliceStepTwo(t, 3, MemoryFormat::ChannelsLast);
100   sliceStepTwo(t, 1, MemoryFormat::ChannelsLast);
101 
102   t = at::rand({4, 4, 1, 1});
103   sliceStepTwo(t, 1, MemoryFormat::Contiguous);
104   t = at::rand({4, 4, 1, 1});
105   t.resize_({4, 4, 1, 1}, at::MemoryFormat::ChannelsLast);
106   t = t.slice(1, 0, 3, 2);
107   EXPECT_TRUE(t.suggest_memory_format() == MemoryFormat::ChannelsLast);
108   t = t.slice(1, 0, 3, 2);
109   // TODO: Should be able to handle this after accumulated permutation is implemented;
110   // won't be able to tell how we ended up here
111   // [4, 1, 1, 4]@[4, 4, 4, 1] slice twice at dim3
112   // [4, 4, 1, 1]@[4, 1, 4, 4] slice twice at dim1
113   // EXPECT_TRUE(t.suggest_memory_format() == MemoryFormat::ChannelsLast);
114   EXPECT_TRUE(t.suggest_memory_format() == MemoryFormat::Contiguous);
115 
116   t = at::rand({4, 1, 4, 4});
117   sliceStepTwo(t, 2, MemoryFormat::Contiguous);
118   sliceStepTwo(t, 3, MemoryFormat::Contiguous);
119   t = at::rand({4, 1, 4, 4});
120   t.resize_({4, 1, 4, 4}, at::MemoryFormat::ChannelsLast);
121   sliceStepTwo(t, 2, MemoryFormat::ChannelsLast);
122   sliceStepTwo(t, 3, MemoryFormat::ChannelsLast);
123 
124   t = at::rand({4, 1, 1, 4});
125   sliceStepTwo(t, 3, MemoryFormat::Contiguous);
126   t = at::rand({4, 1, 1, 4});
127   t.resize_({4, 1, 1, 4}, at::MemoryFormat::ChannelsLast);
128   sliceStepTwo(t, 3, MemoryFormat::ChannelsLast);
129 
130   t = at::rand({4, 1, 4, 1});
131   sliceStepTwo(t, 2, MemoryFormat::Contiguous);
132   t = at::rand({4, 1, 4, 1});
133   t.resize_({4, 1, 4, 1}, at::MemoryFormat::ChannelsLast);
134   sliceStepTwo(t, 2, MemoryFormat::ChannelsLast);
135 }
136 
sliceFirst(Tensor & t,int dim,at::MemoryFormat format)137 inline void sliceFirst(Tensor& t, int dim, at::MemoryFormat format) {
138   t = t.slice(dim, 0, 1, 1);
139   EXPECT_TRUE(t.suggest_memory_format() == format);
140 }
141 
TEST(MemoryFormatTest,SliceFirstMemoryFormat)142 TEST(MemoryFormatTest, SliceFirstMemoryFormat) {
143   Tensor t = at::rand({4, 4, 4, 4});
144   sliceFirst(t, 1, MemoryFormat::Contiguous);
145   sliceFirst(t, 2, MemoryFormat::Contiguous);
146   sliceFirst(t, 3, MemoryFormat::Contiguous);
147 
148   t = at::rand({4, 4, 4, 4});
149   sliceFirst(t, 2, MemoryFormat::Contiguous);
150   sliceFirst(t, 3, MemoryFormat::Contiguous);
151   sliceFirst(t, 1, MemoryFormat::Contiguous);
152 
153   t = at::rand({4, 4, 4, 4});
154   t.resize_({4, 4, 4, 4}, at::MemoryFormat::ChannelsLast);
155   sliceFirst(t, 1, MemoryFormat::ChannelsLast);
156   sliceFirst(t, 2, MemoryFormat::ChannelsLast);
157   sliceFirst(t, 3, MemoryFormat::ChannelsLast);
158 
159   t = at::rand({4, 4, 4, 4});
160   t.resize_({4, 4, 4, 4}, at::MemoryFormat::ChannelsLast);
161   sliceFirst(t, 2, MemoryFormat::ChannelsLast);
162   sliceFirst(t, 3, MemoryFormat::ChannelsLast);
163   sliceFirst(t, 1, MemoryFormat::ChannelsLast);
164 
165   t = at::rand({4, 4, 1, 1});
166   sliceFirst(t, 1, MemoryFormat::Contiguous);
167   t = at::rand({4, 4, 1, 1});
168   t.resize_({4, 4, 1, 1}, at::MemoryFormat::ChannelsLast);
169   sliceFirst(t, 1, MemoryFormat::ChannelsLast);
170 
171   t = at::rand({4, 1, 4, 4});
172   sliceFirst(t, 2, MemoryFormat::Contiguous);
173   sliceFirst(t, 3, MemoryFormat::Contiguous);
174   t = at::rand({4, 1, 4, 4});
175   t.resize_({4, 1, 4, 4}, at::MemoryFormat::ChannelsLast);
176   sliceFirst(t, 2, MemoryFormat::ChannelsLast);
177   sliceFirst(t, 3, MemoryFormat::ChannelsLast);
178 
179   t = at::rand({4, 1, 1, 4});
180   sliceFirst(t, 3, MemoryFormat::Contiguous);
181   t = at::rand({4, 1, 1, 4});
182   t.resize_({4, 1, 1, 4}, at::MemoryFormat::ChannelsLast);
183   sliceFirst(t, 3, MemoryFormat::ChannelsLast);
184 
185   t = at::rand({4, 1, 4, 1});
186   sliceFirst(t, 2, MemoryFormat::Contiguous);
187   t = at::rand({4, 1, 4, 1});
188   t.resize_({4, 1, 4, 1}, at::MemoryFormat::ChannelsLast);
189   // TODO: Should be able to handle this after accumulated permutation is implemented;
190   // [4, 1, 4, 1]@[4, 1, 1, 1] after slice becomes [4, 1, 1, 1]@[4, 1, 1, 1]
191   // sliceFirst(t, 2, MemoryFormat::ChannelsLast);
192   sliceFirst(t, 2, MemoryFormat::Contiguous);
193 }
194