1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15
16 #include <gtest/gtest.h>
17
18 using namespace ::testing;
19
20 class OpMaxPool2DWithIndicesOutTest : public OperatorTest {
21 protected:
22 ::std::tuple<exec_aten::Tensor&, exec_aten::Tensor&>
op_max_pool2d_with_indices_out(const exec_aten::Tensor & self,exec_aten::ArrayRef<int64_t> kernel_size,exec_aten::ArrayRef<int64_t> stride,exec_aten::ArrayRef<int64_t> padding,exec_aten::ArrayRef<int64_t> dilation,bool ceil_mode,exec_aten::Tensor & out,exec_aten::Tensor & indices)23 op_max_pool2d_with_indices_out(
24 const exec_aten::Tensor& self,
25 exec_aten::ArrayRef<int64_t> kernel_size,
26 exec_aten::ArrayRef<int64_t> stride,
27 exec_aten::ArrayRef<int64_t> padding,
28 exec_aten::ArrayRef<int64_t> dilation,
29 bool ceil_mode,
30 exec_aten::Tensor& out,
31 exec_aten::Tensor& indices) {
32 return torch::executor::aten::max_pool2d_with_indices_outf(
33 context_,
34 self,
35 kernel_size,
36 stride,
37 padding,
38 dilation,
39 ceil_mode,
40 out,
41 indices);
42 }
43 };
44
TEST_F(OpMaxPool2DWithIndicesOutTest,SanityTest4D)45 TEST_F(OpMaxPool2DWithIndicesOutTest, SanityTest4D) {
46 torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
47 torch::executor::testing::TensorFactory<exec_aten::ScalarType::Long> tfLong;
48
49 exec_aten::Tensor self = tfFloat.make(
50 {2, 3, 5, 5},
51 {28.75, -38.875, -7.0, -13.5, 70.75, 53.75, 69.625, 97.375,
52 25.375, 99.5, -72.125, -87.25, 79.25, 42.0, -24.75, -15.5,
53 12.5, -86.0, 85.5, -0.25, 67.125, 77.0, 53.375, -61.125,
54 50.0, 3.875, 42.25, -37.375, 51.0, -60.875, 87.0, 32.25,
55 73.5, 68.875, -84.375, -98.75, -30.125, 94.25, 1.625, -86.25,
56 -56.5, -68.0, 74.25, -51.25, 8.125, 71.375, -53.125, 4.875,
57 77.5, -89.875, 4.5, -46.5, -46.375, -92.625, -85.5, -23.0,
58 -8.875, -12.0, -46.625, -88.625, 66.75, 87.75, 90.25, -45.0,
59 -78.125, 63.25, 28.75, 28.125, -30.375, 17.75, -16.0, 5.0,
60 11.125, 88.625, -47.625, 72.25, 32.0, -7.625, 61.625, -63.125,
61 -22.75, 83.125, -40.375, -78.25, 49.5, -39.125, -89.625, 47.875,
62 -61.375, 7.75, 16.875, -96.375, -22.5, 8.5, 74.25, 12.75,
63 90.125, 73.875, -71.75, -10.0, 41.25, 1.125, 10.375, -34.625,
64 29.75, -27.5, 26.625, 81.0, -8.875, 17.625, 84.375, -23.625,
65 -53.875, -26.0, -67.375, -90.75, 16.375, 45.625, 99.5, 56.25,
66 -87.625, -65.5, -79.75, 31.875, 79.75, 6.375, 44.625, -55.25,
67 -5.5, -68.875, -38.625, 54.125, -3.125, 5.75, 29.25, -39.5,
68 26.75, 68.25, -24.625, -53.0, 51.0, 90.625, 65.375, 43.875,
69 90.875, -41.625, 99.875, 6.375, -31.25, -94.0});
70 ::std::vector<int64_t> kernel_size_vec = {2, 2};
71 exec_aten::ArrayRef<int64_t> kernel_size = exec_aten::ArrayRef<int64_t>(
72 kernel_size_vec.data(), kernel_size_vec.size());
73 ::std::vector<int64_t> stride_vec = {1, 1};
74 exec_aten::ArrayRef<int64_t> stride =
75 exec_aten::ArrayRef<int64_t>(stride_vec.data(), stride_vec.size());
76 ::std::vector<int64_t> padding_vec = {0, 0};
77 exec_aten::ArrayRef<int64_t> padding =
78 exec_aten::ArrayRef<int64_t>(padding_vec.data(), padding_vec.size());
79 ::std::vector<int64_t> dilation_vec = {1, 1};
80 exec_aten::ArrayRef<int64_t> dilation =
81 exec_aten::ArrayRef<int64_t>(dilation_vec.data(), dilation_vec.size());
82 bool ceil_mode = false;
83 exec_aten::Tensor out = tfFloat.zeros({2, 3, 4, 4});
84 exec_aten::Tensor indices = tfLong.zeros({2, 3, 4, 4});
85 exec_aten::Tensor out_expected = tfFloat.make(
86 {2, 3, 4, 4},
87 {69.625, 97.375, 97.375, 99.5, 69.625, 97.375, 97.375, 99.5, 12.5,
88 79.25, 85.5, 85.5, 77.0, 77.0, 85.5, 85.5, 87.0, 73.5,
89 73.5, 68.875, 87.0, 94.25, 94.25, 68.875, -30.125, 94.25, 94.25,
90 8.125, 71.375, 74.25, 77.5, 77.5, 4.5, -8.875, -12.0, -46.625,
91 87.75, 90.25, 90.25, -45.0, 87.75, 90.25, 90.25, 17.75, 63.25,
92 28.75, 88.625, 88.625, 83.125, 83.125, 61.625, 61.625, 83.125, 83.125,
93 47.875, 49.5, 16.875, 47.875, 47.875, 74.25, 90.125, 90.125, 73.875,
94 74.25, 41.25, 81.0, 81.0, 29.75, 84.375, 81.0, 81.0, 17.625,
95 84.375, 45.625, 99.5, 99.5, 16.375, 45.625, 99.5, 99.5, 54.125,
96 54.125, 5.75, 29.25, 54.125, 68.25, 68.25, 29.25, 90.625, 90.625,
97 68.25, 90.875, 99.875, 99.875, 65.375, 90.875});
98 exec_aten::Tensor indices_expected = tfLong.make(
99 {2, 3, 4, 4},
100 {6, 7, 7, 9, 6, 7, 7, 9, 16, 12, 18, 18, 21, 21, 18, 18,
101 5, 7, 7, 8, 5, 12, 12, 8, 11, 12, 12, 19, 20, 17, 23, 23,
102 0, 6, 7, 8, 11, 12, 12, 13, 11, 12, 12, 19, 15, 16, 23, 23,
103 6, 6, 3, 3, 6, 6, 12, 9, 15, 12, 12, 19, 21, 21, 22, 19,
104 0, 7, 7, 4, 10, 7, 7, 9, 10, 17, 18, 18, 16, 17, 18, 18,
105 6, 6, 8, 9, 6, 12, 12, 9, 16, 16, 12, 19, 21, 21, 17, 19});
106 op_max_pool2d_with_indices_out(
107 self, kernel_size, stride, padding, dilation, ceil_mode, out, indices);
108 EXPECT_TENSOR_CLOSE(out, out_expected);
109 EXPECT_TENSOR_CLOSE(indices, indices_expected);
110 }
111
TEST_F(OpMaxPool2DWithIndicesOutTest,SanityTest4D_2)112 TEST_F(OpMaxPool2DWithIndicesOutTest, SanityTest4D_2) {
113 torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
114 torch::executor::testing::TensorFactory<exec_aten::ScalarType::Long> tfLong;
115
116 exec_aten::Tensor self = tfFloat.make(
117 {2, 3, 8, 8},
118 {47.375, -18.625, 12.0, 5.375, -40.375, -75.875, 51.0, -48.25,
119 -5.0, -50.625, -96.875, -53.25, 82.25, 0.125, -13.125, 89.75,
120 53.0, 67.125, 79.625, 55.0, 77.75, -85.5, 35.25, -44.625,
121 25.625, -44.625, -27.75, 75.875, -96.375, 37.375, 67.0, -73.75,
122 -77.25, 86.5, -27.5, 93.25, -40.0, -66.125, 57.5, 16.0,
123 -94.125, 18.125, -22.625, -94.25, 41.5, 52.875, 36.875, 0.875,
124 92.75, -94.375, -80.375, -55.625, -11.0, -23.75, -8.375, 33.625,
125 -27.25, -2.25, 70.25, -60.375, -96.125, -33.25, 43.375, -38.75,
126 -58.5, -32.375, -83.75, -15.375, -4.375, 7.625, 58.875, 59.25,
127 -12.5, 85.75, 67.125, -32.0, 32.125, 41.0, 47.625, 72.25,
128 -36.5, 5.125, -30.75, -37.625, -76.25, 96.25, 55.125, -88.0,
129 -78.0, -18.125, -23.75, -8.75, -3.25, 82.0, 10.5, 48.625,
130 -64.0, 11.5, -66.75, 40.75, 38.25, 44.875, 21.0, 7.0,
131 30.125, -59.125, -31.5, -56.0, 75.0, -31.375, -25.0, -23.75,
132 48.625, -20.5, 52.375, -27.25, 84.75, -39.125, -56.5, -80.875,
133 -60.5, -26.125, 42.375, 82.625, 43.375, -5.5, 73.375, 71.25,
134 -12.25, 58.125, -1.375, -97.625, -43.375, 64.125, -71.125, -55.25,
135 5.375, 56.0, 20.0, -70.125, 78.625, 20.625, 62.5, 71.375,
136 93.0, -50.25, 26.125, -16.875, 50.375, -96.25, 32.75, 53.375,
137 -61.5, -48.375, 72.375, -15.5, 2.875, -66.0, -31.75, -1.0,
138 65.625, -46.75, 69.875, 74.5, 53.25, 13.25, -41.625, 96.625,
139 75.875, 97.625, -98.25, -32.875, -75.0, -31.75, 72.875, -89.625,
140 56.625, 10.375, -40.5, -53.625, 5.75, -52.75, 19.5, 71.125,
141 80.875, 16.25, -32.625, 88.25, -98.5, 92.5, -68.25, -83.25,
142 43.375, 20.25, 1.875, -69.25, -92.75, 80.0, -17.125, -52.125,
143 -82.0, 64.5, -7.875, -0.125, -74.5, 46.375, -68.875, 63.125,
144 88.0, 12.0, 87.75, 58.125, -83.375, 21.125, -34.0, -60.125,
145 37.75, 17.125, -44.875, -22.875, 72.5, -99.0, -30.25, 81.125,
146 -73.875, -90.25, -30.0, 64.75, -57.875, -4.0, -33.125, -22.875,
147 -74.0, -46.0, 4.125, -33.375, -72.625, 83.25, 25.125, 54.25,
148 43.0, -22.5, -22.25, -72.375, 80.25, 5.125, -7.0, 73.5,
149 -28.625, -33.75, -91.125, -76.75, -67.625, 79.25, -56.625, 85.0,
150 -53.125, 11.0, -44.0, -75.5, -23.375, 54.5, 21.5, 36.5,
151 33.0, 83.375, 81.625, 3.5, -7.5, 7.125, -87.5, 12.125,
152 9.625, -90.125, -3.75, 60.375, -80.75, -8.375, 59.25, 30.25,
153 -75.75, 78.25, 13.5, -69.125, 48.0, -56.5, -10.5, -65.5,
154 -80.875, 7.0, 64.375, -81.0, 48.625, -33.25, -53.75, -68.75,
155 -12.0, 3.0, -56.625, 17.0, 36.0, 28.25, 36.75, -39.5,
156 -23.5, -53.875, -17.375, -84.375, 0.625, 12.125, -45.5, 14.25,
157 75.5, -92.875, 9.125, 38.25, 20.375, -46.625, 59.875, -79.375,
158 52.625, 6.75, 11.875, -57.5, 91.75, -55.5, -86.875, 11.375,
159 -41.0, 1.25, 41.375, 70.5, 43.625, -65.625, 7.625, -90.875,
160 84.5, 26.0, -52.25, 4.125, 27.75, 17.875, -95.75, 66.5,
161 -88.875, -67.125, 21.5, -0.875, 35.875, 53.0, 62.0, -78.5,
162 -65.125, -57.375, 5.625, 54.75, 70.0, -94.25, 50.625, -0.25,
163 89.5, -32.0, -83.125, -48.625, -23.0, 75.5, 44.0, 75.0,
164 58.0, 18.125, -94.75, -69.375, 70.375, -51.75, -86.75, 81.5,
165 75.75, 61.625, -14.5, -60.75, -58.125, -3.25, 36.25, -95.125});
166 ::std::vector<int64_t> kernel_size_vec = {2, 3};
167 exec_aten::ArrayRef<int64_t> kernel_size = exec_aten::ArrayRef<int64_t>(
168 kernel_size_vec.data(), kernel_size_vec.size());
169 ::std::vector<int64_t> stride_vec = {2, 1};
170 exec_aten::ArrayRef<int64_t> stride =
171 exec_aten::ArrayRef<int64_t>(stride_vec.data(), stride_vec.size());
172 ::std::vector<int64_t> padding_vec = {0, 1};
173 exec_aten::ArrayRef<int64_t> padding =
174 exec_aten::ArrayRef<int64_t>(padding_vec.data(), padding_vec.size());
175 ::std::vector<int64_t> dilation_vec = {2, 1};
176 exec_aten::ArrayRef<int64_t> dilation =
177 exec_aten::ArrayRef<int64_t>(dilation_vec.data(), dilation_vec.size());
178 bool ceil_mode = false;
179 exec_aten::Tensor out = tfFloat.zeros({2, 3, 3, 8});
180 exec_aten::Tensor indices = tfLong.zeros({2, 3, 3, 8});
181 exec_aten::Tensor out_expected = tfFloat.make(
182 {2, 3, 3, 8},
183 {67.125, 79.625, 79.625, 79.625, 77.75, 77.75, 51.0, 51.0, 86.5,
184 86.5, 93.25, 93.25, 93.25, 77.75, 57.5, 57.5, 92.75, 92.75,
185 93.25, 93.25, 93.25, 57.5, 57.5, 57.5, 5.125, 5.125, 5.125,
186 -4.375, 96.25, 96.25, 96.25, 59.25, 11.5, 11.5, 40.75, 40.75,
187 96.25, 96.25, 96.25, 55.125, 48.625, 52.375, 52.375, 84.75, 84.75,
188 84.75, 44.875, 21.0, 93.0, 93.0, 58.125, 50.375, 64.125, 64.125,
189 64.125, 53.375, 93.0, 93.0, 74.5, 74.5, 74.5, 53.25, 96.625,
190 96.625, 65.625, 69.875, 74.5, 74.5, 74.5, 53.25, 96.625, 96.625,
191 88.0, 88.0, 87.75, 87.75, 80.0, 80.0, 80.0, -17.125, 88.0,
192 88.0, 87.75, 87.75, 64.75, 21.125, 21.125, -22.875, 43.0, 43.0,
193 64.75, 80.25, 80.25, 80.25, 73.5, 73.5, 11.0, 11.0, 60.375,
194 60.375, 60.375, 59.25, 59.25, 59.25, 9.625, 64.375, 64.375, 64.375,
195 60.375, 59.25, 59.25, 59.25, 7.0, 64.375, 64.375, 64.375, 48.625,
196 48.625, 14.25, 14.25, 84.5, 84.5, 26.0, 91.75, 91.75, 91.75,
197 66.5, 66.5, 84.5, 84.5, 54.75, 70.0, 70.0, 70.0, 66.5,
198 66.5, 58.0, 58.0, 54.75, 70.375, 70.375, 70.375, 81.5, 81.5});
199 exec_aten::Tensor indices_expected = tfLong.make(
200 {2, 3, 3, 8},
201 {17, 18, 18, 18, 20, 20, 6, 6, 33, 33, 35, 35, 35, 20, 38, 38, 48, 48,
202 35, 35, 35, 38, 38, 38, 17, 17, 17, 4, 21, 21, 21, 7, 33, 33, 35, 35,
203 21, 21, 21, 22, 48, 50, 50, 52, 52, 52, 37, 38, 16, 16, 1, 20, 5, 5,
204 5, 23, 16, 16, 35, 35, 35, 36, 39, 39, 32, 34, 35, 35, 35, 36, 39, 39,
205 16, 16, 18, 18, 5, 5, 5, 6, 16, 16, 18, 18, 35, 21, 21, 39, 48, 48,
206 35, 52, 52, 52, 55, 55, 1, 1, 19, 19, 19, 22, 22, 22, 16, 34, 34, 34,
207 19, 22, 22, 22, 33, 34, 34, 34, 36, 36, 55, 55, 16, 16, 17, 4, 4, 4,
208 23, 23, 16, 16, 35, 36, 36, 36, 23, 23, 48, 48, 35, 52, 52, 52, 55, 55});
209 op_max_pool2d_with_indices_out(
210 self, kernel_size, stride, padding, dilation, ceil_mode, out, indices);
211 EXPECT_TENSOR_CLOSE(out, out_expected);
212 EXPECT_TENSOR_CLOSE(indices, indices_expected);
213 }
214
TEST_F(OpMaxPool2DWithIndicesOutTest,SanityTest3D)215 TEST_F(OpMaxPool2DWithIndicesOutTest, SanityTest3D) {
216 torch::executor::testing::TensorFactory<exec_aten::ScalarType::Float> tfFloat;
217 torch::executor::testing::TensorFactory<exec_aten::ScalarType::Long> tfLong;
218
219 exec_aten::Tensor self = tfFloat.make(
220 {2, 12, 12},
221 {73.625, 15.5, 30.875, 89.25, -55.625, -62.875, 25.0, -50.75,
222 -47.125, 12.125, -73.125, -89.875, 53.625, -63.125, -44.375, 86.0,
223 53.625, -84.125, -6.75, 20.125, -24.25, -43.5, -11.125, -34.625,
224 -7.5, -13.0, 74.375, 33.75, -44.875, 49.125, -59.5, -88.5,
225 -46.5, -33.0, 48.125, 80.875, 38.875, -58.875, 0.875, -48.625,
226 -46.125, -87.25, -66.625, 14.375, -68.25, -77.0, -50.5, -15.625,
227 86.875, 89.875, -37.25, 7.5, -16.75, -6.625, 55.875, 40.5,
228 -83.875, -77.625, -55.375, 32.25, -17.5, -83.125, 43.375, 17.5,
229 2.75, -51.25, 20.25, -77.375, -68.0, -72.625, -47.5, -78.875,
230 -49.375, -52.125, -7.125, -25.125, -77.5, -8.625, -3.125, 99.375,
231 71.875, 19.625, 21.125, -47.0, 44.5, -90.625, -75.75, -87.25,
232 79.75, -42.125, -90.0, 22.5, 2.5, 73.5, -65.125, -50.375,
233 -71.625, 19.25, -60.125, -91.75, -43.375, -60.875, 16.375, 86.875,
234 -93.25, -78.375, 82.5, 14.75, 20.125, 19.625, 33.875, 84.875,
235 60.625, 41.5, 2.0, -4.875, -52.5, 74.375, -40.125, -60.125,
236 88.5, 51.875, -59.75, 49.5, -81.0, -93.5, 43.0, -99.625,
237 40.375, -84.0, 76.5, 27.5, 59.125, -19.5, -55.25, -50.5,
238 81.875, 86.0, -19.75, 51.5, 70.875, -90.5, 74.375, 62.5,
239 -0.625, -31.375, -71.25, 42.75, 42.5, 67.125, 26.125, 85.375,
240 -11.75, -34.375, -97.125, 5.875, -45.25, -50.125, 74.5, -62.0,
241 -81.5, -84.875, 70.75, 33.375, -27.5, -54.25, 94.25, 74.625,
242 -30.0, 16.875, 39.875, -0.5, 0.25, -80.125, 85.375, 42.5,
243 13.125, -82.375, -30.75, -95.75, 34.75, -60.125, -51.625, -10.375,
244 -30.75, -65.5, -96.0, -95.25, 60.125, -33.125, 67.125, -88.0,
245 -26.125, 75.875, 29.5, -27.75, -28.875, 21.375, -2.0, -29.125,
246 11.0, -68.5, -36.75, -85.375, -4.625, 9.0, -31.75, -63.5,
247 98.75, -1.375, 17.125, 61.25, -50.5, 41.375, -18.375, -92.25,
248 -50.375, -40.625, 14.0, 18.5, 22.625, 10.375, 58.875, -86.0,
249 -9.625, 5.125, -69.625, -50.25, -26.875, 26.0, 57.25, -94.5,
250 -53.125, 98.0, 37.375, 35.0, -20.125, -9.375, -13.375, -41.125,
251 41.75, 95.125, 82.75, -71.5, -43.125, -37.75, -91.25, -14.0,
252 -55.5, 52.5, -30.125, 93.875, -26.75, 83.125, 2.625, -63.75,
253 52.875, 31.25, 57.625, 42.75, -2.875, -45.5, 99.0, -18.625,
254 38.375, 88.125, 36.625, -36.875, -35.25, 13.125, -31.875, -50.875,
255 10.5, -38.75, 1.625, 67.125, 3.0, -87.0, 42.0, -31.25,
256 -77.875, -7.125, -94.0, -99.0, 24.75, -21.625, -98.375, 15.875});
257 ::std::vector<int64_t> kernel_size_vec = {4, 3};
258 exec_aten::ArrayRef<int64_t> kernel_size = exec_aten::ArrayRef<int64_t>(
259 kernel_size_vec.data(), kernel_size_vec.size());
260 ::std::vector<int64_t> stride_vec = {3, 2};
261 exec_aten::ArrayRef<int64_t> stride =
262 exec_aten::ArrayRef<int64_t>(stride_vec.data(), stride_vec.size());
263 ::std::vector<int64_t> padding_vec = {2, 1};
264 exec_aten::ArrayRef<int64_t> padding =
265 exec_aten::ArrayRef<int64_t>(padding_vec.data(), padding_vec.size());
266 ::std::vector<int64_t> dilation_vec = {1, 2};
267 exec_aten::ArrayRef<int64_t> dilation =
268 exec_aten::ArrayRef<int64_t>(dilation_vec.data(), dilation_vec.size());
269 bool ceil_mode = false;
270 exec_aten::Tensor out = tfFloat.zeros({2, 5, 5});
271 exec_aten::Tensor indices = tfLong.zeros({2, 5, 5});
272 exec_aten::Tensor out_expected = tfFloat.make(
273 {2, 5, 5},
274 {89.25, 89.25, 89.25, 20.125, 20.125, 89.875, 89.875, 86.0, 49.125,
275 80.875, 89.875, 89.875, 99.375, 99.375, 99.375, 84.875, 84.875, 86.875,
276 86.875, 86.875, 51.875, 86.0, 86.0, 86.0, 62.5, 42.75, 67.125,
277 85.375, 85.375, 85.375, 75.875, 75.875, 42.5, 42.5, 74.625, 75.875,
278 98.0, 98.0, 98.0, 61.25, 95.125, 98.0, 98.0, 98.0, 93.875,
279 88.125, 88.125, 13.125, 13.125, 67.125});
280 exec_aten::Tensor indices_expected = tfLong.make(
281 {2, 5, 5}, {3, 3, 3, 19, 19, 49, 49, 15, 29, 35, 49, 49, 79,
282 79, 79, 111, 111, 103, 103, 103, 121, 137, 137, 137, 143, 3,
283 5, 7, 7, 7, 49, 49, 31, 31, 23, 49, 89, 89, 89,
284 67, 97, 89, 89, 89, 107, 121, 121, 125, 125, 131});
285 op_max_pool2d_with_indices_out(
286 self, kernel_size, stride, padding, dilation, ceil_mode, out, indices);
287 EXPECT_TENSOR_CLOSE(out, out_expected);
288 EXPECT_TENSOR_CLOSE(indices, indices_expected);
289 }
290