1 /* 2 * Copyright (c) 2018-2019 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #ifndef ARM_COMPUTE_UTILS_HELPERS_TENSOR_TRANSFORM_H 25 #define ARM_COMPUTE_UTILS_HELPERS_TENSOR_TRANSFORM_H 26 27 #include "arm_compute/core/Types.h" 28 29 namespace arm_compute 30 { 31 namespace helpers 32 { 33 namespace tensor_transform 34 { 35 /** Computes stride of a given index 36 * 37 * @param[in] index Index of tensor to calculate absolute start position 38 * @param[in] strides Slice strides 39 * 40 * @return Stride at a given index 41 */ 42 int calculate_stride_on_index(int index, Coordinates strides); 43 44 /** Computes absolute start position of a given index for a strided slice operation 45 * 46 * @param[in] input_shape Input tensor shape 47 * @param[in] index Index of tensor to calculate absolute start position 48 * @param[in] starts Start coordinates 49 * @param[in] strides Slice strides 50 * @param[in] begin_mask (Optional) If the ith bit of begin_mask is set, starts[i] is ignored and 51 * the fullest possible range in that dimension is used instead. 52 * 53 * @return Absolute start position of a given index 54 */ 55 int calculate_start_on_index(TensorShape input_shape, int index, Coordinates starts, Coordinates strides, int32_t begin_mask); 56 57 /** Returns the absolute end position of a given index for a strided slice operation 58 * 59 * @param[in] input_shape Input tensor shape 60 * @param[in] index Index of tensor to calculate absolute start position 61 * @param[in] start_on_index Absolute start coordinate for given index 62 * @param[in] ends End coordinates 63 * @param[in] strides Slice strides 64 * @param[in] end_mask (Optional) If the ith bit of end_mask is set, end[i] is ignored and 65 * the fullest possible range in that dimension is used instead. 66 * @param[in] shrink_axis_mask (Optional) If the ith bit of shrink_axis_mask is set, it implies that the ith specification shrinks the dimensionality by 1. 67 * A slice of size 1 starting from starts[i] in the dimension must be preserved. 68 * 69 * @return Absolute end position of a given index 70 */ 71 int calculate_end_on_index(TensorShape input_shape, int index, int start_on_index, Coordinates ends, Coordinates strides, 72 int32_t end_mask = 0, int32_t shrink_axis_mask = 0); 73 74 /** Calculate start, end and stride coordinates for a strided slice 75 * 76 * @param[in] input_shape Input tensor shape 77 * @param[in] starts Start coordinates 78 * @param[in] ends End coordinates 79 * @param[in] strides Slice strides 80 * @param[in] begin_mask (Optional) If the ith bit of begin_mask is set, starts[i] is ignored and 81 * the fullest possible range in that dimension is used instead. 82 * @param[in] end_mask (Optional) If the ith bit of end_mask is set, end[i] is ignored and 83 * the fullest possible range in that dimension is used instead. 84 * @param[in] shrink_axis_mask (Optional) If the ith bit of shrink_axis_mask is set, it implies that the ith specification shrinks the dimensionality by 1. 85 * A slice of size 1 starting from starts[i] in the dimension must be preserved. 86 * 87 * @return A tuple with <Start,End,Strides> 88 */ 89 std::tuple<Coordinates, Coordinates, Coordinates> calculate_strided_slice_coords(TensorShape input_shape, 90 Coordinates starts, Coordinates ends, Coordinates strides, 91 int32_t begin_mask = 0, int32_t end_mask = 0, int32_t shrink_axis_mask = 0); 92 93 /** Computes output shape of strided slice 94 * 95 * @warning Starts and ends must be non-negative 96 * @warning Starts, ends and final strides should have the same dimensions as the input shape 97 * 98 * @param[in] input_shape Input tensor shape 99 * @param[in] starts Absolute start coordinates 100 * @param[in] ends Absolute end coordinates 101 * @param[in] strides Slice strides 102 * @param[in] begin_mask (Optional) If the ith bit of begin_mask is set, starts[i] is ignored and 103 * the fullest possible range in that dimension is used instead. 104 * @param[in] end_mask (Optional) If the ith bit of end_mask is set, end[i] is ignored and 105 * the fullest possible range in that dimension is used instead. 106 * @param[in] shrink_axis_mask (Optional) If the ith bit of shrink_axis_mask is set, it implies that the ith specification shrinks the dimensionality by 1. 107 * A slice of size 1 starting from starts[i] in the dimension must be preserved. 108 * @param[in] return_unshrinked (Optional) Returns un-shrinked shape 109 * 110 * @return The output tensor shape 111 */ 112 TensorShape compute_strided_slice_output_shape(TensorShape input_shape, Coordinates starts, Coordinates ends, Coordinates strides, 113 int32_t begin_mask = 0, int32_t end_mask = 0, int32_t shrink_axis_mask = 0, 114 bool return_unshrinked = false); 115 116 /** Constructs end mask in case we want to perform a slice operation using the strided slice interface 117 * 118 * @note Ends are inclusive in slice operations that is why construction an end mask is needed 119 * 120 * @param[in] ends End coordinates 121 * 122 * @return End mask 123 */ 124 int32_t construct_slice_end_mask(Coordinates ends); 125 } // namespace tensor_tranform 126 } // namespace helpers 127 } // namespace arm_compute 128 #endif /* ARM_COMPUTE_UTILS_HELPERS_TENSOR_TRANSFORM_H */ 129