1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/kernels/portable/cpu/util/slice_util.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 #include <cstring>
12
13 namespace torch {
14 namespace executor {
15
16 using Tensor = exec_aten::Tensor;
17
check_narrow_copy_args(const Tensor & in,int64_t dim,int64_t start,int64_t lenth,Tensor & out)18 bool check_narrow_copy_args(
19 const Tensor& in,
20 int64_t dim,
21 int64_t start,
22 int64_t lenth,
23 Tensor& out) {
24 ET_LOG_AND_RETURN_IF_FALSE(in.dim() > 0);
25 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
26 ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim));
27 ET_LOG_MSG_AND_RETURN_IF_FALSE(lenth >= 0, "lenth must be non-negative");
28 ET_LOG_AND_RETURN_IF_FALSE(start >= -in.size(dim));
29 ET_LOG_AND_RETURN_IF_FALSE(start <= in.size(dim));
30 if (start < 0) {
31 start += in.size(dim);
32 }
33 ET_LOG_AND_RETURN_IF_FALSE(start + lenth <= in.size(dim));
34 return true;
35 }
36
get_narrow_copy_out_target_size(const Tensor & in,int64_t dim,int64_t length,exec_aten::SizesType * out_sizes,size_t * out_ndim)37 void get_narrow_copy_out_target_size(
38 const Tensor& in,
39 int64_t dim,
40 int64_t length,
41 exec_aten::SizesType* out_sizes,
42 size_t* out_ndim) {
43 *out_ndim = in.dim();
44
45 for (size_t d = 0; d < in.dim(); ++d) {
46 out_sizes[d] = in.size(d);
47 }
48 out_sizes[dim] = length;
49 }
50
check_slice_copy_args(const Tensor & in,int64_t dim,int64_t step,Tensor & out)51 bool check_slice_copy_args(
52 const Tensor& in,
53 int64_t dim,
54 int64_t step,
55 Tensor& out) {
56 ET_LOG_AND_RETURN_IF_FALSE(in.dim() > 0);
57 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
58 ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim));
59 ET_LOG_MSG_AND_RETURN_IF_FALSE(
60 step > 0, "slice step must be greater than zero");
61 return true;
62 }
63
get_slice_copy_out_target_size(const Tensor & in,int64_t dim,int64_t length,exec_aten::SizesType * out_sizes,size_t * out_ndim)64 void get_slice_copy_out_target_size(
65 const Tensor& in,
66 int64_t dim,
67 int64_t length,
68 exec_aten::SizesType* out_sizes,
69 size_t* out_ndim) {
70 get_narrow_copy_out_target_size(in, dim, length, out_sizes, out_ndim);
71 }
72
check_slice_scatter_args(const Tensor & input,const Tensor & src,int64_t dim,int64_t num_values,int64_t step,Tensor output)73 bool check_slice_scatter_args(
74 const Tensor& input,
75 const Tensor& src,
76 int64_t dim,
77 int64_t num_values,
78 int64_t step,
79 Tensor output) {
80 ET_LOG_AND_RETURN_IF_FALSE(input.dim() > 0);
81
82 // Check dim. The dim planed to be selected on shall exist in input
83 ET_LOG_AND_RETURN_IF_FALSE(dim_is_valid(dim, input.dim()));
84
85 // Input and output tensors should be the same shape and dtype
86 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_shape_and_dtype(input, output));
87
88 // The input.dim() shall equal to src.dim()
89 ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_rank(input, src));
90
91 // Check step. Step must be greater than zero
92 ET_LOG_MSG_AND_RETURN_IF_FALSE(
93 step > 0, "slice step must be greater than zero");
94
95 // The size of src tensor should follow these rules:
96 // - src.size(i) shall equal to input.size(i) if i != dim,
97 // - src.size(dim) shall equal to num_values
98 for (size_t d = 0; d < input.dim() - 1; d++) {
99 if (d != dim) {
100 ET_LOG_AND_RETURN_IF_FALSE(
101 tensors_have_same_size_at_dims(input, d, src, d));
102 } else {
103 ET_LOG_MSG_AND_RETURN_IF_FALSE(
104 src.size(d) == num_values,
105 "input.size(%zu) %zd != num_values %" PRId64 " | dim = %" PRId64 ")",
106 d,
107 input.size(d),
108 num_values,
109 dim);
110 }
111 }
112
113 return true;
114 }
115
adjust_slice_indices(int64_t dim_length,int64_t * start,int64_t * end,int64_t step)116 int64_t adjust_slice_indices(
117 int64_t dim_length,
118 int64_t* start,
119 int64_t* end,
120 int64_t step) {
121 int64_t num_values = 0;
122
123 // Update start and end index
124 // First convert it to c++ style from python style if needed.
125 // The start index is using python style E.g., for the shape {2, 3, 4},
126 // dim = -1 would refer to dim[2], dim = -2 would refer to dim[1], and so on.
127 *start = *start < 0 ? *start + dim_length : *start;
128 *end = *end < 0 ? *end + dim_length : *end;
129 // Second, if start or end still negative, which means user want to start or
130 // end slicing from very beginning, so set it to zero
131 *start = *start < 0 ? 0 : *start;
132 *end = *end < 0 ? 0 : *end;
133 // Last, if start or end larger than maximum value (dim_length - 1), indicates
134 // user want to start slicing after end or slicing until the end, so update it
135 // to dim_length
136 *start = *start > dim_length ? dim_length : *start;
137 *end = *end > dim_length ? dim_length : *end;
138
139 if (*start >= dim_length || *end <= 0 || *start >= *end) {
140 // Set num_values to 0 if interval [start, end) is non-exist or do not
141 // overlap with [0, dim_length)
142 num_values = 0;
143 } else {
144 // Update num_values to min(max_num_values, num_values)
145 num_values = (*end - 1 - *start) / step + 1;
146 }
147 return num_values;
148 }
149
compute_slice(const Tensor & in,int64_t dim,int64_t start,int64_t length,int64_t step,Tensor & out)150 void compute_slice(
151 const Tensor& in,
152 int64_t dim,
153 int64_t start,
154 int64_t length,
155 int64_t step,
156 Tensor& out) {
157 size_t dim_length = in.size(dim);
158
159 size_t leading_dims = getLeadingDims(in, dim);
160 size_t trailing_dims = getTrailingDims(in, dim);
161
162 if (trailing_dims == 0) {
163 return;
164 }
165
166 size_t length_per_step = trailing_dims * in.element_size();
167
168 const char* input_data = in.const_data_ptr<char>();
169 char* dest = out.mutable_data_ptr<char>();
170
171 for (int i = 0; i < leading_dims; i++) {
172 const char* src = input_data + (i * dim_length + start) * length_per_step;
173 for (int j = 0; j < length; j++) {
174 memcpy(dest, src, length_per_step);
175 src += step * length_per_step;
176 dest += length_per_step;
177 }
178 }
179 }
180
181 } // namespace executor
182 } // namespace torch
183