1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_IMAGE_TENSOR_SPECS_H_ 16 #define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_IMAGE_TENSOR_SPECS_H_ 17 18 #include <array> 19 20 #include "absl/types/optional.h" 21 #include "tensorflow/lite/c/common.h" 22 #include "tensorflow_lite_support/cc/port/statusor.h" 23 #include "tensorflow_lite_support/cc/task/core/tflite_engine.h" 24 #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h" 25 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h" 26 27 namespace tflite { 28 namespace task { 29 namespace vision { 30 31 // Parameters used for input image normalization when input tensor has 32 // kTfLiteFloat32 type. 33 // 34 // Exactly 1 or 3 values are expected for `mean_values` and `std_values`. In 35 // case 1 value only is specified, it is used for all channels. E.g. for a RGB 36 // image, the normalization is done as follow: 37 // 38 // (R - mean_values[0]) / std_values[0] 39 // (G - mean_values[1]) / std_values[1] 40 // (B - mean_values[2]) / std_values[2] 41 // 42 // `num_values` keeps track of how many values have been provided, which should 43 // be 1 or 3 (see above). In particular, single-channel grayscale images expect 44 // only 1 value. 45 struct NormalizationOptions { 46 std::array<float, 3> mean_values; 47 std::array<float, 3> std_values; 48 int num_values; 49 }; 50 51 // Parameters related to the expected tensor specifications when the tensor 52 // represents an image. 53 // 54 // E.g. input tensor specifications expected by the model at Invoke() time. In 55 // such a case, and before running inference with the TF Lite interpreter, the 56 // caller must use these values and perform image preprocessing and/or 57 // normalization so as to fill the actual input tensor appropriately. 58 struct ImageTensorSpecs { 59 // Expected image dimensions, e.g. image_width=224, image_height=224. 60 int image_width; 61 int image_height; 62 // Expected color space, e.g. color_space=RGB. 63 tflite::ColorSpaceType color_space; 64 // Expected input tensor type, e.g. if tensor_type=kTfLiteFloat32 the caller 65 // should usually perform some normalization to convert the uint8 pixels into 66 // floats (see NormalizationOptions in TF Lite Metadata for more details). 67 TfLiteType tensor_type; 68 // Optional normalization parameters read from TF Lite Metadata. Those are 69 // mandatory when tensor_type=kTfLiteFloat32 in order to convert the input 70 // image data into the expected range of floating point values, an error is 71 // returned otherwise (see sanity checks below). They should be ignored for 72 // other tensor input types, e.g. kTfLiteUInt8. 73 absl::optional<NormalizationOptions> normalization_options; 74 }; 75 76 // Performs sanity checks on the expected input tensor including consistency 77 // checks against model metadata, if any. For now, a single RGB input with BHWD 78 // layout, where B = 1 and D = 3, is expected. Returns the corresponding input 79 // specifications if they pass, or an error otherwise (too many input tensors, 80 // etc). 81 // Note: both interpreter and metadata extractor *must* be successfully 82 // initialized before calling this function by means of (respectively): 83 // - `tflite::InterpreterBuilder`, 84 // - `tflite::metadata::ModelMetadataExtractor::CreateFromModelBuffer`. 85 tflite::support::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs( 86 const tflite::task::core::TfLiteEngine::Interpreter& interpreter, 87 const tflite::metadata::ModelMetadataExtractor& metadata_extractor); 88 89 } // namespace vision 90 } // namespace task 91 } // namespace tflite 92 93 #endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_IMAGE_TENSOR_SPECS_H_ 94