1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/flex/util.h"
16
17 #include <string>
18
19 #include "absl/strings/str_format.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/platform/status.h"
22 #include "tensorflow/core/platform/statusor.h"
23 #include "tensorflow/core/protobuf/error_codes.pb.h"
24 #include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
25 #include "tensorflow/lite/string_util.h"
26
27 namespace tflite {
28 namespace flex {
29
30 static constexpr char kResourceVariablePrefix[] = "tflite_resource_variable";
31
ConvertStatus(TfLiteContext * context,const tensorflow::Status & status)32 TfLiteStatus ConvertStatus(TfLiteContext* context,
33 const tensorflow::Status& status) {
34 if (!status.ok()) {
35 TF_LITE_KERNEL_LOG(context, "%s", status.error_message().c_str());
36 return kTfLiteError;
37 }
38 return kTfLiteOk;
39 }
40
CopyShapeAndType(TfLiteContext * context,const tensorflow::Tensor & src,TfLiteTensor * tensor)41 TfLiteStatus CopyShapeAndType(TfLiteContext* context,
42 const tensorflow::Tensor& src,
43 TfLiteTensor* tensor) {
44 tensor->type = GetTensorFlowLiteType(static_cast<TF_DataType>(src.dtype()));
45 if (tensor->type == kTfLiteNoType) {
46 TF_LITE_KERNEL_LOG(context,
47 "TF Lite does not support TensorFlow data type: %s",
48 DataTypeString(src.dtype()).c_str());
49 return kTfLiteError;
50 }
51
52 int num_dims = src.dims();
53 TfLiteIntArray* shape = TfLiteIntArrayCreate(num_dims);
54 for (int j = 0; j < num_dims; ++j) {
55 // We need to cast from TensorFlow's int64 to TF Lite's int32. Let's
56 // make sure there's no overflow.
57 if (src.dim_size(j) >= std::numeric_limits<int>::max()) {
58 TF_LITE_KERNEL_LOG(context,
59 "Dimension value in TensorFlow shape is larger than "
60 "supported by TF Lite");
61 TfLiteIntArrayFree(shape);
62 return kTfLiteError;
63 }
64 shape->data[j] = static_cast<int>(src.dim_size(j));
65 }
66 return context->ResizeTensor(context, tensor, shape);
67 }
68
GetTensorFlowDataType(TfLiteType type)69 TF_DataType GetTensorFlowDataType(TfLiteType type) {
70 switch (type) {
71 case kTfLiteNoType:
72 return TF_FLOAT;
73 case kTfLiteFloat32:
74 return TF_FLOAT;
75 case kTfLiteFloat16:
76 return TF_HALF;
77 case kTfLiteFloat64:
78 return TF_DOUBLE;
79 case kTfLiteInt16:
80 return TF_INT16;
81 case kTfLiteUInt16:
82 return TF_UINT16;
83 case kTfLiteInt32:
84 return TF_INT32;
85 case kTfLiteUInt32:
86 return TF_UINT32;
87 case kTfLiteUInt8:
88 return TF_UINT8;
89 case kTfLiteInt8:
90 return TF_INT8;
91 case kTfLiteInt64:
92 return TF_INT64;
93 case kTfLiteUInt64:
94 return TF_UINT64;
95 case kTfLiteComplex64:
96 return TF_COMPLEX64;
97 case kTfLiteComplex128:
98 return TF_COMPLEX128;
99 case kTfLiteString:
100 return TF_STRING;
101 case kTfLiteBool:
102 return TF_BOOL;
103 case kTfLiteResource:
104 return TF_RESOURCE;
105 case kTfLiteVariant:
106 return TF_VARIANT;
107 }
108 }
109
GetTensorFlowLiteType(TF_DataType type)110 TfLiteType GetTensorFlowLiteType(TF_DataType type) {
111 switch (type) {
112 case TF_FLOAT:
113 return kTfLiteFloat32;
114 case TF_HALF:
115 return kTfLiteFloat16;
116 case TF_DOUBLE:
117 return kTfLiteFloat64;
118 case TF_INT16:
119 return kTfLiteInt16;
120 case TF_INT32:
121 return kTfLiteInt32;
122 case TF_UINT8:
123 return kTfLiteUInt8;
124 case TF_INT8:
125 return kTfLiteInt8;
126 case TF_INT64:
127 return kTfLiteInt64;
128 case TF_UINT64:
129 return kTfLiteUInt64;
130 case TF_COMPLEX64:
131 return kTfLiteComplex64;
132 case TF_COMPLEX128:
133 return kTfLiteComplex128;
134 case TF_STRING:
135 return kTfLiteString;
136 case TF_BOOL:
137 return kTfLiteBool;
138 case TF_RESOURCE:
139 return kTfLiteResource;
140 case TF_VARIANT:
141 return kTfLiteVariant;
142 default:
143 return kTfLiteNoType;
144 }
145 }
146
147 // Returns the TF data type name to be stored in the FunctionDef.
TfLiteTypeToTfTypeName(TfLiteType type)148 const char* TfLiteTypeToTfTypeName(TfLiteType type) {
149 switch (type) {
150 case kTfLiteNoType:
151 return "invalid";
152 case kTfLiteFloat32:
153 return "float";
154 case kTfLiteInt16:
155 return "int16";
156 case kTfLiteUInt16:
157 return "uint16";
158 case kTfLiteInt32:
159 return "int32";
160 case kTfLiteUInt32:
161 return "uint32";
162 case kTfLiteUInt8:
163 return "uint8";
164 case kTfLiteInt8:
165 return "int8";
166 case kTfLiteInt64:
167 return "int64";
168 case kTfLiteUInt64:
169 return "uint64";
170 case kTfLiteBool:
171 return "bool";
172 case kTfLiteComplex64:
173 return "complex64";
174 case kTfLiteComplex128:
175 return "complex128";
176 case kTfLiteString:
177 return "string";
178 case kTfLiteFloat16:
179 return "float16";
180 case kTfLiteFloat64:
181 return "float64";
182 case kTfLiteResource:
183 return "resource";
184 case kTfLiteVariant:
185 return "variant";
186 }
187 return "invalid";
188 }
189
TfLiteResourceIdentifier(const TfLiteTensor * tensor)190 std::string TfLiteResourceIdentifier(const TfLiteTensor* tensor) {
191 // TODO(b/199782192): Create a util function to get Resource ID from a TF Lite
192 // resource tensor.
193 const int resource_id = tensor->data.i32[0];
194 return absl::StrFormat("%s:%d", kResourceVariablePrefix, resource_id);
195 }
196
GetTfLiteResourceTensorFromResourceHandle(const tensorflow::ResourceHandle & resource_handle,TfLiteTensor * tensor)197 bool GetTfLiteResourceTensorFromResourceHandle(
198 const tensorflow::ResourceHandle& resource_handle, TfLiteTensor* tensor) {
199 std::vector<std::string> parts = absl::StrSplit(resource_handle.name(), ':');
200 if (parts.size() != 2) {
201 return false;
202 }
203 const int kBytesRequired = sizeof(int32_t);
204 TfLiteTensorRealloc(kBytesRequired, tensor);
205 int resource_id;
206 if (parts[0] == kResourceVariablePrefix &&
207 absl::SimpleAtoi<int32_t>(parts[1], &resource_id)) {
208 // TODO(b/199782192): Create a util function to set the Resource ID of
209 // a TF Lite resource tensor.
210 GetTensorData<int32_t>(tensor)[0] = resource_id;
211 return true;
212 }
213 return false;
214 }
215
CreateTfTensorFromTfLiteTensor(const TfLiteTensor * tflite_tensor)216 tensorflow::StatusOr<tensorflow::Tensor> CreateTfTensorFromTfLiteTensor(
217 const TfLiteTensor* tflite_tensor) {
218 if (IsResourceOrVariant(tflite_tensor)) {
219 // Returns error if the input tflite tensor has variant or resource type.
220 return tensorflow::Status(tensorflow::error::INVALID_ARGUMENT,
221 "Input tensor has resource or variant type.");
222 }
223
224 tensorflow::TensorShape shape;
225 int num_dims = tflite_tensor->dims->size;
226 for (int i = 0; i < num_dims; ++i) {
227 shape.AddDim(tflite_tensor->dims->data[i]);
228 }
229
230 tensorflow::Tensor tf_tensor(
231 tensorflow::DataType(GetTensorFlowDataType(tflite_tensor->type)), shape);
232 if (tf_tensor.dtype() == tensorflow::DataType::DT_STRING &&
233 tf_tensor.data()) {
234 tensorflow::tstring* buf =
235 static_cast<tensorflow::tstring*>(tf_tensor.data());
236 for (int i = 0; i < tflite::GetStringCount(tflite_tensor); ++buf, ++i) {
237 auto ref = GetString(tflite_tensor, i);
238 buf->assign(ref.str, ref.len);
239 }
240 } else {
241 if (tf_tensor.tensor_data().size() != tflite_tensor->bytes) {
242 return tensorflow::Status(
243 tensorflow::error::INTERNAL,
244 "TfLiteTensor's size doesn't match the TF tensor's size.");
245 }
246 if (!tflite_tensor->data.raw) {
247 return tensorflow::Status(tensorflow::error::INTERNAL,
248 "TfLiteTensor's data field is null.");
249 }
250 std::memcpy(tf_tensor.data(), tflite_tensor->data.raw,
251 tflite_tensor->bytes);
252 }
253
254 return tf_tensor;
255 }
256
257 } // namespace flex
258 } // namespace tflite
259