1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Example usage:
17 // bazel run -c opt \
18 // tensorflow_lite_support/examples/task/vision/desktop:image_classifier_demo \
19 // -- \
20 // --model_path=/path/to/model.tflite \
21 // --image_path=/path/to/image.jpg
22
23 #include <iostream>
24
25 #include "absl/flags/flag.h"
26 #include "absl/flags/parse.h"
27 #include "absl/status/status.h"
28 #include "absl/strings/str_format.h"
29 #include "tensorflow_lite_support/cc/port/statusor.h"
30 #include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
31 #include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h"
32 #include "tensorflow_lite_support/cc/task/vision/image_classifier.h"
33 #include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h"
34 #include "tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h"
35 #include "tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h"
36 #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
37 #include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h"
38
39 ABSL_FLAG(std::string, model_path, "",
40 "Absolute path to the '.tflite' image classifier model.");
41 ABSL_FLAG(std::string, image_path, "",
42 "Absolute path to the image to classify. The image must be RGB or "
43 "RGBA (grayscale is not supported). The image EXIF orientation "
44 "flag, if any, is NOT taken into account.");
45 ABSL_FLAG(int32, max_results, 5,
46 "Maximum number of classification results to display.");
47 ABSL_FLAG(float, score_threshold, 0,
48 "Classification results with a confidence score below this value are "
49 "rejected. If >= 0, overrides the score threshold(s) provided in the "
50 "TFLite Model Metadata. Ignored otherwise.");
51 ABSL_FLAG(
52 std::vector<std::string>, class_name_whitelist, {},
53 "Comma-separated list of class names that acts as a whitelist. If "
54 "non-empty, classification results whose 'class_name' is not in this list "
55 "are filtered out. Mutually exclusive with 'class_name_blacklist'.");
56 ABSL_FLAG(
57 std::vector<std::string>, class_name_blacklist, {},
58 "Comma-separated list of class names that acts as a blacklist. If "
59 "non-empty, classification results whose 'class_name' is in this list "
60 "are filtered out. Mutually exclusive with 'class_name_whitelist'.");
61
62 namespace tflite {
63 namespace task {
64 namespace vision {
65
BuildOptions()66 ImageClassifierOptions BuildOptions() {
67 ImageClassifierOptions options;
68 options.mutable_model_file_with_metadata()->set_file_name(
69 absl::GetFlag(FLAGS_model_path));
70 options.set_max_results(absl::GetFlag(FLAGS_max_results));
71 if (absl::GetFlag(FLAGS_score_threshold) >= 0) {
72 options.set_score_threshold(absl::GetFlag(FLAGS_score_threshold));
73 }
74 for (const std::string& class_name :
75 absl::GetFlag(FLAGS_class_name_whitelist)) {
76 options.add_class_name_whitelist(class_name);
77 }
78 for (const std::string& class_name :
79 absl::GetFlag(FLAGS_class_name_blacklist)) {
80 options.add_class_name_blacklist(class_name);
81 }
82 return options;
83 }
84
DisplayResult(const ClassificationResult & result)85 void DisplayResult(const ClassificationResult& result) {
86 std::cout << "Results:\n";
87 for (int head = 0; head < result.classifications_size(); ++head) {
88 if (result.classifications_size() > 1) {
89 std::cout << absl::StrFormat(" Head index %d:\n", head);
90 }
91 const Classifications& classifications = result.classifications(head);
92 for (int rank = 0; rank < classifications.classes_size(); ++rank) {
93 const Class& classification = classifications.classes(rank);
94 std::cout << absl::StrFormat(" Rank #%d:\n", rank);
95 std::cout << absl::StrFormat(" index : %d\n",
96 classification.index());
97 std::cout << absl::StrFormat(" score : %.5f\n",
98 classification.score());
99 if (classification.has_class_name()) {
100 std::cout << absl::StrFormat(" class name : %s\n",
101 classification.class_name());
102 }
103 if (classification.has_display_name()) {
104 std::cout << absl::StrFormat(" display name: %s\n",
105 classification.display_name());
106 }
107 }
108 }
109 }
110
Classify()111 absl::Status Classify() {
112 // Build ImageClassifier.
113 const ImageClassifierOptions& options = BuildOptions();
114 ASSIGN_OR_RETURN(std::unique_ptr<ImageClassifier> image_classifier,
115 ImageClassifier::CreateFromOptions(options));
116
117 // Load image in a FrameBuffer.
118 ASSIGN_OR_RETURN(ImageData image,
119 DecodeImageFromFile(absl::GetFlag(FLAGS_image_path)));
120 std::unique_ptr<FrameBuffer> frame_buffer;
121 if (image.channels == 3) {
122 frame_buffer =
123 CreateFromRgbRawBuffer(image.pixel_data, {image.width, image.height});
124 } else if (image.channels == 4) {
125 frame_buffer =
126 CreateFromRgbaRawBuffer(image.pixel_data, {image.width, image.height});
127 } else {
128 return absl::InvalidArgumentError(absl::StrFormat(
129 "Expected image with 3 (RGB) or 4 (RGBA) channels, found %d",
130 image.channels));
131 }
132
133 // Run classification and display results.
134 ASSIGN_OR_RETURN(ClassificationResult result,
135 image_classifier->Classify(*frame_buffer));
136 DisplayResult(result);
137
138 // Cleanup and return.
139 ImageDataFree(&image);
140 return absl::OkStatus();
141 }
142
143 } // namespace vision
144 } // namespace task
145 } // namespace tflite
146
main(int argc,char ** argv)147 int main(int argc, char** argv) {
148 // Parse command line arguments and perform sanity checks.
149 absl::ParseCommandLine(argc, argv);
150 if (absl::GetFlag(FLAGS_model_path).empty()) {
151 std::cerr << "Missing mandatory 'model_path' argument.\n";
152 return 1;
153 }
154 if (absl::GetFlag(FLAGS_image_path).empty()) {
155 std::cerr << "Missing mandatory 'image_path' argument.\n";
156 return 1;
157 }
158 if (!absl::GetFlag(FLAGS_class_name_whitelist).empty() &&
159 !absl::GetFlag(FLAGS_class_name_blacklist).empty()) {
160 std::cerr << "'class_name_whitelist' and 'class_name_blacklist' arguments "
161 "are mutually exclusive.\n";
162 return 1;
163 }
164
165 // Run classification.
166 absl::Status status = tflite::task::vision::Classify();
167 if (status.ok()) {
168 return 0;
169 } else {
170 std::cerr << "Classification failed: " << status.message() << "\n";
171 return 1;
172 }
173 }
174