1 #include <sstream>
2 #include <string>
3
4 #include <ATen/core/jit_type.h>
5 #include <c10/core/ScalarType.h>
6 #include <torch/csrc/jit/backends/backend.h>
7 #include <torch/csrc/jit/backends/backend_detail.h>
8 #include <torch/csrc/jit/backends/backend_preprocess.h>
9 #include <torch/csrc/jit/mobile/nnc/aot_compiler.h>
10 #include <torch/csrc/jit/passes/freeze_module.h>
11 #include <torch/csrc/jit/serialization/export.h>
12 #include <torch/csrc/jit/serialization/import.h>
13 #include <torch/csrc/jit/tensorexpr/graph_opt.h>
14 #include <torch/csrc/jit/tensorexpr/kernel.h>
15 #include <torch/script.h>
16
17 C10_DEFINE_string(model, "", "The torch script model to optimize.");
18 C10_DEFINE_string(model_name, "", "The name of the model.");
19 C10_DEFINE_string(model_version, "", "The version of the model.");
20 C10_DEFINE_string(
21 input_dims,
22 "",
23 "The dimensions of input TensorCPUs using comma separated numbers."
24 "If multiple inputs needed, use semicolon to separate "
25 "the dimension of different tensors.");
26 C10_DEFINE_string(
27 input_types,
28 "float",
29 "The dtype of input TensorCPUs."
30 "If multiple inputs needed, use semicolon to separate "
31 "the dtype of different tensors."
32 "Supported dtypes: float, int64, uint8");
33 C10_DEFINE_string(
34 input_memory_formats,
35 "",
36 "Input memory format."
37 "If multiple inputs needed, use semicolon to separate."
38 "Supported values: contiguous, channels_last");
39 C10_DEFINE_string(
40 dynamic_dims,
41 "",
42 "Comma separated dimensions of input tensors that can be dynamic");
43 C10_DEFINE_string(method_name, "forward", "The name of the method.");
44 C10_DEFINE_string(
45 output_llvm,
46 "",
47 "Name of the output llvm assembly to be saved.");
48 C10_DEFINE_string(output_model, "", "Name of the output model to be saved.");
49
50 namespace {
51
split(char separator,const std::string & string,bool ignore_empty=true)52 std::vector<std::string> split(
53 char separator,
54 const std::string& string,
55 bool ignore_empty = true) {
56 std::vector<std::string> pieces;
57 std::stringstream ss(string);
58 std::string item;
59 while (getline(ss, item, separator)) {
60 if (!ignore_empty || !item.empty()) {
61 pieces.push_back(std::move(item));
62 }
63 }
64 return pieces;
65 }
66
createCompileSpec()67 c10::Dict<c10::IValue, c10::IValue> createCompileSpec() {
68 c10::Dict<c10::IValue, c10::IValue> compile_spec(
69 c10::StringType::get(), c10::AnyType::get());
70 c10::Dict<c10::IValue, c10::IValue> method_spec(
71 c10::StringType::get(), c10::AnyType::get());
72 method_spec.insert("sizes", FLAGS_input_dims);
73 method_spec.insert("types", FLAGS_input_types);
74 method_spec.insert("memory_formats", FLAGS_input_memory_formats);
75 method_spec.insert("dynamic_sizes", FLAGS_dynamic_dims);
76 method_spec.insert("asmfile", FLAGS_output_llvm);
77 method_spec.insert("model_name", FLAGS_model_name);
78 method_spec.insert("model_version", FLAGS_model_version);
79 compile_spec.insert(FLAGS_method_name, method_spec);
80 return compile_spec;
81 }
82
83 } // namespace
84
main(int argc,char ** argv)85 int main(int argc, char** argv) {
86 c10::SetUsageMessage(
87 "Run NNC AOT compiler for pytorch model. Example usage:\n"
88 "build/bin/aot_model_compiler"
89 " --model=<model file>"
90 " --model_name=<model name>"
91 " --model_version=<model version>"
92 " --input_dims=<input dimensions like '1,3,224,224;2,2'>"
93 " --input_types=<input dtypes like 'float;float'>"
94 " --input_memory_formats=<input memory formats like 'channels_last;contiguous'>"
95 " [--method_name=<method name>]"
96 " [--output_llvm=<llvm assembly output file path>]"
97 " [--output_model=<output model file path>]");
98
99 if (!c10::ParseCommandLineFlags(&argc, &argv)) {
100 std::cerr << "Failed to parse command line flags!" << std::endl;
101 std::cout << c10::UsageMessage() << std::endl;
102 return 1;
103 }
104
105 CAFFE_ENFORCE(!FLAGS_model.empty(), c10::UsageMessage());
106 CAFFE_ENFORCE(!FLAGS_model_name.empty(), c10::UsageMessage());
107 CAFFE_ENFORCE(!FLAGS_model_version.empty(), c10::UsageMessage());
108 CAFFE_ENFORCE(!FLAGS_input_dims.empty(), c10::UsageMessage());
109 const auto dims_size = split(';', FLAGS_input_dims).size();
110 CAFFE_ENFORCE(
111 dims_size == split(';', FLAGS_input_types).size(),
112 "Number of input_dims and input_types should be the same");
113 const auto mem_formats_size = split(';', FLAGS_input_memory_formats).size();
114 CAFFE_ENFORCE(
115 mem_formats_size == 0 || mem_formats_size == dims_size,
116 "Number of input_memory_formats should be 0 (default contiguous) or the same as number of input_dims");
117 if (FLAGS_output_llvm.empty()) {
118 FLAGS_output_llvm =
119 FLAGS_model.substr(0, FLAGS_model.find('.')) + ".compiled.ll";
120 }
121
122 std::string output_model_name = FLAGS_output_model;
123 if (output_model_name.empty()) {
124 output_model_name =
125 FLAGS_model.substr(0, FLAGS_model.find('.')) + ".compiled.pt";
126 }
127
128 auto m = torch::jit::load(FLAGS_model);
129 m.eval();
130 auto frozen_m = torch::jit::freeze_module(m.clone());
131
132 auto compile_spec = createCompileSpec();
133 auto any_dict_ty =
134 c10::DictType::create(c10::StringType::get(), c10::AnyType::get());
135 auto compiled_module = torch::jit::detail::codegen_backend_module(
136 "nnc", frozen_m, compile_spec, any_dict_ty);
137 compiled_module._save_for_mobile(output_model_name);
138 std::cout << "The compiled model was saved to " << output_model_name
139 << std::endl;
140 return 0;
141 }
142