xref: /aosp_15_r20/external/tflite-support/tensorflow_lite_support/cc/task/vision/utils/image_tensor_specs.h (revision b16991f985baa50654c05c5adbb3c8bbcfb40082)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_IMAGE_TENSOR_SPECS_H_
16 #define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_IMAGE_TENSOR_SPECS_H_
17 
18 #include <array>
19 
20 #include "absl/types/optional.h"
21 #include "tensorflow/lite/c/common.h"
22 #include "tensorflow_lite_support/cc/port/statusor.h"
23 #include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
24 #include "tensorflow_lite_support/metadata/cc/metadata_extractor.h"
25 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
26 
27 namespace tflite {
28 namespace task {
29 namespace vision {
30 
31 // Parameters used for input image normalization when input tensor has
32 // kTfLiteFloat32 type.
33 //
34 // Exactly 1 or 3 values are expected for `mean_values` and `std_values`. In
35 // case 1 value only is specified, it is used for all channels. E.g. for a RGB
36 // image, the normalization is done as follow:
37 //
38 //   (R - mean_values[0]) / std_values[0]
39 //   (G - mean_values[1]) / std_values[1]
40 //   (B - mean_values[2]) / std_values[2]
41 //
42 // `num_values` keeps track of how many values have been provided, which should
43 // be 1 or 3 (see above). In particular, single-channel grayscale images expect
44 // only 1 value.
45 struct NormalizationOptions {
46   std::array<float, 3> mean_values;
47   std::array<float, 3> std_values;
48   int num_values;
49 };
50 
51 // Parameters related to the expected tensor specifications when the tensor
52 // represents an image.
53 //
54 // E.g. input tensor specifications expected by the model at Invoke() time. In
55 // such a case, and before running inference with the TF Lite interpreter, the
56 // caller must use these values and perform image preprocessing and/or
57 // normalization so as to fill the actual input tensor appropriately.
58 struct ImageTensorSpecs {
59   // Expected image dimensions, e.g. image_width=224, image_height=224.
60   int image_width;
61   int image_height;
62   // Expected color space, e.g. color_space=RGB.
63   tflite::ColorSpaceType color_space;
64   // Expected input tensor type, e.g. if tensor_type=kTfLiteFloat32 the caller
65   // should usually perform some normalization to convert the uint8 pixels into
66   // floats (see NormalizationOptions in TF Lite Metadata for more details).
67   TfLiteType tensor_type;
68   // Optional normalization parameters read from TF Lite Metadata. Those are
69   // mandatory when tensor_type=kTfLiteFloat32 in order to convert the input
70   // image data into the expected range of floating point values, an error is
71   // returned otherwise (see sanity checks below). They should be ignored for
72   // other tensor input types, e.g. kTfLiteUInt8.
73   absl::optional<NormalizationOptions> normalization_options;
74 };
75 
76 // Performs sanity checks on the expected input tensor including consistency
77 // checks against model metadata, if any. For now, a single RGB input with BHWD
78 // layout, where B = 1 and D = 3, is expected. Returns the corresponding input
79 // specifications if they pass, or an error otherwise (too many input tensors,
80 // etc).
81 // Note: both interpreter and metadata extractor *must* be successfully
82 // initialized before calling this function by means of (respectively):
83 // - `tflite::InterpreterBuilder`,
84 // - `tflite::metadata::ModelMetadataExtractor::CreateFromModelBuffer`.
85 tflite::support::StatusOr<ImageTensorSpecs> BuildInputImageTensorSpecs(
86     const tflite::task::core::TfLiteEngine::Interpreter& interpreter,
87     const tflite::metadata::ModelMetadataExtractor& metadata_extractor);
88 
89 }  // namespace vision
90 }  // namespace task
91 }  // namespace tflite
92 
93 #endif  // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_UTILS_IMAGE_TENSOR_SPECS_H_
94