xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/backends/nnapi/nnapi_backend_preprocess.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <pybind11/pybind11.h>
2 #include <torch/csrc/jit/backends/backend.h>
3 #include <torch/csrc/jit/backends/backend_preprocess.h>
4 #include <torch/csrc/jit/python/pybind_utils.h>
5 #include <torch/csrc/utils/pybind.h>
6 
7 namespace py = pybind11;
8 
9 // Converts model to Android NNAPI backend and serializes it for mobile
10 // Returns a dictionary with preprocessed items:
11 //    "shape_compute_module": torch::jit::Module,
12 //    "ser_model": at::Tensor,
13 //    "weights": List[torch.Tensor],
14 //    "inp_mem_fmts": List[int],
15 //    "out_mem_fmts": List[int]
16 //
17 // method_compile_spec should contain a Tensor or
18 // Tensor List which bundles several input parameters:
19 // shape, dtype, quantization, and dimorder (NHWC/NCHW)
20 // For input shapes, use 0 for run/load time flexible input
21 //
22 // The compile_spec should include the format:
23 // {"forward": {"inputs": at::Tensor}}
24 // OR {"forward": {"inputs": c10::List<at::Tensor>}}
25 // Example input Tensor:
26 // torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1)
27 //
28 // In the future, preprocess will accept a dedicated object
preprocess(const torch::jit::Module & mod,const c10::Dict<c10::IValue,c10::IValue> & method_compile_spec,const torch::jit::BackendDebugHandleGenerator & generate_debug_handles)29 c10::IValue preprocess(
30     const torch::jit::Module& mod,
31     const c10::Dict<c10::IValue, c10::IValue>& method_compile_spec,
32     const torch::jit::BackendDebugHandleGenerator& generate_debug_handles) {
33   // Import the python function for processing modules to Android NNAPI backend
34   py::gil_scoped_acquire gil;
35   py::object pyModule = py::module_::import("torch.backends._nnapi.prepare");
36   py::object pyMethod = pyModule.attr("process_for_nnapi");
37 
38   // Wrap the c module in a RecursiveScriptModule
39   auto wrapped_mod =
40       py::module::import("torch.jit._recursive").attr("wrap_cpp_module")(mod);
41   wrapped_mod.attr("eval")();
42 
43   // Test that method_compile_spec contains the necessary keys and
44   // Tensor/TensorList input
45   c10::IValue inp;
46   std::string error = "";
47   if (!method_compile_spec.contains("forward")) {
48     error = R"(method_compile_spec does not contain the "forward" key.)";
49   } else {
50     auto innerDict = method_compile_spec.at("forward");
51     if (!innerDict.isGenericDict() ||
52         !innerDict.toGenericDict().contains("inputs")) {
53       error =
54           R"(method_compile_spec does not contain a dictionary with an "inputs" key, under it's "forward" key.)";
55     } else {
56       inp = innerDict.toGenericDict().at("inputs");
57       if (!inp.isTensor() && !inp.isTensorList()) {
58         error =
59             R"(method_compile_spec does not contain either a Tensor or TensorList, under it's "inputs" key.)";
60       }
61     }
62   }
63   if (!error.empty()) {
64     throw std::runtime_error(
65         error +
66         "\nmethod_compile_spec should contain a Tensor or Tensor List which bundles input parameters:"
67         " shape, dtype, quantization, and dimorder."
68         "\nFor input shapes, use 0 for run/load time flexible input."
69         "\nmethod_compile_spec must use the following format:"
70         "\n{\"forward\": {\"inputs\": at::Tensor}} OR {\"forward\": {\"inputs\": c10::List<at::Tensor>}}");
71   }
72 
73   // Convert input to a Tensor or a python list of Tensors
74   py::list nnapi_processed;
75   if (inp.isTensor()) {
76     nnapi_processed = pyMethod(wrapped_mod, inp.toTensor());
77   } else {
78     py::list pyInp;
79     for (at::Tensor inpElem : inp.toTensorList()) {
80       pyInp.append(inpElem);
81     }
82     nnapi_processed = pyMethod(wrapped_mod, pyInp);
83   }
84 
85   // Cast and insert processed items into dict
86   c10::Dict<c10::IValue, c10::IValue> dict(
87       c10::StringType::get(), c10::AnyType::get());
88   dict.insert("ser_model", py::cast<at::Tensor>(nnapi_processed[1]));
89 
90   // Serialize shape_compute_module for mobile
91   auto shape_compute_module =
92       py::cast<torch::jit::Module>(nnapi_processed[0].attr("_c"));
93   std::stringstream ss;
94   shape_compute_module._save_for_mobile(ss);
95   dict.insert("shape_compute_module", ss.str());
96 
97   // transform Python lists to C++ c10::List
98   c10::List<at::Tensor> weights(
99       py::cast<std::vector<at::Tensor>>(nnapi_processed[2]));
100   for (auto i = 0U; i < weights.size(); i++) {
101     weights.set(i, weights.get(i).contiguous());
102   }
103   c10::List<int64_t> inp_mem_fmts(
104       py::cast<std::vector<int64_t>>(nnapi_processed[3]));
105   c10::List<int64_t> out_mem_fmts(
106       py::cast<std::vector<int64_t>>(nnapi_processed[4]));
107   dict.insert("weights", weights);
108   dict.insert("inp_mem_fmts", inp_mem_fmts);
109   dict.insert("out_mem_fmts", out_mem_fmts);
110 
111   return dict;
112 }
113 
114 constexpr auto backend_name = "nnapi";
115 static auto pre_reg =
116     torch::jit::backend_preprocess_register(backend_name, preprocess);
117