xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/segment_sum.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <stdint.h>
17 
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20 #include "tensorflow/lite/kernels/internal/tensor.h"
21 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 namespace segment_sum {
28 
29 static const int kInputDataTensor = 0;
30 static const int kInputSegmentIdsTensor = 1;
31 static const int kOutputTensor = 0;
32 
ResizeOutputTensor(TfLiteContext * context,const TfLiteTensor * data,const TfLiteTensor * segment_ids,TfLiteTensor * output)33 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
34                                 const TfLiteTensor* data,
35                                 const TfLiteTensor* segment_ids,
36                                 TfLiteTensor* output) {
37   // Segment ids should be of same cardinality as first input dimension and they
38   // should be increasing by at most 1, from 0 (e.g., [0, 0, 1, 2, 3] is valid)
39   const int segment_id_size = segment_ids->dims->data[0];
40   TF_LITE_ENSURE_EQ(context, segment_id_size, data->dims->data[0]);
41   int previous_segment_id = -1;
42   for (int i = 0; i < segment_id_size; i++) {
43     const int current_segment_id = GetTensorData<int32_t>(segment_ids)[i];
44     if (i == 0) {
45       TF_LITE_ENSURE_EQ(context, current_segment_id, 0);
46     } else {
47       int delta = current_segment_id - previous_segment_id;
48       TF_LITE_ENSURE(context, delta == 0 || delta == 1);
49     }
50     previous_segment_id = current_segment_id;
51   }
52 
53   const int max_index = previous_segment_id;
54 
55   const int data_rank = NumDimensions(data);
56   TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(data));
57   output_shape->data[0] = max_index + 1;
58   for (int i = 1; i < data_rank; ++i) {
59     output_shape->data[i] = data->dims->data[i];
60   }
61   return context->ResizeTensor(context, output, output_shape);
62 }
63 
Prepare(TfLiteContext * context,TfLiteNode * node)64 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
65   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
66   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
67   const TfLiteTensor* data;
68   TF_LITE_ENSURE_OK(context,
69                     GetInputSafe(context, node, kInputDataTensor, &data));
70   const TfLiteTensor* segment_ids;
71   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputSegmentIdsTensor,
72                                           &segment_ids));
73   TfLiteTensor* output;
74   TF_LITE_ENSURE_OK(context,
75                     GetOutputSafe(context, node, kOutputTensor, &output));
76   TF_LITE_ENSURE(context,
77                  data->type == kTfLiteInt32 || data->type == kTfLiteFloat32);
78   TF_LITE_ENSURE_EQ(context, segment_ids->type, kTfLiteInt32);
79 
80   if (!IsConstantTensor(data) || !IsConstantTensor(segment_ids)) {
81     SetTensorToDynamic(output);
82     return kTfLiteOk;
83   }
84 
85   return ResizeOutputTensor(context, data, segment_ids, output);
86 }
87 
Eval(TfLiteContext * context,TfLiteNode * node)88 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
89   const TfLiteTensor* data;
90   TF_LITE_ENSURE_OK(context,
91                     GetInputSafe(context, node, kInputDataTensor, &data));
92   const TfLiteTensor* segment_ids;
93   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputSegmentIdsTensor,
94                                           &segment_ids));
95   TfLiteTensor* output;
96   TF_LITE_ENSURE_OK(context,
97                     GetOutputSafe(context, node, kOutputTensor, &output));
98 
99   if (IsDynamicTensor(output)) {
100     TF_LITE_ENSURE_OK(context,
101                       ResizeOutputTensor(context, data, segment_ids, output));
102   }
103 
104 #define TF_LITE_SEGMENT_SUM(dtype)                                      \
105   reference_ops::SegmentSum<dtype>(                                     \
106       GetTensorShape(data), GetTensorData<dtype>(data),                 \
107       GetTensorShape(segment_ids), GetTensorData<int32_t>(segment_ids), \
108       GetTensorShape(output), GetTensorData<dtype>(output));
109   switch (data->type) {
110     case kTfLiteInt32:
111       TF_LITE_SEGMENT_SUM(int32_t);
112       break;
113     case kTfLiteFloat32:
114       TF_LITE_SEGMENT_SUM(float);
115       break;
116     default:
117       TF_LITE_KERNEL_LOG(context,
118                          "Currently SegmentSum doesn't support type: %s",
119                          TfLiteTypeGetName(data->type));
120       return kTfLiteError;
121   }
122 #undef TF_LITE_SEGMENT_SUM
123   return kTfLiteOk;
124 }
125 
126 }  // namespace segment_sum
127 
Register_SEGMENT_SUM()128 TfLiteRegistration* Register_SEGMENT_SUM() {
129   static TfLiteRegistration r = {nullptr, nullptr, segment_sum::Prepare,
130                                  segment_sum::Eval};
131   return &r;
132 }
133 
134 }  // namespace builtin
135 }  // namespace ops
136 }  // namespace tflite
137