xref: /aosp_15_r20/external/executorch/kernels/test/op_pixel_unshuffle_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
10 #include <executorch/kernels/test/TestUtil.h>
11 #include <executorch/kernels/test/supported_features.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 #include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
14 #include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
15 
16 #include <gtest/gtest.h>
17 
18 using namespace ::testing;
19 using exec_aten::Scalar;
20 using exec_aten::ScalarType;
21 using exec_aten::Tensor;
22 using torch::executor::testing::SupportedFeatures;
23 using torch::executor::testing::TensorFactory;
24 
25 class OpPixelUnshuffleOutTest : public OperatorTest {
26  protected:
op_pixel_unshuffle_out(const Tensor & self,int64_t upscale_factor,Tensor & out)27   Tensor& op_pixel_unshuffle_out(
28       const Tensor& self,
29       int64_t upscale_factor,
30       Tensor& out) {
31     return torch::executor::aten::pixel_unshuffle_outf(
32         context_, self, upscale_factor, out);
33   }
34 
35   template <ScalarType DTYPE_IN>
test_pixel_unshuffle()36   void test_pixel_unshuffle() {
37     TensorFactory<DTYPE_IN> tf_in;
38 
39     const std::vector<int32_t> sizes = {1, 1, 4, 4};
40     const std::vector<int32_t> out_sizes = {1, 4, 2, 2};
41 
42     // Destination for the pixel_unshuffle.
43     Tensor out = tf_in.zeros(out_sizes);
44 
45     op_pixel_unshuffle_out(
46         tf_in.make(
47             sizes, {0, 4, 1, 5, 8, 12, 9, 13, 2, 6, 3, 7, 10, 14, 11, 15}),
48         2,
49         out);
50     EXPECT_TENSOR_EQ(
51         out,
52         tf_in.make(
53             out_sizes, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}));
54   }
55 };
56 
57 //
58 // Correctness Tests
59 //
60 
61 /**
62  * Uses the function templates above to test all input dtypes.
63  */
TEST_F(OpPixelUnshuffleOutTest,AllRealDtypesSupported)64 TEST_F(OpPixelUnshuffleOutTest, AllRealDtypesSupported) {
65 #define ENUMERATE_TEST_ENTRY(ctype, dtype) \
66   test_pixel_unshuffle<ScalarType::dtype>();
67 
68   ET_FORALL_REAL_TYPES(ENUMERATE_TEST_ENTRY)
69 
70 #undef ENUMERATE_TEST_ENTRY
71 }
72 
TEST_F(OpPixelUnshuffleOutTest,LargerInputRank)73 TEST_F(OpPixelUnshuffleOutTest, LargerInputRank) {
74   TensorFactory<ScalarType::Int> tf;
75 
76   // Pixel unshuffle allows a 3D (or higher) input tensor, make sure the extra
77   // dimensions don't cause issues.
78   Tensor a = tf.ones(/*sizes=*/{1, 4, 1, 1, 4, 4});
79 
80   const std::vector<int32_t> out_sizes = {1, 4, 1, 4, 2, 2};
81   Tensor out = tf.zeros(out_sizes);
82 
83   op_pixel_unshuffle_out(a, 2, out);
84   EXPECT_TENSOR_EQ(out, tf.ones(out_sizes));
85 }
86 
87 // Mismatched shape tests.
TEST_F(OpPixelUnshuffleOutTest,InvalidInputShapeDies)88 TEST_F(OpPixelUnshuffleOutTest, InvalidInputShapeDies) {
89   TensorFactory<ScalarType::Int> tf;
90 
91   // Input tensors with invalid shapes. 7 is not divisible by downsample_factor
92   Tensor a = tf.ones(/*sizes=*/{1, 1, 7, 8});
93 
94   Tensor out = tf.zeros(/*sizes=*/{1, 4, 4, 4});
95 
96   // Using the wrong input shape should exit with an error code.
97   ET_EXPECT_KERNEL_FAILURE(context_, op_pixel_unshuffle_out(a, 2, out));
98 }
99 
TEST_F(OpPixelUnshuffleOutTest,WrongInputRankDies)100 TEST_F(OpPixelUnshuffleOutTest, WrongInputRankDies) {
101   TensorFactory<ScalarType::Int> tf;
102 
103   // Pixel unshuffle requires a 3D or higher input tensor.
104   Tensor a = tf.ones(/*sizes=*/{1, 2});
105   Tensor out = tf.zeros(/*sizes=*/{1, 2});
106 
107   // Using the wrong input rank should exit with an error code.
108   ET_EXPECT_KERNEL_FAILURE(context_, op_pixel_unshuffle_out(a, 2, out));
109 }
110 
TEST_F(OpPixelUnshuffleOutTest,DifferentDtypeDies)111 TEST_F(OpPixelUnshuffleOutTest, DifferentDtypeDies) {
112   TensorFactory<ScalarType::Int> tf;
113   TensorFactory<ScalarType::Float> tf_float;
114 
115   Tensor a = tf.ones(/*sizes=*/{1, 2, 12, 12});
116 
117   // Pixel unshuffle requires two tensors with the same dtype.
118   Tensor out = tf_float.zeros(/*sizes=*/{1, 18, 4, 4});
119 
120   // Using the wrong output dtype should exit with an error code.
121   ET_EXPECT_KERNEL_FAILURE(context_, op_pixel_unshuffle_out(a, 3, out));
122 }
123 
TEST_F(OpPixelUnshuffleOutTest,NegativeUpscaleFactorDies)124 TEST_F(OpPixelUnshuffleOutTest, NegativeUpscaleFactorDies) {
125   TensorFactory<ScalarType::Int> tf;
126   Tensor a = tf.ones(/*sizes=*/{1, 2, 12, 12});
127   Tensor out = tf.zeros(/*sizes=*/{1, 18, 4, 4});
128   // Using a negative upscale factor should exit with an error code.
129   ET_EXPECT_KERNEL_FAILURE(context_, op_pixel_unshuffle_out(a, -3, out));
130 }
131