xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/hexagon/builders/pack_builder.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/hexagon/builders/pack_builder.h"
16 
17 #include <stdint.h>
18 
19 #include <limits>
20 
21 #include "tensorflow/lite/c/builtin_op_data.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 
24 namespace tflite {
25 namespace delegates {
26 namespace hexagon {
27 namespace {
28 
GetAxis(int axis,const TfLiteIntArray * inputs,TfLiteContext * context)29 int GetAxis(int axis, const TfLiteIntArray* inputs, TfLiteContext* context) {
30   auto& input_tensor = context->tensors[inputs->data[0]];
31   // Handle -ve axis.
32   if (axis < 0) {
33     axis += input_tensor.dims->size + 1;
34   }
35   // We need to adjust the axis to be as if the inputs are of rank 4, since
36   // we represent tensors in Hexagon of rank 4.
37   return (4 - input_tensor.dims->size) + axis - 1;
38 }
39 
40 }  // namespace
PopulateSubGraph(const TfLiteIntArray * inputs,const TfLiteIntArray * outputs,TfLiteContext * context)41 TfLiteStatus PackOpBuilder::PopulateSubGraph(const TfLiteIntArray* inputs,
42                                              const TfLiteIntArray* outputs,
43                                              TfLiteContext* context) {
44   auto* params = reinterpret_cast<TfLitePackParams*>(builtin_data_);
45   int axis = GetAxis(params->axis, inputs, context);
46   // Add axis
47   auto* axis_node = graph_builder_->AddConstNodeWithData(
48       kScalarShape, reinterpret_cast<char*>(&axis), sizeof(axis));
49   AddInput(TensorID(axis_node->GetID(), 0));
50 
51   // Add all input tensors.
52   minima_.reserve(inputs->size);
53   maxima_.reserve(inputs->size);
54   int tensor_id = -1;
55   float data_min, data_max;
56   for (int i = 0; i < inputs->size; ++i) {
57     tensor_id = inputs->data[i];
58     auto& input_tensor = context->tensors[tensor_id];
59     AddInput(graph_builder_->GetHexagonTensorId(tensor_id));
60     TF_LITE_ENSURE_STATUS(
61         ComputeMinAndMaxQuantValues(input_tensor, &data_min, &data_max));
62     minima_.push_back(data_min);
63     maxima_.push_back(data_max);
64   }
65 
66   // Minima tensors.
67   for (int i = 0; i < minima_.size(); ++i) {
68     auto* data_min_const = graph_builder_->AddConstNodeWithData(
69         kScalarShape, reinterpret_cast<char*>(&minima_[i]), sizeof(minima_[i]));
70     AddInput(TensorID(data_min_const->GetID(), 0));
71   }
72 
73   // Maxima tensors.
74   for (int i = 0; i < maxima_.size(); ++i) {
75     auto* data_max_const = graph_builder_->AddConstNodeWithData(
76         kScalarShape, reinterpret_cast<char*>(&maxima_[i]), sizeof(maxima_[i]));
77     AddInput(TensorID(data_max_const->GetID(), 0));
78   }
79 
80   // Hexagon outputs for this node.
81   int output_batch_size, output_height_size, output_width_size,
82       output_depth_size;
83   GetDims(&output_batch_size, &output_height_size, &output_width_size,
84           &output_depth_size, context->tensors[outputs->data[0]].dims);
85 
86   TensorID pack_out = AddOutput(sizeof(uint8_t), 4,
87                                 {output_batch_size, output_height_size,
88                                  output_width_size, output_depth_size});
89 
90   // Output min/max for requantization.
91   float output_min, output_max;
92   TF_LITE_ENSURE_STATUS(ComputeMinAndMaxQuantValues(
93       context->tensors[outputs->data[0]], &output_min, &output_max));
94   auto* output_min_const = graph_builder_->AddConstNodeWithData(
95       kScalarShape, reinterpret_cast<char*>(&output_min), sizeof(output_min));
96   auto* output_max_const = graph_builder_->AddConstNodeWithData(
97       kScalarShape, reinterpret_cast<char*>(&output_max), sizeof(output_max));
98 
99   const auto& pack_out_min = AddOutput(sizeof(float), 4, kScalarShape);
100   const auto& pack_out_max = AddOutput(sizeof(float), 4, kScalarShape);
101 
102   // Requantize output to the expected min/max.
103   auto* requantize_op = graph_builder_->AddNode(GetTFLiteNodeID());
104   requantize_op->SetOpType(OP_Requantize_8to8);
105   requantize_op->AddInput(pack_out);
106   requantize_op->AddInput(pack_out_min);
107   requantize_op->AddInput(pack_out_max);
108   requantize_op->AddInput(TensorID(output_min_const->GetID(), 0));
109   requantize_op->AddInput(TensorID(output_max_const->GetID(), 0));
110   node_output_ =
111       requantize_op->AddOutput(sizeof(uint8_t), 4,
112                                {output_batch_size, output_height_size,
113                                 output_width_size, output_depth_size});
114   requantize_op->AddOutput(sizeof(float), 4, kScalarShape);
115   requantize_op->AddOutput(sizeof(float), 4, kScalarShape);
116   return kTfLiteOk;
117 }
118 
RegisterOutputs(const TfLiteIntArray * outputs,TfLiteContext * context)119 TfLiteStatus PackOpBuilder::RegisterOutputs(const TfLiteIntArray* outputs,
120                                             TfLiteContext* context) {
121   // Should be only 1 output.
122   graph_builder_->AddTensorWithID(outputs->data[0], node_output_.first,
123                                   node_output_.second);
124   return kTfLiteOk;
125 }
126 
CreatePackBuilder(GraphBuilder * graph_builder,int op_type)127 OpBuilder* CreatePackBuilder(GraphBuilder* graph_builder, int op_type) {
128   return new PackOpBuilder(graph_builder, op_type);
129 }
130 
131 }  // namespace hexagon
132 }  // namespace delegates
133 }  // namespace tflite
134