xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/expand_dims.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <stdint.h>
16 #include <string.h>
17 
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20 #include "tensorflow/lite/kernels/internal/tensor.h"
21 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 namespace expand_dims {
28 
29 // Input indices
30 enum { kInput = 0, kAxis };
31 
32 namespace {
ExpandTensorDim(TfLiteContext * context,const TfLiteTensor & input,int axis,TfLiteTensor * output)33 TfLiteStatus ExpandTensorDim(TfLiteContext* context, const TfLiteTensor& input,
34                              int axis, TfLiteTensor* output) {
35   const TfLiteIntArray& input_dims = *input.dims;
36   if (axis < 0) {
37     axis = input_dims.size + 1 + axis;
38   }
39   TF_LITE_ENSURE(context, axis <= input_dims.size);
40   TF_LITE_ENSURE(context, axis >= 0);
41 
42   TfLiteIntArray* output_dims = TfLiteIntArrayCreate(input_dims.size + 1);
43   for (int i = 0; i < output_dims->size; ++i) {
44     if (i < axis) {
45       output_dims->data[i] = input_dims.data[i];
46     } else if (i == axis) {
47       output_dims->data[i] = 1;
48     } else {
49       output_dims->data[i] = input_dims.data[i - 1];
50     }
51   }
52 
53   return context->ResizeTensor(context, output, output_dims);
54 }
55 
GetAxisValueFromTensor(TfLiteContext * context,const TfLiteTensor & axis,int * axis_value)56 TfLiteStatus GetAxisValueFromTensor(TfLiteContext* context,
57                                     const TfLiteTensor& axis, int* axis_value) {
58   TF_LITE_ENSURE_EQ(context, NumElements(&axis), 1);
59   switch (axis.type) {
60     case kTfLiteInt32:
61       *axis_value = *GetTensorData<int32_t>(&axis);
62       return kTfLiteOk;
63     case kTfLiteInt64:
64       *axis_value = *GetTensorData<int64_t>(&axis);
65       return kTfLiteOk;
66     default:
67       return kTfLiteError;
68   }
69 }
70 
71 }  // namespace
72 
Prepare(TfLiteContext * context,TfLiteNode * node)73 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
74   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
75   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
76 
77   const TfLiteTensor* input;
78   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
79   const TfLiteTensor* axis;
80   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis));
81   TfLiteTensor* output;
82   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
83 
84   output->type = input->type;
85   TF_LITE_ENSURE_EQ(context, input->params.scale, output->params.scale);
86   TF_LITE_ENSURE_EQ(context, input->params.zero_point,
87                     output->params.zero_point);
88   if (input->type == kTfLiteInt16) {
89     TF_LITE_ENSURE_EQ(context, input->params.zero_point, 0);
90   }
91 
92   if (IsConstantTensor(axis)) {
93     int axis_value;
94     TF_LITE_ENSURE_OK(context,
95                       GetAxisValueFromTensor(context, *axis, &axis_value));
96     return ExpandTensorDim(context, *input, axis_value, output);
97   }
98   SetTensorToDynamic(output);
99 
100   return kTfLiteOk;
101 }
102 
Eval(TfLiteContext * context,TfLiteNode * node)103 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
104   // Just copy input to output.
105   const TfLiteTensor* input;
106   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
107   TfLiteTensor* output;
108   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, 0, &output));
109   const TfLiteTensor* axis;
110   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxis, &axis));
111   if (IsDynamicTensor(output)) {
112     int axis_value;
113     TF_LITE_ENSURE_OK(context,
114                       GetAxisValueFromTensor(context, *axis, &axis_value));
115     TF_LITE_ENSURE_OK(context,
116                       ExpandTensorDim(context, *input, axis_value, output));
117   }
118   if (output->type == kTfLiteString) {
119     TfLiteTensorRealloc(input->bytes, output);
120   }
121   memcpy(output->data.raw, input->data.raw, input->bytes);
122   return kTfLiteOk;
123 }
124 
125 }  // namespace expand_dims
Register_EXPAND_DIMS()126 TfLiteRegistration* Register_EXPAND_DIMS() {
127   static TfLiteRegistration r = {nullptr, nullptr, expand_dims::Prepare,
128                                  expand_dims::Eval};
129   return &r;
130 }
131 }  // namespace builtin
132 }  // namespace ops
133 }  // namespace tflite
134