xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/nnc/context.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <memory>
4 #include <string>
5 #include <utility>
6 #include <vector>
7 
8 #include <ATen/core/ivalue.h>
9 #include <c10/core/ScalarType.h>
10 
11 namespace torch {
12 namespace jit {
13 namespace mobile {
14 namespace nnc {
15 
16 // Specify the requirements on an input tensor.
17 // TODO: support input tensor with dynamic shape (PR #54982)
18 struct TORCH_API InputSpec {
19   InputSpec() = default;
20 
21   // Deserialize the spec from an IValue.
22   explicit InputSpec(const c10::IValue& value);
23 
24   // Serialize the spec into an IValue.
25   C10_NODISCARD c10::IValue serialize() const;
26 
27   // Check whether the input tensor adheres to the spec.
28   C10_NODISCARD bool validate(const at::Tensor& input) const;
29 
30   std::vector<int64_t> sizes_;
31   c10::ScalarType dtype_{c10::ScalarType::Undefined};
32 };
33 
34 // Specify the sizes/dtype/... of output tensor to preallocate the output.
35 // TODO: support the case where kernel allocates output tensors dynamically.
36 struct TORCH_API OutputSpec {
37   OutputSpec() = default;
38 
39   // Deserialize the spec from an IValue.
40   explicit OutputSpec(const c10::IValue& value);
41 
42   // Serialize the spec into an IValue.
43   C10_NODISCARD c10::IValue serialize() const;
44 
45   // Allocate an output tensor in accordance with the spec.
46   C10_NODISCARD at::Tensor allocate() const;
47 
48   std::vector<int64_t> sizes_;
49   c10::ScalarType dtype_{c10::ScalarType::Undefined};
50   std::optional<double> qscale_;
51   std::optional<int64_t> qzero_;
52 };
53 
54 // Hold the temporary buffers / states needed during the execution.
55 struct TORCH_API ExecutionState {
56   ExecutionState() = default;
57   ExecutionState(const ExecutionState&) = delete;
58   ExecutionState(ExecutionState&&) = default;
59   ExecutionState& operator=(const ExecutionState&) = delete;
60   ExecutionState& operator=(ExecutionState&&) = default;
61 
62   // Preallocated buffers needed by the NNC kernel.
63   std::vector<c10::DataPtr> preallocations_;
64 
65   // The NNC kernel expects the following arguments layout:
66   //   input tensor 1
67   //   ...
68   //   input tensor INPUT_NUM
69   //   output tensor 1
70   //   ...
71   //   output tensor OUTPUT_NUM
72   //   parameter tensor 1
73   //   ...
74   //   parameter tensor PARAM_NUM
75   //   temporary buffer 1
76   //   ...
77   //   temporary buffer BUFFER_NUM
78   std::vector<void*> arguments_;
79 };
80 
81 // Specify how to allocate temporary buffers at initialization.
82 struct TORCH_API MemoryPlan {
83   MemoryPlan() = default;
84 
85   explicit MemoryPlan(const c10::IValue& value);
86 
87   C10_NODISCARD c10::IValue serialize() const;
88 
89   void allocate(ExecutionState* state) const;
90 
91   std::vector<int64_t> buffer_sizes_;
92 };
93 
94 // Location of a symbolic shape among dimensions of the inputs
95 struct TORCH_API SymbolicShapePosition {
96   SymbolicShapePosition() = default;
SymbolicShapePositionSymbolicShapePosition97   SymbolicShapePosition(int64_t input_idx, int64_t dim_idx)
98       : input_idx_(input_idx), dim_idx_(dim_idx) {}
99 
100   int64_t input_idx_;
101   int64_t dim_idx_;
102 };
103 
104 // Represents a compiled NNC function which has a 1-1 correspondence with a
105 // `Method` (e.g. `forward`). It's similar as torch::jit::mobile::Function.
106 class TORCH_API Function {
107  public:
108   explicit Function() = default;
109 
110   // Deserialize from an IValue that is generated by the 'serialize()' method.
111   explicit Function(const c10::IValue& value);
112 
113   // Serialize into an IValue.
114   c10::IValue serialize() const;
115 
116   // Execute the compiled NNC function.
117   c10::impl::GenericList run(const c10::impl::GenericList& inputs) const;
118 
119   // The name of the function as specified in the model code.
name()120   c10::QualifiedName name() const {
121     return name_;
122   }
123 
set_name(const c10::QualifiedName & name)124   void set_name(const c10::QualifiedName& name) {
125     name_ = name;
126   }
127 
128   // The unique id of the generated NNC kernel corresponding to the function.
nnc_kernel_id()129   const std::string& nnc_kernel_id() const {
130     return nnc_kernel_id_;
131   }
132 
set_nnc_kernel_id(const std::string & name)133   void set_nnc_kernel_id(const std::string& name) {
134     nnc_kernel_id_ = name;
135   }
136 
137   // The parameters (e.g. weights / bias tensors) to be passed to the generated
138   // NNC kernel.
parameters()139   const c10::impl::GenericList& parameters() const {
140     return parameters_;
141   }
142 
set_parameters(const c10::impl::GenericList & parameters)143   void set_parameters(const c10::impl::GenericList& parameters) {
144     parameters_ = parameters;
145   }
146 
input_specs()147   const std::vector<InputSpec>& input_specs() const {
148     return input_specs_;
149   }
150 
set_input_specs(const std::vector<InputSpec> & input_specs)151   void set_input_specs(const std::vector<InputSpec>& input_specs) {
152     input_specs_ = input_specs;
153   }
154 
output_specs()155   const std::vector<OutputSpec>& output_specs() const {
156     return output_specs_;
157   }
158 
set_output_specs(const std::vector<OutputSpec> & output_specs)159   void set_output_specs(const std::vector<OutputSpec>& output_specs) {
160     output_specs_ = output_specs;
161   }
162 
memory_plan()163   const MemoryPlan& memory_plan() const {
164     return memory_plan_;
165   }
166 
set_memory_plan(const MemoryPlan & memory_plan)167   void set_memory_plan(const MemoryPlan& memory_plan) {
168     memory_plan_ = memory_plan;
169   }
170 
sym_shape_positions()171   const std::vector<SymbolicShapePosition>& sym_shape_positions() const {
172     return sym_shape_positions_;
173   }
174 
set_sym_shape_positions(const std::vector<SymbolicShapePosition> & sym_shape_pos)175   void set_sym_shape_positions(
176       const std::vector<SymbolicShapePosition>& sym_shape_pos) {
177     sym_shape_positions_ = sym_shape_pos;
178   }
179 
180  private:
181   void init_execution_state() const;
182 
183   c10::QualifiedName name_;
184   std::string nnc_kernel_id_;
185   c10::impl::GenericList parameters_{at::AnyType::get()};
186   std::vector<InputSpec> input_specs_;
187   std::vector<OutputSpec> output_specs_;
188   std::vector<SymbolicShapePosition> sym_shape_positions_;
189   MemoryPlan memory_plan_;
190   mutable std::unique_ptr<ExecutionState> execution_state_;
191 };
192 
193 // CompilationUnit consists of a set of compiled NNC functions. It has a 1-1
194 // correspondence with a `Module`.
195 // It's similar as torch::jit::mobile::CompilationUnit.
196 class TORCH_API CompilationUnit {
197  public:
198   CompilationUnit() = default;
199   CompilationUnit(const CompilationUnit&) = delete;
200   CompilationUnit(CompilationUnit&&) = default;
201   CompilationUnit& operator=(const CompilationUnit&) = delete;
202   CompilationUnit& operator=(CompilationUnit&&) = default;
203 
204   // Deserialize from an IValue that is generated by the 'serialize()' method.
205   explicit CompilationUnit(const c10::IValue& value);
206 
207   // Serialize all registered functions into an IValue. The IValue will be save
208   // into the compiled TorchScript model file ahead-of-time on the host, and
209   // will be deserialized at runtime on the target device.
210   C10_NODISCARD c10::IValue serialize() const;
211 
212   // Execute a registered function.
213   C10_NODISCARD c10::impl::GenericList run(
214       const c10::QualifiedName& function_name,
215       const c10::impl::GenericList& inputs) const;
216 
217   // Register a function to the compilation unit.
218   void register_function(std::unique_ptr<Function> fn);
219 
220  private:
221   C10_NODISCARD Function* find_function(const c10::QualifiedName& qn) const;
222 
223   std::unordered_map<c10::QualifiedName, std::unique_ptr<Function>> functions_;
224 };
225 
226 } // namespace nnc
227 } // namespace mobile
228 } // namespace jit
229 } // namespace torch
230