xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/function.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/dynamic_type.h>
2 #include <torch/csrc/jit/mobile/function.h>
3 #include <torch/csrc/jit/mobile/interpreter.h>
4 #include <torch/csrc/jit/mobile/parse_bytecode.h>
5 #include <torch/csrc/jit/mobile/parse_operators.h>
6 #include <torch/csrc/jit/mobile/prim_ops_registery.h>
7 #include <torch/csrc/jit/mobile/type_parser.h>
8 #include <torch/csrc/jit/runtime/instruction.h>
9 #include <torch/csrc/jit/runtime/operator.h>
10 
11 namespace torch::jit {
12 
13 char const* toString(OpCode op);
14 namespace mobile {
Function(c10::QualifiedName name)15 Function::Function(c10::QualifiedName name) : name_(std::move(name)) {}
16 
Function(c10::QualifiedName name,Code code,std::optional<c10::FunctionSchema> schema)17 Function::Function(
18     c10::QualifiedName name,
19     Code code,
20     std::optional<c10::FunctionSchema> schema)
21     : name_(std::move(name)),
22       code_(std::move(code)),
23       schema_(std::move(schema)) {}
24 
qualname() const25 const c10::QualifiedName& Function::qualname() const {
26   return name_;
27 }
28 
append_instruction(OpCode op,int64_t X,int64_t N,int64_t dbg_handle)29 void Function::append_instruction(
30     OpCode op,
31     int64_t X,
32     int64_t N,
33     int64_t dbg_handle) {
34   TORCH_CHECK(
35       isOpSupportedInMobile(op),
36       toString(op),
37       " is not supported in mobile module.");
38   code_.instructions_.emplace_back(op, X, N);
39   code_.debug_handles_.emplace_back(dbg_handle);
40 }
41 
append_instruction(OpCode op,int64_t X,int64_t N)42 void Function::append_instruction(OpCode op, int64_t X, int64_t N) {
43   TORCH_CHECK(
44       isOpSupportedInMobile(op),
45       toString(op),
46       " is not supported in mobile module.");
47   code_.instructions_.emplace_back(op, X, N);
48 }
49 
append_operator(const std::string & name,const std::string & overload_name,const std::optional<int> & num_specified_args)50 void Function::append_operator(
51     const std::string& name,
52     const std::string& overload_name,
53     const std::optional<int>& num_specified_args) {
54   // Keep the original opname in code_
55   code_.op_names_.emplace_back(name, overload_name);
56   code_.operator_input_sizes_.emplace_back(num_specified_args.value_or(-1));
57 }
58 
operator_str(const c10::OperatorName & opname)59 std::string operator_str(const c10::OperatorName& opname) {
60   std::string result = opname.name;
61   if (!opname.overload_name.empty()) {
62     result += "." + opname.overload_name;
63   }
64   return result;
65 }
66 
initialize_operators(bool should_check_operators)67 bool Function::initialize_operators(bool should_check_operators) {
68   if (code_.initialized) {
69     return true;
70   }
71   std::unordered_set<std::string> unsupported_op_names;
72   code_.operators_.resize(code_.op_names_.size());
73   bool all_ops_supported = true;
74   for (unsigned i = 0; i < code_.op_names_.size(); i++) {
75     const auto& opname = code_.op_names_[i];
76     int num_args = code_.operator_input_sizes_[i];
77     std::optional<int> num_specified_args =
78         num_args < 0 ? std::nullopt : std::optional<int>(num_args);
79     auto func = makeOperatorFunction(opname, num_specified_args);
80     if (!func.has_value()) {
81       unsupported_op_names.insert(operator_str(opname));
82       all_ops_supported = false;
83     } else {
84       code_.operators_[i] = *func;
85     }
86   }
87   if (should_check_operators) {
88     TORCH_CHECK(
89         unsupported_op_names.empty(),
90         "Following ops cannot be found: [",
91         c10::Join(", ", unsupported_op_names),
92         "]. Please check if the operator library is included in the build. If built with selected ops, check if these ops are in the list. If you are a Meta employee, please see fburl.com/missing_ops for a fix. Or post it in https://discuss.pytorch.org/c/mobile/");
93   }
94   code_.initialized = all_ops_supported;
95   return all_ops_supported;
96 }
97 
append_constant(const c10::IValue & constant)98 void Function::append_constant(const c10::IValue& constant) {
99   code_.constants_.push_back(constant);
100 }
101 
append_type(const at::TypePtr & type)102 void Function::append_type(const at::TypePtr& type) {
103   code_.types_.push_back(type);
104 }
105 
append_function(mobile::Function & function)106 void Function::append_function(mobile::Function& function) {
107   code_.functions_.push_back(&function);
108 }
109 
set_register_size(size_t size)110 void Function::set_register_size(size_t size) {
111   code_.register_size_ = size;
112 }
113 
get_debug_handle(size_t pc) const114 int64_t Function::get_debug_handle(size_t pc) const {
115   TORCH_CHECK(
116       pc < code_.debug_handles_.size(),
117       "Module debug info index out of boundary.");
118   return code_.debug_handles_[pc];
119 }
120 
setSchema(c10::FunctionSchema schema)121 torch::jit::Function& Function::setSchema(c10::FunctionSchema schema) {
122   schema_ = std::move(schema);
123   return *this;
124 }
125 
hasSchema() const126 bool Function::hasSchema() const {
127   return schema_.has_value();
128 }
129 
getSchema() const130 const c10::FunctionSchema& Function::getSchema() const {
131   return *schema_;
132 }
133 
run(Stack & stack)134 void Function::run(Stack& stack) {
135   initialize_operators(/* should_check_operators */ true);
136   if (hasSchema()) { // if we have a schema then resolve optional args if any
137     getSchema().checkAndNormalizeInputs<c10::DynamicType>(
138         stack, std::unordered_map<std::string, IValue>{} /*kwargs*/);
139   }
140   InterpreterState interp_state(code_);
141   interp_state.run(stack);
142 }
143 
operator ()(Stack & stack)144 at::IValue Function::operator()(Stack& stack) {
145   run(stack);
146   return stack.front();
147 }
148 
num_inputs() const149 size_t Function::num_inputs() const {
150   return schema_->arguments().size();
151 }
152 
call(Stack &,c10::function_ref<void (const mobile::Code &)> f)153 bool Function::call(Stack&, c10::function_ref<void(const mobile::Code&)> f) {
154   initialize_operators(true);
155   f(code_);
156   return true;
157 }
158 
get_code() const159 const Code& Function::get_code() const {
160   return code_;
161 }
162 
get_code()163 Code& Function::get_code() {
164   return code_;
165 }
166 
getExceptionDebugHandles() const167 const std::vector<int64_t>& Function::getExceptionDebugHandles() const {
168   return getInterpretersExceptionDebugHandles();
169 }
170 
makeOperatorFunction(const c10::OperatorName & opname,std::optional<int> num_specified_args)171 std::optional<std::function<void(Stack&)>> makeOperatorFunction(
172     const c10::OperatorName& opname,
173     std::optional<int> num_specified_args) {
174   std::function<void(Stack&)> fn;
175   const auto full_name = c10::toString(opname);
176   const std::vector<c10::Argument>* pArgs = nullptr;
177   bool promoted_op = mobile::hasPrimOpsFn(full_name);
178   if (promoted_op) {
179     fn = mobile::getPrimOpsFn(full_name);
180   } else {
181     std::shared_ptr<Operator> jit_op = findOperatorFor(opname);
182     if (jit_op) {
183       fn = [jit_op](Stack& stack) { jit_op->getOperation()(stack); };
184       pArgs = &jit_op->schema().arguments();
185     } else {
186       auto op = c10::Dispatcher::singleton().findSchema(opname);
187       if (op.has_value()) {
188         fn = [op](Stack& stack) { op->callBoxed(&stack); };
189         if (op->hasSchema()) {
190           pArgs = &op->schema().arguments();
191         } else {
192           TORCH_CHECK(false, "arguments are missing for operator ", opname);
193         }
194       } else {
195         return std::nullopt;
196       }
197     }
198   }
199 
200   if (!promoted_op) {
201     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(pArgs);
202     const auto& args = *pArgs;
203     // num_specified_args >= 0 indicates number of arguments are available
204     // from model. We can use it to handle backward compatibility.
205     if (num_specified_args &&
206         num_specified_args.value() < static_cast<int64_t>(args.size())) {
207       fn = [fn, num_specified_args, &args](Stack& stack) {
208         std::vector<IValue> out_args;
209         // The following logic pops and temporarily stores all out arguments
210         // from the stack (which can be 0 or more, and always appended to the
211         // schema), in order to push the necessary default values. Finally,
212         // the out arguments are pushed back into the stack.
213         for (size_t i = args.size() - 1; i > 0 && args.at(i).is_out(); i--) {
214           out_args.push_back(stack.back());
215           stack.pop_back();
216         }
217         TORCH_CHECK(
218             static_cast<size_t>(num_specified_args.value()) >= out_args.size(),
219             "The number of output arguments is: ",
220             out_args.size(),
221             ", which is more then the number of specified arguments: ",
222             num_specified_args.value());
223         size_t start_index = num_specified_args.value() - out_args.size();
224         for (size_t i = start_index; i < (args.size() - out_args.size()); ++i) {
225           TORCH_CHECK(
226               args[i].default_value().has_value(),
227               "Error happened at preparing for default values for the argument. The ",
228               i,
229               "th argument ",
230               args[i].name(),
231               " does not have a specified value or default value. ");
232 
233           stack.emplace_back(args[i].default_value());
234         }
235         stack.insert(stack.end(), out_args.rbegin(), out_args.rend());
236         fn(stack);
237       };
238     }
239   }
240   return fn;
241 }
242 
registerFunc(const std::string & qualified_name,const std::vector<Instruction> & instructions,const std::vector<c10::IValue> & constants,const std::vector<c10::TypePtr> & types,const size_t register_size)243 Function& Function::registerFunc(
244     const std::string& qualified_name,
245     const std::vector<Instruction>& instructions,
246     const std::vector<c10::IValue>& constants,
247     const std::vector<c10::TypePtr>& types,
248     const size_t register_size) {
249   static std::unordered_map<c10::QualifiedName, Function>
250       upgrader_function_holder;
251   c10::QualifiedName name = c10::QualifiedName(qualified_name);
252   auto found = upgrader_function_holder.find(name);
253   // Register the function if it's not found in the map.
254   if (found == upgrader_function_holder.end()) {
255     auto name_function_pair =
256         upgrader_function_holder.emplace(name, Function(name));
257     auto& func = name_function_pair.first->second;
258     for (auto const& inst : instructions) {
259       func.append_instruction(inst.op, inst.X, inst.N);
260     }
261     for (auto const& constant : constants) {
262       func.append_constant(constant);
263     }
264     for (auto const& type : types) {
265       func.append_type(type);
266     }
267     func.set_register_size(register_size);
268     return func;
269   }
270   auto& upgrader_function_in_holder = found->second;
271   return upgrader_function_in_holder;
272 }
273 
274 } // namespace mobile
275 } // namespace torch::jit
276