xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/internal/optimized/im2col_utils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_IM2COL_UTILS_H_
16 #define TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_IM2COL_UTILS_H_
17 
18 #include <algorithm>
19 #include <cassert>
20 
21 #include "ruy/profiler/instrumentation.h"  // from @ruy
22 #include "tensorflow/lite/kernels/internal/types.h"
23 
24 namespace tflite {
25 namespace optimized_ops {
26 
27 template <typename T>
ExtractPatchIntoBufferColumn(const RuntimeShape & input_shape,int w,int h,int b,int kheight,int kwidth,int stride_width,int stride_height,int pad_width,int pad_height,int in_width,int in_height,int in_depth,int single_buffer_length,int buffer_id,const T * in_data,T * conv_buffer_data,uint8 zero_byte)28 inline void ExtractPatchIntoBufferColumn(const RuntimeShape& input_shape, int w,
29                                          int h, int b, int kheight, int kwidth,
30                                          int stride_width, int stride_height,
31                                          int pad_width, int pad_height,
32                                          int in_width, int in_height,
33                                          int in_depth, int single_buffer_length,
34                                          int buffer_id, const T* in_data,
35                                          T* conv_buffer_data, uint8 zero_byte) {
36   ruy::profiler::ScopeLabel label("ExtractPatchIntoBufferColumn");
37   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
38   // This chunk of code reshapes all the inputs corresponding to
39   // output (b, h, w) to a column vector in conv_buffer(:, buffer_id).
40   const int kwidth_times_indepth = kwidth * in_depth;
41   const int inwidth_times_indepth = in_width * in_depth;
42   const int ih_ungated_start = h * stride_height - pad_height;
43   const int ih_ungated_end = (ih_ungated_start + kheight);
44   const int ih_end = std::min(ih_ungated_end, in_height);
45   const int iw_ungated_start = w * stride_width - pad_width;
46   const int iw_ungated_end = (iw_ungated_start + kwidth);
47   const int iw_end = std::min(iw_ungated_end, in_width);
48   // If the patch is off the edge of the input image, skip writing those rows
49   // and columns from the patch into the output array.
50   const int h_offset = std::max(0, -ih_ungated_start);
51   const int w_offset = std::max(0, -iw_ungated_start);
52   const int ih_start = std::max(0, ih_ungated_start);
53   const int iw_start = std::max(0, iw_ungated_start);
54   const int single_row_num =
55       std::max(0, std::min(kwidth - w_offset, in_width - iw_start)) * in_depth;
56   const int output_row_offset = (buffer_id * single_buffer_length);
57   int out_offset =
58       output_row_offset + (h_offset * kwidth + w_offset) * in_depth;
59   int in_offset = Offset(input_shape, b, ih_start, iw_start, 0);
60 
61   // Express all of the calculations as padding around the input patch.
62   const int top_padding = h_offset;
63   const int bottom_padding = (ih_ungated_end - ih_end);
64   const int left_padding = w_offset;
65   const int right_padding = (iw_ungated_end - iw_end);
66   assert(single_row_num ==
67          ((kwidth - (left_padding + right_padding)) * in_depth));
68 
69   // Write out zeroes to the elements representing the top rows of the input
70   // patch that are off the edge of the input image.
71   if (top_padding > 0) {
72     const int top_row_elements = (top_padding * kwidth * in_depth);
73     memset(conv_buffer_data + output_row_offset, zero_byte,
74            (top_row_elements * sizeof(T)));
75   }
76 
77   // If the patch is on the interior of the input image horizontally, just copy
78   // over the rows sequentially, otherwise add zero padding at the start or end.
79   if ((left_padding == 0) && (right_padding == 0)) {
80     for (int ih = ih_start; ih < ih_end; ++ih) {
81       memcpy(conv_buffer_data + out_offset, in_data + in_offset,
82              single_row_num * sizeof(T));
83       out_offset += kwidth_times_indepth;
84       in_offset += inwidth_times_indepth;
85     }
86   } else {
87     for (int ih = ih_start; ih < ih_end; ++ih) {
88       if (left_padding > 0) {
89         const int left_start = (out_offset - (left_padding * in_depth));
90         memset(conv_buffer_data + left_start, zero_byte,
91                (left_padding * in_depth * sizeof(T)));
92       }
93       memcpy(conv_buffer_data + out_offset, in_data + in_offset,
94              single_row_num * sizeof(T));
95       if (right_padding > 0) {
96         const int right_start = (out_offset + single_row_num);
97         memset(conv_buffer_data + right_start, zero_byte,
98                (right_padding * in_depth * sizeof(T)));
99       }
100       out_offset += kwidth_times_indepth;
101       in_offset += inwidth_times_indepth;
102     }
103   }
104 
105   // If the bottom of the patch falls off the input image, pad the values
106   // representing those input rows with zeroes.
107   if (bottom_padding > 0) {
108     const int bottom_row_elements = (bottom_padding * kwidth * in_depth);
109     const int bottom_start =
110         output_row_offset +
111         ((top_padding + (ih_end - ih_start)) * kwidth * in_depth);
112     memset(conv_buffer_data + bottom_start, zero_byte,
113            (bottom_row_elements * sizeof(T)));
114   }
115 }
116 
117 // Supports per-batch zero_byte for per-batch asymmetric quantized inputs.
118 template <typename T>
DilatedIm2col(const ConvParams & params,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & filter_shape,const RuntimeShape & output_shape,T * im2col_data,const int32_t * zero_bytes,const int zero_bytes_len)119 void DilatedIm2col(const ConvParams& params, const RuntimeShape& input_shape,
120                    const T* input_data, const RuntimeShape& filter_shape,
121                    const RuntimeShape& output_shape, T* im2col_data,
122                    const int32_t* zero_bytes, const int zero_bytes_len) {
123   const int stride_width = params.stride_width;
124   const int stride_height = params.stride_height;
125   const int dilation_width_factor = params.dilation_width_factor;
126   const int dilation_height_factor = params.dilation_height_factor;
127   const int pad_width = params.padding_values.width;
128   const int pad_height = params.padding_values.height;
129   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
130   TFLITE_DCHECK_EQ(filter_shape.DimensionsCount(), 4);
131   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
132 
133   // For dilated convolution, the input pixels are not contiguous therefore we
134   // can't use the same optimizations as Im2Col(). Though note this code would
135   // work fine for the non-dilated case too (though likely a bit slower).
136   ruy::profiler::ScopeLabel label("DilatedIm2col");
137   TFLITE_DCHECK(dilation_width_factor != 1 || dilation_height_factor != 1);
138   TFLITE_DCHECK(im2col_data);
139   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
140   const int input_height = input_shape.Dims(1);
141   const int input_width = input_shape.Dims(2);
142   const int input_depth = MatchingDim(input_shape, 3, filter_shape, 3);
143   const int filter_height = filter_shape.Dims(1);
144   const int filter_width = filter_shape.Dims(2);
145   const int output_height = output_shape.Dims(1);
146   const int output_width = output_shape.Dims(2);
147   MatchingDim(output_shape, 3, filter_shape, 0);
148 
149   // Construct the MxN sized im2col matrix.
150   // The rows M, are sub-ordered B x H x W
151   const RuntimeShape row_shape({1, batches, output_height, output_width});
152   // The columns, N, are sub-ordered Kh x Kw x Din
153   const RuntimeShape col_shape({1, filter_height, filter_width, input_depth});
154   // Use dimensions M and N to construct dims for indexing directly into im2col
155   const RuntimeShape im2col_shape(
156       {1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
157 
158   // Loop through the output rows (B x H x W)
159   for (int batch = 0; batch < batches; ++batch) {
160     const T zero_byte = zero_bytes_len > 1 ? static_cast<T>(zero_bytes[batch])
161                                            : static_cast<T>(zero_bytes[0]);
162     for (int out_y = 0; out_y < output_height; ++out_y) {
163       for (int out_x = 0; out_x < output_width; ++out_x) {
164         // Each im2col row is an output pixel. Arrange the input data in this
165         // row in an order we can conveniently multiply with the filter data.
166         int row_offset = Offset(row_shape, 0, batch, out_y, out_x);
167         const int in_x_origin = (out_x * stride_width) - pad_width;
168         const int in_y_origin = (out_y * stride_height) - pad_height;
169         // Loop through all the pixels of the filter (Kh x Kw)
170         for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
171           const int in_y = in_y_origin + dilation_height_factor * filter_y;
172           if ((in_y >= 0) && (in_y < input_height)) {
173             // Filter row is within the input data.
174             // Loop through all the filter pixels in this row.
175             for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
176               const int in_x = in_x_origin + dilation_width_factor * filter_x;
177               int col_offset = Offset(col_shape, 0, filter_y, filter_x, 0);
178               T* dst = im2col_data +
179                        Offset(im2col_shape, 0, 0, row_offset, col_offset);
180               if ((in_x >= 0) && (in_x < input_width)) {
181                 // Filter pixel is within the input, copy the input data.
182                 T const* src =
183                     input_data + Offset(input_shape, batch, in_y, in_x, 0);
184                 memcpy(dst, src, input_depth * sizeof(T));
185               } else {
186                 // Filter pixel is outside the input, zero it out.
187                 memset(dst, zero_byte, input_depth * sizeof(T));
188               }
189             }
190           } else {
191             // Filter row is outside the input, zero out the entire filter row.
192             int col_offset = Offset(col_shape, 0, filter_y, 0, 0);
193             T* dst = im2col_data +
194                      Offset(im2col_shape, 0, 0, row_offset, col_offset);
195             memset(dst, zero_byte, filter_width * input_depth * sizeof(T));
196           }
197         }
198       }
199     }
200   }
201 }
202 
203 template <typename T>
DilatedIm2col(const ConvParams & params,uint8 zero_byte,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & filter_shape,const RuntimeShape & output_shape,T * im2col_data)204 void DilatedIm2col(const ConvParams& params, uint8 zero_byte,
205                    const RuntimeShape& input_shape, const T* input_data,
206                    const RuntimeShape& filter_shape,
207                    const RuntimeShape& output_shape, T* im2col_data) {
208   const int32_t zero_point = static_cast<int32_t>(zero_byte);
209   DilatedIm2col<T>(params, input_shape, input_data, filter_shape, output_shape,
210                    im2col_data, &zero_point, 1);
211 }
212 
213 template <typename T>
Im2col(const ConvParams & params,int kheight,int kwidth,uint8 zero_byte,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)214 void Im2col(const ConvParams& params, int kheight, int kwidth, uint8 zero_byte,
215             const RuntimeShape& input_shape, const T* input_data,
216             const RuntimeShape& output_shape, T* output_data) {
217   ruy::profiler::ScopeLabel label("Im2col");
218   const int stride_width = params.stride_width;
219   const int stride_height = params.stride_height;
220   const int pad_width = params.padding_values.width;
221   const int pad_height = params.padding_values.height;
222   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
223   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
224 
225   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
226   const int input_depth = input_shape.Dims(3);
227   const int input_width = input_shape.Dims(2);
228   const int input_height = input_shape.Dims(1);
229   const int output_depth = output_shape.Dims(3);
230   const int output_width = output_shape.Dims(2);
231   const int output_height = output_shape.Dims(1);
232 
233   int buffer_id = 0;
234   // Loop over the output nodes.
235   for (int b = 0; b < batches; ++b) {
236     for (int h = 0; h < output_height; ++h) {
237       for (int w = 0; w < output_width; ++w) {
238         ExtractPatchIntoBufferColumn(
239             input_shape, w, h, b, kheight, kwidth, stride_width, stride_height,
240             pad_width, pad_height, input_width, input_height, input_depth,
241             output_depth, buffer_id, input_data, output_data, zero_byte);
242         ++buffer_id;
243       }
244     }
245   }
246 }
247 
248 template <typename T>
Im2col(const ConvParams & params,int kheight,int kwidth,const int32_t * input_offsets,const int input_offsets_size,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & output_shape,T * output_data)249 void Im2col(const ConvParams& params, int kheight, int kwidth,
250             const int32_t* input_offsets, const int input_offsets_size,
251             const RuntimeShape& input_shape, const T* input_data,
252             const RuntimeShape& output_shape, T* output_data) {
253   ruy::profiler::ScopeLabel label("Im2col");
254   const int stride_width = params.stride_width;
255   const int stride_height = params.stride_height;
256   const int pad_width = params.padding_values.width;
257   const int pad_height = params.padding_values.height;
258   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 4);
259   TFLITE_DCHECK_EQ(output_shape.DimensionsCount(), 4);
260 
261   const int batches = MatchingDim(input_shape, 0, output_shape, 0);
262   TFLITE_DCHECK_EQ(batches, input_offsets_size);
263   const int input_depth = input_shape.Dims(3);
264   const int input_width = input_shape.Dims(2);
265   const int input_height = input_shape.Dims(1);
266   const int output_depth = output_shape.Dims(3);
267   const int output_width = output_shape.Dims(2);
268   const int output_height = output_shape.Dims(1);
269 
270   int buffer_id = 0;
271   // Loop over the output nodes.
272   for (int b = 0; b < batches; ++b) {
273     uint8_t zero_byte = static_cast<uint8_t>(input_offsets[b]);
274     for (int h = 0; h < output_height; ++h) {
275       for (int w = 0; w < output_width; ++w) {
276         ExtractPatchIntoBufferColumn(
277             input_shape, w, h, b, kheight, kwidth, stride_width, stride_height,
278             pad_width, pad_height, input_width, input_height, input_depth,
279             output_depth, buffer_id, input_data, output_data, zero_byte);
280         ++buffer_id;
281       }
282     }
283   }
284 }
285 
286 template <typename T>
ExtractPatchIntoBufferColumn3D(int b,int d,int h,int w,int kdepth,int kheight,int kwidth,int stride_depth,int stride_height,int stride_width,int pad_depth,int pad_height,int pad_width,int in_depth,int in_height,int in_width,int in_channel,int output_row_offset,const T * in_data,T * conv_buffer_data,uint8 zero_byte)287 inline void ExtractPatchIntoBufferColumn3D(
288     int b, int d, int h, int w,                             // Output indexes.
289     int kdepth, int kheight, int kwidth,                    // Kernel params.
290     int stride_depth, int stride_height, int stride_width,  // Stride params.
291     int pad_depth, int pad_height, int pad_width,           // Padding params.
292     int in_depth, int in_height, int in_width, int in_channel,  // Input shape.
293     int output_row_offset, const T* in_data, T* conv_buffer_data,
294     uint8 zero_byte) {
295   ruy::profiler::ScopeLabel label("ExtractPatchIntoBufferColumn3D");
296 
297   // This chunk of code reshapes all the inputs corresponding to
298   // output (b, d, h, w) to a column vector in conv_buffer(:, buffer_id).
299   const int id_ungated_start = d * stride_depth - pad_depth;
300   const int id_start = std::max(0, id_ungated_start);
301   const int id_ungated_end = (id_ungated_start + kdepth);
302   const int id_end = std::min(id_ungated_end, in_depth);
303 
304   const int ih_ungated_start = h * stride_height - pad_height;
305   const int ih_start = std::max(0, ih_ungated_start);
306   const int ih_ungated_end = (ih_ungated_start + kheight);
307   const int ih_end = std::min(ih_ungated_end, in_height);
308 
309   const int iw_ungated_start = w * stride_width - pad_width;
310   const int iw_start = std::max(0, iw_ungated_start);
311   const int iw_ungated_end = (iw_ungated_start + kwidth);
312   const int iw_end = std::min(iw_ungated_end, in_width);
313 
314   // Calculate the padding sizes.
315   const int d_padding_before = std::max(0, -id_ungated_start);
316   const int d_padding_after = (id_ungated_end - id_end);
317   const int h_padding_before = std::max(0, -ih_ungated_start);
318   const int h_padding_after = (ih_ungated_end - ih_end);
319   const int w_padding_before = std::max(0, -iw_ungated_start);
320   const int w_padding_after = (iw_ungated_end - iw_end);
321 
322   // Memset if there are paddings in the depth dimension.
323   const int kd_stride_size = kheight * kwidth * in_channel;
324   const int id_stride_size = in_height * in_width * in_channel;
325 
326   if (d_padding_before > 0) {
327     const int d_padding_before_elements = (d_padding_before * kd_stride_size);
328     memset(conv_buffer_data + output_row_offset, zero_byte,
329            (d_padding_before_elements * sizeof(T)));
330   }
331 
332   if (d_padding_after > 0) {
333     const int d_padding_after_elements = (d_padding_after * kd_stride_size);
334     const int bottom_start =
335         output_row_offset + (kdepth - d_padding_after) * kd_stride_size;
336     memset(conv_buffer_data + bottom_start, zero_byte,
337            (d_padding_after_elements * sizeof(T)));
338   }
339 
340   // If there are paddings in height or width dimension, menset the entire area
341   // to take advantage of sequential memory handling performance.
342   int out_offset = output_row_offset + d_padding_before * kd_stride_size;
343   if (h_padding_before > 0 || h_padding_after > 0 || w_padding_before > 0 ||
344       w_padding_after > 0) {
345     const int middle_elements = (id_end - id_start) * kd_stride_size;
346     memset(conv_buffer_data + out_offset, zero_byte,
347            (middle_elements * sizeof(T)));
348   }
349 
350   // Copy the valid data from the input tensor.
351   const int kh_stride_size = kwidth * in_channel;
352   const int ih_stride_size = in_width * in_channel;
353   const int h_padding = h_padding_before + h_padding_after;
354   const int w_padding = w_padding_before + w_padding_after;
355   const int single_row_num = (kwidth - w_padding) * in_channel;
356   out_offset +=
357       h_padding_before * kh_stride_size + w_padding_before * in_channel;
358   const int in_offset_without_d = b * in_depth * id_stride_size +
359                                   ih_start * ih_stride_size +
360                                   iw_start * in_channel;
361   for (int id = id_start; id < id_end; ++id) {
362     int in_offset = in_offset_without_d + id * id_stride_size;
363     for (int ih = ih_start; ih < ih_end; ++ih) {
364       memcpy(conv_buffer_data + out_offset, in_data + in_offset,
365              single_row_num * sizeof(T));
366       out_offset += kh_stride_size;
367       in_offset += ih_stride_size;
368     }
369     out_offset += h_padding * kh_stride_size;
370   }
371 }
372 
373 template <typename T>
Im2col3D(const Conv3DParams & params,int kdepth,int kheight,int kwidth,uint8 zero_byte,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & im2col_shape,T * im2col_data)374 void Im2col3D(const Conv3DParams& params, int kdepth, int kheight, int kwidth,
375               uint8 zero_byte, const RuntimeShape& input_shape,
376               const T* input_data, const RuntimeShape& im2col_shape,
377               T* im2col_data) {
378   ruy::profiler::ScopeLabel label("Im2col3D");
379   const int stride_depth = params.stride_depth;
380   const int stride_width = params.stride_width;
381   const int stride_height = params.stride_height;
382   const int pad_depth = params.padding_values.depth;
383   const int pad_width = params.padding_values.width;
384   const int pad_height = params.padding_values.height;
385   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 5);
386   TFLITE_DCHECK_EQ(im2col_shape.DimensionsCount(), 5);
387 
388   const int batches = MatchingDim(input_shape, 0, im2col_shape, 0);
389   const int input_depth = input_shape.Dims(1);
390   const int input_height = input_shape.Dims(2);
391   const int input_width = input_shape.Dims(3);
392   const int input_channel = input_shape.Dims(4);
393 
394   const int output_depth = im2col_shape.Dims(1);
395   const int output_height = im2col_shape.Dims(2);
396   const int output_width = im2col_shape.Dims(3);
397   const int output_channel = im2col_shape.Dims(4);
398 
399   int buffer_id = 0;
400   // Loop over the output nodes.
401   for (int b = 0; b < batches; ++b) {
402     for (int d = 0; d < output_depth; ++d) {
403       for (int h = 0; h < output_height; ++h) {
404         for (int w = 0; w < output_width; ++w) {
405           ExtractPatchIntoBufferColumn3D(
406               b, d, h, w, kdepth, kheight, kwidth, stride_depth, stride_height,
407               stride_width, pad_depth, pad_height, pad_width, input_depth,
408               input_height, input_width, input_channel, buffer_id, input_data,
409               im2col_data, zero_byte);
410           buffer_id += output_channel;
411         }
412       }
413     }
414   }
415 }
416 
417 template <typename T>
DilatedIm2col3D(const Conv3DParams & params,int filter_depth,int filter_height,int filter_width,uint8 zero_byte,const RuntimeShape & input_shape,const T * input_data,const RuntimeShape & im2col_shape,T * im2col_data)418 inline void DilatedIm2col3D(const Conv3DParams& params, int filter_depth,
419                             int filter_height, int filter_width,
420                             uint8 zero_byte, const RuntimeShape& input_shape,
421                             const T* input_data,
422                             const RuntimeShape& im2col_shape, T* im2col_data) {
423   ruy::profiler::ScopeLabel label("DilatedIm2col3D");
424   TFLITE_DCHECK_EQ(input_shape.DimensionsCount(), 5);
425   TFLITE_DCHECK_EQ(im2col_shape.DimensionsCount(), 5);
426 
427   // Only NDHWC format is currently supported.
428   const int batches = MatchingDim(input_shape, 0, im2col_shape, 0);
429   const int input_channels = input_shape.Dims(4);
430   const int input_width = input_shape.Dims(3);
431   const int input_height = input_shape.Dims(2);
432   const int input_depth = input_shape.Dims(1);
433 
434   const int output_width = im2col_shape.Dims(3);
435   const int output_height = im2col_shape.Dims(2);
436   const int output_depth = im2col_shape.Dims(1);
437 
438   const int pad_width = params.padding_values.width;
439   const int pad_height = params.padding_values.height;
440   const int pad_depth = params.padding_values.depth;
441 
442   // Construct the MxN sized im2col matrix.
443   // The rows M, are sub-ordered B x D x H x W.
444   const RuntimeShape row_shape(
445       {1, batches, output_depth, output_height, output_width});
446   // The columns, N, are sub-ordered Kd x Kh x Kw x Din.
447   const RuntimeShape col_shape(
448       {1, filter_depth, filter_height, filter_width, input_channels});
449   // Use dimensions M and N to construct dims for indexing directly into im2col.
450   const RuntimeShape im2col_reshaped(
451       {1, 1, row_shape.FlatSize(), col_shape.FlatSize()});
452 
453   for (int batch = 0; batch < batches; ++batch) {
454     for (int out_d = 0; out_d < output_depth; ++out_d) {
455       const int in_d_origin = (out_d * params.stride_depth) - pad_depth;
456       for (int out_y = 0; out_y < output_height; ++out_y) {
457         const int in_y_origin = (out_y * params.stride_height) - pad_height;
458         for (int out_x = 0; out_x < output_width; ++out_x) {
459           const int in_x_origin = (out_x * params.stride_width) - pad_width;
460           const int row_offset =
461               Offset(row_shape, 0, batch, out_d, out_y, out_x);
462           for (int filter_d = 0; filter_d < filter_depth; ++filter_d) {
463             const int in_d = in_d_origin + params.dilation_depth * filter_d;
464             if ((in_d >= 0) && (in_d < input_depth)) {
465               for (int filter_y = 0; filter_y < filter_height; ++filter_y) {
466                 const int in_y =
467                     in_y_origin + params.dilation_height * filter_y;
468                 if ((in_y >= 0) && (in_y < input_height)) {
469                   for (int filter_x = 0; filter_x < filter_width; ++filter_x) {
470                     const int in_x =
471                         in_x_origin + params.dilation_width * filter_x;
472                     int col_offset =
473                         Offset(col_shape, 0, filter_d, filter_y, filter_x, 0);
474                     T* dst = im2col_data + Offset(im2col_reshaped, 0, 0,
475                                                   row_offset, col_offset);
476                     if ((in_x >= 0) && (in_x < input_width)) {
477                       // Filter pixel is within the input, copy the input data.
478                       T const* src = input_data + Offset(input_shape, batch,
479                                                          in_d, in_y, in_x, 0);
480                       memcpy(dst, src, input_depth * sizeof(T));
481                     } else {
482                       // Filter pixel is outside the input, zero it out.
483                       memset(dst, zero_byte, input_depth * sizeof(T));
484                     }
485                   }
486                 } else {
487                   const int col_offset =
488                       Offset(col_shape, 0, filter_d, filter_y, 0, 0);
489                   T* dst = im2col_data + Offset(im2col_reshaped, 0, 0,
490                                                 row_offset, col_offset);
491                   memset(dst, zero_byte,
492                          filter_width * input_depth * sizeof(T));
493                 }
494               }
495             } else {
496               const int col_offset = Offset(col_shape, 0, filter_d, 0, 0, 0);
497               T* dst = im2col_data +
498                        Offset(im2col_reshaped, 0, 0, row_offset, col_offset);
499               memset(dst, zero_byte,
500                      filter_height * filter_width * input_depth * sizeof(T));
501             }
502           }
503         }
504       }
505     }
506   }
507 }
508 
509 }  // namespace optimized_ops
510 }  // namespace tflite
511 
512 #endif  // TENSORFLOW_LITE_KERNELS_INTERNAL_OPTIMIZED_IM2COL_UTILS_H_
513