1*4bdc9457SAndroid Build Coastguard Worker // Copyright 2022 Google LLC 2*4bdc9457SAndroid Build Coastguard Worker // 3*4bdc9457SAndroid Build Coastguard Worker // This source code is licensed under the BSD-style license found in the 4*4bdc9457SAndroid Build Coastguard Worker // LICENSE file in the root directory of this source tree. 5*4bdc9457SAndroid Build Coastguard Worker 6*4bdc9457SAndroid Build Coastguard Worker #pragma once 7*4bdc9457SAndroid Build Coastguard Worker 8*4bdc9457SAndroid Build Coastguard Worker #include <gtest/gtest.h> 9*4bdc9457SAndroid Build Coastguard Worker 10*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack.h> 11*4bdc9457SAndroid Build Coastguard Worker #include <xnnpack/normalization.h> 12*4bdc9457SAndroid Build Coastguard Worker 13*4bdc9457SAndroid Build Coastguard Worker 14*4bdc9457SAndroid Build Coastguard Worker class TransposeNormalizationTester { 15*4bdc9457SAndroid Build Coastguard Worker public: num_dims(size_t num_dims)16*4bdc9457SAndroid Build Coastguard Worker inline TransposeNormalizationTester& num_dims(size_t num_dims) { 17*4bdc9457SAndroid Build Coastguard Worker assert(num_dims != 0); 18*4bdc9457SAndroid Build Coastguard Worker this->num_dims_ = num_dims; 19*4bdc9457SAndroid Build Coastguard Worker return *this; 20*4bdc9457SAndroid Build Coastguard Worker } 21*4bdc9457SAndroid Build Coastguard Worker num_dims()22*4bdc9457SAndroid Build Coastguard Worker inline size_t num_dims() const { return this->num_dims_; } 23*4bdc9457SAndroid Build Coastguard Worker element_size(size_t element_size)24*4bdc9457SAndroid Build Coastguard Worker inline TransposeNormalizationTester& element_size(size_t element_size) { 25*4bdc9457SAndroid Build Coastguard Worker this->element_size_ = element_size; 26*4bdc9457SAndroid Build Coastguard Worker return *this; 27*4bdc9457SAndroid Build Coastguard Worker } 28*4bdc9457SAndroid Build Coastguard Worker element_size()29*4bdc9457SAndroid Build Coastguard Worker inline size_t element_size() const { return this->element_size_; } 30*4bdc9457SAndroid Build Coastguard Worker expected_dims(size_t expected_dims)31*4bdc9457SAndroid Build Coastguard Worker inline TransposeNormalizationTester& expected_dims(size_t expected_dims) { 32*4bdc9457SAndroid Build Coastguard Worker this->expected_dims_ = expected_dims; 33*4bdc9457SAndroid Build Coastguard Worker return *this; 34*4bdc9457SAndroid Build Coastguard Worker } 35*4bdc9457SAndroid Build Coastguard Worker expected_dims()36*4bdc9457SAndroid Build Coastguard Worker inline size_t expected_dims() const { return this->expected_dims_; } 37*4bdc9457SAndroid Build Coastguard Worker expected_element_size(size_t expected_element_size)38*4bdc9457SAndroid Build Coastguard Worker inline TransposeNormalizationTester& expected_element_size(size_t expected_element_size) { 39*4bdc9457SAndroid Build Coastguard Worker this->expected_element_size_ = expected_element_size; 40*4bdc9457SAndroid Build Coastguard Worker return *this; 41*4bdc9457SAndroid Build Coastguard Worker } 42*4bdc9457SAndroid Build Coastguard Worker expected_element_size()43*4bdc9457SAndroid Build Coastguard Worker inline size_t expected_element_size() const { return this->expected_element_size_; } 44*4bdc9457SAndroid Build Coastguard Worker shape(const std::vector<size_t> shape)45*4bdc9457SAndroid Build Coastguard Worker inline TransposeNormalizationTester& shape(const std::vector<size_t> shape) { 46*4bdc9457SAndroid Build Coastguard Worker assert(shape.size() <= XNN_MAX_TENSOR_DIMS); 47*4bdc9457SAndroid Build Coastguard Worker this->shape_ = shape; 48*4bdc9457SAndroid Build Coastguard Worker return *this; 49*4bdc9457SAndroid Build Coastguard Worker } 50*4bdc9457SAndroid Build Coastguard Worker perm(const std::vector<size_t> perm)51*4bdc9457SAndroid Build Coastguard Worker inline TransposeNormalizationTester& perm(const std::vector<size_t> perm) { 52*4bdc9457SAndroid Build Coastguard Worker assert(perm.size() <= XNN_MAX_TENSOR_DIMS); 53*4bdc9457SAndroid Build Coastguard Worker this->perm_ = perm; 54*4bdc9457SAndroid Build Coastguard Worker return *this; 55*4bdc9457SAndroid Build Coastguard Worker } 56*4bdc9457SAndroid Build Coastguard Worker input_stride(const std::vector<size_t> input_stride)57*4bdc9457SAndroid Build Coastguard Worker inline TransposeNormalizationTester& input_stride(const std::vector<size_t> input_stride) { 58*4bdc9457SAndroid Build Coastguard Worker assert(input_stride.size() <= XNN_MAX_TENSOR_DIMS); 59*4bdc9457SAndroid Build Coastguard Worker this->input_stride_ = input_stride; 60*4bdc9457SAndroid Build Coastguard Worker return *this; 61*4bdc9457SAndroid Build Coastguard Worker } 62*4bdc9457SAndroid Build Coastguard Worker output_stride(const std::vector<size_t> output_stride)63*4bdc9457SAndroid Build Coastguard Worker inline TransposeNormalizationTester& output_stride(const std::vector<size_t> output_stride) { 64*4bdc9457SAndroid Build Coastguard Worker assert(output_stride.size() <= XNN_MAX_TENSOR_DIMS); 65*4bdc9457SAndroid Build Coastguard Worker this->output_stride_ = output_stride; 66*4bdc9457SAndroid Build Coastguard Worker return *this; 67*4bdc9457SAndroid Build Coastguard Worker } 68*4bdc9457SAndroid Build Coastguard Worker expected_shape(const std::vector<size_t> expected_shape)69*4bdc9457SAndroid Build Coastguard Worker inline TransposeNormalizationTester& expected_shape(const std::vector<size_t> expected_shape) { 70*4bdc9457SAndroid Build Coastguard Worker this->expected_shape_ = expected_shape; 71*4bdc9457SAndroid Build Coastguard Worker return *this; 72*4bdc9457SAndroid Build Coastguard Worker } 73*4bdc9457SAndroid Build Coastguard Worker expected_shape()74*4bdc9457SAndroid Build Coastguard Worker inline const std::vector<size_t>& expected_shape() const { return this->expected_shape_; } 75*4bdc9457SAndroid Build Coastguard Worker expected_perm(const std::vector<size_t> expected_perm)76*4bdc9457SAndroid Build Coastguard Worker inline TransposeNormalizationTester& expected_perm(const std::vector<size_t> expected_perm) { 77*4bdc9457SAndroid Build Coastguard Worker this->expected_perm_ = expected_perm; 78*4bdc9457SAndroid Build Coastguard Worker return *this; 79*4bdc9457SAndroid Build Coastguard Worker } 80*4bdc9457SAndroid Build Coastguard Worker expected_perm()81*4bdc9457SAndroid Build Coastguard Worker inline const std::vector<size_t>& expected_perm() const { return this->expected_perm_; } 82*4bdc9457SAndroid Build Coastguard Worker expected_input_stride(const std::vector<size_t> expected_input_stride)83*4bdc9457SAndroid Build Coastguard Worker inline TransposeNormalizationTester& expected_input_stride(const std::vector<size_t> expected_input_stride) { 84*4bdc9457SAndroid Build Coastguard Worker this->expected_input_stride_ = expected_input_stride; 85*4bdc9457SAndroid Build Coastguard Worker return *this; 86*4bdc9457SAndroid Build Coastguard Worker } 87*4bdc9457SAndroid Build Coastguard Worker expected_output_stride(const std::vector<size_t> expected_output_stride)88*4bdc9457SAndroid Build Coastguard Worker inline TransposeNormalizationTester& expected_output_stride(const std::vector<size_t> expected_output_stride) { 89*4bdc9457SAndroid Build Coastguard Worker this->expected_output_stride_ = expected_output_stride; 90*4bdc9457SAndroid Build Coastguard Worker return *this; 91*4bdc9457SAndroid Build Coastguard Worker } 92*4bdc9457SAndroid Build Coastguard Worker expected_input_stride()93*4bdc9457SAndroid Build Coastguard Worker inline const std::vector<size_t>& expected_input_stride() const { return this->expected_input_stride_; } 94*4bdc9457SAndroid Build Coastguard Worker expected_output_stride()95*4bdc9457SAndroid Build Coastguard Worker inline const std::vector<size_t>& expected_output_stride() const { return this->expected_output_stride_; } 96*4bdc9457SAndroid Build Coastguard Worker calculate_expected_input_stride()97*4bdc9457SAndroid Build Coastguard Worker inline TransposeNormalizationTester& calculate_expected_input_stride() { 98*4bdc9457SAndroid Build Coastguard Worker expected_input_stride_.resize(expected_dims()); 99*4bdc9457SAndroid Build Coastguard Worker expected_input_stride_[expected_dims() - 1] = expected_element_size(); 100*4bdc9457SAndroid Build Coastguard Worker for(size_t i = expected_dims() - 1; i-- != 0;) { 101*4bdc9457SAndroid Build Coastguard Worker expected_input_stride_[i] = expected_input_stride_[i + 1] * expected_shape_[i + 1]; 102*4bdc9457SAndroid Build Coastguard Worker } 103*4bdc9457SAndroid Build Coastguard Worker return *this; 104*4bdc9457SAndroid Build Coastguard Worker } 105*4bdc9457SAndroid Build Coastguard Worker calculate_expected_output_stride()106*4bdc9457SAndroid Build Coastguard Worker inline TransposeNormalizationTester& calculate_expected_output_stride() { 107*4bdc9457SAndroid Build Coastguard Worker expected_output_stride_.resize(expected_dims()); 108*4bdc9457SAndroid Build Coastguard Worker expected_output_stride_[expected_dims() - 1] = expected_element_size(); 109*4bdc9457SAndroid Build Coastguard Worker for(size_t i = expected_dims() - 1; i-- != 0;) { 110*4bdc9457SAndroid Build Coastguard Worker expected_output_stride_[i] = expected_output_stride_[i + 1] 111*4bdc9457SAndroid Build Coastguard Worker * expected_shape_[expected_perm_[i + 1]]; 112*4bdc9457SAndroid Build Coastguard Worker } 113*4bdc9457SAndroid Build Coastguard Worker return *this; 114*4bdc9457SAndroid Build Coastguard Worker } 115*4bdc9457SAndroid Build Coastguard Worker Test()116*4bdc9457SAndroid Build Coastguard Worker void Test() const { 117*4bdc9457SAndroid Build Coastguard Worker size_t actual_element_size; 118*4bdc9457SAndroid Build Coastguard Worker size_t actual_normalized_dims; 119*4bdc9457SAndroid Build Coastguard Worker std::vector<size_t> actual_normalized_shape(num_dims()); 120*4bdc9457SAndroid Build Coastguard Worker std::vector<size_t> actual_normalized_perm(num_dims()); 121*4bdc9457SAndroid Build Coastguard Worker std::vector<size_t> actual_normalized_input_stride(num_dims()); 122*4bdc9457SAndroid Build Coastguard Worker std::vector<size_t> actual_normalized_output_stride(num_dims()); 123*4bdc9457SAndroid Build Coastguard Worker 124*4bdc9457SAndroid Build Coastguard Worker xnn_normalize_transpose_permutation(num_dims(), element_size(), perm_.data(), 125*4bdc9457SAndroid Build Coastguard Worker shape_.data(), input_stride_.empty() ? nullptr : input_stride_.data(), 126*4bdc9457SAndroid Build Coastguard Worker output_stride_.empty() ? nullptr : output_stride_.data(), 127*4bdc9457SAndroid Build Coastguard Worker &actual_normalized_dims, &actual_element_size, actual_normalized_perm.data(), 128*4bdc9457SAndroid Build Coastguard Worker actual_normalized_shape.data(), actual_normalized_input_stride.data(), 129*4bdc9457SAndroid Build Coastguard Worker actual_normalized_output_stride.data()); 130*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(expected_element_size(), actual_element_size); 131*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(expected_dims(), actual_normalized_dims); 132*4bdc9457SAndroid Build Coastguard Worker 133*4bdc9457SAndroid Build Coastguard Worker for (size_t i = 0; i < expected_dims(); ++i) { 134*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(expected_shape()[i], actual_normalized_shape[i]); 135*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(expected_perm()[i], actual_normalized_perm[i]); 136*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(expected_input_stride()[i], actual_normalized_input_stride[i]); 137*4bdc9457SAndroid Build Coastguard Worker EXPECT_EQ(expected_output_stride()[i], actual_normalized_output_stride[i]); 138*4bdc9457SAndroid Build Coastguard Worker } 139*4bdc9457SAndroid Build Coastguard Worker } 140*4bdc9457SAndroid Build Coastguard Worker 141*4bdc9457SAndroid Build Coastguard Worker private: 142*4bdc9457SAndroid Build Coastguard Worker size_t num_dims_; 143*4bdc9457SAndroid Build Coastguard Worker size_t element_size_; 144*4bdc9457SAndroid Build Coastguard Worker size_t expected_dims_; 145*4bdc9457SAndroid Build Coastguard Worker size_t expected_element_size_; 146*4bdc9457SAndroid Build Coastguard Worker std::vector<size_t> shape_; 147*4bdc9457SAndroid Build Coastguard Worker std::vector<size_t> perm_; 148*4bdc9457SAndroid Build Coastguard Worker std::vector<size_t> input_stride_; 149*4bdc9457SAndroid Build Coastguard Worker std::vector<size_t> output_stride_; 150*4bdc9457SAndroid Build Coastguard Worker std::vector<size_t> expected_shape_; 151*4bdc9457SAndroid Build Coastguard Worker std::vector<size_t> expected_perm_; 152*4bdc9457SAndroid Build Coastguard Worker std::vector<size_t> expected_input_stride_; 153*4bdc9457SAndroid Build Coastguard Worker std::vector<size_t> expected_output_stride_; 154*4bdc9457SAndroid Build Coastguard Worker }; 155