xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/toco/model_cmdline_flags.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/toco/model_cmdline_flags.h"
16 
17 #include <string>
18 #include <vector>
19 
20 #include "absl/strings/numbers.h"
21 #include "absl/strings/str_join.h"
22 #include "absl/strings/str_split.h"
23 #include "absl/strings/string_view.h"
24 #include "absl/strings/strip.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/util/command_line_flags.h"
27 #include "tensorflow/lite/toco/args.h"
28 #include "tensorflow/lite/toco/toco_graphviz_dump_options.h"
29 #include "tensorflow/lite/toco/toco_port.h"
30 
31 // "batch" flag only exists internally
32 #ifdef PLATFORM_GOOGLE
33 #include "base/commandlineflags.h"
34 #endif
35 
36 namespace toco {
37 
ParseModelFlagsFromCommandLineFlags(int * argc,char * argv[],std::string * msg,ParsedModelFlags * parsed_model_flags_ptr)38 bool ParseModelFlagsFromCommandLineFlags(
39     int* argc, char* argv[], std::string* msg,
40     ParsedModelFlags* parsed_model_flags_ptr) {
41   ParsedModelFlags& parsed_flags = *parsed_model_flags_ptr;
42   using tensorflow::Flag;
43   std::vector<tensorflow::Flag> flags = {
44       Flag("input_array", parsed_flags.input_array.bind(),
45            parsed_flags.input_array.default_value(),
46            "Deprecated: use --input_arrays instead. Name of the input array. "
47            "If not specified, will try to read "
48            "that information from the input file."),
49       Flag("input_arrays", parsed_flags.input_arrays.bind(),
50            parsed_flags.input_arrays.default_value(),
51            "Names of the input arrays, comma-separated. If not specified, "
52            "will try to read that information from the input file."),
53       Flag("output_array", parsed_flags.output_array.bind(),
54            parsed_flags.output_array.default_value(),
55            "Deprecated: use --output_arrays instead. Name of the output array, "
56            "when specifying a unique output array. "
57            "If not specified, will try to read that information from the "
58            "input file."),
59       Flag("output_arrays", parsed_flags.output_arrays.bind(),
60            parsed_flags.output_arrays.default_value(),
61            "Names of the output arrays, comma-separated. "
62            "If not specified, will try to read "
63            "that information from the input file."),
64       Flag("input_shape", parsed_flags.input_shape.bind(),
65            parsed_flags.input_shape.default_value(),
66            "Deprecated: use --input_shapes instead. Input array shape. For "
67            "many models the shape takes the form "
68            "batch size, input array height, input array width, input array "
69            "depth."),
70       Flag("input_shapes", parsed_flags.input_shapes.bind(),
71            parsed_flags.input_shapes.default_value(),
72            "Shapes corresponding to --input_arrays, colon-separated. For "
73            "many models each shape takes the form batch size, input array "
74            "height, input array width, input array depth."),
75       Flag("batch_size", parsed_flags.batch_size.bind(),
76            parsed_flags.batch_size.default_value(),
77            "Deprecated. Batch size for the model. Replaces the first dimension "
78            "of an input size array if undefined. Use only with SavedModels "
79            "when --input_shapes flag is not specified. Always use "
80            "--input_shapes flag with frozen graphs."),
81       Flag("input_data_type", parsed_flags.input_data_type.bind(),
82            parsed_flags.input_data_type.default_value(),
83            "Deprecated: use --input_data_types instead. Input array type, if "
84            "not already provided in the graph. "
85            "Typically needs to be specified when passing arbitrary arrays "
86            "to --input_arrays."),
87       Flag("input_data_types", parsed_flags.input_data_types.bind(),
88            parsed_flags.input_data_types.default_value(),
89            "Input arrays types, comma-separated, if not already provided in "
90            "the graph. "
91            "Typically needs to be specified when passing arbitrary arrays "
92            "to --input_arrays."),
93       Flag("mean_value", parsed_flags.mean_value.bind(),
94            parsed_flags.mean_value.default_value(),
95            "Deprecated: use --mean_values instead. mean_value parameter for "
96            "image models, used to compute input "
97            "activations from input pixel data."),
98       Flag("mean_values", parsed_flags.mean_values.bind(),
99            parsed_flags.mean_values.default_value(),
100            "mean_values parameter for image models, comma-separated list of "
101            "doubles, used to compute input activations from input pixel "
102            "data. Each entry in the list should match an entry in "
103            "--input_arrays."),
104       Flag("std_value", parsed_flags.std_value.bind(),
105            parsed_flags.std_value.default_value(),
106            "Deprecated: use --std_values instead. std_value parameter for "
107            "image models, used to compute input "
108            "activations from input pixel data."),
109       Flag("std_values", parsed_flags.std_values.bind(),
110            parsed_flags.std_values.default_value(),
111            "std_value parameter for image models, comma-separated list of "
112            "doubles, used to compute input activations from input pixel "
113            "data. Each entry in the list should match an entry in "
114            "--input_arrays."),
115       Flag("variable_batch", parsed_flags.variable_batch.bind(),
116            parsed_flags.variable_batch.default_value(),
117            "If true, the model accepts an arbitrary batch size. Mutually "
118            "exclusive "
119            "with the 'batch' field: at most one of these two fields can be "
120            "set."),
121       Flag("rnn_states", parsed_flags.rnn_states.bind(),
122            parsed_flags.rnn_states.default_value(), ""),
123       Flag("model_checks", parsed_flags.model_checks.bind(),
124            parsed_flags.model_checks.default_value(),
125            "A list of model checks to be applied to verify the form of the "
126            "model.  Applied after the graph transformations after import."),
127       Flag("dump_graphviz", parsed_flags.dump_graphviz.bind(),
128            parsed_flags.dump_graphviz.default_value(),
129            "Dump graphviz during LogDump call. If string is non-empty then "
130            "it defines path to dump, otherwise will skip dumping."),
131       Flag("dump_graphviz_video", parsed_flags.dump_graphviz_video.bind(),
132            parsed_flags.dump_graphviz_video.default_value(),
133            "If true, will dump graphviz at each "
134            "graph transformation, which may be used to generate a video."),
135       Flag("conversion_summary_dir", parsed_flags.conversion_summary_dir.bind(),
136            parsed_flags.conversion_summary_dir.default_value(),
137            "Local file directory to store the conversion logs."),
138       Flag("allow_nonexistent_arrays",
139            parsed_flags.allow_nonexistent_arrays.bind(),
140            parsed_flags.allow_nonexistent_arrays.default_value(),
141            "If true, will allow passing inexistent arrays in --input_arrays "
142            "and --output_arrays. This makes little sense, is only useful to "
143            "more easily get graph visualizations."),
144       Flag("allow_nonascii_arrays", parsed_flags.allow_nonascii_arrays.bind(),
145            parsed_flags.allow_nonascii_arrays.default_value(),
146            "If true, will allow passing non-ascii-printable characters in "
147            "--input_arrays and --output_arrays. By default (if false), only "
148            "ascii printable characters are allowed, i.e. character codes "
149            "ranging from 32 to 127. This is disallowed by default so as to "
150            "catch common copy-and-paste issues where invisible unicode "
151            "characters are unwittingly added to these strings."),
152       Flag(
153           "arrays_extra_info_file", parsed_flags.arrays_extra_info_file.bind(),
154           parsed_flags.arrays_extra_info_file.default_value(),
155           "Path to an optional file containing a serialized ArraysExtraInfo "
156           "proto allowing to pass extra information about arrays not specified "
157           "in the input model file, such as extra MinMax information."),
158       Flag("model_flags_file", parsed_flags.model_flags_file.bind(),
159            parsed_flags.model_flags_file.default_value(),
160            "Path to an optional file containing a serialized ModelFlags proto. "
161            "Options specified on the command line will override the values in "
162            "the proto."),
163       Flag("change_concat_input_ranges",
164            parsed_flags.change_concat_input_ranges.bind(),
165            parsed_flags.change_concat_input_ranges.default_value(),
166            "Boolean to change the behavior of min/max ranges for inputs and"
167            " output of the concat operators."),
168   };
169   bool asked_for_help =
170       *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
171   if (asked_for_help) {
172     *msg += tensorflow::Flags::Usage(argv[0], flags);
173     return false;
174   } else {
175     if (!tensorflow::Flags::Parse(argc, argv, flags)) return false;
176   }
177   auto& dump_options = *GraphVizDumpOptions::singleton();
178   dump_options.dump_graphviz_video = parsed_flags.dump_graphviz_video.value();
179   dump_options.dump_graphviz = parsed_flags.dump_graphviz.value();
180 
181   return true;
182 }
183 
ReadModelFlagsFromCommandLineFlags(const ParsedModelFlags & parsed_model_flags,ModelFlags * model_flags)184 void ReadModelFlagsFromCommandLineFlags(
185     const ParsedModelFlags& parsed_model_flags, ModelFlags* model_flags) {
186   toco::port::CheckInitGoogleIsDone("InitGoogle is not done yet");
187 
188   // Load proto containing the initial model flags.
189   // Additional flags specified on the command line will overwrite the values.
190   if (parsed_model_flags.model_flags_file.specified()) {
191     std::string model_flags_file_contents;
192     QCHECK(port::file::GetContents(parsed_model_flags.model_flags_file.value(),
193                                    &model_flags_file_contents,
194                                    port::file::Defaults())
195                .ok())
196         << "Specified --model_flags_file="
197         << parsed_model_flags.model_flags_file.value()
198         << " was not found or could not be read";
199     QCHECK(ParseFromStringEitherTextOrBinary(model_flags_file_contents,
200                                              model_flags))
201         << "Specified --model_flags_file="
202         << parsed_model_flags.model_flags_file.value()
203         << " could not be parsed";
204   }
205 
206 #ifdef PLATFORM_GOOGLE
207   CHECK(!((base::WasPresentOnCommandLine("batch") &&
208            parsed_model_flags.variable_batch.specified())))
209       << "The --batch and --variable_batch flags are mutually exclusive.";
210 #endif
211   CHECK(!(parsed_model_flags.output_array.specified() &&
212           parsed_model_flags.output_arrays.specified()))
213       << "The --output_array and --vs flags are mutually exclusive.";
214 
215   if (parsed_model_flags.output_array.specified()) {
216     model_flags->add_output_arrays(parsed_model_flags.output_array.value());
217   }
218 
219   if (parsed_model_flags.output_arrays.specified()) {
220     std::vector<std::string> output_arrays =
221         absl::StrSplit(parsed_model_flags.output_arrays.value(), ',');
222     for (const std::string& output_array : output_arrays) {
223       model_flags->add_output_arrays(output_array);
224     }
225   }
226 
227   const bool uses_single_input_flags =
228       parsed_model_flags.input_array.specified() ||
229       parsed_model_flags.mean_value.specified() ||
230       parsed_model_flags.std_value.specified() ||
231       parsed_model_flags.input_shape.specified();
232 
233   const bool uses_multi_input_flags =
234       parsed_model_flags.input_arrays.specified() ||
235       parsed_model_flags.mean_values.specified() ||
236       parsed_model_flags.std_values.specified() ||
237       parsed_model_flags.input_shapes.specified();
238 
239   QCHECK(!(uses_single_input_flags && uses_multi_input_flags))
240       << "Use either the singular-form input flags (--input_array, "
241          "--input_shape, --mean_value, --std_value) or the plural form input "
242          "flags (--input_arrays, --input_shapes, --mean_values, --std_values), "
243          "but not both forms within the same command line.";
244 
245   if (parsed_model_flags.input_array.specified()) {
246     QCHECK(uses_single_input_flags);
247     model_flags->add_input_arrays()->set_name(
248         parsed_model_flags.input_array.value());
249   }
250   if (parsed_model_flags.input_arrays.specified()) {
251     QCHECK(uses_multi_input_flags);
252     for (const auto& input_array :
253          absl::StrSplit(parsed_model_flags.input_arrays.value(), ',')) {
254       model_flags->add_input_arrays()->set_name(std::string(input_array));
255     }
256   }
257   if (parsed_model_flags.mean_value.specified()) {
258     QCHECK(uses_single_input_flags);
259     model_flags->mutable_input_arrays(0)->set_mean_value(
260         parsed_model_flags.mean_value.value());
261   }
262   if (parsed_model_flags.mean_values.specified()) {
263     QCHECK(uses_multi_input_flags);
264     std::vector<std::string> mean_values =
265         absl::StrSplit(parsed_model_flags.mean_values.value(), ',');
266     QCHECK(static_cast<int>(mean_values.size()) ==
267            model_flags->input_arrays_size());
268     for (size_t i = 0; i < mean_values.size(); ++i) {
269       char* last = nullptr;
270       model_flags->mutable_input_arrays(i)->set_mean_value(
271           strtod(mean_values[i].data(), &last));
272       CHECK(last != mean_values[i].data());
273     }
274   }
275   if (parsed_model_flags.std_value.specified()) {
276     QCHECK(uses_single_input_flags);
277     model_flags->mutable_input_arrays(0)->set_std_value(
278         parsed_model_flags.std_value.value());
279   }
280   if (parsed_model_flags.std_values.specified()) {
281     QCHECK(uses_multi_input_flags);
282     std::vector<std::string> std_values =
283         absl::StrSplit(parsed_model_flags.std_values.value(), ',');
284     QCHECK(static_cast<int>(std_values.size()) ==
285            model_flags->input_arrays_size());
286     for (size_t i = 0; i < std_values.size(); ++i) {
287       char* last = nullptr;
288       model_flags->mutable_input_arrays(i)->set_std_value(
289           strtod(std_values[i].data(), &last));
290       CHECK(last != std_values[i].data());
291     }
292   }
293   if (parsed_model_flags.input_data_type.specified()) {
294     QCHECK(uses_single_input_flags);
295     IODataType type;
296     QCHECK(IODataType_Parse(parsed_model_flags.input_data_type.value(), &type));
297     model_flags->mutable_input_arrays(0)->set_data_type(type);
298   }
299   if (parsed_model_flags.input_data_types.specified()) {
300     QCHECK(uses_multi_input_flags);
301     std::vector<std::string> input_data_types =
302         absl::StrSplit(parsed_model_flags.input_data_types.value(), ',');
303     QCHECK(static_cast<int>(input_data_types.size()) ==
304            model_flags->input_arrays_size());
305     for (size_t i = 0; i < input_data_types.size(); ++i) {
306       IODataType type;
307       QCHECK(IODataType_Parse(input_data_types[i], &type));
308       model_flags->mutable_input_arrays(i)->set_data_type(type);
309     }
310   }
311   if (parsed_model_flags.input_shape.specified()) {
312     QCHECK(uses_single_input_flags);
313     if (model_flags->input_arrays().empty()) {
314       model_flags->add_input_arrays();
315     }
316     auto* shape = model_flags->mutable_input_arrays(0)->mutable_shape();
317     shape->clear_dims();
318     const IntList& list = parsed_model_flags.input_shape.value();
319     for (auto& dim : list.elements) {
320       shape->add_dims(dim);
321     }
322   }
323   if (parsed_model_flags.input_shapes.specified()) {
324     QCHECK(uses_multi_input_flags);
325     std::vector<std::string> input_shapes =
326         absl::StrSplit(parsed_model_flags.input_shapes.value(), ':');
327     QCHECK(static_cast<int>(input_shapes.size()) ==
328            model_flags->input_arrays_size());
329     for (size_t i = 0; i < input_shapes.size(); ++i) {
330       auto* shape = model_flags->mutable_input_arrays(i)->mutable_shape();
331       shape->clear_dims();
332       // Treat an empty input shape as a scalar.
333       if (input_shapes[i].empty()) {
334         continue;
335       }
336       for (const auto& dim_str : absl::StrSplit(input_shapes[i], ',')) {
337         int size;
338         CHECK(absl::SimpleAtoi(dim_str, &size))
339             << "Failed to parse input_shape: " << input_shapes[i];
340         shape->add_dims(size);
341       }
342     }
343   }
344 
345 #define READ_MODEL_FLAG(name)                                   \
346   do {                                                          \
347     if (parsed_model_flags.name.specified()) {                  \
348       model_flags->set_##name(parsed_model_flags.name.value()); \
349     }                                                           \
350   } while (false)
351 
352   READ_MODEL_FLAG(variable_batch);
353 
354 #undef READ_MODEL_FLAG
355 
356   for (const auto& element : parsed_model_flags.rnn_states.value().elements) {
357     auto* rnn_state_proto = model_flags->add_rnn_states();
358     for (const auto& kv_pair : element) {
359       const std::string& key = kv_pair.first;
360       const std::string& value = kv_pair.second;
361       if (key == "state_array") {
362         rnn_state_proto->set_state_array(value);
363       } else if (key == "back_edge_source_array") {
364         rnn_state_proto->set_back_edge_source_array(value);
365       } else if (key == "size") {
366         int32_t size = 0;
367         CHECK(absl::SimpleAtoi(value, &size));
368         CHECK_GT(size, 0);
369         rnn_state_proto->set_size(size);
370       } else if (key == "num_dims") {
371         int32_t size = 0;
372         CHECK(absl::SimpleAtoi(value, &size));
373         CHECK_GT(size, 0);
374         rnn_state_proto->set_num_dims(size);
375       } else {
376         LOG(FATAL) << "Unknown key '" << key << "' in --rnn_states";
377       }
378     }
379     CHECK(rnn_state_proto->has_state_array() &&
380           rnn_state_proto->has_back_edge_source_array() &&
381           rnn_state_proto->has_size())
382         << "--rnn_states must include state_array, back_edge_source_array and "
383            "size.";
384   }
385 
386   for (const auto& element : parsed_model_flags.model_checks.value().elements) {
387     auto* model_check_proto = model_flags->add_model_checks();
388     for (const auto& kv_pair : element) {
389       const std::string& key = kv_pair.first;
390       const std::string& value = kv_pair.second;
391       if (key == "count_type") {
392         model_check_proto->set_count_type(value);
393       } else if (key == "count_min") {
394         int32_t count = 0;
395         CHECK(absl::SimpleAtoi(value, &count));
396         CHECK_GE(count, -1);
397         model_check_proto->set_count_min(count);
398       } else if (key == "count_max") {
399         int32_t count = 0;
400         CHECK(absl::SimpleAtoi(value, &count));
401         CHECK_GE(count, -1);
402         model_check_proto->set_count_max(count);
403       } else {
404         LOG(FATAL) << "Unknown key '" << key << "' in --model_checks";
405       }
406     }
407   }
408 
409   if (!model_flags->has_allow_nonascii_arrays()) {
410     model_flags->set_allow_nonascii_arrays(
411         parsed_model_flags.allow_nonascii_arrays.value());
412   }
413   if (!model_flags->has_allow_nonexistent_arrays()) {
414     model_flags->set_allow_nonexistent_arrays(
415         parsed_model_flags.allow_nonexistent_arrays.value());
416   }
417   if (!model_flags->has_change_concat_input_ranges()) {
418     model_flags->set_change_concat_input_ranges(
419         parsed_model_flags.change_concat_input_ranges.value());
420   }
421 
422   if (parsed_model_flags.arrays_extra_info_file.specified()) {
423     std::string arrays_extra_info_file_contents;
424     CHECK(port::file::GetContents(
425               parsed_model_flags.arrays_extra_info_file.value(),
426               &arrays_extra_info_file_contents, port::file::Defaults())
427               .ok());
428     ParseFromStringEitherTextOrBinary(arrays_extra_info_file_contents,
429                                       model_flags->mutable_arrays_extra_info());
430   }
431 }
432 
UncheckedGlobalParsedModelFlags(bool must_already_exist)433 ParsedModelFlags* UncheckedGlobalParsedModelFlags(bool must_already_exist) {
434   static auto* flags = [must_already_exist]() {
435     if (must_already_exist) {
436       fprintf(stderr, __FILE__
437               ":"
438               "GlobalParsedModelFlags() used without initialization\n");
439       fflush(stderr);
440       abort();
441     }
442     return new toco::ParsedModelFlags;
443   }();
444   return flags;
445 }
446 
GlobalParsedModelFlags()447 ParsedModelFlags* GlobalParsedModelFlags() {
448   return UncheckedGlobalParsedModelFlags(true);
449 }
450 
ParseModelFlagsOrDie(int * argc,char * argv[])451 void ParseModelFlagsOrDie(int* argc, char* argv[]) {
452   // TODO(aselle): in the future allow Google version to use
453   // flags, and only use this mechanism for open source
454   auto* flags = UncheckedGlobalParsedModelFlags(false);
455   std::string msg;
456   bool model_success =
457       toco::ParseModelFlagsFromCommandLineFlags(argc, argv, &msg, flags);
458   if (!model_success || !msg.empty()) {
459     // Log in non-standard way since this happens pre InitGoogle.
460     fprintf(stderr, "%s", msg.c_str());
461     fflush(stderr);
462     abort();
463   }
464 }
465 
466 }  // namespace toco
467