xref: /aosp_15_r20/external/tflite-support/tensorflow_lite_support/cc/task/vision/image_classifier.h (revision b16991f985baa50654c05c5adbb3c8bbcfb40082)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_CLASSIFIER_H_
17 #define TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_CLASSIFIER_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/status/status.h"
24 #include "tensorflow/lite/c/common.h"
25 #include "tensorflow/lite/core/api/op_resolver.h"
26 #include "tensorflow_lite_support/cc/port/integral_types.h"
27 #include "tensorflow_lite_support/cc/port/statusor.h"
28 #include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
29 #include "tensorflow_lite_support/cc/task/vision/core/base_vision_task_api.h"
30 #include "tensorflow_lite_support/cc/task/vision/core/classification_head.h"
31 #include "tensorflow_lite_support/cc/task/vision/core/frame_buffer.h"
32 #include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h"
33 #include "tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h"
34 #include "tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h"
35 #include "tensorflow_lite_support/cc/task/vision/utils/score_calibration.h"
36 
37 namespace tflite {
38 namespace task {
39 namespace vision {
40 
41 // Performs classification on images.
42 //
43 // The API expects a TFLite model with optional, but strongly recommended,
44 // TFLite Model Metadata.
45 //
46 // Input tensor:
47 //   (kTfLiteUInt8/kTfLiteFloat32)
48 //    - image input of size `[batch x height x width x channels]`.
49 //    - batch inference is not supported (`batch` is required to be 1).
50 //    - only RGB inputs are supported (`channels` is required to be 3).
51 //    - if type is kTfLiteFloat32, NormalizationOptions are required to be
52 //      attached to the metadata for input normalization.
53 // At least one output tensor with:
54 //   (kTfLiteUInt8/kTfLiteFloat32)
55 //    -  `N `classes and either 2 or 4 dimensions, i.e. `[1 x N]` or
56 //       `[1 x 1 x 1 x N]`
57 //    - optional (but recommended) label map(s) as AssociatedFile-s with type
58 //      TENSOR_AXIS_LABELS, containing one label per line. The first such
59 //      AssociatedFile (if any) is used to fill the `class_name` field of the
60 //      results. The `display_name` field is filled from the AssociatedFile (if
61 //      any) whose locale matches the `display_names_locale` field of the
62 //      `ImageClassifierOptions` used at creation time ("en" by default, i.e.
63 //      English). If none of these are available, only the `index` field of the
64 //      results will be filled.
65 //
66 // An example of such model can be found at:
67 // https://tfhub.dev/bohemian-visual-recognition-alliance/lite-model/models/mushroom-identification_v1/1
68 //
69 // A CLI demo tool is available for easily trying out this API, and provides
70 // example usage. See:
71 // examples/task/vision/desktop/image_classifier_demo.cc
72 class ImageClassifier : public BaseVisionTaskApi<ClassificationResult> {
73  public:
74   using BaseVisionTaskApi::BaseVisionTaskApi;
75 
76   // Creates an ImageClassifier from the provided options. A non-default
77   // OpResolver can be specified in order to support custom Ops or specify a
78   // subset of built-in Ops.
79   static tflite::support::StatusOr<std::unique_ptr<ImageClassifier>>
80   CreateFromOptions(
81       const ImageClassifierOptions& options,
82       std::unique_ptr<tflite::OpResolver> resolver =
83           absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
84 
85   // Performs actual classification on the provided FrameBuffer.
86   //
87   // The FrameBuffer can be of any size and any of the supported formats, i.e.
88   // RGBA, RGB, NV12, NV21, YV12, YV21. It is automatically pre-processed before
89   // inference in order to (and in this order):
90   // - resize it (with bilinear interpolation, aspect-ratio *not* preserved) to
91   //   the dimensions of the model input tensor,
92   // - convert it to the colorspace of the input tensor (i.e. RGB, which is the
93   //   only supported colorspace for now),
94   // - rotate it according to its `Orientation` so that inference is performed
95   //   on an "upright" image.
96   tflite::support::StatusOr<ClassificationResult> Classify(
97       const FrameBuffer& frame_buffer);
98 
99   // Same as above, except that the classification is performed based on the
100   // input region of interest. Cropping according to this region of interest is
101   // prepended to the pre-processing operations.
102   //
103   // IMPORTANT: as a consequence of cropping occurring first, the provided
104   // region of interest is expressed in the unrotated frame of reference
105   // coordinates system, i.e. in `[0, frame_buffer.width) x [0,
106   // frame_buffer.height)`, which are the dimensions of the underlying
107   // `frame_buffer` data before any `Orientation` flag gets applied. Also, the
108   // region of interest is not clamped, so this method will return a non-ok
109   // status if the region is out of these bounds.
110   tflite::support::StatusOr<ClassificationResult> Classify(
111       const FrameBuffer& frame_buffer, const BoundingBox& roi);
112 
113  protected:
114   // The options used to build this ImageClassifier.
115   std::unique_ptr<ImageClassifierOptions> options_;
116 
117   // The list of classification heads associated with the corresponding output
118   // tensors. Built from TFLite Model Metadata.
119   std::vector<ClassificationHead> classification_heads_;
120 
121   // Post-processing to transform the raw model outputs into classification
122   // results.
123   tflite::support::StatusOr<ClassificationResult> Postprocess(
124       const std::vector<const TfLiteTensor*>& output_tensors,
125       const FrameBuffer& frame_buffer, const BoundingBox& roi) override;
126 
127   // Performs sanity checks on the provided ImageClassifierOptions.
128   static absl::Status SanityCheckOptions(const ImageClassifierOptions& options);
129 
130   // Initializes the ImageClassifier from the provided ImageClassifierOptions,
131   // whose ownership is transferred to this object.
132   absl::Status Init(std::unique_ptr<ImageClassifierOptions> options);
133 
134   // Performs pre-initialization actions.
135   virtual absl::Status PreInit();
136   // Performs post-initialization actions.
137   virtual absl::Status PostInit();
138 
139  private:
140   // Performs sanity checks on the model outputs and extracts their metadata.
141   absl::Status CheckAndSetOutputs();
142 
143   // Performs sanity checks on the class whitelist/blacklist and forms the class
144   // name set.
145   absl::Status CheckAndSetClassNameSet();
146 
147   // Initializes the score calibration parameters based on corresponding TFLite
148   // Model Metadata, if any.
149   absl::Status InitScoreCalibrations();
150 
151   // Given a ClassificationResult object containing class indices, fills the
152   // name and display name from the label map(s).
153   absl::Status FillResultsFromLabelMaps(ClassificationResult* result);
154 
155   // The number of output tensors. This corresponds to the number of
156   // classification heads.
157   int num_outputs_;
158   // Whether the model features quantized inference type (QUANTIZED_UINT8). This
159   // is currently detected by checking if all output tensors data type is uint8.
160   bool has_uint8_outputs_;
161 
162   // Set of whitelisted or blacklisted class names.
163   struct ClassNameSet {
164     absl::flat_hash_set<std::string> values;
165     bool is_whitelist;
166   };
167 
168   // Whitelisted or blacklisted class names based on provided options at
169   // construction time. These are used to filter out results during
170   // post-processing.
171   ClassNameSet class_name_set_;
172 
173   // List of score calibration parameters, if any. Built from TFLite Model
174   // Metadata.
175   std::vector<std::unique_ptr<ScoreCalibration>> score_calibrations_;
176 };
177 
178 }  // namespace vision
179 }  // namespace task
180 }  // namespace tflite
181 
182 #endif  // TENSORFLOW_LITE_SUPPORT_CC_TASK_VISION_IMAGE_CLASSIFIER_H_
183