xref: /aosp_15_r20/external/XNNPACK/src/normalization.c (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <stdbool.h>
7 #include <stddef.h>
8 #include <string.h>
9 
10 #include <xnnpack/math.h>
11 
12 // Returns true if input stride and output stride are NULL or the expected input/output stride matches the actual input/output stride.
can_dimension_be_removed(const size_t * input_stride,const size_t * output_stride,const size_t * shape,const size_t * perm,size_t dim)13 static bool can_dimension_be_removed(
14     const size_t* input_stride,
15     const size_t* output_stride,
16     const size_t* shape,
17     const size_t* perm,
18     size_t dim) {
19   if (dim == 0 && perm[dim] == 0) {
20     return true;
21   }
22   if (input_stride != NULL && dim > 0) {
23     if (input_stride[dim - 1] != input_stride[dim] * shape[dim]) {
24       return false;
25     }
26   }
27   if (output_stride != NULL && perm[dim] > 0) {
28     if (output_stride[perm[dim] - 1] != output_stride[perm[dim]] * shape[dim]) {
29       return false;
30     }
31   }
32   return true;
33 }
34 
35 // Remove dimension perm[dim] from shape, perm, input & output strides.
remove_dimension(size_t * shape,size_t * perm,size_t * input_stride,size_t * output_stride,size_t num_dims,size_t dim)36 static void remove_dimension(
37     size_t* shape,
38     size_t* perm,
39     size_t* input_stride,
40     size_t* output_stride,
41     size_t num_dims,
42     size_t dim)
43 {
44   for (size_t j = perm[dim]; j + 1 < num_dims; ++j) {
45     shape[j] = shape[j + 1];
46   }
47   if (input_stride != NULL) {
48     for (size_t j = max(1, perm[dim]) - 1; j + 1 < num_dims; ++j) {
49       input_stride[j] = input_stride[j + 1];
50     }
51   }
52   if (output_stride != NULL) {
53     for (size_t j = max(1, dim) - 1; j + 1 < num_dims; ++j) {
54       output_stride[j] = output_stride[j + 1];
55     }
56   }
57   for (size_t j = 0; j < num_dims; ++j) {
58     if (perm[j] > perm[dim]) {
59       perm[j] -= 1;
60     }
61   }
62   for (size_t j = dim; j + 1 < num_dims; ++j) {
63     perm[j] = perm[j + 1];
64   }
65 }
xnn_normalize_transpose_permutation(const size_t num_dims,const size_t element_size,const size_t * perm,const size_t * shape,const size_t * input_stride,const size_t * output_stride,size_t * normalized_num_dims,size_t * normalized_element_size_out,size_t * normalized_perm,size_t * normalized_shape,size_t * normalized_input_stride,size_t * normalized_output_stride)66 void xnn_normalize_transpose_permutation(
67     const size_t num_dims,
68     const size_t element_size,
69     const size_t* perm,
70     const size_t* shape,
71     const size_t* input_stride,
72     const size_t* output_stride,
73     size_t* normalized_num_dims,
74     size_t* normalized_element_size_out,
75     size_t* normalized_perm,
76     size_t* normalized_shape,
77     size_t* normalized_input_stride,
78     size_t* normalized_output_stride)
79 {
80   size_t output_dims = num_dims;
81   memcpy(normalized_perm, perm, num_dims * sizeof(size_t));
82   memcpy(normalized_shape, shape, num_dims * sizeof(size_t));
83   size_t* normalized_input_stride_ptr = NULL;
84   size_t* normalized_output_stride_ptr = NULL;
85   if (input_stride != NULL) {
86     memcpy(normalized_input_stride, input_stride, num_dims * sizeof(size_t));
87     normalized_input_stride_ptr = normalized_input_stride;
88   }
89   if (output_stride != NULL) {
90     memcpy(normalized_output_stride, output_stride, num_dims * sizeof(size_t));
91     normalized_output_stride_ptr = normalized_output_stride;
92   }
93 
94   size_t output_pos = 0;
95   // Remove dimensions of size 1 and fold dimensions which are adjacent in both input and output tensors.
96   for (; output_pos < output_dims;) {
97     if (can_dimension_be_removed(normalized_input_stride_ptr, normalized_output_stride_ptr, normalized_shape,
98                                  normalized_perm, normalized_perm[output_pos])
99         && ((normalized_shape[normalized_perm[output_pos]] == 1)
100             || (output_pos > 0 && normalized_perm[output_pos] == normalized_perm[output_pos - 1] + 1))) {
101       if (output_pos > 0) {
102         normalized_shape[normalized_perm[output_pos - 1]] *= normalized_shape[normalized_perm[output_pos]];
103       }
104       remove_dimension(normalized_shape, normalized_perm, normalized_input_stride_ptr, normalized_output_stride_ptr,
105                        output_dims, output_pos);
106       output_dims -= 1;
107       // When a dimension has been removed, new folds may be possible so check
108       // it again.
109       if (output_pos > 0) {
110         output_pos -= 1;
111       }
112     } else {
113       output_pos += 1;
114     }
115   }
116   // All dimensions are size 1.
117   if (output_pos == 0) {
118     *normalized_num_dims = 1;
119     *normalized_element_size_out = element_size;
120     normalized_perm[0] = 0;
121     normalized_shape[0] = 1;
122     normalized_input_stride[0] = element_size;
123     normalized_output_stride[0] = element_size;
124     return;
125   }
126 
127   // If The last input and output dimensions are the same, treat it as one large
128   // element.
129   size_t normalized_element_size = element_size;
130   if (normalized_perm[output_dims - 1] == output_dims - 1) {
131     normalized_element_size = element_size * normalized_shape[output_dims - 1];
132     if (output_dims > 1 && can_dimension_be_removed(normalized_input_stride_ptr, normalized_output_stride_ptr, normalized_shape,
133                                  normalized_perm, output_dims - 1)) {
134       output_dims -= 1;
135     } else {
136       if (normalized_input_stride != NULL) {
137         normalized_input_stride[output_dims - 1] *= normalized_shape[output_dims - 1];
138       }
139       if (normalized_output_stride != NULL) {
140         normalized_output_stride[normalized_perm[output_dims - 1]] *= normalized_shape[output_dims - 1];
141       }
142       normalized_shape[output_dims - 1] = 1;
143     }
144   }
145   // If input_strides is not provided, calculate it using normalized_shape and normalized_element_size.
146   if (input_stride == NULL) {
147     normalized_input_stride[output_dims - 1] = normalized_element_size;
148     for(size_t i = output_dims - 1; i > 0; --i) {
149       normalized_input_stride[i - 1] = normalized_input_stride[i] * normalized_shape[i];
150     }
151   } else {
152     // Scale input_stride by element size.
153     for (size_t i = 0; i < output_dims; ++i) {
154       normalized_input_stride[i] *= element_size;
155     }
156   }
157   // If output_strides is not provided, calculate it using normalized_shape and normalized_element_size.
158   if (output_stride == NULL) {
159     normalized_output_stride[output_dims - 1] = normalized_element_size;
160     for(size_t i = output_dims - 1; i > 0; --i) {
161       normalized_output_stride[i - 1] = normalized_output_stride[i] * normalized_shape[normalized_perm[i]];
162     }
163   } else {
164     // Scale output_stride by element size.
165     for (size_t i = 0; i < output_dims; ++i) {
166       normalized_output_stride[i] *= element_size;
167     }
168   }
169   *normalized_element_size_out = normalized_element_size;
170   *normalized_num_dims = output_dims;
171 }
172