xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/toco/args.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // This abstracts command line arguments in toco.
16 // Arg<T> is a parseable type that can register a default value, be able to
17 // parse itself, and keep track of whether it was specified.
18 #ifndef TENSORFLOW_LITE_TOCO_ARGS_H_
19 #define TENSORFLOW_LITE_TOCO_ARGS_H_
20 
21 #include <functional>
22 #include <string>
23 #include <unordered_map>
24 #include <vector>
25 
26 #include "absl/strings/numbers.h"
27 #include "absl/strings/str_split.h"
28 #include "tensorflow/lite/toco/toco_port.h"
29 #include "tensorflow/lite/toco/toco_types.h"
30 
31 namespace toco {
32 
33 // Since std::vector<int32> is in the std namespace, and we are not allowed
34 // to add ParseFlag/UnparseFlag to std, we introduce a simple wrapper type
35 // to use as the flag type:
36 struct IntList {
37   std::vector<int32> elements;
38 };
39 struct StringMapList {
40   std::vector<std::unordered_map<std::string, std::string>> elements;
41 };
42 
43 // command_line_flags.h don't track whether or not a flag is specified. Arg
44 // contains the value (which will be default if not specified) and also
45 // whether the flag is specified.
46 // TODO(aselle): consider putting doc string and ability to construct the
47 // tensorflow argument into this, so declaration of parameters can be less
48 // distributed.
49 // Every template specialization of Arg is required to implement
50 // default_value(), specified(), value(), parse(), bind().
51 template <class T>
52 class Arg final {
53  public:
value_(default_)54   explicit Arg(T default_ = T()) : value_(default_) {}
~Arg()55   virtual ~Arg() {}
56 
57   // Provide default_value() to arg list
default_value()58   T default_value() const { return value_; }
59   // Return true if the command line argument was specified on the command line.
specified()60   bool specified() const { return specified_; }
61   // Const reference to parsed value.
value()62   const T& value() const { return value_; }
63 
64   // Parsing callback for the tensorflow::Flags code
Parse(T value_in)65   bool Parse(T value_in) {
66     value_ = value_in;
67     specified_ = true;
68     return true;
69   }
70 
71   // Bind the parse member function so tensorflow::Flags can call it.
bind()72   std::function<bool(T)> bind() {
73     return std::bind(&Arg::Parse, this, std::placeholders::_1);
74   }
75 
76  private:
77   // Becomes true after parsing if the value was specified
78   bool specified_ = false;
79   // Value of the argument (initialized to the default in the constructor).
80   T value_;
81 };
82 
83 template <>
84 class Arg<toco::IntList> final {
85  public:
86   // Provide default_value() to arg list
default_value()87   std::string default_value() const { return ""; }
88   // Return true if the command line argument was specified on the command line.
specified()89   bool specified() const { return specified_; }
90   // Bind the parse member function so tensorflow::Flags can call it.
91   bool Parse(std::string text);
92 
bind()93   std::function<bool(std::string)> bind() {
94     return std::bind(&Arg::Parse, this, std::placeholders::_1);
95   }
96 
value()97   const toco::IntList& value() const { return parsed_value_; }
98 
99  private:
100   toco::IntList parsed_value_;
101   bool specified_ = false;
102 };
103 
104 template <>
105 class Arg<toco::StringMapList> final {
106  public:
107   // Provide default_value() to StringMapList
default_value()108   std::string default_value() const { return ""; }
109   // Return true if the command line argument was specified on the command line.
specified()110   bool specified() const { return specified_; }
111   // Bind the parse member function so tensorflow::Flags can call it.
112 
113   bool Parse(std::string text);
114 
bind()115   std::function<bool(std::string)> bind() {
116     return std::bind(&Arg::Parse, this, std::placeholders::_1);
117   }
118 
value()119   const toco::StringMapList& value() const { return parsed_value_; }
120 
121  private:
122   toco::StringMapList parsed_value_;
123   bool specified_ = false;
124 };
125 
126 // Flags that describe a model. See model_cmdline_flags.cc for details.
127 struct ParsedModelFlags {
128   Arg<std::string> input_array;
129   Arg<std::string> input_arrays;
130   Arg<std::string> output_array;
131   Arg<std::string> output_arrays;
132   Arg<std::string> input_shapes;
133   Arg<int> batch_size = Arg<int>(1);
134   Arg<float> mean_value = Arg<float>(0.f);
135   Arg<std::string> mean_values;
136   Arg<float> std_value = Arg<float>(1.f);
137   Arg<std::string> std_values;
138   Arg<std::string> input_data_type;
139   Arg<std::string> input_data_types;
140   Arg<bool> variable_batch = Arg<bool>(false);
141   Arg<toco::IntList> input_shape;
142   Arg<toco::StringMapList> rnn_states;
143   Arg<toco::StringMapList> model_checks;
144   Arg<bool> change_concat_input_ranges = Arg<bool>(true);
145   // Debugging output options.
146   // TODO(benoitjacob): these shouldn't be ModelFlags.
147   Arg<std::string> graphviz_first_array;
148   Arg<std::string> graphviz_last_array;
149   Arg<std::string> dump_graphviz;
150   Arg<bool> dump_graphviz_video = Arg<bool>(false);
151   Arg<std::string> conversion_summary_dir;
152   Arg<bool> allow_nonexistent_arrays = Arg<bool>(false);
153   Arg<bool> allow_nonascii_arrays = Arg<bool>(false);
154   Arg<std::string> arrays_extra_info_file;
155   Arg<std::string> model_flags_file;
156 };
157 
158 // Flags that describe the operation you would like to do (what conversion
159 // you want). See toco_cmdline_flags.cc for details.
160 struct ParsedTocoFlags {
161   Arg<std::string> input_file;
162   Arg<std::string> savedmodel_directory;
163   Arg<std::string> output_file;
164   Arg<std::string> input_format = Arg<std::string>("TENSORFLOW_GRAPHDEF");
165   Arg<std::string> output_format = Arg<std::string>("TFLITE");
166   Arg<std::string> savedmodel_tagset;
167   // TODO(aselle): command_line_flags  doesn't support doubles
168   Arg<float> default_ranges_min = Arg<float>(0.);
169   Arg<float> default_ranges_max = Arg<float>(0.);
170   Arg<float> default_int16_ranges_min = Arg<float>(0.);
171   Arg<float> default_int16_ranges_max = Arg<float>(0.);
172   Arg<std::string> inference_type;
173   Arg<std::string> inference_input_type;
174   Arg<bool> drop_fake_quant = Arg<bool>(false);
175   Arg<bool> reorder_across_fake_quant = Arg<bool>(false);
176   Arg<bool> allow_custom_ops = Arg<bool>(false);
177   Arg<bool> allow_dynamic_tensors = Arg<bool>(true);
178   Arg<std::string> custom_opdefs;
179   Arg<bool> post_training_quantize = Arg<bool>(false);
180   Arg<bool> quantize_to_float16 = Arg<bool>(false);
181   // Deprecated flags
182   Arg<bool> quantize_weights = Arg<bool>(false);
183   Arg<std::string> input_type;
184   Arg<std::string> input_types;
185   Arg<bool> debug_disable_recurrent_cell_fusion = Arg<bool>(false);
186   Arg<bool> drop_control_dependency = Arg<bool>(false);
187   Arg<bool> propagate_fake_quant_num_bits = Arg<bool>(false);
188   Arg<bool> allow_nudging_weights_to_use_fast_gemm_kernel = Arg<bool>(false);
189   Arg<int64_t> dedupe_array_min_size_bytes = Arg<int64_t>(64);
190   Arg<bool> split_tflite_lstm_inputs = Arg<bool>(true);
191   // WARNING: Experimental interface, subject to change
192   Arg<bool> enable_select_tf_ops = Arg<bool>(false);
193   // WARNING: Experimental interface, subject to change
194   Arg<bool> force_select_tf_ops = Arg<bool>(false);
195   // WARNING: Experimental interface, subject to change
196   Arg<bool> unfold_batchmatmul = Arg<bool>(true);
197   // WARNING: Experimental interface, subject to change
198   Arg<std::string> accumulation_type;
199   // WARNING: Experimental interface, subject to change
200   Arg<bool> allow_bfloat16;
201 };
202 
203 }  // namespace toco
204 #endif  // TENSORFLOW_LITE_TOCO_ARGS_H_
205