xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/conv3d.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/kernels/internal/reference/conv3d.h"
17 
18 #include <cstddef>
19 #include <cstdint>
20 #include <vector>
21 
22 #include "tensorflow/lite/c/builtin_op_data.h"
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/kernels/cpu_backend_context.h"
25 #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
26 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
27 #include "tensorflow/lite/kernels/internal/types.h"
28 #include "tensorflow/lite/kernels/kernel_util.h"
29 #include "tensorflow/lite/kernels/padding.h"
30 #include "tensorflow/lite/util.h"
31 
32 namespace tflite {
33 namespace ops {
34 namespace builtin {
35 namespace conv3d {
36 
37 enum KernelType {
38   kReference,
39   kGenericOptimized,
40 };
41 
42 // Struct to carry data from Prepare to Eval.
43 const int kTensorNotAllocated = -1;
44 static constexpr size_t kMaxIm2colBufferSizeMobile = 1024 * 1024 * 1024;  // 1GB
45 
46 struct OpData {
47   Padding3DValues padding;
48   int im2col_tensor_id = kTensorNotAllocated;
49   int transposed_filter_tensor_id = kTensorNotAllocated;
50 
51   bool need_im2col = false;
52   bool need_transposed_filter = false;
53 
54   // Disable im2col if the temporary im2col tensor requires too much memory
55   // (i.e. >= kMaxIm2colBufferSizeMobile).
56   bool im2col_oversized = false;
57 
58   int32_t im2col_index;
59   int32_t transposed_filter_index;
60 };
61 
Init(TfLiteContext * context,const char * buffer,size_t length)62 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
63   auto* opdata = new OpData;
64   return opdata;
65 }
66 
Free(TfLiteContext * context,void * buffer)67 void Free(TfLiteContext* context, void* buffer) {
68   delete static_cast<OpData*>(buffer);
69 }
70 
AllocateTemporaryTensorsIfRequired(KernelType kernel_type,TfLiteContext * context,TfLiteNode * node,OpData * opdata,TfLiteConv3DParams * params,const TfLiteTensor * filter,size_t im2col_bytes)71 TfLiteStatus AllocateTemporaryTensorsIfRequired(
72     KernelType kernel_type, TfLiteContext* context, TfLiteNode* node,
73     OpData* opdata, TfLiteConv3DParams* params, const TfLiteTensor* filter,
74     size_t im2col_bytes) {
75   int temporaries_count = 0;
76   const bool need_dilated_im2col = params->dilation_width_factor != 1 ||
77                                    params->dilation_height_factor != 1 ||
78                                    params->dilation_depth_factor != 1;
79   const bool need_non_dilated_im2col =
80       params->stride_depth != 1 || params->stride_width != 1 ||
81       params->stride_height != 1 || filter->dims->data[2] != 1 ||
82       filter->dims->data[1] != 1 || filter->dims->data[0] != 1;
83 
84   opdata->need_im2col = (kernel_type == kGenericOptimized) &&
85                         (need_dilated_im2col || need_non_dilated_im2col);
86   // TODO(b/183455632): Add transposing logic in converter so constant folding
87   // might work on constant filter tensor.
88   opdata->need_transposed_filter = (kernel_type == kGenericOptimized);
89 
90   // On mobile platforms, the generic optimized kernel will not be used if the
91   // temporary im2col tensor requires too much memory.
92   if (IsMobilePlatform() && opdata->need_im2col &&
93       im2col_bytes >= kMaxIm2colBufferSizeMobile) {
94     opdata->need_im2col = false;
95     opdata->need_transposed_filter = false;
96     opdata->im2col_oversized = true;
97   }
98 
99   if (opdata->need_im2col) {
100     if (opdata->im2col_tensor_id == kTensorNotAllocated) {
101       TF_LITE_ENSURE_OK(
102           context, context->AddTensors(context, 1, &opdata->im2col_tensor_id));
103     }
104     opdata->im2col_index = temporaries_count++;
105   }
106 
107   if (opdata->need_transposed_filter) {
108     if (opdata->transposed_filter_tensor_id == kTensorNotAllocated) {
109       TF_LITE_ENSURE_OK(
110           context, context->AddTensors(context, 1,
111                                        &opdata->transposed_filter_tensor_id));
112     }
113     opdata->transposed_filter_index = temporaries_count++;
114   }
115 
116   TfLiteIntArrayFree(node->temporaries);
117   node->temporaries = TfLiteIntArrayCreate(temporaries_count);
118   return kTfLiteOk;
119 }
120 
Prepare(KernelType kernel_type,TfLiteContext * context,TfLiteNode * node)121 TfLiteStatus Prepare(KernelType kernel_type, TfLiteContext* context,
122                      TfLiteNode* node) {
123   auto* params = static_cast<TfLiteConv3DParams*>(node->builtin_data);
124   OpData* opdata = reinterpret_cast<OpData*>(node->user_data);
125 
126   // Check number of inputs/outputs.
127   TF_LITE_ENSURE(context, node->inputs->size == 2 || node->inputs->size == 3);
128   TF_LITE_ENSURE_EQ(context, node->outputs->size, 1);
129   TfLiteTensor* output;
130   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
131   const TfLiteTensor* input;
132   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
133   const TfLiteTensor* filter;
134   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &filter));
135 
136   // Check dimensionality of input, filter.
137   TF_LITE_ENSURE_EQ(context, input->dims->size, 5);
138   TF_LITE_ENSURE_EQ(context, filter->dims->size, 5);
139 
140   // Check input channels matching filter.
141   TF_LITE_ENSURE_EQ(context, input->dims->data[4], filter->dims->data[3]);
142 
143   // Check types.
144   TfLiteType input_type = input->type;
145   TF_LITE_ENSURE_TYPES_EQ(context, input_type, kTfLiteFloat32);
146   TF_LITE_ENSURE_TYPES_EQ(context, filter->type, kTfLiteFloat32);
147   TF_LITE_ENSURE_TYPES_EQ(context, output->type, input_type);
148 
149   // Check bias.
150   const TfLiteTensor* bias = GetInput(context, node, 2);
151   if (bias) {
152     TF_LITE_ENSURE_TYPES_EQ(context, bias->type, input_type);
153     TF_LITE_ENSURE_EQ(context, NumElements(bias), SizeOfDimension(filter, 4));
154   }
155 
156   // Filter has shape of [filter_depth, filter_height, filter_width,
157   // in_channels, out_channels].
158   int batches = input->dims->data[0];
159   int channels_out = filter->dims->data[4];
160   int depth = input->dims->data[1];
161   int height = input->dims->data[2];
162   int width = input->dims->data[3];
163   int filter_depth = filter->dims->data[0];
164   int filter_height = filter->dims->data[1];
165   int filter_width = filter->dims->data[2];
166   int input_channel = filter->dims->data[3];
167 
168   // Matching GetWindowedOutputSize in TensorFlow.
169   int out_width, out_height, out_depth;
170   opdata->padding = ComputePadding3DValues(
171       params->stride_height, params->stride_width, params->stride_depth,
172       params->dilation_height_factor, params->dilation_width_factor,
173       params->dilation_depth_factor, height, width, depth, filter_height,
174       filter_width, filter_depth, params->padding, &out_height, &out_width,
175       &out_depth);
176 
177   TfLiteIntArray* output_size = TfLiteIntArrayCreate(5);
178   output_size->data[0] = batches;
179   output_size->data[1] = out_depth;
180   output_size->data[2] = out_height;
181   output_size->data[3] = out_width;
182   output_size->data[4] = channels_out;
183   TF_LITE_ENSURE_OK(context,
184                     context->ResizeTensor(context, output, output_size));
185 
186   // Allocate temporary tensors.
187   size_t input_type_size;
188   TF_LITE_ENSURE_STATUS(GetSizeOfType(context, input->type, &input_type_size));
189   const size_t im2col_bytes = batches * out_depth * out_height * out_width *
190                               input_channel * filter_depth * filter_height *
191                               filter_width * input_type_size;
192   TF_LITE_ENSURE_OK(context, AllocateTemporaryTensorsIfRequired(
193                                  kernel_type, context, node, opdata, params,
194                                  filter, im2col_bytes));
195 
196   if (opdata->need_im2col) {
197     TfLiteIntArray* im2col_size = TfLiteIntArrayCreate(5);
198     im2col_size->data[0] = output_size->data[0];
199     im2col_size->data[1] = output_size->data[1];
200     im2col_size->data[2] = output_size->data[2];
201     im2col_size->data[3] = output_size->data[3];
202     im2col_size->data[4] =
203         input_channel * filter_depth * filter_height * filter_width;
204 
205     TfLiteTensor* im2col;
206     node->temporaries->data[opdata->im2col_index] = opdata->im2col_tensor_id;
207     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node,
208                                                 opdata->im2col_index, &im2col));
209     im2col->type = input->type;
210     im2col->allocation_type = kTfLiteArenaRw;
211     TF_LITE_ENSURE_OK(context,
212                       context->ResizeTensor(context, im2col, im2col_size));
213   }
214 
215   if (opdata->need_transposed_filter) {
216     TfLiteIntArray* transposed_filter_size = TfLiteIntArrayCreate(5);
217     transposed_filter_size->data[0] = filter->dims->data[4];
218     transposed_filter_size->data[1] = filter->dims->data[0];
219     transposed_filter_size->data[2] = filter->dims->data[1];
220     transposed_filter_size->data[3] = filter->dims->data[2];
221     transposed_filter_size->data[4] = filter->dims->data[3];
222 
223     TfLiteTensor* transposed_filter;
224     node->temporaries->data[opdata->transposed_filter_index] =
225         opdata->transposed_filter_tensor_id;
226     TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node,
227                                                 opdata->transposed_filter_index,
228                                                 &transposed_filter));
229     transposed_filter->type = filter->type;
230     transposed_filter->allocation_type = kTfLiteArenaRw;
231     TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, transposed_filter,
232                                                      transposed_filter_size));
233   }
234   return kTfLiteOk;
235 }
236 
237 template <KernelType kernel_type>
Prepare(TfLiteContext * context,TfLiteNode * node)238 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
239   return Prepare(kernel_type, context, node);
240 }
241 
EvalFloat(KernelType kernel_type,TfLiteContext * context,TfLiteNode * node,TfLiteConv3DParams * params,OpData * opdata,const TfLiteTensor * input,const TfLiteTensor * filter,const TfLiteTensor * bias,TfLiteTensor * im2col,TfLiteTensor * tranposed_filter,TfLiteTensor * output)242 void EvalFloat(KernelType kernel_type, TfLiteContext* context, TfLiteNode* node,
243                TfLiteConv3DParams* params, OpData* opdata,
244                const TfLiteTensor* input, const TfLiteTensor* filter,
245                const TfLiteTensor* bias, TfLiteTensor* im2col,
246                TfLiteTensor* tranposed_filter, TfLiteTensor* output) {
247   float output_activation_min, output_activation_max;
248   CalculateActivationRange(params->activation, &output_activation_min,
249                            &output_activation_max);
250 
251   Conv3DParams runtime_params;
252   runtime_params.padding_values = opdata->padding;
253   runtime_params.stride_depth = params->stride_depth;
254   runtime_params.stride_height = params->stride_height;
255   runtime_params.stride_width = params->stride_width;
256   runtime_params.dilation_depth = params->dilation_depth_factor;
257   runtime_params.dilation_height = params->dilation_height_factor;
258   runtime_params.dilation_width = params->dilation_width_factor;
259   runtime_params.float_activation_min = output_activation_min;
260   runtime_params.float_activation_max = output_activation_max;
261   switch (kernel_type) {
262     case kReference: {
263       reference_ops::Conv3D(runtime_params, GetTensorShape(input),
264                             GetTensorData<float>(input), GetTensorShape(filter),
265                             GetTensorData<float>(filter), GetTensorShape(bias),
266                             GetTensorData<float>(bias), GetTensorShape(output),
267                             GetTensorData<float>(output));
268       break;
269     }
270     case kGenericOptimized: {
271       optimized_ops::Conv3D(
272           runtime_params, GetTensorShape(input), GetTensorData<float>(input),
273           GetTensorShape(filter), GetTensorData<float>(filter),
274           GetTensorShape(bias), GetTensorData<float>(bias),
275           GetTensorShape(output), GetTensorData<float>(output),
276           GetTensorShape(im2col), GetTensorData<float>(im2col),
277           GetTensorShape(tranposed_filter),
278           GetTensorData<float>(tranposed_filter),
279           CpuBackendContext::GetFromContext(context));
280     } break;
281   }
282 }
283 
Eval(KernelType kernel_type,TfLiteContext * context,TfLiteNode * node)284 TfLiteStatus Eval(KernelType kernel_type, TfLiteContext* context,
285                   TfLiteNode* node) {
286   auto* params = reinterpret_cast<TfLiteConv3DParams*>(node->builtin_data);
287   OpData* opdata = reinterpret_cast<OpData*>(node->user_data);
288 
289   TfLiteTensor* output;
290   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
291   const TfLiteTensor* input;
292   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 0, &input));
293   const TfLiteTensor* filter;
294   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, 1, &filter));
295   const TfLiteTensor* bias = GetInput(context, node, 2);
296 
297   TfLiteTensor* im2col = opdata->need_im2col
298                              ? &context->tensors[opdata->im2col_tensor_id]
299                              : nullptr;
300   TfLiteTensor* transposed_filter =
301       opdata->need_transposed_filter
302           ? &context->tensors[opdata->transposed_filter_tensor_id]
303           : nullptr;
304 
305   // Fallback to reference execution path when im2col is needed but disabled.
306   if (opdata->im2col_oversized) {
307     kernel_type = kReference;
308   }
309 
310   switch (input->type) {
311     case kTfLiteFloat32:
312       EvalFloat(kernel_type, context, node, params, opdata, input, filter, bias,
313                 im2col, transposed_filter, output);
314       break;
315     default:
316       TF_LITE_KERNEL_LOG(context, "Type %s currently not supported.",
317                          TfLiteTypeGetName(input->type));
318       return kTfLiteError;
319   }
320   return kTfLiteOk;
321 }
322 
323 template <KernelType kernel_type>
Eval(TfLiteContext * context,TfLiteNode * node)324 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
325   return Eval(kernel_type, context, node);
326 }
327 
328 }  // namespace conv3d
329 
Register_CONV_3D_REF()330 TfLiteRegistration* Register_CONV_3D_REF() {
331   static TfLiteRegistration r = {conv3d::Init, conv3d::Free,
332                                  conv3d::Prepare<conv3d::kReference>,
333                                  conv3d::Eval<conv3d::kReference>};
334   return &r;
335 }
336 
Register_CONV_3D_GENERIC_OPT()337 TfLiteRegistration* Register_CONV_3D_GENERIC_OPT() {
338   static TfLiteRegistration r = {conv3d::Init, conv3d::Free,
339                                  conv3d::Prepare<conv3d::kGenericOptimized>,
340                                  conv3d::Eval<conv3d::kGenericOptimized>};
341   return &r;
342 }
343 
Register_CONV_3D()344 TfLiteRegistration* Register_CONV_3D() {
345   return Register_CONV_3D_GENERIC_OPT();
346 }
347 
348 }  // namespace builtin
349 }  // namespace ops
350 }  // namespace tflite
351