xref: /aosp_15_r20/external/XNNPACK/test/transpose-normalization-tester.h (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC
2*4bdc9457SAndroid Build Coastguard Worker //
3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the
4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree.
5*4bdc9457SAndroid Build Coastguard Worker 
6*4bdc9457SAndroid Build Coastguard Worker #pragma once
7*4bdc9457SAndroid Build Coastguard Worker 
8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h>
9*4bdc9457SAndroid Build Coastguard Worker 
10*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h>
11*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/normalization.h>
12*4bdc9457SAndroid Build Coastguard Worker 
13*4bdc9457SAndroid Build Coastguard Worker 
14*4bdc9457SAndroid Build Coastguard Worker class TransposeNormalizationTester {
15*4bdc9457SAndroid Build Coastguard Worker  public:
num_dims(size_t num_dims)16*4bdc9457SAndroid Build Coastguard Worker   inline TransposeNormalizationTester& num_dims(size_t num_dims) {
17*4bdc9457SAndroid Build Coastguard Worker     assert(num_dims != 0);
18*4bdc9457SAndroid Build Coastguard Worker     this->num_dims_ = num_dims;
19*4bdc9457SAndroid Build Coastguard Worker     return *this;
20*4bdc9457SAndroid Build Coastguard Worker   }
21*4bdc9457SAndroid Build Coastguard Worker 
num_dims()22*4bdc9457SAndroid Build Coastguard Worker   inline size_t num_dims() const { return this->num_dims_; }
23*4bdc9457SAndroid Build Coastguard Worker 
element_size(size_t element_size)24*4bdc9457SAndroid Build Coastguard Worker   inline TransposeNormalizationTester& element_size(size_t element_size) {
25*4bdc9457SAndroid Build Coastguard Worker     this->element_size_ = element_size;
26*4bdc9457SAndroid Build Coastguard Worker     return *this;
27*4bdc9457SAndroid Build Coastguard Worker   }
28*4bdc9457SAndroid Build Coastguard Worker 
element_size()29*4bdc9457SAndroid Build Coastguard Worker   inline size_t element_size() const { return this->element_size_; }
30*4bdc9457SAndroid Build Coastguard Worker 
expected_dims(size_t expected_dims)31*4bdc9457SAndroid Build Coastguard Worker   inline TransposeNormalizationTester& expected_dims(size_t expected_dims) {
32*4bdc9457SAndroid Build Coastguard Worker     this->expected_dims_ = expected_dims;
33*4bdc9457SAndroid Build Coastguard Worker     return *this;
34*4bdc9457SAndroid Build Coastguard Worker   }
35*4bdc9457SAndroid Build Coastguard Worker 
expected_dims()36*4bdc9457SAndroid Build Coastguard Worker   inline size_t expected_dims() const { return this->expected_dims_; }
37*4bdc9457SAndroid Build Coastguard Worker 
expected_element_size(size_t expected_element_size)38*4bdc9457SAndroid Build Coastguard Worker   inline TransposeNormalizationTester& expected_element_size(size_t expected_element_size) {
39*4bdc9457SAndroid Build Coastguard Worker     this->expected_element_size_ = expected_element_size;
40*4bdc9457SAndroid Build Coastguard Worker     return *this;
41*4bdc9457SAndroid Build Coastguard Worker   }
42*4bdc9457SAndroid Build Coastguard Worker 
expected_element_size()43*4bdc9457SAndroid Build Coastguard Worker   inline size_t expected_element_size() const { return this->expected_element_size_; }
44*4bdc9457SAndroid Build Coastguard Worker 
shape(const std::vector<size_t> shape)45*4bdc9457SAndroid Build Coastguard Worker   inline TransposeNormalizationTester& shape(const std::vector<size_t> shape) {
46*4bdc9457SAndroid Build Coastguard Worker     assert(shape.size() <= XNN_MAX_TENSOR_DIMS);
47*4bdc9457SAndroid Build Coastguard Worker     this->shape_ = shape;
48*4bdc9457SAndroid Build Coastguard Worker     return *this;
49*4bdc9457SAndroid Build Coastguard Worker   }
50*4bdc9457SAndroid Build Coastguard Worker 
perm(const std::vector<size_t> perm)51*4bdc9457SAndroid Build Coastguard Worker   inline TransposeNormalizationTester& perm(const std::vector<size_t> perm) {
52*4bdc9457SAndroid Build Coastguard Worker     assert(perm.size() <= XNN_MAX_TENSOR_DIMS);
53*4bdc9457SAndroid Build Coastguard Worker     this->perm_ = perm;
54*4bdc9457SAndroid Build Coastguard Worker     return *this;
55*4bdc9457SAndroid Build Coastguard Worker   }
56*4bdc9457SAndroid Build Coastguard Worker 
input_stride(const std::vector<size_t> input_stride)57*4bdc9457SAndroid Build Coastguard Worker   inline TransposeNormalizationTester& input_stride(const std::vector<size_t> input_stride) {
58*4bdc9457SAndroid Build Coastguard Worker     assert(input_stride.size() <= XNN_MAX_TENSOR_DIMS);
59*4bdc9457SAndroid Build Coastguard Worker     this->input_stride_ = input_stride;
60*4bdc9457SAndroid Build Coastguard Worker     return *this;
61*4bdc9457SAndroid Build Coastguard Worker   }
62*4bdc9457SAndroid Build Coastguard Worker 
output_stride(const std::vector<size_t> output_stride)63*4bdc9457SAndroid Build Coastguard Worker   inline TransposeNormalizationTester& output_stride(const std::vector<size_t> output_stride) {
64*4bdc9457SAndroid Build Coastguard Worker     assert(output_stride.size() <= XNN_MAX_TENSOR_DIMS);
65*4bdc9457SAndroid Build Coastguard Worker     this->output_stride_ = output_stride;
66*4bdc9457SAndroid Build Coastguard Worker     return *this;
67*4bdc9457SAndroid Build Coastguard Worker   }
68*4bdc9457SAndroid Build Coastguard Worker 
expected_shape(const std::vector<size_t> expected_shape)69*4bdc9457SAndroid Build Coastguard Worker   inline TransposeNormalizationTester& expected_shape(const std::vector<size_t> expected_shape) {
70*4bdc9457SAndroid Build Coastguard Worker     this->expected_shape_ = expected_shape;
71*4bdc9457SAndroid Build Coastguard Worker     return *this;
72*4bdc9457SAndroid Build Coastguard Worker   }
73*4bdc9457SAndroid Build Coastguard Worker 
expected_shape()74*4bdc9457SAndroid Build Coastguard Worker   inline const std::vector<size_t>& expected_shape() const { return this->expected_shape_; }
75*4bdc9457SAndroid Build Coastguard Worker 
expected_perm(const std::vector<size_t> expected_perm)76*4bdc9457SAndroid Build Coastguard Worker   inline TransposeNormalizationTester& expected_perm(const std::vector<size_t> expected_perm) {
77*4bdc9457SAndroid Build Coastguard Worker     this->expected_perm_ = expected_perm;
78*4bdc9457SAndroid Build Coastguard Worker     return *this;
79*4bdc9457SAndroid Build Coastguard Worker   }
80*4bdc9457SAndroid Build Coastguard Worker 
expected_perm()81*4bdc9457SAndroid Build Coastguard Worker   inline const std::vector<size_t>& expected_perm() const { return this->expected_perm_; }
82*4bdc9457SAndroid Build Coastguard Worker 
expected_input_stride(const std::vector<size_t> expected_input_stride)83*4bdc9457SAndroid Build Coastguard Worker   inline TransposeNormalizationTester& expected_input_stride(const std::vector<size_t> expected_input_stride) {
84*4bdc9457SAndroid Build Coastguard Worker     this->expected_input_stride_ = expected_input_stride;
85*4bdc9457SAndroid Build Coastguard Worker     return *this;
86*4bdc9457SAndroid Build Coastguard Worker   }
87*4bdc9457SAndroid Build Coastguard Worker 
expected_output_stride(const std::vector<size_t> expected_output_stride)88*4bdc9457SAndroid Build Coastguard Worker   inline TransposeNormalizationTester& expected_output_stride(const std::vector<size_t> expected_output_stride) {
89*4bdc9457SAndroid Build Coastguard Worker     this->expected_output_stride_ = expected_output_stride;
90*4bdc9457SAndroid Build Coastguard Worker     return *this;
91*4bdc9457SAndroid Build Coastguard Worker   }
92*4bdc9457SAndroid Build Coastguard Worker 
expected_input_stride()93*4bdc9457SAndroid Build Coastguard Worker   inline const std::vector<size_t>& expected_input_stride() const { return this->expected_input_stride_; }
94*4bdc9457SAndroid Build Coastguard Worker 
expected_output_stride()95*4bdc9457SAndroid Build Coastguard Worker   inline const std::vector<size_t>& expected_output_stride() const { return this->expected_output_stride_; }
96*4bdc9457SAndroid Build Coastguard Worker 
calculate_expected_input_stride()97*4bdc9457SAndroid Build Coastguard Worker   inline TransposeNormalizationTester& calculate_expected_input_stride() {
98*4bdc9457SAndroid Build Coastguard Worker     expected_input_stride_.resize(expected_dims());
99*4bdc9457SAndroid Build Coastguard Worker     expected_input_stride_[expected_dims() - 1] = expected_element_size();
100*4bdc9457SAndroid Build Coastguard Worker     for(size_t i = expected_dims() - 1; i-- != 0;) {
101*4bdc9457SAndroid Build Coastguard Worker       expected_input_stride_[i] = expected_input_stride_[i + 1] * expected_shape_[i + 1];
102*4bdc9457SAndroid Build Coastguard Worker     }
103*4bdc9457SAndroid Build Coastguard Worker     return *this;
104*4bdc9457SAndroid Build Coastguard Worker   }
105*4bdc9457SAndroid Build Coastguard Worker 
calculate_expected_output_stride()106*4bdc9457SAndroid Build Coastguard Worker   inline TransposeNormalizationTester& calculate_expected_output_stride() {
107*4bdc9457SAndroid Build Coastguard Worker     expected_output_stride_.resize(expected_dims());
108*4bdc9457SAndroid Build Coastguard Worker     expected_output_stride_[expected_dims() - 1] = expected_element_size();
109*4bdc9457SAndroid Build Coastguard Worker     for(size_t i = expected_dims() - 1; i-- != 0;) {
110*4bdc9457SAndroid Build Coastguard Worker       expected_output_stride_[i] = expected_output_stride_[i + 1]
111*4bdc9457SAndroid Build Coastguard Worker           * expected_shape_[expected_perm_[i + 1]];
112*4bdc9457SAndroid Build Coastguard Worker     }
113*4bdc9457SAndroid Build Coastguard Worker     return *this;
114*4bdc9457SAndroid Build Coastguard Worker   }
115*4bdc9457SAndroid Build Coastguard Worker 
Test()116*4bdc9457SAndroid Build Coastguard Worker   void Test() const {
117*4bdc9457SAndroid Build Coastguard Worker     size_t actual_element_size;
118*4bdc9457SAndroid Build Coastguard Worker     size_t actual_normalized_dims;
119*4bdc9457SAndroid Build Coastguard Worker     std::vector<size_t> actual_normalized_shape(num_dims());
120*4bdc9457SAndroid Build Coastguard Worker     std::vector<size_t> actual_normalized_perm(num_dims());
121*4bdc9457SAndroid Build Coastguard Worker     std::vector<size_t> actual_normalized_input_stride(num_dims());
122*4bdc9457SAndroid Build Coastguard Worker     std::vector<size_t> actual_normalized_output_stride(num_dims());
123*4bdc9457SAndroid Build Coastguard Worker 
124*4bdc9457SAndroid Build Coastguard Worker     xnn_normalize_transpose_permutation(num_dims(), element_size(), perm_.data(),
125*4bdc9457SAndroid Build Coastguard Worker                                         shape_.data(), input_stride_.empty() ? nullptr : input_stride_.data(),
126*4bdc9457SAndroid Build Coastguard Worker                                         output_stride_.empty() ? nullptr : output_stride_.data(),
127*4bdc9457SAndroid Build Coastguard Worker                                         &actual_normalized_dims, &actual_element_size, actual_normalized_perm.data(),
128*4bdc9457SAndroid Build Coastguard Worker                                         actual_normalized_shape.data(), actual_normalized_input_stride.data(),
129*4bdc9457SAndroid Build Coastguard Worker                                         actual_normalized_output_stride.data());
130*4bdc9457SAndroid Build Coastguard Worker     EXPECT_EQ(expected_element_size(), actual_element_size);
131*4bdc9457SAndroid Build Coastguard Worker     EXPECT_EQ(expected_dims(), actual_normalized_dims);
132*4bdc9457SAndroid Build Coastguard Worker 
133*4bdc9457SAndroid Build Coastguard Worker     for (size_t i = 0; i < expected_dims(); ++i) {
134*4bdc9457SAndroid Build Coastguard Worker       EXPECT_EQ(expected_shape()[i], actual_normalized_shape[i]);
135*4bdc9457SAndroid Build Coastguard Worker       EXPECT_EQ(expected_perm()[i], actual_normalized_perm[i]);
136*4bdc9457SAndroid Build Coastguard Worker       EXPECT_EQ(expected_input_stride()[i], actual_normalized_input_stride[i]);
137*4bdc9457SAndroid Build Coastguard Worker       EXPECT_EQ(expected_output_stride()[i], actual_normalized_output_stride[i]);
138*4bdc9457SAndroid Build Coastguard Worker     }
139*4bdc9457SAndroid Build Coastguard Worker   }
140*4bdc9457SAndroid Build Coastguard Worker 
141*4bdc9457SAndroid Build Coastguard Worker  private:
142*4bdc9457SAndroid Build Coastguard Worker   size_t num_dims_;
143*4bdc9457SAndroid Build Coastguard Worker   size_t element_size_;
144*4bdc9457SAndroid Build Coastguard Worker   size_t expected_dims_;
145*4bdc9457SAndroid Build Coastguard Worker   size_t expected_element_size_;
146*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> shape_;
147*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> perm_;
148*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> input_stride_;
149*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> output_stride_;
150*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> expected_shape_;
151*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> expected_perm_;
152*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> expected_input_stride_;
153*4bdc9457SAndroid Build Coastguard Worker   std::vector<size_t> expected_output_stride_;
154*4bdc9457SAndroid Build Coastguard Worker };
155