xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/nnc/context.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/mobile/nnc/context.h>
2 
3 #include <ATen/Functions.h>
4 #include <ATen/core/functional.h>
5 #include <c10/core/CPUAllocator.h>
6 #include <c10/util/irange.h>
7 
8 #include <torch/csrc/jit/mobile/nnc/registry.h>
9 
10 namespace torch {
11 namespace jit {
12 namespace mobile {
13 namespace nnc {
14 
15 constexpr int64_t kProducedNNCFileFormatVersion = 0x1L;
16 
17 namespace {
18 
Tup(std::initializer_list<c10::IValue> ivalues)19 c10::IValue Tup(std::initializer_list<c10::IValue> ivalues) {
20   return c10::ivalue::Tuple::create(ivalues);
21 }
22 
Tup(std::vector<c10::IValue> && ivalues)23 c10::IValue Tup(std::vector<c10::IValue>&& ivalues) {
24   return c10::ivalue::Tuple::create(ivalues);
25 }
26 
27 } // namespace
28 
InputSpec(const c10::IValue & value)29 InputSpec::InputSpec(const c10::IValue& value) {
30   auto dict = value.toGenericDict();
31   sizes_ = dict.at("sizes").toIntVector();
32   dtype_ = dict.at("dtype").toScalarType();
33 }
34 
serialize() const35 c10::IValue InputSpec::serialize() const {
36   c10::Dict<c10::IValue, c10::IValue> dict(
37       at::StringType::get(), at::AnyType::get());
38   dict.insert("sizes", sizes_);
39   dict.insert("dtype", dtype_);
40   return dict;
41 }
42 
validate(const at::Tensor & input) const43 bool InputSpec::validate(const at::Tensor& input) const {
44   if (sizes_.size() != input.sizes().size() || input.scalar_type() != dtype_) {
45     return false;
46   }
47   auto spec_sizes = sizes_;
48   for (const auto i : c10::irange(spec_sizes.size())) {
49     // InputSpec size 0 means that the dimension is dynamic
50     if (spec_sizes[i] != 0 && spec_sizes[i] != input.sizes()[i]) {
51       return false;
52     }
53   }
54   return true;
55 }
56 
OutputSpec(const c10::IValue & value)57 OutputSpec::OutputSpec(const c10::IValue& value) {
58   auto dict = value.toGenericDict();
59   sizes_ = dict.at("sizes").toIntVector();
60   dtype_ = dict.at("dtype").toScalarType();
61   if (dict.contains("qscale")) {
62     qscale_ = dict.at("qscale").toDouble();
63   }
64   if (dict.contains("qzero")) {
65     qzero_ = dict.at("qzero").toInt();
66   }
67 }
68 
serialize() const69 c10::IValue OutputSpec::serialize() const {
70   c10::Dict<c10::IValue, c10::IValue> dict(
71       at::StringType::get(), at::AnyType::get());
72   dict.insert("sizes", sizes_);
73   dict.insert("dtype", dtype_);
74   if (qscale_) {
75     dict.insert("qscale", *qscale_);
76   }
77   if (qzero_) {
78     dict.insert("qzero", *qzero_);
79   }
80   return dict;
81 }
82 
allocate() const83 at::Tensor OutputSpec::allocate() const {
84   if (isQIntType(dtype_)) {
85     TORCH_CHECK(
86         qscale_ && qzero_,
87         "Quantized output tensor must have qscale_ and qzero_");
88     return at::_empty_affine_quantized(
89         sizes_,
90         at::TensorOptions()
91             .dtype(dtype_)
92             .layout(at::kStrided)
93             .device(at::kCPU)
94             .requires_grad(false),
95         *qscale_,
96         *qzero_);
97   }
98   return at::empty(
99       sizes_,
100       at::TensorOptions()
101           .dtype(dtype_)
102           .layout(at::kStrided)
103           .device(at::kCPU)
104           .requires_grad(false));
105 }
106 
MemoryPlan(const c10::IValue & value)107 MemoryPlan::MemoryPlan(const c10::IValue& value) {
108   auto dict = value.toGenericDict();
109   buffer_sizes_ = dict.at("buffer_sizes").toIntVector();
110 }
111 
serialize() const112 c10::IValue MemoryPlan::serialize() const {
113   c10::Dict<c10::IValue, c10::IValue> dict(
114       at::StringType::get(), at::AnyType::get());
115   dict.insert("buffer_sizes", buffer_sizes_);
116   return dict;
117 }
118 
allocate(ExecutionState * state) const119 void MemoryPlan::allocate(ExecutionState* state) const {
120   auto& allocations = state->preallocations_;
121   allocations.clear();
122   allocations.reserve(buffer_sizes_.size());
123   for (int64_t buffer_size : buffer_sizes_) {
124     at::DataPtr buffer = c10::GetCPUAllocator()->allocate(buffer_size);
125     allocations.emplace_back(std::move(buffer));
126   }
127 }
128 
Function(const c10::IValue & value)129 Function::Function(const c10::IValue& value) {
130   auto dict = value.toGenericDict();
131   name_ = c10::QualifiedName(dict.at("name").toStringRef());
132   nnc_kernel_id_ = dict.at("nnc_kernel_id").toStringRef();
133   parameters_ = dict.at("parameters").toList();
134 
135   // input_specs_
136   for (const auto& input_value :
137        dict.at("input_specs").toTupleRef().elements()) {
138     input_specs_.emplace_back(input_value);
139   }
140 
141   // output_specs_
142   for (const auto& output_value :
143        dict.at("output_specs").toTupleRef().elements()) {
144     output_specs_.emplace_back(output_value);
145   }
146 
147   // memory_plan_
148   memory_plan_ = MemoryPlan(dict.at("memory_plan"));
149 
150   // symbolic shape positions
151   for (const auto& sym_shape_pos :
152        dict.at("sym_shape_pos").toTupleRef().elements()) {
153     auto sym_shape_elements = sym_shape_pos.toTupleRef().elements();
154     sym_shape_positions_.emplace_back(
155         sym_shape_elements[0].toInt(), sym_shape_elements[1].toInt());
156   }
157 }
158 
serialize() const159 c10::IValue Function::serialize() const {
160   c10::Dict<c10::IValue, c10::IValue> dict(
161       at::StringType::get(), at::AnyType::get());
162 
163   dict.insert("name", name_.qualifiedName());
164   dict.insert("nnc_kernel_id", nnc_kernel_id_);
165   // TODO: should serialize parameters with Module instead of with each Method.
166   // And ideally the parameters should be shared between the compiled model
167   // and the original model if we can serialize both in the same model file.
168   dict.insert("parameters", parameters_);
169 
170   // input_specs_
171   std::vector<c10::IValue> input_specs;
172   input_specs.reserve(input_specs_.size());
173   for (const auto& input_spec : input_specs_) {
174     input_specs.emplace_back(input_spec.serialize());
175   }
176   dict.insert("input_specs", Tup(std::move(input_specs)));
177 
178   // output_specs_
179   std::vector<c10::IValue> output_specs;
180   output_specs.reserve(output_specs_.size());
181   for (const auto& output_spec : output_specs_) {
182     output_specs.emplace_back(output_spec.serialize());
183   }
184   dict.insert("output_specs", Tup(std::move(output_specs)));
185 
186   // memory_plan_
187   dict.insert("memory_plan", memory_plan_.serialize());
188 
189   // sym_shape_positions_
190   std::vector<c10::IValue> sym_shape_pos_vec;
191   sym_shape_pos_vec.reserve(sym_shape_positions_.size());
192   for (const auto& sym_shape_pos : sym_shape_positions_) {
193     sym_shape_pos_vec.emplace_back(
194         Tup({sym_shape_pos.input_idx_, sym_shape_pos.dim_idx_}));
195   }
196   dict.insert("sym_shape_pos", Tup(std::move(sym_shape_pos_vec)));
197 
198   return dict;
199 }
200 
init_execution_state() const201 void Function::init_execution_state() const {
202   if (execution_state_.get() != nullptr) {
203     return;
204   }
205 
206   ExecutionState state;
207   memory_plan_.allocate(&state);
208 
209   // The arguments vector consists of 5 sections: inputs, symbolic shapes,
210   // outputs, parameters and buffers.
211   auto input_args = input_specs_.size();
212   auto sym_shape_args = sym_shape_positions_.size();
213   auto output_args = output_specs_.size();
214   auto param_args = parameters_.size();
215   auto buffer_args = state.preallocations_.size();
216 
217   auto& arguments = state.arguments_;
218   arguments.reserve(
219       input_args + sym_shape_args + output_args + param_args + buffer_args);
220 
221   // Keep empty slots to fill in inputs/outputs pointers at execution time.
222   arguments.resize(input_args + sym_shape_args + output_args);
223 
224   // Fill in parameters as untyped raw pointers.
225   // The underlying storage of the parameters should be owned by `parameters_`,
226   // which should be alive when `execution_state_` is being used.
227   for (const auto& param : parameters_) {
228     const c10::IValue& ivalue = (c10::IValue)param;
229     if (ivalue.isTensor()) {
230       arguments.emplace_back(ivalue.toTensor().data_ptr());
231     } else if (torch::isCustomClass(ivalue)) {
232       arguments.emplace_back(ivalue.toObjectRef().getSlot(0).toCapsule().get());
233     } else {
234       TORCH_CHECK(false, "Invalid parameter: ", ivalue);
235     }
236   }
237 
238   // Fill in preallocated buffer pointers.
239   for (const auto& preallocation : state.preallocations_) {
240     arguments.emplace_back(preallocation.get());
241   }
242 
243   execution_state_ = std::make_unique<ExecutionState>(std::move(state));
244 }
245 
run(const c10::impl::GenericList & inputs) const246 c10::impl::GenericList Function::run(
247     const c10::impl::GenericList& inputs) const {
248   TORCH_CHECK(
249       registry::has_nnc_kernel(nnc_kernel_id_),
250       "Cannot find NNC kernel: ",
251       nnc_kernel_id_);
252 
253   init_execution_state();
254 
255   std::vector<void*>& args = execution_state_->arguments_;
256 
257   // Fill in input tensors.
258   TORCH_CHECK(
259       input_specs_.size() == inputs.size(),
260       "Input size doesn't match the spec, expect: ",
261       input_specs_.size(),
262       " actual: ",
263       inputs.size());
264   std::vector<int64_t> scalar_values;
265   int offset = 0;
266   for (const auto i : c10::irange(inputs.size())) {
267     const c10::IValue& input = inputs[i];
268     const auto& spec = input_specs_[i];
269     const auto& input_tensor = input.toTensor();
270     TORCH_CHECK(spec.validate(input_tensor), "Invalid input at pos: ", i);
271     args[i] = input_tensor.data_ptr();
272   }
273   offset += inputs.size();
274 
275   scalar_values.reserve(sym_shape_positions_.size());
276   for (const auto i : c10::irange(sym_shape_positions_.size())) {
277     const auto& sym_shape_pos = sym_shape_positions_[i];
278     const c10::IValue& input = inputs[sym_shape_pos.input_idx_];
279     auto dim = input.toTensor().size(sym_shape_pos.dim_idx_);
280     scalar_values.push_back(dim);
281     args[i + offset] = &scalar_values[scalar_values.size() - 1];
282   }
283   offset += sym_shape_positions_.size();
284 
285   // Preallocate and fill in output tensors.
286   c10::List<at::Tensor> outputs;
287   outputs.reserve(output_specs_.size());
288   for (const auto i : c10::irange(output_specs_.size())) {
289     at::Tensor output = output_specs_[i].allocate();
290     outputs.emplace_back(output);
291     args[i + offset] = output.data_ptr();
292   }
293 
294   // TODO: check consistency, e.g.: code version, input shape and compiled
295   // shape, etc.
296   auto kernel = registry::get_nnc_kernel(nnc_kernel_id_);
297   kernel->execute(args.data());
298 
299   return c10::impl::toList(outputs);
300 }
301 
CompilationUnit(const c10::IValue & value)302 CompilationUnit::CompilationUnit(const c10::IValue& value) {
303   const auto& root = value.toTupleRef().elements();
304   const auto& functions = root[1].toTupleRef().elements();
305   for (const auto& function : functions) {
306     register_function(std::make_unique<Function>(function));
307   }
308 }
309 
serialize() const310 c10::IValue CompilationUnit::serialize() const {
311   auto functions =
312       c10::fmap(functions_, [](decltype(functions_)::const_reference func) {
313         return func.second->serialize();
314       });
315   return Tup({kProducedNNCFileFormatVersion, Tup(std::move(functions))});
316 }
317 
run(const c10::QualifiedName & name,const c10::impl::GenericList & inputs) const318 c10::impl::GenericList CompilationUnit::run(
319     const c10::QualifiedName& name,
320     const c10::impl::GenericList& inputs) const {
321   Function* func = find_function(name);
322   TORCH_CHECK(
323       func != nullptr, "Function '", name.qualifiedName(), "' is not defined.");
324   return func->run(inputs);
325 }
326 
register_function(std::unique_ptr<Function> fn)327 void CompilationUnit::register_function(std::unique_ptr<Function> fn) {
328   TORCH_CHECK(
329       0 == functions_.count(fn->name()),
330       "method '",
331       fn->name().qualifiedName(),
332       "' already defined.");
333   const auto& name = fn->name();
334   functions_.emplace(name, std::move(fn));
335 }
336 
find_function(const c10::QualifiedName & name) const337 Function* CompilationUnit::find_function(const c10::QualifiedName& name) const {
338   auto it = functions_.find(name);
339   if (it == functions_.end()) {
340     return nullptr;
341   }
342   return it->second.get();
343 }
344 
345 } // namespace nnc
346 } // namespace mobile
347 } // namespace jit
348 } // namespace torch
349