xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/toco/toco_cmdline_flags.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/toco/toco_cmdline_flags.h"
17 
18 #include <optional>
19 #include <string>
20 #include <vector>
21 
22 #include "absl/strings/numbers.h"
23 #include "absl/strings/str_join.h"
24 #include "absl/strings/str_split.h"
25 #include "absl/strings/strip.h"
26 #include "absl/types/optional.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/util/command_line_flags.h"
29 #include "tensorflow/lite/toco/toco_port.h"
30 
31 namespace toco {
32 
ParseTocoFlagsFromCommandLineFlags(int * argc,char * argv[],std::string * msg,ParsedTocoFlags * parsed_toco_flags_ptr)33 bool ParseTocoFlagsFromCommandLineFlags(
34     int* argc, char* argv[], std::string* msg,
35     ParsedTocoFlags* parsed_toco_flags_ptr) {
36   using tensorflow::Flag;
37   ParsedTocoFlags& parsed_flags = *parsed_toco_flags_ptr;
38   std::vector<tensorflow::Flag> flags = {
39       Flag("input_file", parsed_flags.input_file.bind(),
40            parsed_flags.input_file.default_value(),
41            "Input file (model of any supported format). For Protobuf "
42            "formats, both text and binary are supported regardless of file "
43            "extension."),
44       Flag("savedmodel_directory", parsed_flags.savedmodel_directory.bind(),
45            parsed_flags.savedmodel_directory.default_value(),
46            "Deprecated. Full path to the directory containing the SavedModel."),
47       Flag("output_file", parsed_flags.output_file.bind(),
48            parsed_flags.output_file.default_value(),
49            "Output file. "
50            "For Protobuf formats, the binary format will be used."),
51       Flag("input_format", parsed_flags.input_format.bind(),
52            parsed_flags.input_format.default_value(),
53            "Input file format. One of: TENSORFLOW_GRAPHDEF, TFLITE."),
54       Flag("output_format", parsed_flags.output_format.bind(),
55            parsed_flags.output_format.default_value(),
56            "Output file format. "
57            "One of TENSORFLOW_GRAPHDEF, TFLITE, GRAPHVIZ_DOT."),
58       Flag("savedmodel_tagset", parsed_flags.savedmodel_tagset.bind(),
59            parsed_flags.savedmodel_tagset.default_value(),
60            "Deprecated. Comma-separated set of tags identifying the "
61            "MetaGraphDef within the SavedModel to analyze. All tags in the tag "
62            "set must be specified."),
63       Flag("default_ranges_min", parsed_flags.default_ranges_min.bind(),
64            parsed_flags.default_ranges_min.default_value(),
65            "If defined, will be used as the default value for the min bound "
66            "of min/max ranges used for quantization of uint8 arrays."),
67       Flag("default_ranges_max", parsed_flags.default_ranges_max.bind(),
68            parsed_flags.default_ranges_max.default_value(),
69            "If defined, will be used as the default value for the max bound "
70            "of min/max ranges used for quantization of uint8 arrays."),
71       Flag("default_int16_ranges_min",
72            parsed_flags.default_int16_ranges_min.bind(),
73            parsed_flags.default_int16_ranges_min.default_value(),
74            "If defined, will be used as the default value for the min bound "
75            "of min/max ranges used for quantization of int16 arrays."),
76       Flag("default_int16_ranges_max",
77            parsed_flags.default_int16_ranges_max.bind(),
78            parsed_flags.default_int16_ranges_max.default_value(),
79            "If defined, will be used as the default value for the max bound "
80            "of min/max ranges used for quantization of int16 arrays."),
81       Flag("inference_type", parsed_flags.inference_type.bind(),
82            parsed_flags.inference_type.default_value(),
83            "Target data type of arrays in the output file (for input_arrays, "
84            "this may be overridden by inference_input_type). "
85            "One of FLOAT, QUANTIZED_UINT8."),
86       Flag("inference_input_type", parsed_flags.inference_input_type.bind(),
87            parsed_flags.inference_input_type.default_value(),
88            "Target data type of input arrays. "
89            "If not specified, inference_type is used. "
90            "One of FLOAT, QUANTIZED_UINT8."),
91       Flag("input_type", parsed_flags.input_type.bind(),
92            parsed_flags.input_type.default_value(),
93            "Deprecated ambiguous flag that set both --input_data_types and "
94            "--inference_input_type."),
95       Flag("input_types", parsed_flags.input_types.bind(),
96            parsed_flags.input_types.default_value(),
97            "Deprecated ambiguous flag that set both --input_data_types and "
98            "--inference_input_type. Was meant to be a "
99            "comma-separated list, but this was deprecated before "
100            "multiple-input-types was ever properly supported."),
101 
102       Flag("drop_fake_quant", parsed_flags.drop_fake_quant.bind(),
103            parsed_flags.drop_fake_quant.default_value(),
104            "Ignore and discard FakeQuant nodes. For instance, to "
105            "generate plain float code without fake-quantization from a "
106            "quantized graph."),
107       Flag(
108           "reorder_across_fake_quant",
109           parsed_flags.reorder_across_fake_quant.bind(),
110           parsed_flags.reorder_across_fake_quant.default_value(),
111           "Normally, FakeQuant nodes must be strict boundaries for graph "
112           "transformations, in order to ensure that quantized inference has "
113           "the exact same arithmetic behavior as quantized training --- which "
114           "is the whole point of quantized training and of FakeQuant nodes in "
115           "the first place. "
116           "However, that entails subtle requirements on where exactly "
117           "FakeQuant nodes must be placed in the graph. Some quantized graphs "
118           "have FakeQuant nodes at unexpected locations, that prevent graph "
119           "transformations that are necessary in order to generate inference "
120           "code for these graphs. Such graphs should be fixed, but as a "
121           "temporary work-around, setting this reorder_across_fake_quant flag "
122           "allows TOCO to perform necessary graph transformaitons on them, "
123           "at the cost of no longer faithfully matching inference and training "
124           "arithmetic."),
125       Flag("allow_custom_ops", parsed_flags.allow_custom_ops.bind(),
126            parsed_flags.allow_custom_ops.default_value(),
127            "If true, allow TOCO to create TF Lite Custom operators for all the "
128            "unsupported TensorFlow ops."),
129       Flag("custom_opdefs", parsed_flags.custom_opdefs.bind(),
130            parsed_flags.custom_opdefs.default_value(),
131            "List of strings representing custom ops OpDefs that are included "
132            "in the GraphDef."),
133       Flag("allow_dynamic_tensors", parsed_flags.allow_dynamic_tensors.bind(),
134            parsed_flags.allow_dynamic_tensors.default_value(),
135            "Boolean flag indicating whether the converter should allow models "
136            "with dynamic Tensor shape. When set to False, the converter will "
137            "generate runtime memory offsets for activation Tensors (with 128 "
138            "bits alignment) and error out on models with undetermined Tensor "
139            "shape. (Default: True)"),
140       Flag(
141           "drop_control_dependency",
142           parsed_flags.drop_control_dependency.bind(),
143           parsed_flags.drop_control_dependency.default_value(),
144           "If true, ignore control dependency requirements in input TensorFlow "
145           "GraphDef. Otherwise an error will be raised upon control dependency "
146           "inputs."),
147       Flag("debug_disable_recurrent_cell_fusion",
148            parsed_flags.debug_disable_recurrent_cell_fusion.bind(),
149            parsed_flags.debug_disable_recurrent_cell_fusion.default_value(),
150            "If true, disable fusion of known identifiable cell subgraphs into "
151            "cells. This includes, for example, specific forms of LSTM cell."),
152       Flag("propagate_fake_quant_num_bits",
153            parsed_flags.propagate_fake_quant_num_bits.bind(),
154            parsed_flags.propagate_fake_quant_num_bits.default_value(),
155            "If true, use FakeQuant* operator num_bits attributes to adjust "
156            "array data_types."),
157       Flag("allow_nudging_weights_to_use_fast_gemm_kernel",
158            parsed_flags.allow_nudging_weights_to_use_fast_gemm_kernel.bind(),
159            parsed_flags.allow_nudging_weights_to_use_fast_gemm_kernel
160                .default_value(),
161            "Some fast uint8 GEMM kernels require uint8 weights to avoid the "
162            "value 0. This flag allows nudging them to 1 to allow proceeding, "
163            "with moderate inaccuracy."),
164       Flag("dedupe_array_min_size_bytes",
165            parsed_flags.dedupe_array_min_size_bytes.bind(),
166            parsed_flags.dedupe_array_min_size_bytes.default_value(),
167            "Minimum size of constant arrays to deduplicate; arrays smaller "
168            "will not be deduplicated."),
169       Flag("split_tflite_lstm_inputs",
170            parsed_flags.split_tflite_lstm_inputs.bind(),
171            parsed_flags.split_tflite_lstm_inputs.default_value(),
172            "Split the LSTM inputs from 5 tensors to 18 tensors for TFLite. "
173            "Ignored if the output format is not TFLite."),
174       Flag("quantize_to_float16", parsed_flags.quantize_to_float16.bind(),
175            parsed_flags.quantize_to_float16.default_value(),
176            "Used in conjunction with post_training_quantize. Specifies that "
177            "the weights should be quantized to fp16 instead of the default "
178            "(int8)"),
179       Flag("quantize_weights", parsed_flags.quantize_weights.bind(),
180            parsed_flags.quantize_weights.default_value(),
181            "Deprecated. Please use --post_training_quantize instead."),
182       Flag("post_training_quantize", parsed_flags.post_training_quantize.bind(),
183            parsed_flags.post_training_quantize.default_value(),
184            "Boolean indicating whether to quantize the weights of the "
185            "converted float model. Model size will be reduced and there will "
186            "be latency improvements (at the cost of accuracy)."),
187       // TODO(b/118822804): Unify the argument definition with `tflite_convert`.
188       // WARNING: Experimental interface, subject to change
189       Flag("enable_select_tf_ops", parsed_flags.enable_select_tf_ops.bind(),
190            parsed_flags.enable_select_tf_ops.default_value(), ""),
191       // WARNING: Experimental interface, subject to change
192       Flag("force_select_tf_ops", parsed_flags.force_select_tf_ops.bind(),
193            parsed_flags.force_select_tf_ops.default_value(), ""),
194       // WARNING: Experimental interface, subject to change
195       Flag("unfold_batchmatmul", parsed_flags.unfold_batchmatmul.bind(),
196            parsed_flags.unfold_batchmatmul.default_value(), ""),
197       // WARNING: Experimental interface, subject to change
198       Flag("accumulation_type", parsed_flags.accumulation_type.bind(),
199            parsed_flags.accumulation_type.default_value(),
200            "Accumulation type to use with quantize_to_float16"),
201       // WARNING: Experimental interface, subject to change
202       Flag("allow_bfloat16", parsed_flags.allow_bfloat16.bind(),
203            parsed_flags.allow_bfloat16.default_value(), "")};
204 
205   bool asked_for_help =
206       *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
207   if (asked_for_help) {
208     *msg += tensorflow::Flags::Usage(argv[0], flags);
209     return false;
210   } else {
211     return tensorflow::Flags::Parse(argc, argv, flags);
212   }
213 }
214 
215 namespace {
216 
217 // Defines the requirements for a given flag. kUseDefault means the default
218 // should be used in cases where the value isn't specified by the user.
219 enum class FlagRequirement {
220   kNone,
221   kMustBeSpecified,
222   kMustNotBeSpecified,
223   kUseDefault,
224 };
225 
226 // Enforces the FlagRequirements are met for a given flag.
227 template <typename T>
EnforceFlagRequirement(const T & flag,const std::string & flag_name,FlagRequirement requirement)228 void EnforceFlagRequirement(const T& flag, const std::string& flag_name,
229                             FlagRequirement requirement) {
230   if (requirement == FlagRequirement::kMustBeSpecified) {
231     QCHECK(flag.specified()) << "Missing required flag " << flag_name;
232   }
233   if (requirement == FlagRequirement::kMustNotBeSpecified) {
234     QCHECK(!flag.specified())
235         << "Given other flags, this flag should not have been specified: "
236         << flag_name;
237   }
238 }
239 
240 // Gets the value from the flag if specified. Returns default if the
241 // FlagRequirement is kUseDefault.
242 template <typename T>
GetFlagValue(const Arg<T> & flag,FlagRequirement requirement)243 std::optional<T> GetFlagValue(const Arg<T>& flag, FlagRequirement requirement) {
244   if (flag.specified()) return flag.value();
245   if (requirement == FlagRequirement::kUseDefault) return flag.default_value();
246   return std::optional<T>();
247 }
248 
249 }  // namespace
250 
ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags & parsed_toco_flags,TocoFlags * toco_flags)251 void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
252                                        TocoFlags* toco_flags) {
253   namespace port = toco::port;
254   port::CheckInitGoogleIsDone("InitGoogle is not done yet");
255 
256 #define READ_TOCO_FLAG(name, requirement)                                \
257   do {                                                                   \
258     EnforceFlagRequirement(parsed_toco_flags.name, #name, requirement);  \
259     auto flag_value = GetFlagValue(parsed_toco_flags.name, requirement); \
260     if (flag_value.has_value()) {                                        \
261       toco_flags->set_##name(flag_value.value());                        \
262     }                                                                    \
263   } while (false)
264 
265 #define PARSE_TOCO_FLAG(Type, name, requirement)                         \
266   do {                                                                   \
267     EnforceFlagRequirement(parsed_toco_flags.name, #name, requirement);  \
268     auto flag_value = GetFlagValue(parsed_toco_flags.name, requirement); \
269     if (flag_value.has_value()) {                                        \
270       Type x;                                                            \
271       QCHECK(Type##_Parse(flag_value.value(), &x))                       \
272           << "Unrecognized " << #Type << " value "                       \
273           << parsed_toco_flags.name.value();                             \
274       toco_flags->set_##name(x);                                         \
275     }                                                                    \
276   } while (false)
277 
278   PARSE_TOCO_FLAG(FileFormat, input_format, FlagRequirement::kUseDefault);
279   PARSE_TOCO_FLAG(FileFormat, output_format, FlagRequirement::kUseDefault);
280   PARSE_TOCO_FLAG(IODataType, inference_type, FlagRequirement::kNone);
281   PARSE_TOCO_FLAG(IODataType, inference_input_type, FlagRequirement::kNone);
282   READ_TOCO_FLAG(default_ranges_min, FlagRequirement::kNone);
283   READ_TOCO_FLAG(default_ranges_max, FlagRequirement::kNone);
284   READ_TOCO_FLAG(default_int16_ranges_min, FlagRequirement::kNone);
285   READ_TOCO_FLAG(default_int16_ranges_max, FlagRequirement::kNone);
286   READ_TOCO_FLAG(drop_fake_quant, FlagRequirement::kNone);
287   READ_TOCO_FLAG(reorder_across_fake_quant, FlagRequirement::kNone);
288   READ_TOCO_FLAG(allow_custom_ops, FlagRequirement::kNone);
289   READ_TOCO_FLAG(drop_control_dependency, FlagRequirement::kNone);
290   READ_TOCO_FLAG(debug_disable_recurrent_cell_fusion, FlagRequirement::kNone);
291   READ_TOCO_FLAG(propagate_fake_quant_num_bits, FlagRequirement::kNone);
292   READ_TOCO_FLAG(allow_nudging_weights_to_use_fast_gemm_kernel,
293                  FlagRequirement::kNone);
294   READ_TOCO_FLAG(dedupe_array_min_size_bytes, FlagRequirement::kNone);
295   READ_TOCO_FLAG(split_tflite_lstm_inputs, FlagRequirement::kNone);
296   READ_TOCO_FLAG(quantize_weights, FlagRequirement::kNone);
297   READ_TOCO_FLAG(quantize_to_float16, FlagRequirement::kNone);
298   READ_TOCO_FLAG(post_training_quantize, FlagRequirement::kNone);
299   READ_TOCO_FLAG(enable_select_tf_ops, FlagRequirement::kNone);
300   READ_TOCO_FLAG(force_select_tf_ops, FlagRequirement::kNone);
301   READ_TOCO_FLAG(unfold_batchmatmul, FlagRequirement::kNone);
302   PARSE_TOCO_FLAG(IODataType, accumulation_type, FlagRequirement::kNone);
303   READ_TOCO_FLAG(allow_bfloat16, FlagRequirement::kNone);
304 
305   if (parsed_toco_flags.force_select_tf_ops.value() &&
306       !parsed_toco_flags.enable_select_tf_ops.value()) {
307     // TODO(ycling): Consider to enforce `enable_select_tf_ops` when
308     // `force_select_tf_ops` is true.
309     LOG(WARNING) << "--force_select_tf_ops should always be used with "
310                     "--enable_select_tf_ops.";
311   }
312 
313   // Deprecated flag handling.
314   if (parsed_toco_flags.input_type.specified()) {
315     LOG(WARNING)
316         << "--input_type is deprecated. It was an ambiguous flag that set both "
317            "--input_data_types and --inference_input_type. If you are trying "
318            "to complement the input file with information about the type of "
319            "input arrays, use --input_data_type. If you are trying to control "
320            "the quantization/dequantization of real-numbers input arrays in "
321            "the output file, use --inference_input_type.";
322     toco::IODataType input_type;
323     QCHECK(toco::IODataType_Parse(parsed_toco_flags.input_type.value(),
324                                   &input_type));
325     toco_flags->set_inference_input_type(input_type);
326   }
327   if (parsed_toco_flags.input_types.specified()) {
328     LOG(WARNING)
329         << "--input_types is deprecated. It was an ambiguous flag that set "
330            "both --input_data_types and --inference_input_type. If you are "
331            "trying to complement the input file with information about the "
332            "type of input arrays, use --input_data_type. If you are trying to "
333            "control the quantization/dequantization of real-numbers input "
334            "arrays in the output file, use --inference_input_type.";
335     std::vector<std::string> input_types =
336         absl::StrSplit(parsed_toco_flags.input_types.value(), ',');
337     QCHECK(!input_types.empty());
338     for (size_t i = 1; i < input_types.size(); i++) {
339       QCHECK_EQ(input_types[i], input_types[0]);
340     }
341     toco::IODataType input_type;
342     QCHECK(toco::IODataType_Parse(input_types[0], &input_type));
343     toco_flags->set_inference_input_type(input_type);
344   }
345   if (parsed_toco_flags.quantize_weights.value()) {
346     LOG(WARNING)
347         << "--quantize_weights is deprecated. Falling back to "
348            "--post_training_quantize. Please switch --post_training_quantize.";
349     toco_flags->set_post_training_quantize(
350         parsed_toco_flags.quantize_weights.value());
351   }
352   if (parsed_toco_flags.quantize_weights.value()) {
353     if (toco_flags->inference_type() == IODataType::QUANTIZED_UINT8) {
354       LOG(WARNING)
355           << "--post_training_quantize quantizes a graph of inference_type "
356              "FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT.";
357       toco_flags->set_inference_type(IODataType::FLOAT);
358     }
359   }
360 
361 #undef READ_TOCO_FLAG
362 #undef PARSE_TOCO_FLAG
363 }
364 }  // namespace toco
365