1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Example usage:
17 // bazel run -c opt \
18 // tensorflow_lite_support/examples/task/vision/desktop:image_segmenter_demo \
19 // -- \
20 // --model_path=/path/to/model.tflite \
21 // --image_path=/path/to/image.jpg \
22 // --output_mask_png=/path/to/output/mask.png
23
24 #include <iostream>
25
26 #include "absl/flags/flag.h"
27 #include "absl/flags/parse.h"
28 #include "absl/status/status.h"
29 #include "absl/strings/match.h"
30 #include "absl/strings/str_format.h"
31 #include "tensorflow_lite_support/cc/port/statusor.h"
32 #include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
33 #include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h"
34 #include "tensorflow_lite_support/cc/task/vision/image_segmenter.h"
35 #include "tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h"
36 #include "tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h"
37 #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
38 #include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h"
39
40 ABSL_FLAG(std::string, model_path, "",
41 "Absolute path to the '.tflite' image segmenter model.");
42 ABSL_FLAG(std::string, image_path, "",
43 "Absolute path to the image to segment. The image must be RGB or "
44 "RGBA (grayscale is not supported). The image EXIF orientation "
45 "flag, if any, is NOT taken into account.");
46 ABSL_FLAG(std::string, output_mask_png, "",
47 "Absolute path to the output category mask (confidence masks outputs "
48 "are not supported by this tool). Must have a '.png' extension.");
49
50 namespace tflite {
51 namespace task {
52 namespace vision {
53
BuildOptions()54 ImageSegmenterOptions BuildOptions() {
55 ImageSegmenterOptions options;
56 options.mutable_model_file_with_metadata()->set_file_name(
57 absl::GetFlag(FLAGS_model_path));
58 // Confidence masks are not supported by this tool: output_type is set to
59 // CATEGORY_MASK by default.
60 return options;
61 }
62
EncodeMaskToPngFile(const SegmentationResult & result)63 absl::Status EncodeMaskToPngFile(const SegmentationResult& result) {
64 if (result.segmentation_size() != 1) {
65 return absl::UnimplementedError(
66 "Image segmentation models with multiple output segmentations are not "
67 "supported by this tool.");
68 }
69 const Segmentation& segmentation = result.segmentation(0);
70 // Extract raw mask data as a uint8 pointer.
71 const uint8* raw_mask =
72 reinterpret_cast<const uint8*>(segmentation.category_mask().data());
73
74 // Create RgbImageData for the output mask.
75 uint8* pixel_data = static_cast<uint8*>(
76 malloc(segmentation.width() * segmentation.height() * 3 * sizeof(uint8)));
77 ImageData mask = {.pixel_data = pixel_data,
78 .width = segmentation.width(),
79 .height = segmentation.height(),
80 .channels = 3};
81
82 // Populate RgbImageData from the raw mask and ColoredLabel-s.
83 for (int i = 0; i < segmentation.width() * segmentation.height(); ++i) {
84 Segmentation::ColoredLabel colored_label =
85 segmentation.colored_labels(raw_mask[i]);
86 pixel_data[3 * i] = colored_label.r();
87 pixel_data[3 * i + 1] = colored_label.g();
88 pixel_data[3 * i + 2] = colored_label.b();
89 }
90
91 // Encode mask as PNG.
92 RETURN_IF_ERROR(
93 EncodeImageToPngFile(mask, absl::GetFlag(FLAGS_output_mask_png)));
94 std::cout << absl::StrFormat("Category mask saved to: %s\n",
95 absl::GetFlag(FLAGS_output_mask_png));
96
97 // Cleanup and return.
98 ImageDataFree(&mask);
99 return absl::OkStatus();
100 }
101
DisplayColorLegend(const SegmentationResult & result)102 absl::Status DisplayColorLegend(const SegmentationResult& result) {
103 if (result.segmentation_size() != 1) {
104 return absl::UnimplementedError(
105 "Image segmentation models with multiple output segmentations are not "
106 "supported by this tool.");
107 }
108 const Segmentation& segmentation = result.segmentation(0);
109 const int num_labels = segmentation.colored_labels_size();
110
111 std::cout << "Color Legend:\n";
112 for (int index = 0; index < num_labels; ++index) {
113 Segmentation::ColoredLabel colored_label =
114 segmentation.colored_labels(index);
115 std::cout << absl::StrFormat(" (r: %03d, g: %03d, b: %03d):\n",
116 colored_label.r(), colored_label.g(),
117 colored_label.b());
118 std::cout << absl::StrFormat(" index : %d\n", index);
119 if (colored_label.has_class_name()) {
120 std::cout << absl::StrFormat(" class name : %s\n",
121 colored_label.class_name());
122 }
123 if (colored_label.has_display_name()) {
124 std::cout << absl::StrFormat(" display name: %s\n",
125 colored_label.display_name());
126 }
127 }
128 std::cout << "Tip: use a color picker on the output PNG file to inspect the "
129 "output mask with this legend.\n";
130
131 return absl::OkStatus();
132 }
133
Segment()134 absl::Status Segment() {
135 // Build ImageClassifier.
136 const ImageSegmenterOptions& options = BuildOptions();
137 ASSIGN_OR_RETURN(std::unique_ptr<ImageSegmenter> image_segmenter,
138 ImageSegmenter::CreateFromOptions(options));
139
140 // Load image in a FrameBuffer.
141 ASSIGN_OR_RETURN(ImageData image,
142 DecodeImageFromFile(absl::GetFlag(FLAGS_image_path)));
143 std::unique_ptr<FrameBuffer> frame_buffer;
144 if (image.channels == 3) {
145 frame_buffer =
146 CreateFromRgbRawBuffer(image.pixel_data, {image.width, image.height});
147 } else if (image.channels == 4) {
148 frame_buffer =
149 CreateFromRgbaRawBuffer(image.pixel_data, {image.width, image.height});
150 } else {
151 return absl::InvalidArgumentError(absl::StrFormat(
152 "Expected image with 3 (RGB) or 4 (RGBA) channels, found %d",
153 image.channels));
154 }
155
156 // Run segmentation and save category mask.
157 ASSIGN_OR_RETURN(SegmentationResult result,
158 image_segmenter->Segment(*frame_buffer));
159 RETURN_IF_ERROR(EncodeMaskToPngFile(result));
160
161 // Display the legend.
162 RETURN_IF_ERROR(DisplayColorLegend(result));
163
164 // Cleanup and return.
165 ImageDataFree(&image);
166 return absl::OkStatus();
167 }
168
169 } // namespace vision
170 } // namespace task
171 } // namespace tflite
172
main(int argc,char ** argv)173 int main(int argc, char** argv) {
174 // Parse command line arguments and perform sanity checks.
175 absl::ParseCommandLine(argc, argv);
176 if (absl::GetFlag(FLAGS_model_path).empty()) {
177 std::cerr << "Missing mandatory 'model_path' argument.\n";
178 return 1;
179 }
180 if (absl::GetFlag(FLAGS_image_path).empty()) {
181 std::cerr << "Missing mandatory 'image_path' argument.\n";
182 return 1;
183 }
184 if (absl::GetFlag(FLAGS_output_mask_png).empty()) {
185 std::cerr << "Missing mandatory 'output_mask_png' argument.\n";
186 return 1;
187 }
188 if (!absl::EndsWithIgnoreCase(absl::GetFlag(FLAGS_output_mask_png), ".png")) {
189 std::cerr << "Argument 'output_mask_png' must end with '.png' or '.PNG'\n";
190 return 1;
191 }
192
193 // Run segmentation.
194 absl::Status status = tflite::task::vision::Segment();
195 if (status.ok()) {
196 return 0;
197 } else {
198 std::cerr << "Segmentation failed: " << status.message() << "\n";
199 return 1;
200 }
201 }
202