xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/flex/util.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/flex/util.h"
16 
17 #include <string>
18 
19 #include "absl/strings/str_format.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/platform/status.h"
22 #include "tensorflow/core/platform/statusor.h"
23 #include "tensorflow/core/protobuf/error_codes.pb.h"
24 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
25 #include "tensorflow/lite/string_util.h"
26 
27 namespace tflite {
28 namespace flex {
29 
30 static constexpr char kResourceVariablePrefix[] = "tflite_resource_variable";
31 
ConvertStatus(TfLiteContext * context,const tensorflow::Status & status)32 TfLiteStatus ConvertStatus(TfLiteContext* context,
33                            const tensorflow::Status& status) {
34   if (!status.ok()) {
35     TF_LITE_KERNEL_LOG(context, "%s", status.error_message().c_str());
36     return kTfLiteError;
37   }
38   return kTfLiteOk;
39 }
40 
CopyShapeAndType(TfLiteContext * context,const tensorflow::Tensor & src,TfLiteTensor * tensor)41 TfLiteStatus CopyShapeAndType(TfLiteContext* context,
42                               const tensorflow::Tensor& src,
43                               TfLiteTensor* tensor) {
44   tensor->type = GetTensorFlowLiteType(static_cast<TF_DataType>(src.dtype()));
45   if (tensor->type == kTfLiteNoType) {
46     TF_LITE_KERNEL_LOG(context,
47                        "TF Lite does not support TensorFlow data type: %s",
48                        DataTypeString(src.dtype()).c_str());
49     return kTfLiteError;
50   }
51 
52   int num_dims = src.dims();
53   TfLiteIntArray* shape = TfLiteIntArrayCreate(num_dims);
54   for (int j = 0; j < num_dims; ++j) {
55     // We need to cast from TensorFlow's int64 to TF Lite's int32. Let's
56     // make sure there's no overflow.
57     if (src.dim_size(j) >= std::numeric_limits<int>::max()) {
58       TF_LITE_KERNEL_LOG(context,
59                          "Dimension value in TensorFlow shape is larger than "
60                          "supported by TF Lite");
61       TfLiteIntArrayFree(shape);
62       return kTfLiteError;
63     }
64     shape->data[j] = static_cast<int>(src.dim_size(j));
65   }
66   return context->ResizeTensor(context, tensor, shape);
67 }
68 
GetTensorFlowDataType(TfLiteType type)69 TF_DataType GetTensorFlowDataType(TfLiteType type) {
70   switch (type) {
71     case kTfLiteNoType:
72       return TF_FLOAT;
73     case kTfLiteFloat32:
74       return TF_FLOAT;
75     case kTfLiteFloat16:
76       return TF_HALF;
77     case kTfLiteFloat64:
78       return TF_DOUBLE;
79     case kTfLiteInt16:
80       return TF_INT16;
81     case kTfLiteUInt16:
82       return TF_UINT16;
83     case kTfLiteInt32:
84       return TF_INT32;
85     case kTfLiteUInt32:
86       return TF_UINT32;
87     case kTfLiteUInt8:
88       return TF_UINT8;
89     case kTfLiteInt8:
90       return TF_INT8;
91     case kTfLiteInt64:
92       return TF_INT64;
93     case kTfLiteUInt64:
94       return TF_UINT64;
95     case kTfLiteComplex64:
96       return TF_COMPLEX64;
97     case kTfLiteComplex128:
98       return TF_COMPLEX128;
99     case kTfLiteString:
100       return TF_STRING;
101     case kTfLiteBool:
102       return TF_BOOL;
103     case kTfLiteResource:
104       return TF_RESOURCE;
105     case kTfLiteVariant:
106       return TF_VARIANT;
107   }
108 }
109 
GetTensorFlowLiteType(TF_DataType type)110 TfLiteType GetTensorFlowLiteType(TF_DataType type) {
111   switch (type) {
112     case TF_FLOAT:
113       return kTfLiteFloat32;
114     case TF_HALF:
115       return kTfLiteFloat16;
116     case TF_DOUBLE:
117       return kTfLiteFloat64;
118     case TF_INT16:
119       return kTfLiteInt16;
120     case TF_INT32:
121       return kTfLiteInt32;
122     case TF_UINT8:
123       return kTfLiteUInt8;
124     case TF_INT8:
125       return kTfLiteInt8;
126     case TF_INT64:
127       return kTfLiteInt64;
128     case TF_UINT64:
129       return kTfLiteUInt64;
130     case TF_COMPLEX64:
131       return kTfLiteComplex64;
132     case TF_COMPLEX128:
133       return kTfLiteComplex128;
134     case TF_STRING:
135       return kTfLiteString;
136     case TF_BOOL:
137       return kTfLiteBool;
138     case TF_RESOURCE:
139       return kTfLiteResource;
140     case TF_VARIANT:
141       return kTfLiteVariant;
142     default:
143       return kTfLiteNoType;
144   }
145 }
146 
147 // Returns the TF data type name to be stored in the FunctionDef.
TfLiteTypeToTfTypeName(TfLiteType type)148 const char* TfLiteTypeToTfTypeName(TfLiteType type) {
149   switch (type) {
150     case kTfLiteNoType:
151       return "invalid";
152     case kTfLiteFloat32:
153       return "float";
154     case kTfLiteInt16:
155       return "int16";
156     case kTfLiteUInt16:
157       return "uint16";
158     case kTfLiteInt32:
159       return "int32";
160     case kTfLiteUInt32:
161       return "uint32";
162     case kTfLiteUInt8:
163       return "uint8";
164     case kTfLiteInt8:
165       return "int8";
166     case kTfLiteInt64:
167       return "int64";
168     case kTfLiteUInt64:
169       return "uint64";
170     case kTfLiteBool:
171       return "bool";
172     case kTfLiteComplex64:
173       return "complex64";
174     case kTfLiteComplex128:
175       return "complex128";
176     case kTfLiteString:
177       return "string";
178     case kTfLiteFloat16:
179       return "float16";
180     case kTfLiteFloat64:
181       return "float64";
182     case kTfLiteResource:
183       return "resource";
184     case kTfLiteVariant:
185       return "variant";
186   }
187   return "invalid";
188 }
189 
TfLiteResourceIdentifier(const TfLiteTensor * tensor)190 std::string TfLiteResourceIdentifier(const TfLiteTensor* tensor) {
191   // TODO(b/199782192): Create a util function to get Resource ID from a TF Lite
192   // resource tensor.
193   const int resource_id = tensor->data.i32[0];
194   return absl::StrFormat("%s:%d", kResourceVariablePrefix, resource_id);
195 }
196 
GetTfLiteResourceTensorFromResourceHandle(const tensorflow::ResourceHandle & resource_handle,TfLiteTensor * tensor)197 bool GetTfLiteResourceTensorFromResourceHandle(
198     const tensorflow::ResourceHandle& resource_handle, TfLiteTensor* tensor) {
199   std::vector<std::string> parts = absl::StrSplit(resource_handle.name(), ':');
200   if (parts.size() != 2) {
201     return false;
202   }
203   const int kBytesRequired = sizeof(int32_t);
204   TfLiteTensorRealloc(kBytesRequired, tensor);
205   int resource_id;
206   if (parts[0] == kResourceVariablePrefix &&
207       absl::SimpleAtoi<int32_t>(parts[1], &resource_id)) {
208     // TODO(b/199782192): Create a util function to set the Resource ID of
209     // a TF Lite resource tensor.
210     GetTensorData<int32_t>(tensor)[0] = resource_id;
211     return true;
212   }
213   return false;
214 }
215 
CreateTfTensorFromTfLiteTensor(const TfLiteTensor * tflite_tensor)216 tensorflow::StatusOr<tensorflow::Tensor> CreateTfTensorFromTfLiteTensor(
217     const TfLiteTensor* tflite_tensor) {
218   if (IsResourceOrVariant(tflite_tensor)) {
219     // Returns error if the input tflite tensor has variant or resource type.
220     return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
221                               "Input tensor has resource or variant type.");
222   }
223 
224   tensorflow::TensorShape shape;
225   int num_dims = tflite_tensor->dims->size;
226   for (int i = 0; i < num_dims; ++i) {
227     shape.AddDim(tflite_tensor->dims->data[i]);
228   }
229 
230   tensorflow::Tensor tf_tensor(
231       tensorflow::DataType(GetTensorFlowDataType(tflite_tensor->type)), shape);
232   if (tf_tensor.dtype() == tensorflow::DataType::DT_STRING &&
233       tf_tensor.data()) {
234     tensorflow::tstring* buf =
235         static_cast<tensorflow::tstring*>(tf_tensor.data());
236     for (int i = 0; i < tflite::GetStringCount(tflite_tensor); ++buf, ++i) {
237       auto ref = GetString(tflite_tensor, i);
238       buf->assign(ref.str, ref.len);
239     }
240   } else {
241     if (tf_tensor.tensor_data().size() != tflite_tensor->bytes) {
242       return tensorflow::Status(
243           tensorflow::error::INTERNAL,
244           "TfLiteTensor's size doesn't match the TF tensor's size.");
245     }
246     if (!tflite_tensor->data.raw) {
247       return tensorflow::Status(tensorflow::error::INTERNAL,
248                                 "TfLiteTensor's data field is null.");
249     }
250     std::memcpy(tf_tensor.data(), tflite_tensor->data.raw,
251                 tflite_tensor->bytes);
252   }
253 
254   return tf_tensor;
255 }
256 
257 }  // namespace flex
258 }  // namespace tflite
259