1 /* 2 * Copyright (c) 2020-2021 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #ifndef ARM_COMPUTE_CPP_SPLIT_H 25 #define ARM_COMPUTE_CPP_SPLIT_H 26 27 #include "arm_compute/core/Error.h" 28 #include "arm_compute/core/Helpers.h" 29 #include "arm_compute/core/TensorInfo.h" 30 #include "arm_compute/core/Types.h" 31 #include "arm_compute/core/utils/misc/ShapeCalculator.h" 32 33 #include "support/ToolchainSupport.h" 34 35 #include "arm_compute/runtime/IFunction.h" 36 37 namespace arm_compute 38 { 39 /** Basic function to split a tensor along a given axis */ 40 template <typename SliceType, typename TensorInterfaceType = ITensor> 41 class CPPSplit : public IFunction 42 { 43 public: CPPSplit()44 CPPSplit() 45 : _outputs_vector(), _slice_functions(), _num_outputs(0) 46 { 47 } 48 /** Static function to check if given info will lead to a valid configuration of @ref CPPSplit 49 * 50 * @param[in] input The input tensor info. Data types supported: All. 51 * @param[in] outputs A vector containing the output tensors' info. Data types supported: same as @p input. 52 * The output tensors should match the input tensor dimensions for all shape dimensions apart 53 * from the split dimension 54 * @param[in] axis Axis on which to split the input. 55 * 56 * @return a status 57 */ validate(const ITensorInfo * input,const std::vector<ITensorInfo * > & outputs,unsigned int axis)58 static Status validate(const ITensorInfo *input, const std::vector<ITensorInfo *> &outputs, unsigned int axis) 59 { 60 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input); 61 ARM_COMPUTE_RETURN_ERROR_ON(axis >= input->num_dimensions()); 62 ARM_COMPUTE_RETURN_ERROR_ON(outputs.size() < 2); 63 64 // Get output shape 65 TensorShape output_shape{}; 66 unsigned int total_output_shape_size = 0; 67 68 // Sum the output sizes and fall back to evenly-sized splits if any are zero 69 const bool using_split_shapes = std::none_of(outputs.begin(), outputs.end(), [&total_output_shape_size](ITensorInfo * info) 70 { 71 unsigned int output_shape_size = info->tensor_shape().total_size(); 72 total_output_shape_size += output_shape_size; 73 return output_shape_size == 0; 74 }); 75 76 if(using_split_shapes) 77 { 78 ARM_COMPUTE_RETURN_ERROR_ON(input->tensor_shape().total_size() != total_output_shape_size); 79 } 80 else 81 { 82 output_shape = arm_compute::misc::shape_calculator::compute_split_shape(input, axis, outputs.size()); 83 ARM_COMPUTE_RETURN_ERROR_ON(output_shape.total_size() == 0); 84 } 85 86 // Validate output tensors 87 unsigned int axis_offset = 0; 88 for(const auto &output : outputs) 89 { 90 ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(output); 91 if(using_split_shapes) 92 { 93 output_shape = output->tensor_shape(); 94 ARM_COMPUTE_RETURN_ERROR_ON(output_shape.total_size() == 0); 95 } 96 97 const size_t axis_split_step = output_shape[axis]; 98 99 // Start/End coordinates 100 Coordinates start_coords; 101 Coordinates end_coords; 102 for(unsigned int d = 0; d < output_shape.num_dimensions(); ++d) 103 { 104 end_coords.set(d, -1); 105 } 106 107 // Output auto inizialitation if not yet initialized 108 TensorInfo tmp_output_info = *output->clone(); 109 if(tmp_output_info.tensor_shape().total_size() == 0) 110 { 111 tmp_output_info = input->clone()->set_is_resizable(true).set_tensor_shape(output_shape); 112 } 113 114 // Update coordinate on axis 115 start_coords.set(axis, axis_offset); 116 end_coords.set(axis, axis_offset + axis_split_step); 117 118 ARM_COMPUTE_RETURN_ON_ERROR(SliceType::validate(input, output, start_coords, end_coords)); 119 axis_offset += axis_split_step; 120 } 121 122 return Status{}; 123 } 124 125 /** Initialise the kernel's input and outputs. 126 * 127 * @param[in] input The input tensor. Data types supported: All 128 * @param[out] outputs A vector containing the output tensors. Data types supported: Same as @p input. 129 * The output tensors should match the input tensor dimensions for all shape dimensions apart 130 * from the split dimension. 131 * @param[in] axis Axis on which to split the input. 132 */ configure(const TensorInterfaceType * input,const std::vector<TensorInterfaceType * > & outputs,unsigned int axis)133 void configure(const TensorInterfaceType *input, const std::vector<TensorInterfaceType *> &outputs, unsigned int axis) 134 { 135 // Create Slice functions 136 _num_outputs = outputs.size(); 137 _slice_functions.resize(_num_outputs); 138 139 // Extract output tensor info 140 std::vector<ITensorInfo *> outputs_info; 141 for(auto &output : outputs) 142 { 143 ARM_COMPUTE_ERROR_ON_NULLPTR(output); 144 outputs_info.emplace_back(output->info()); 145 } 146 147 // If any of the outputs have a zero size, fall-back to using evenly-sized output splits 148 const bool outputs_have_sizes = std::none_of(outputs_info.begin(), outputs_info.end(), [](ITensorInfo * info) 149 { 150 return info->tensor_shape().total_size() == 0; 151 }); 152 153 // Validate 154 ARM_COMPUTE_ERROR_THROW_ON(CPPSplit::validate(input->info(), outputs_info, axis)); 155 156 unsigned int axis_offset = 0; 157 unsigned int i = 0; 158 159 for(const auto &output_info : outputs_info) 160 { 161 // Get output shape 162 TensorShape output_shape = (outputs_have_sizes ? 163 output_info->tensor_shape() : 164 arm_compute::misc::shape_calculator::compute_split_shape(input->info(), axis, _num_outputs)); 165 166 const size_t axis_split_step = output_shape[axis]; 167 168 // Start/End coordinates 169 Coordinates start_coords; 170 Coordinates end_coords; 171 172 for(unsigned int d = 0; d < output_shape.num_dimensions(); ++d) 173 { 174 end_coords.set(d, -1); 175 } 176 177 // Update coordinate on axis 178 start_coords.set(axis, axis_offset); 179 end_coords.set(axis, axis_offset + axis_split_step); 180 181 // Configure slice function 182 _slice_functions[i].configure(input, outputs[i], start_coords, end_coords); 183 184 // Set valid region from shape 185 outputs[i]->info()->set_valid_region(ValidRegion(Coordinates(), output_shape)); 186 187 // Update axis offset 188 axis_offset += axis_split_step; 189 ++i; 190 } 191 } 192 193 protected: 194 std::vector<TensorInterfaceType *> _outputs_vector; 195 std::vector<SliceType> _slice_functions; 196 unsigned int _num_outputs; 197 }; 198 199 } // namespace arm_compute 200 #endif /* ARM_COMPUTE_CPP_SPLIT_H */ 201