1 #pragma once 2 3 #include <memory> 4 #include <string> 5 #include <utility> 6 #include <vector> 7 8 #include <ATen/core/ivalue.h> 9 #include <c10/core/ScalarType.h> 10 11 namespace torch { 12 namespace jit { 13 namespace mobile { 14 namespace nnc { 15 16 // Specify the requirements on an input tensor. 17 // TODO: support input tensor with dynamic shape (PR #54982) 18 struct TORCH_API InputSpec { 19 InputSpec() = default; 20 21 // Deserialize the spec from an IValue. 22 explicit InputSpec(const c10::IValue& value); 23 24 // Serialize the spec into an IValue. 25 C10_NODISCARD c10::IValue serialize() const; 26 27 // Check whether the input tensor adheres to the spec. 28 C10_NODISCARD bool validate(const at::Tensor& input) const; 29 30 std::vector<int64_t> sizes_; 31 c10::ScalarType dtype_{c10::ScalarType::Undefined}; 32 }; 33 34 // Specify the sizes/dtype/... of output tensor to preallocate the output. 35 // TODO: support the case where kernel allocates output tensors dynamically. 36 struct TORCH_API OutputSpec { 37 OutputSpec() = default; 38 39 // Deserialize the spec from an IValue. 40 explicit OutputSpec(const c10::IValue& value); 41 42 // Serialize the spec into an IValue. 43 C10_NODISCARD c10::IValue serialize() const; 44 45 // Allocate an output tensor in accordance with the spec. 46 C10_NODISCARD at::Tensor allocate() const; 47 48 std::vector<int64_t> sizes_; 49 c10::ScalarType dtype_{c10::ScalarType::Undefined}; 50 std::optional<double> qscale_; 51 std::optional<int64_t> qzero_; 52 }; 53 54 // Hold the temporary buffers / states needed during the execution. 55 struct TORCH_API ExecutionState { 56 ExecutionState() = default; 57 ExecutionState(const ExecutionState&) = delete; 58 ExecutionState(ExecutionState&&) = default; 59 ExecutionState& operator=(const ExecutionState&) = delete; 60 ExecutionState& operator=(ExecutionState&&) = default; 61 62 // Preallocated buffers needed by the NNC kernel. 63 std::vector<c10::DataPtr> preallocations_; 64 65 // The NNC kernel expects the following arguments layout: 66 // input tensor 1 67 // ... 68 // input tensor INPUT_NUM 69 // output tensor 1 70 // ... 71 // output tensor OUTPUT_NUM 72 // parameter tensor 1 73 // ... 74 // parameter tensor PARAM_NUM 75 // temporary buffer 1 76 // ... 77 // temporary buffer BUFFER_NUM 78 std::vector<void*> arguments_; 79 }; 80 81 // Specify how to allocate temporary buffers at initialization. 82 struct TORCH_API MemoryPlan { 83 MemoryPlan() = default; 84 85 explicit MemoryPlan(const c10::IValue& value); 86 87 C10_NODISCARD c10::IValue serialize() const; 88 89 void allocate(ExecutionState* state) const; 90 91 std::vector<int64_t> buffer_sizes_; 92 }; 93 94 // Location of a symbolic shape among dimensions of the inputs 95 struct TORCH_API SymbolicShapePosition { 96 SymbolicShapePosition() = default; SymbolicShapePositionSymbolicShapePosition97 SymbolicShapePosition(int64_t input_idx, int64_t dim_idx) 98 : input_idx_(input_idx), dim_idx_(dim_idx) {} 99 100 int64_t input_idx_; 101 int64_t dim_idx_; 102 }; 103 104 // Represents a compiled NNC function which has a 1-1 correspondence with a 105 // `Method` (e.g. `forward`). It's similar as torch::jit::mobile::Function. 106 class TORCH_API Function { 107 public: 108 explicit Function() = default; 109 110 // Deserialize from an IValue that is generated by the 'serialize()' method. 111 explicit Function(const c10::IValue& value); 112 113 // Serialize into an IValue. 114 c10::IValue serialize() const; 115 116 // Execute the compiled NNC function. 117 c10::impl::GenericList run(const c10::impl::GenericList& inputs) const; 118 119 // The name of the function as specified in the model code. name()120 c10::QualifiedName name() const { 121 return name_; 122 } 123 set_name(const c10::QualifiedName & name)124 void set_name(const c10::QualifiedName& name) { 125 name_ = name; 126 } 127 128 // The unique id of the generated NNC kernel corresponding to the function. nnc_kernel_id()129 const std::string& nnc_kernel_id() const { 130 return nnc_kernel_id_; 131 } 132 set_nnc_kernel_id(const std::string & name)133 void set_nnc_kernel_id(const std::string& name) { 134 nnc_kernel_id_ = name; 135 } 136 137 // The parameters (e.g. weights / bias tensors) to be passed to the generated 138 // NNC kernel. parameters()139 const c10::impl::GenericList& parameters() const { 140 return parameters_; 141 } 142 set_parameters(const c10::impl::GenericList & parameters)143 void set_parameters(const c10::impl::GenericList& parameters) { 144 parameters_ = parameters; 145 } 146 input_specs()147 const std::vector<InputSpec>& input_specs() const { 148 return input_specs_; 149 } 150 set_input_specs(const std::vector<InputSpec> & input_specs)151 void set_input_specs(const std::vector<InputSpec>& input_specs) { 152 input_specs_ = input_specs; 153 } 154 output_specs()155 const std::vector<OutputSpec>& output_specs() const { 156 return output_specs_; 157 } 158 set_output_specs(const std::vector<OutputSpec> & output_specs)159 void set_output_specs(const std::vector<OutputSpec>& output_specs) { 160 output_specs_ = output_specs; 161 } 162 memory_plan()163 const MemoryPlan& memory_plan() const { 164 return memory_plan_; 165 } 166 set_memory_plan(const MemoryPlan & memory_plan)167 void set_memory_plan(const MemoryPlan& memory_plan) { 168 memory_plan_ = memory_plan; 169 } 170 sym_shape_positions()171 const std::vector<SymbolicShapePosition>& sym_shape_positions() const { 172 return sym_shape_positions_; 173 } 174 set_sym_shape_positions(const std::vector<SymbolicShapePosition> & sym_shape_pos)175 void set_sym_shape_positions( 176 const std::vector<SymbolicShapePosition>& sym_shape_pos) { 177 sym_shape_positions_ = sym_shape_pos; 178 } 179 180 private: 181 void init_execution_state() const; 182 183 c10::QualifiedName name_; 184 std::string nnc_kernel_id_; 185 c10::impl::GenericList parameters_{at::AnyType::get()}; 186 std::vector<InputSpec> input_specs_; 187 std::vector<OutputSpec> output_specs_; 188 std::vector<SymbolicShapePosition> sym_shape_positions_; 189 MemoryPlan memory_plan_; 190 mutable std::unique_ptr<ExecutionState> execution_state_; 191 }; 192 193 // CompilationUnit consists of a set of compiled NNC functions. It has a 1-1 194 // correspondence with a `Module`. 195 // It's similar as torch::jit::mobile::CompilationUnit. 196 class TORCH_API CompilationUnit { 197 public: 198 CompilationUnit() = default; 199 CompilationUnit(const CompilationUnit&) = delete; 200 CompilationUnit(CompilationUnit&&) = default; 201 CompilationUnit& operator=(const CompilationUnit&) = delete; 202 CompilationUnit& operator=(CompilationUnit&&) = default; 203 204 // Deserialize from an IValue that is generated by the 'serialize()' method. 205 explicit CompilationUnit(const c10::IValue& value); 206 207 // Serialize all registered functions into an IValue. The IValue will be save 208 // into the compiled TorchScript model file ahead-of-time on the host, and 209 // will be deserialized at runtime on the target device. 210 C10_NODISCARD c10::IValue serialize() const; 211 212 // Execute a registered function. 213 C10_NODISCARD c10::impl::GenericList run( 214 const c10::QualifiedName& function_name, 215 const c10::impl::GenericList& inputs) const; 216 217 // Register a function to the compilation unit. 218 void register_function(std::unique_ptr<Function> fn); 219 220 private: 221 C10_NODISCARD Function* find_function(const c10::QualifiedName& qn) const; 222 223 std::unordered_map<c10::QualifiedName, std::unique_ptr<Function>> functions_; 224 }; 225 226 } // namespace nnc 227 } // namespace mobile 228 } // namespace jit 229 } // namespace torch 230