1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Example usage:
17 // bazel run -c opt \
18 //  tensorflow_lite_support/examples/task/vision/desktop:image_classifier_demo \
19 //  -- \
20 //  --model_path=/path/to/model.tflite \
21 //  --image_path=/path/to/image.jpg
22 
23 #include <iostream>
24 
25 #include "absl/flags/flag.h"
26 #include "absl/flags/parse.h"
27 #include "absl/status/status.h"
28 #include "absl/strings/str_format.h"
29 #include "tensorflow_lite_support/cc/port/statusor.h"
30 #include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
31 #include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h"
32 #include "tensorflow_lite_support/cc/task/vision/image_classifier.h"
33 #include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h"
34 #include "tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h"
35 #include "tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h"
36 #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
37 #include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h"
38 
39 ABSL_FLAG(std::string, model_path, "",
40           "Absolute path to the '.tflite' image classifier model.");
41 ABSL_FLAG(std::string, image_path, "",
42           "Absolute path to the image to classify. The image must be RGB or "
43           "RGBA (grayscale is not supported). The image EXIF orientation "
44           "flag, if any, is NOT taken into account.");
45 ABSL_FLAG(int32, max_results, 5,
46           "Maximum number of classification results to display.");
47 ABSL_FLAG(float, score_threshold, 0,
48           "Classification results with a confidence score below this value are "
49           "rejected. If >= 0, overrides the score threshold(s) provided in the "
50           "TFLite Model Metadata. Ignored otherwise.");
51 ABSL_FLAG(
52     std::vector<std::string>, class_name_whitelist, {},
53     "Comma-separated list of class names that acts as a whitelist. If "
54     "non-empty, classification results whose 'class_name' is not in this list "
55     "are filtered out. Mutually exclusive with 'class_name_blacklist'.");
56 ABSL_FLAG(
57     std::vector<std::string>, class_name_blacklist, {},
58     "Comma-separated list of class names that acts as a blacklist. If "
59     "non-empty, classification results whose 'class_name' is in this list "
60     "are filtered out. Mutually exclusive with 'class_name_whitelist'.");
61 
62 namespace tflite {
63 namespace task {
64 namespace vision {
65 
BuildOptions()66 ImageClassifierOptions BuildOptions() {
67   ImageClassifierOptions options;
68   options.mutable_model_file_with_metadata()->set_file_name(
69       absl::GetFlag(FLAGS_model_path));
70   options.set_max_results(absl::GetFlag(FLAGS_max_results));
71   if (absl::GetFlag(FLAGS_score_threshold) >= 0) {
72     options.set_score_threshold(absl::GetFlag(FLAGS_score_threshold));
73   }
74   for (const std::string& class_name :
75        absl::GetFlag(FLAGS_class_name_whitelist)) {
76     options.add_class_name_whitelist(class_name);
77   }
78   for (const std::string& class_name :
79        absl::GetFlag(FLAGS_class_name_blacklist)) {
80     options.add_class_name_blacklist(class_name);
81   }
82   return options;
83 }
84 
DisplayResult(const ClassificationResult & result)85 void DisplayResult(const ClassificationResult& result) {
86   std::cout << "Results:\n";
87   for (int head = 0; head < result.classifications_size(); ++head) {
88     if (result.classifications_size() > 1) {
89       std::cout << absl::StrFormat(" Head index %d:\n", head);
90     }
91     const Classifications& classifications = result.classifications(head);
92     for (int rank = 0; rank < classifications.classes_size(); ++rank) {
93       const Class& classification = classifications.classes(rank);
94       std::cout << absl::StrFormat("  Rank #%d:\n", rank);
95       std::cout << absl::StrFormat("   index       : %d\n",
96                                    classification.index());
97       std::cout << absl::StrFormat("   score       : %.5f\n",
98                                    classification.score());
99       if (classification.has_class_name()) {
100         std::cout << absl::StrFormat("   class name  : %s\n",
101                                      classification.class_name());
102       }
103       if (classification.has_display_name()) {
104         std::cout << absl::StrFormat("   display name: %s\n",
105                                      classification.display_name());
106       }
107     }
108   }
109 }
110 
Classify()111 absl::Status Classify() {
112   // Build ImageClassifier.
113   const ImageClassifierOptions& options = BuildOptions();
114   ASSIGN_OR_RETURN(std::unique_ptr<ImageClassifier> image_classifier,
115                    ImageClassifier::CreateFromOptions(options));
116 
117   // Load image in a FrameBuffer.
118   ASSIGN_OR_RETURN(ImageData image,
119                    DecodeImageFromFile(absl::GetFlag(FLAGS_image_path)));
120   std::unique_ptr<FrameBuffer> frame_buffer;
121   if (image.channels == 3) {
122     frame_buffer =
123         CreateFromRgbRawBuffer(image.pixel_data, {image.width, image.height});
124   } else if (image.channels == 4) {
125     frame_buffer =
126         CreateFromRgbaRawBuffer(image.pixel_data, {image.width, image.height});
127   } else {
128     return absl::InvalidArgumentError(absl::StrFormat(
129         "Expected image with 3 (RGB) or 4 (RGBA) channels, found %d",
130         image.channels));
131   }
132 
133   // Run classification and display results.
134   ASSIGN_OR_RETURN(ClassificationResult result,
135                    image_classifier->Classify(*frame_buffer));
136   DisplayResult(result);
137 
138   // Cleanup and return.
139   ImageDataFree(&image);
140   return absl::OkStatus();
141 }
142 
143 }  // namespace vision
144 }  // namespace task
145 }  // namespace tflite
146 
main(int argc,char ** argv)147 int main(int argc, char** argv) {
148   // Parse command line arguments and perform sanity checks.
149   absl::ParseCommandLine(argc, argv);
150   if (absl::GetFlag(FLAGS_model_path).empty()) {
151     std::cerr << "Missing mandatory 'model_path' argument.\n";
152     return 1;
153   }
154   if (absl::GetFlag(FLAGS_image_path).empty()) {
155     std::cerr << "Missing mandatory 'image_path' argument.\n";
156     return 1;
157   }
158   if (!absl::GetFlag(FLAGS_class_name_whitelist).empty() &&
159       !absl::GetFlag(FLAGS_class_name_blacklist).empty()) {
160     std::cerr << "'class_name_whitelist' and 'class_name_blacklist' arguments "
161                  "are mutually exclusive.\n";
162     return 1;
163   }
164 
165   // Run classification.
166   absl::Status status = tflite::task::vision::Classify();
167   if (status.ok()) {
168     return 0;
169   } else {
170     std::cerr << "Classification failed: " << status.message() << "\n";
171     return 1;
172   }
173 }
174