xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/util/slice_util.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/kernels/portable/cpu/util/slice_util.h>
10 #include <executorch/runtime/kernel/kernel_includes.h>
11 #include <cstring>
12 
13 namespace torch {
14 namespace executor {
15 
16 using Tensor = exec_aten::Tensor;
17 
check_narrow_copy_args(const Tensor & in,int64_t dim,int64_t start,int64_t lenth,Tensor & out)18 bool check_narrow_copy_args(
19     const Tensor& in,
20     int64_t dim,
21     int64_t start,
22     int64_t lenth,
23     Tensor& out) {
24   ET_LOG_AND_RETURN_IF_FALSE(in.dim() > 0);
25   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
26   ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim));
27   ET_LOG_MSG_AND_RETURN_IF_FALSE(lenth >= 0, "lenth must be non-negative");
28   ET_LOG_AND_RETURN_IF_FALSE(start >= -in.size(dim));
29   ET_LOG_AND_RETURN_IF_FALSE(start <= in.size(dim));
30   if (start < 0) {
31     start += in.size(dim);
32   }
33   ET_LOG_AND_RETURN_IF_FALSE(start + lenth <= in.size(dim));
34   return true;
35 }
36 
get_narrow_copy_out_target_size(const Tensor & in,int64_t dim,int64_t length,exec_aten::SizesType * out_sizes,size_t * out_ndim)37 void get_narrow_copy_out_target_size(
38     const Tensor& in,
39     int64_t dim,
40     int64_t length,
41     exec_aten::SizesType* out_sizes,
42     size_t* out_ndim) {
43   *out_ndim = in.dim();
44 
45   for (size_t d = 0; d < in.dim(); ++d) {
46     out_sizes[d] = in.size(d);
47   }
48   out_sizes[dim] = length;
49 }
50 
check_slice_copy_args(const Tensor & in,int64_t dim,int64_t step,Tensor & out)51 bool check_slice_copy_args(
52     const Tensor& in,
53     int64_t dim,
54     int64_t step,
55     Tensor& out) {
56   ET_LOG_AND_RETURN_IF_FALSE(in.dim() > 0);
57   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, out));
58   ET_LOG_AND_RETURN_IF_FALSE(tensor_has_dim(in, dim));
59   ET_LOG_MSG_AND_RETURN_IF_FALSE(
60       step > 0, "slice step must be greater than zero");
61   return true;
62 }
63 
get_slice_copy_out_target_size(const Tensor & in,int64_t dim,int64_t length,exec_aten::SizesType * out_sizes,size_t * out_ndim)64 void get_slice_copy_out_target_size(
65     const Tensor& in,
66     int64_t dim,
67     int64_t length,
68     exec_aten::SizesType* out_sizes,
69     size_t* out_ndim) {
70   get_narrow_copy_out_target_size(in, dim, length, out_sizes, out_ndim);
71 }
72 
check_slice_scatter_args(const Tensor & input,const Tensor & src,int64_t dim,int64_t num_values,int64_t step,Tensor output)73 bool check_slice_scatter_args(
74     const Tensor& input,
75     const Tensor& src,
76     int64_t dim,
77     int64_t num_values,
78     int64_t step,
79     Tensor output) {
80   ET_LOG_AND_RETURN_IF_FALSE(input.dim() > 0);
81 
82   // Check dim. The dim planed to be selected on shall exist in input
83   ET_LOG_AND_RETURN_IF_FALSE(dim_is_valid(dim, input.dim()));
84 
85   // Input and output tensors should be the same shape and dtype
86   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_shape_and_dtype(input, output));
87 
88   // The input.dim() shall equal to src.dim()
89   ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_rank(input, src));
90 
91   // Check step. Step must be greater than zero
92   ET_LOG_MSG_AND_RETURN_IF_FALSE(
93       step > 0, "slice step must be greater than zero");
94 
95   // The size of src tensor should follow these rules:
96   // - src.size(i) shall equal to input.size(i) if i != dim,
97   // - src.size(dim) shall equal to num_values
98   for (size_t d = 0; d < input.dim() - 1; d++) {
99     if (d != dim) {
100       ET_LOG_AND_RETURN_IF_FALSE(
101           tensors_have_same_size_at_dims(input, d, src, d));
102     } else {
103       ET_LOG_MSG_AND_RETURN_IF_FALSE(
104           src.size(d) == num_values,
105           "input.size(%zu) %zd != num_values %" PRId64 " | dim = %" PRId64 ")",
106           d,
107           input.size(d),
108           num_values,
109           dim);
110     }
111   }
112 
113   return true;
114 }
115 
adjust_slice_indices(int64_t dim_length,int64_t * start,int64_t * end,int64_t step)116 int64_t adjust_slice_indices(
117     int64_t dim_length,
118     int64_t* start,
119     int64_t* end,
120     int64_t step) {
121   int64_t num_values = 0;
122 
123   // Update start and end index
124   // First convert it to c++ style from python style if needed.
125   // The start index is using python style E.g., for the shape {2, 3, 4},
126   // dim = -1 would refer to dim[2], dim = -2 would refer to dim[1], and so on.
127   *start = *start < 0 ? *start + dim_length : *start;
128   *end = *end < 0 ? *end + dim_length : *end;
129   // Second, if start or end still negative, which means user want to start or
130   // end slicing from very beginning, so set it to zero
131   *start = *start < 0 ? 0 : *start;
132   *end = *end < 0 ? 0 : *end;
133   // Last, if start or end larger than maximum value (dim_length - 1), indicates
134   // user want to start slicing after end or slicing until the end, so update it
135   // to dim_length
136   *start = *start > dim_length ? dim_length : *start;
137   *end = *end > dim_length ? dim_length : *end;
138 
139   if (*start >= dim_length || *end <= 0 || *start >= *end) {
140     // Set num_values to 0 if interval [start, end) is non-exist or do not
141     // overlap with [0, dim_length)
142     num_values = 0;
143   } else {
144     // Update num_values to min(max_num_values, num_values)
145     num_values = (*end - 1 - *start) / step + 1;
146   }
147   return num_values;
148 }
149 
compute_slice(const Tensor & in,int64_t dim,int64_t start,int64_t length,int64_t step,Tensor & out)150 void compute_slice(
151     const Tensor& in,
152     int64_t dim,
153     int64_t start,
154     int64_t length,
155     int64_t step,
156     Tensor& out) {
157   size_t dim_length = in.size(dim);
158 
159   size_t leading_dims = getLeadingDims(in, dim);
160   size_t trailing_dims = getTrailingDims(in, dim);
161 
162   if (trailing_dims == 0) {
163     return;
164   }
165 
166   size_t length_per_step = trailing_dims * in.element_size();
167 
168   const char* input_data = in.const_data_ptr<char>();
169   char* dest = out.mutable_data_ptr<char>();
170 
171   for (int i = 0; i < leading_dims; i++) {
172     const char* src = input_data + (i * dim_length + start) * length_per_step;
173     for (int j = 0; j < length; j++) {
174       memcpy(dest, src, length_per_step);
175       src += step * length_per_step;
176       dest += length_per_step;
177     }
178   }
179 }
180 
181 } // namespace executor
182 } // namespace torch
183