#include #include #include #include #include #include #include #include #include #include #include #include #include #include C10_DEFINE_string(model, "", "The torch script model to optimize."); C10_DEFINE_string(model_name, "", "The name of the model."); C10_DEFINE_string(model_version, "", "The version of the model."); C10_DEFINE_string( input_dims, "", "The dimensions of input TensorCPUs using comma separated numbers." "If multiple inputs needed, use semicolon to separate " "the dimension of different tensors."); C10_DEFINE_string( input_types, "float", "The dtype of input TensorCPUs." "If multiple inputs needed, use semicolon to separate " "the dtype of different tensors." "Supported dtypes: float, int64, uint8"); C10_DEFINE_string( input_memory_formats, "", "Input memory format." "If multiple inputs needed, use semicolon to separate." "Supported values: contiguous, channels_last"); C10_DEFINE_string( dynamic_dims, "", "Comma separated dimensions of input tensors that can be dynamic"); C10_DEFINE_string(method_name, "forward", "The name of the method."); C10_DEFINE_string( output_llvm, "", "Name of the output llvm assembly to be saved."); C10_DEFINE_string(output_model, "", "Name of the output model to be saved."); namespace { std::vector split( char separator, const std::string& string, bool ignore_empty = true) { std::vector pieces; std::stringstream ss(string); std::string item; while (getline(ss, item, separator)) { if (!ignore_empty || !item.empty()) { pieces.push_back(std::move(item)); } } return pieces; } c10::Dict createCompileSpec() { c10::Dict compile_spec( c10::StringType::get(), c10::AnyType::get()); c10::Dict method_spec( c10::StringType::get(), c10::AnyType::get()); method_spec.insert("sizes", FLAGS_input_dims); method_spec.insert("types", FLAGS_input_types); method_spec.insert("memory_formats", FLAGS_input_memory_formats); method_spec.insert("dynamic_sizes", FLAGS_dynamic_dims); method_spec.insert("asmfile", FLAGS_output_llvm); method_spec.insert("model_name", FLAGS_model_name); method_spec.insert("model_version", FLAGS_model_version); compile_spec.insert(FLAGS_method_name, method_spec); return compile_spec; } } // namespace int main(int argc, char** argv) { c10::SetUsageMessage( "Run NNC AOT compiler for pytorch model. Example usage:\n" "build/bin/aot_model_compiler" " --model=" " --model_name=" " --model_version=" " --input_dims=" " --input_types=" " --input_memory_formats=" " [--method_name=]" " [--output_llvm=]" " [--output_model=]"); if (!c10::ParseCommandLineFlags(&argc, &argv)) { std::cerr << "Failed to parse command line flags!" << std::endl; std::cout << c10::UsageMessage() << std::endl; return 1; } CAFFE_ENFORCE(!FLAGS_model.empty(), c10::UsageMessage()); CAFFE_ENFORCE(!FLAGS_model_name.empty(), c10::UsageMessage()); CAFFE_ENFORCE(!FLAGS_model_version.empty(), c10::UsageMessage()); CAFFE_ENFORCE(!FLAGS_input_dims.empty(), c10::UsageMessage()); const auto dims_size = split(';', FLAGS_input_dims).size(); CAFFE_ENFORCE( dims_size == split(';', FLAGS_input_types).size(), "Number of input_dims and input_types should be the same"); const auto mem_formats_size = split(';', FLAGS_input_memory_formats).size(); CAFFE_ENFORCE( mem_formats_size == 0 || mem_formats_size == dims_size, "Number of input_memory_formats should be 0 (default contiguous) or the same as number of input_dims"); if (FLAGS_output_llvm.empty()) { FLAGS_output_llvm = FLAGS_model.substr(0, FLAGS_model.find('.')) + ".compiled.ll"; } std::string output_model_name = FLAGS_output_model; if (output_model_name.empty()) { output_model_name = FLAGS_model.substr(0, FLAGS_model.find('.')) + ".compiled.pt"; } auto m = torch::jit::load(FLAGS_model); m.eval(); auto frozen_m = torch::jit::freeze_module(m.clone()); auto compile_spec = createCompileSpec(); auto any_dict_ty = c10::DictType::create(c10::StringType::get(), c10::AnyType::get()); auto compiled_module = torch::jit::detail::codegen_backend_module( "nnc", frozen_m, compile_spec, any_dict_ty); compiled_module._save_for_mobile(output_model_name); std::cout << "The compiled model was saved to " << output_model_name << std::endl; return 0; }