xref: /aosp_15_r20/external/pytorch/binaries/aot_model_compiler.cc (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <sstream>
2*da0073e9SAndroid Build Coastguard Worker #include <string>
3*da0073e9SAndroid Build Coastguard Worker 
4*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/jit_type.h>
5*da0073e9SAndroid Build Coastguard Worker #include <c10/core/ScalarType.h>
6*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/jit/backends/backend.h>
7*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/jit/backends/backend_detail.h>
8*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/jit/backends/backend_preprocess.h>
9*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/jit/mobile/nnc/aot_compiler.h>
10*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/jit/passes/freeze_module.h>
11*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/jit/serialization/export.h>
12*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/jit/serialization/import.h>
13*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/jit/tensorexpr/graph_opt.h>
14*da0073e9SAndroid Build Coastguard Worker #include <torch/csrc/jit/tensorexpr/kernel.h>
15*da0073e9SAndroid Build Coastguard Worker #include <torch/script.h>
16*da0073e9SAndroid Build Coastguard Worker 
17*da0073e9SAndroid Build Coastguard Worker C10_DEFINE_string(model, "", "The torch script model to optimize.");
18*da0073e9SAndroid Build Coastguard Worker C10_DEFINE_string(model_name, "", "The name of the model.");
19*da0073e9SAndroid Build Coastguard Worker C10_DEFINE_string(model_version, "", "The version of the model.");
20*da0073e9SAndroid Build Coastguard Worker C10_DEFINE_string(
21*da0073e9SAndroid Build Coastguard Worker     input_dims,
22*da0073e9SAndroid Build Coastguard Worker     "",
23*da0073e9SAndroid Build Coastguard Worker     "The dimensions of input TensorCPUs using comma separated numbers."
24*da0073e9SAndroid Build Coastguard Worker     "If multiple inputs needed, use semicolon to separate "
25*da0073e9SAndroid Build Coastguard Worker     "the dimension of different tensors.");
26*da0073e9SAndroid Build Coastguard Worker C10_DEFINE_string(
27*da0073e9SAndroid Build Coastguard Worker     input_types,
28*da0073e9SAndroid Build Coastguard Worker     "float",
29*da0073e9SAndroid Build Coastguard Worker     "The dtype of input TensorCPUs."
30*da0073e9SAndroid Build Coastguard Worker     "If multiple inputs needed, use semicolon to separate "
31*da0073e9SAndroid Build Coastguard Worker     "the dtype of different tensors."
32*da0073e9SAndroid Build Coastguard Worker     "Supported dtypes: float, int64, uint8");
33*da0073e9SAndroid Build Coastguard Worker C10_DEFINE_string(
34*da0073e9SAndroid Build Coastguard Worker     input_memory_formats,
35*da0073e9SAndroid Build Coastguard Worker     "",
36*da0073e9SAndroid Build Coastguard Worker     "Input memory format."
37*da0073e9SAndroid Build Coastguard Worker     "If multiple inputs needed, use semicolon to separate."
38*da0073e9SAndroid Build Coastguard Worker     "Supported values: contiguous, channels_last");
39*da0073e9SAndroid Build Coastguard Worker C10_DEFINE_string(
40*da0073e9SAndroid Build Coastguard Worker     dynamic_dims,
41*da0073e9SAndroid Build Coastguard Worker     "",
42*da0073e9SAndroid Build Coastguard Worker     "Comma separated dimensions of input tensors that can be dynamic");
43*da0073e9SAndroid Build Coastguard Worker C10_DEFINE_string(method_name, "forward", "The name of the method.");
44*da0073e9SAndroid Build Coastguard Worker C10_DEFINE_string(
45*da0073e9SAndroid Build Coastguard Worker     output_llvm,
46*da0073e9SAndroid Build Coastguard Worker     "",
47*da0073e9SAndroid Build Coastguard Worker     "Name of the output llvm assembly to be saved.");
48*da0073e9SAndroid Build Coastguard Worker C10_DEFINE_string(output_model, "", "Name of the output model to be saved.");
49*da0073e9SAndroid Build Coastguard Worker 
50*da0073e9SAndroid Build Coastguard Worker namespace {
51*da0073e9SAndroid Build Coastguard Worker 
split(char separator,const std::string & string,bool ignore_empty=true)52*da0073e9SAndroid Build Coastguard Worker std::vector<std::string> split(
53*da0073e9SAndroid Build Coastguard Worker     char separator,
54*da0073e9SAndroid Build Coastguard Worker     const std::string& string,
55*da0073e9SAndroid Build Coastguard Worker     bool ignore_empty = true) {
56*da0073e9SAndroid Build Coastguard Worker   std::vector<std::string> pieces;
57*da0073e9SAndroid Build Coastguard Worker   std::stringstream ss(string);
58*da0073e9SAndroid Build Coastguard Worker   std::string item;
59*da0073e9SAndroid Build Coastguard Worker   while (getline(ss, item, separator)) {
60*da0073e9SAndroid Build Coastguard Worker     if (!ignore_empty || !item.empty()) {
61*da0073e9SAndroid Build Coastguard Worker       pieces.push_back(std::move(item));
62*da0073e9SAndroid Build Coastguard Worker     }
63*da0073e9SAndroid Build Coastguard Worker   }
64*da0073e9SAndroid Build Coastguard Worker   return pieces;
65*da0073e9SAndroid Build Coastguard Worker }
66*da0073e9SAndroid Build Coastguard Worker 
createCompileSpec()67*da0073e9SAndroid Build Coastguard Worker c10::Dict<c10::IValue, c10::IValue> createCompileSpec() {
68*da0073e9SAndroid Build Coastguard Worker   c10::Dict<c10::IValue, c10::IValue> compile_spec(
69*da0073e9SAndroid Build Coastguard Worker       c10::StringType::get(), c10::AnyType::get());
70*da0073e9SAndroid Build Coastguard Worker   c10::Dict<c10::IValue, c10::IValue> method_spec(
71*da0073e9SAndroid Build Coastguard Worker       c10::StringType::get(), c10::AnyType::get());
72*da0073e9SAndroid Build Coastguard Worker   method_spec.insert("sizes", FLAGS_input_dims);
73*da0073e9SAndroid Build Coastguard Worker   method_spec.insert("types", FLAGS_input_types);
74*da0073e9SAndroid Build Coastguard Worker   method_spec.insert("memory_formats", FLAGS_input_memory_formats);
75*da0073e9SAndroid Build Coastguard Worker   method_spec.insert("dynamic_sizes", FLAGS_dynamic_dims);
76*da0073e9SAndroid Build Coastguard Worker   method_spec.insert("asmfile", FLAGS_output_llvm);
77*da0073e9SAndroid Build Coastguard Worker   method_spec.insert("model_name", FLAGS_model_name);
78*da0073e9SAndroid Build Coastguard Worker   method_spec.insert("model_version", FLAGS_model_version);
79*da0073e9SAndroid Build Coastguard Worker   compile_spec.insert(FLAGS_method_name, method_spec);
80*da0073e9SAndroid Build Coastguard Worker   return compile_spec;
81*da0073e9SAndroid Build Coastguard Worker }
82*da0073e9SAndroid Build Coastguard Worker 
83*da0073e9SAndroid Build Coastguard Worker } // namespace
84*da0073e9SAndroid Build Coastguard Worker 
main(int argc,char ** argv)85*da0073e9SAndroid Build Coastguard Worker int main(int argc, char** argv) {
86*da0073e9SAndroid Build Coastguard Worker   c10::SetUsageMessage(
87*da0073e9SAndroid Build Coastguard Worker       "Run NNC AOT compiler for pytorch model. Example usage:\n"
88*da0073e9SAndroid Build Coastguard Worker       "build/bin/aot_model_compiler"
89*da0073e9SAndroid Build Coastguard Worker       " --model=<model file>"
90*da0073e9SAndroid Build Coastguard Worker       " --model_name=<model name>"
91*da0073e9SAndroid Build Coastguard Worker       " --model_version=<model version>"
92*da0073e9SAndroid Build Coastguard Worker       " --input_dims=<input dimensions like '1,3,224,224;2,2'>"
93*da0073e9SAndroid Build Coastguard Worker       " --input_types=<input dtypes like 'float;float'>"
94*da0073e9SAndroid Build Coastguard Worker       " --input_memory_formats=<input memory formats like 'channels_last;contiguous'>"
95*da0073e9SAndroid Build Coastguard Worker       " [--method_name=<method name>]"
96*da0073e9SAndroid Build Coastguard Worker       " [--output_llvm=<llvm assembly output file path>]"
97*da0073e9SAndroid Build Coastguard Worker       " [--output_model=<output model file path>]");
98*da0073e9SAndroid Build Coastguard Worker 
99*da0073e9SAndroid Build Coastguard Worker   if (!c10::ParseCommandLineFlags(&argc, &argv)) {
100*da0073e9SAndroid Build Coastguard Worker     std::cerr << "Failed to parse command line flags!" << std::endl;
101*da0073e9SAndroid Build Coastguard Worker     std::cout << c10::UsageMessage() << std::endl;
102*da0073e9SAndroid Build Coastguard Worker     return 1;
103*da0073e9SAndroid Build Coastguard Worker   }
104*da0073e9SAndroid Build Coastguard Worker 
105*da0073e9SAndroid Build Coastguard Worker   CAFFE_ENFORCE(!FLAGS_model.empty(), c10::UsageMessage());
106*da0073e9SAndroid Build Coastguard Worker   CAFFE_ENFORCE(!FLAGS_model_name.empty(), c10::UsageMessage());
107*da0073e9SAndroid Build Coastguard Worker   CAFFE_ENFORCE(!FLAGS_model_version.empty(), c10::UsageMessage());
108*da0073e9SAndroid Build Coastguard Worker   CAFFE_ENFORCE(!FLAGS_input_dims.empty(), c10::UsageMessage());
109*da0073e9SAndroid Build Coastguard Worker   const auto dims_size = split(';', FLAGS_input_dims).size();
110*da0073e9SAndroid Build Coastguard Worker   CAFFE_ENFORCE(
111*da0073e9SAndroid Build Coastguard Worker       dims_size == split(';', FLAGS_input_types).size(),
112*da0073e9SAndroid Build Coastguard Worker       "Number of input_dims and input_types should be the same");
113*da0073e9SAndroid Build Coastguard Worker   const auto mem_formats_size = split(';', FLAGS_input_memory_formats).size();
114*da0073e9SAndroid Build Coastguard Worker   CAFFE_ENFORCE(
115*da0073e9SAndroid Build Coastguard Worker       mem_formats_size == 0 || mem_formats_size == dims_size,
116*da0073e9SAndroid Build Coastguard Worker       "Number of input_memory_formats should be 0 (default contiguous) or the same as number of input_dims");
117*da0073e9SAndroid Build Coastguard Worker   if (FLAGS_output_llvm.empty()) {
118*da0073e9SAndroid Build Coastguard Worker     FLAGS_output_llvm =
119*da0073e9SAndroid Build Coastguard Worker         FLAGS_model.substr(0, FLAGS_model.find('.')) + ".compiled.ll";
120*da0073e9SAndroid Build Coastguard Worker   }
121*da0073e9SAndroid Build Coastguard Worker 
122*da0073e9SAndroid Build Coastguard Worker   std::string output_model_name = FLAGS_output_model;
123*da0073e9SAndroid Build Coastguard Worker   if (output_model_name.empty()) {
124*da0073e9SAndroid Build Coastguard Worker     output_model_name =
125*da0073e9SAndroid Build Coastguard Worker         FLAGS_model.substr(0, FLAGS_model.find('.')) + ".compiled.pt";
126*da0073e9SAndroid Build Coastguard Worker   }
127*da0073e9SAndroid Build Coastguard Worker 
128*da0073e9SAndroid Build Coastguard Worker   auto m = torch::jit::load(FLAGS_model);
129*da0073e9SAndroid Build Coastguard Worker   m.eval();
130*da0073e9SAndroid Build Coastguard Worker   auto frozen_m = torch::jit::freeze_module(m.clone());
131*da0073e9SAndroid Build Coastguard Worker 
132*da0073e9SAndroid Build Coastguard Worker   auto compile_spec = createCompileSpec();
133*da0073e9SAndroid Build Coastguard Worker   auto any_dict_ty =
134*da0073e9SAndroid Build Coastguard Worker       c10::DictType::create(c10::StringType::get(), c10::AnyType::get());
135*da0073e9SAndroid Build Coastguard Worker   auto compiled_module = torch::jit::detail::codegen_backend_module(
136*da0073e9SAndroid Build Coastguard Worker       "nnc", frozen_m, compile_spec, any_dict_ty);
137*da0073e9SAndroid Build Coastguard Worker   compiled_module._save_for_mobile(output_model_name);
138*da0073e9SAndroid Build Coastguard Worker   std::cout << "The compiled model was saved to " << output_model_name
139*da0073e9SAndroid Build Coastguard Worker             << std::endl;
140*da0073e9SAndroid Build Coastguard Worker   return 0;
141*da0073e9SAndroid Build Coastguard Worker }
142