1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Example usage:
17 // bazel run -c opt \
18 //  tensorflow_lite_support/examples/task/vision/desktop:image_segmenter_demo \
19 //  -- \
20 //  --model_path=/path/to/model.tflite \
21 //  --image_path=/path/to/image.jpg \
22 //  --output_mask_png=/path/to/output/mask.png
23 
24 #include <iostream>
25 
26 #include "absl/flags/flag.h"
27 #include "absl/flags/parse.h"
28 #include "absl/status/status.h"
29 #include "absl/strings/match.h"
30 #include "absl/strings/str_format.h"
31 #include "tensorflow_lite_support/cc/port/statusor.h"
32 #include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
33 #include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h"
34 #include "tensorflow_lite_support/cc/task/vision/image_segmenter.h"
35 #include "tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h"
36 #include "tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h"
37 #include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h"
38 #include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h"
39 
40 ABSL_FLAG(std::string, model_path, "",
41           "Absolute path to the '.tflite' image segmenter model.");
42 ABSL_FLAG(std::string, image_path, "",
43           "Absolute path to the image to segment. The image must be RGB or "
44           "RGBA (grayscale is not supported). The image EXIF orientation "
45           "flag, if any, is NOT taken into account.");
46 ABSL_FLAG(std::string, output_mask_png, "",
47           "Absolute path to the output category mask (confidence masks outputs "
48           "are not supported by this tool). Must have a '.png' extension.");
49 
50 namespace tflite {
51 namespace task {
52 namespace vision {
53 
BuildOptions()54 ImageSegmenterOptions BuildOptions() {
55   ImageSegmenterOptions options;
56   options.mutable_model_file_with_metadata()->set_file_name(
57       absl::GetFlag(FLAGS_model_path));
58   // Confidence masks are not supported by this tool: output_type is set to
59   // CATEGORY_MASK by default.
60   return options;
61 }
62 
EncodeMaskToPngFile(const SegmentationResult & result)63 absl::Status EncodeMaskToPngFile(const SegmentationResult& result) {
64   if (result.segmentation_size() != 1) {
65     return absl::UnimplementedError(
66         "Image segmentation models with multiple output segmentations are not "
67         "supported by this tool.");
68   }
69   const Segmentation& segmentation = result.segmentation(0);
70   // Extract raw mask data as a uint8 pointer.
71   const uint8* raw_mask =
72       reinterpret_cast<const uint8*>(segmentation.category_mask().data());
73 
74   // Create RgbImageData for the output mask.
75   uint8* pixel_data = static_cast<uint8*>(
76       malloc(segmentation.width() * segmentation.height() * 3 * sizeof(uint8)));
77   ImageData mask = {.pixel_data = pixel_data,
78                     .width = segmentation.width(),
79                     .height = segmentation.height(),
80                     .channels = 3};
81 
82   // Populate RgbImageData from the raw mask and ColoredLabel-s.
83   for (int i = 0; i < segmentation.width() * segmentation.height(); ++i) {
84     Segmentation::ColoredLabel colored_label =
85         segmentation.colored_labels(raw_mask[i]);
86     pixel_data[3 * i] = colored_label.r();
87     pixel_data[3 * i + 1] = colored_label.g();
88     pixel_data[3 * i + 2] = colored_label.b();
89   }
90 
91   // Encode mask as PNG.
92   RETURN_IF_ERROR(
93       EncodeImageToPngFile(mask, absl::GetFlag(FLAGS_output_mask_png)));
94   std::cout << absl::StrFormat("Category mask saved to: %s\n",
95                                absl::GetFlag(FLAGS_output_mask_png));
96 
97   // Cleanup and return.
98   ImageDataFree(&mask);
99   return absl::OkStatus();
100 }
101 
DisplayColorLegend(const SegmentationResult & result)102 absl::Status DisplayColorLegend(const SegmentationResult& result) {
103   if (result.segmentation_size() != 1) {
104     return absl::UnimplementedError(
105         "Image segmentation models with multiple output segmentations are not "
106         "supported by this tool.");
107   }
108   const Segmentation& segmentation = result.segmentation(0);
109   const int num_labels = segmentation.colored_labels_size();
110 
111   std::cout << "Color Legend:\n";
112   for (int index = 0; index < num_labels; ++index) {
113     Segmentation::ColoredLabel colored_label =
114         segmentation.colored_labels(index);
115     std::cout << absl::StrFormat(" (r: %03d, g: %03d, b: %03d):\n",
116                                  colored_label.r(), colored_label.g(),
117                                  colored_label.b());
118     std::cout << absl::StrFormat("  index       : %d\n", index);
119     if (colored_label.has_class_name()) {
120       std::cout << absl::StrFormat("  class name  : %s\n",
121                                    colored_label.class_name());
122     }
123     if (colored_label.has_display_name()) {
124       std::cout << absl::StrFormat("  display name: %s\n",
125                                    colored_label.display_name());
126     }
127   }
128   std::cout << "Tip: use a color picker on the output PNG file to inspect the "
129                "output mask with this legend.\n";
130 
131   return absl::OkStatus();
132 }
133 
Segment()134 absl::Status Segment() {
135   // Build ImageClassifier.
136   const ImageSegmenterOptions& options = BuildOptions();
137   ASSIGN_OR_RETURN(std::unique_ptr<ImageSegmenter> image_segmenter,
138                    ImageSegmenter::CreateFromOptions(options));
139 
140   // Load image in a FrameBuffer.
141   ASSIGN_OR_RETURN(ImageData image,
142                    DecodeImageFromFile(absl::GetFlag(FLAGS_image_path)));
143   std::unique_ptr<FrameBuffer> frame_buffer;
144   if (image.channels == 3) {
145     frame_buffer =
146         CreateFromRgbRawBuffer(image.pixel_data, {image.width, image.height});
147   } else if (image.channels == 4) {
148     frame_buffer =
149         CreateFromRgbaRawBuffer(image.pixel_data, {image.width, image.height});
150   } else {
151     return absl::InvalidArgumentError(absl::StrFormat(
152         "Expected image with 3 (RGB) or 4 (RGBA) channels, found %d",
153         image.channels));
154   }
155 
156   // Run segmentation and save category mask.
157   ASSIGN_OR_RETURN(SegmentationResult result,
158                    image_segmenter->Segment(*frame_buffer));
159   RETURN_IF_ERROR(EncodeMaskToPngFile(result));
160 
161   // Display the legend.
162   RETURN_IF_ERROR(DisplayColorLegend(result));
163 
164   // Cleanup and return.
165   ImageDataFree(&image);
166   return absl::OkStatus();
167 }
168 
169 }  // namespace vision
170 }  // namespace task
171 }  // namespace tflite
172 
main(int argc,char ** argv)173 int main(int argc, char** argv) {
174   // Parse command line arguments and perform sanity checks.
175   absl::ParseCommandLine(argc, argv);
176   if (absl::GetFlag(FLAGS_model_path).empty()) {
177     std::cerr << "Missing mandatory 'model_path' argument.\n";
178     return 1;
179   }
180   if (absl::GetFlag(FLAGS_image_path).empty()) {
181     std::cerr << "Missing mandatory 'image_path' argument.\n";
182     return 1;
183   }
184   if (absl::GetFlag(FLAGS_output_mask_png).empty()) {
185     std::cerr << "Missing mandatory 'output_mask_png' argument.\n";
186     return 1;
187   }
188   if (!absl::EndsWithIgnoreCase(absl::GetFlag(FLAGS_output_mask_png), ".png")) {
189     std::cerr << "Argument 'output_mask_png' must end with '.png' or '.PNG'\n";
190     return 1;
191   }
192 
193   // Run segmentation.
194   absl::Status status = tflite::task::vision::Segment();
195   if (status.ok()) {
196     return 0;
197   } else {
198     std::cerr << "Segmentation failed: " << status.message() << "\n";
199     return 1;
200   }
201 }
202