1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <stdint.h>
17
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20 #include "tensorflow/lite/kernels/internal/tensor.h"
21 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23
24 namespace tflite {
25 namespace ops {
26 namespace builtin {
27 namespace segment_sum {
28
29 static const int kInputDataTensor = 0;
30 static const int kInputSegmentIdsTensor = 1;
31 static const int kOutputTensor = 0;
32
ResizeOutputTensor(TfLiteContext * context,const TfLiteTensor * data,const TfLiteTensor * segment_ids,TfLiteTensor * output)33 TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
34 const TfLiteTensor* data,
35 const TfLiteTensor* segment_ids,
36 TfLiteTensor* output) {
37 // Segment ids should be of same cardinality as first input dimension and they
38 // should be increasing by at most 1, from 0 (e.g., [0, 0, 1, 2, 3] is valid)
39 const int segment_id_size = segment_ids->dims->data[0];
40 TF_LITE_ENSURE_EQ(context, segment_id_size, data->dims->data[0]);
41 int previous_segment_id = -1;
42 for (int i = 0; i < segment_id_size; i++) {
43 const int current_segment_id = GetTensorData<int32_t>(segment_ids)[i];
44 if (i == 0) {
45 TF_LITE_ENSURE_EQ(context, current_segment_id, 0);
46 } else {
47 int delta = current_segment_id - previous_segment_id;
48 TF_LITE_ENSURE(context, delta == 0 || delta == 1);
49 }
50 previous_segment_id = current_segment_id;
51 }
52
53 const int max_index = previous_segment_id;
54
55 const int data_rank = NumDimensions(data);
56 TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(data));
57 output_shape->data[0] = max_index + 1;
58 for (int i = 1; i < data_rank; ++i) {
59 output_shape->data[i] = data->dims->data[i];
60 }
61 return context->ResizeTensor(context, output, output_shape);
62 }
63
Prepare(TfLiteContext * context,TfLiteNode * node)64 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
65 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
66 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
67 const TfLiteTensor* data;
68 TF_LITE_ENSURE_OK(context,
69 GetInputSafe(context, node, kInputDataTensor, &data));
70 const TfLiteTensor* segment_ids;
71 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputSegmentIdsTensor,
72 &segment_ids));
73 TfLiteTensor* output;
74 TF_LITE_ENSURE_OK(context,
75 GetOutputSafe(context, node, kOutputTensor, &output));
76 TF_LITE_ENSURE(context,
77 data->type == kTfLiteInt32 || data->type == kTfLiteFloat32);
78 TF_LITE_ENSURE_EQ(context, segment_ids->type, kTfLiteInt32);
79
80 if (!IsConstantTensor(data) || !IsConstantTensor(segment_ids)) {
81 SetTensorToDynamic(output);
82 return kTfLiteOk;
83 }
84
85 return ResizeOutputTensor(context, data, segment_ids, output);
86 }
87
Eval(TfLiteContext * context,TfLiteNode * node)88 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
89 const TfLiteTensor* data;
90 TF_LITE_ENSURE_OK(context,
91 GetInputSafe(context, node, kInputDataTensor, &data));
92 const TfLiteTensor* segment_ids;
93 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputSegmentIdsTensor,
94 &segment_ids));
95 TfLiteTensor* output;
96 TF_LITE_ENSURE_OK(context,
97 GetOutputSafe(context, node, kOutputTensor, &output));
98
99 if (IsDynamicTensor(output)) {
100 TF_LITE_ENSURE_OK(context,
101 ResizeOutputTensor(context, data, segment_ids, output));
102 }
103
104 #define TF_LITE_SEGMENT_SUM(dtype) \
105 reference_ops::SegmentSum<dtype>( \
106 GetTensorShape(data), GetTensorData<dtype>(data), \
107 GetTensorShape(segment_ids), GetTensorData<int32_t>(segment_ids), \
108 GetTensorShape(output), GetTensorData<dtype>(output));
109 switch (data->type) {
110 case kTfLiteInt32:
111 TF_LITE_SEGMENT_SUM(int32_t);
112 break;
113 case kTfLiteFloat32:
114 TF_LITE_SEGMENT_SUM(float);
115 break;
116 default:
117 TF_LITE_KERNEL_LOG(context,
118 "Currently SegmentSum doesn't support type: %s",
119 TfLiteTypeGetName(data->type));
120 return kTfLiteError;
121 }
122 #undef TF_LITE_SEGMENT_SUM
123 return kTfLiteOk;
124 }
125
126 } // namespace segment_sum
127
Register_SEGMENT_SUM()128 TfLiteRegistration* Register_SEGMENT_SUM() {
129 static TfLiteRegistration r = {nullptr, nullptr, segment_sum::Prepare,
130 segment_sum::Eval};
131 return &r;
132 }
133
134 } // namespace builtin
135 } // namespace ops
136 } // namespace tflite
137