1 #include <pybind11/pybind11.h>
2 #include <torch/csrc/jit/backends/backend.h>
3 #include <torch/csrc/jit/backends/backend_preprocess.h>
4 #include <torch/csrc/jit/python/pybind_utils.h>
5 #include <torch/csrc/utils/pybind.h>
6
7 namespace py = pybind11;
8
9 // Converts model to Android NNAPI backend and serializes it for mobile
10 // Returns a dictionary with preprocessed items:
11 // "shape_compute_module": torch::jit::Module,
12 // "ser_model": at::Tensor,
13 // "weights": List[torch.Tensor],
14 // "inp_mem_fmts": List[int],
15 // "out_mem_fmts": List[int]
16 //
17 // method_compile_spec should contain a Tensor or
18 // Tensor List which bundles several input parameters:
19 // shape, dtype, quantization, and dimorder (NHWC/NCHW)
20 // For input shapes, use 0 for run/load time flexible input
21 //
22 // The compile_spec should include the format:
23 // {"forward": {"inputs": at::Tensor}}
24 // OR {"forward": {"inputs": c10::List<at::Tensor>}}
25 // Example input Tensor:
26 // torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1)
27 //
28 // In the future, preprocess will accept a dedicated object
preprocess(const torch::jit::Module & mod,const c10::Dict<c10::IValue,c10::IValue> & method_compile_spec,const torch::jit::BackendDebugHandleGenerator & generate_debug_handles)29 c10::IValue preprocess(
30 const torch::jit::Module& mod,
31 const c10::Dict<c10::IValue, c10::IValue>& method_compile_spec,
32 const torch::jit::BackendDebugHandleGenerator& generate_debug_handles) {
33 // Import the python function for processing modules to Android NNAPI backend
34 py::gil_scoped_acquire gil;
35 py::object pyModule = py::module_::import("torch.backends._nnapi.prepare");
36 py::object pyMethod = pyModule.attr("process_for_nnapi");
37
38 // Wrap the c module in a RecursiveScriptModule
39 auto wrapped_mod =
40 py::module::import("torch.jit._recursive").attr("wrap_cpp_module")(mod);
41 wrapped_mod.attr("eval")();
42
43 // Test that method_compile_spec contains the necessary keys and
44 // Tensor/TensorList input
45 c10::IValue inp;
46 std::string error = "";
47 if (!method_compile_spec.contains("forward")) {
48 error = R"(method_compile_spec does not contain the "forward" key.)";
49 } else {
50 auto innerDict = method_compile_spec.at("forward");
51 if (!innerDict.isGenericDict() ||
52 !innerDict.toGenericDict().contains("inputs")) {
53 error =
54 R"(method_compile_spec does not contain a dictionary with an "inputs" key, under it's "forward" key.)";
55 } else {
56 inp = innerDict.toGenericDict().at("inputs");
57 if (!inp.isTensor() && !inp.isTensorList()) {
58 error =
59 R"(method_compile_spec does not contain either a Tensor or TensorList, under it's "inputs" key.)";
60 }
61 }
62 }
63 if (!error.empty()) {
64 throw std::runtime_error(
65 error +
66 "\nmethod_compile_spec should contain a Tensor or Tensor List which bundles input parameters:"
67 " shape, dtype, quantization, and dimorder."
68 "\nFor input shapes, use 0 for run/load time flexible input."
69 "\nmethod_compile_spec must use the following format:"
70 "\n{\"forward\": {\"inputs\": at::Tensor}} OR {\"forward\": {\"inputs\": c10::List<at::Tensor>}}");
71 }
72
73 // Convert input to a Tensor or a python list of Tensors
74 py::list nnapi_processed;
75 if (inp.isTensor()) {
76 nnapi_processed = pyMethod(wrapped_mod, inp.toTensor());
77 } else {
78 py::list pyInp;
79 for (at::Tensor inpElem : inp.toTensorList()) {
80 pyInp.append(inpElem);
81 }
82 nnapi_processed = pyMethod(wrapped_mod, pyInp);
83 }
84
85 // Cast and insert processed items into dict
86 c10::Dict<c10::IValue, c10::IValue> dict(
87 c10::StringType::get(), c10::AnyType::get());
88 dict.insert("ser_model", py::cast<at::Tensor>(nnapi_processed[1]));
89
90 // Serialize shape_compute_module for mobile
91 auto shape_compute_module =
92 py::cast<torch::jit::Module>(nnapi_processed[0].attr("_c"));
93 std::stringstream ss;
94 shape_compute_module._save_for_mobile(ss);
95 dict.insert("shape_compute_module", ss.str());
96
97 // transform Python lists to C++ c10::List
98 c10::List<at::Tensor> weights(
99 py::cast<std::vector<at::Tensor>>(nnapi_processed[2]));
100 for (auto i = 0U; i < weights.size(); i++) {
101 weights.set(i, weights.get(i).contiguous());
102 }
103 c10::List<int64_t> inp_mem_fmts(
104 py::cast<std::vector<int64_t>>(nnapi_processed[3]));
105 c10::List<int64_t> out_mem_fmts(
106 py::cast<std::vector<int64_t>>(nnapi_processed[4]));
107 dict.insert("weights", weights);
108 dict.insert("inp_mem_fmts", inp_mem_fmts);
109 dict.insert("out_mem_fmts", out_mem_fmts);
110
111 return dict;
112 }
113
114 constexpr auto backend_name = "nnapi";
115 static auto pre_reg =
116 torch::jit::backend_preprocess_register(backend_name, preprocess);
117