1 #include <ATen/core/dynamic_type.h>
2 #include <torch/csrc/jit/mobile/function.h>
3 #include <torch/csrc/jit/mobile/interpreter.h>
4 #include <torch/csrc/jit/mobile/parse_bytecode.h>
5 #include <torch/csrc/jit/mobile/parse_operators.h>
6 #include <torch/csrc/jit/mobile/prim_ops_registery.h>
7 #include <torch/csrc/jit/mobile/type_parser.h>
8 #include <torch/csrc/jit/runtime/instruction.h>
9 #include <torch/csrc/jit/runtime/operator.h>
10
11 namespace torch::jit {
12
13 char const* toString(OpCode op);
14 namespace mobile {
Function(c10::QualifiedName name)15 Function::Function(c10::QualifiedName name) : name_(std::move(name)) {}
16
Function(c10::QualifiedName name,Code code,std::optional<c10::FunctionSchema> schema)17 Function::Function(
18 c10::QualifiedName name,
19 Code code,
20 std::optional<c10::FunctionSchema> schema)
21 : name_(std::move(name)),
22 code_(std::move(code)),
23 schema_(std::move(schema)) {}
24
qualname() const25 const c10::QualifiedName& Function::qualname() const {
26 return name_;
27 }
28
append_instruction(OpCode op,int64_t X,int64_t N,int64_t dbg_handle)29 void Function::append_instruction(
30 OpCode op,
31 int64_t X,
32 int64_t N,
33 int64_t dbg_handle) {
34 TORCH_CHECK(
35 isOpSupportedInMobile(op),
36 toString(op),
37 " is not supported in mobile module.");
38 code_.instructions_.emplace_back(op, X, N);
39 code_.debug_handles_.emplace_back(dbg_handle);
40 }
41
append_instruction(OpCode op,int64_t X,int64_t N)42 void Function::append_instruction(OpCode op, int64_t X, int64_t N) {
43 TORCH_CHECK(
44 isOpSupportedInMobile(op),
45 toString(op),
46 " is not supported in mobile module.");
47 code_.instructions_.emplace_back(op, X, N);
48 }
49
append_operator(const std::string & name,const std::string & overload_name,const std::optional<int> & num_specified_args)50 void Function::append_operator(
51 const std::string& name,
52 const std::string& overload_name,
53 const std::optional<int>& num_specified_args) {
54 // Keep the original opname in code_
55 code_.op_names_.emplace_back(name, overload_name);
56 code_.operator_input_sizes_.emplace_back(num_specified_args.value_or(-1));
57 }
58
operator_str(const c10::OperatorName & opname)59 std::string operator_str(const c10::OperatorName& opname) {
60 std::string result = opname.name;
61 if (!opname.overload_name.empty()) {
62 result += "." + opname.overload_name;
63 }
64 return result;
65 }
66
initialize_operators(bool should_check_operators)67 bool Function::initialize_operators(bool should_check_operators) {
68 if (code_.initialized) {
69 return true;
70 }
71 std::unordered_set<std::string> unsupported_op_names;
72 code_.operators_.resize(code_.op_names_.size());
73 bool all_ops_supported = true;
74 for (unsigned i = 0; i < code_.op_names_.size(); i++) {
75 const auto& opname = code_.op_names_[i];
76 int num_args = code_.operator_input_sizes_[i];
77 std::optional<int> num_specified_args =
78 num_args < 0 ? std::nullopt : std::optional<int>(num_args);
79 auto func = makeOperatorFunction(opname, num_specified_args);
80 if (!func.has_value()) {
81 unsupported_op_names.insert(operator_str(opname));
82 all_ops_supported = false;
83 } else {
84 code_.operators_[i] = *func;
85 }
86 }
87 if (should_check_operators) {
88 TORCH_CHECK(
89 unsupported_op_names.empty(),
90 "Following ops cannot be found: [",
91 c10::Join(", ", unsupported_op_names),
92 "]. Please check if the operator library is included in the build. If built with selected ops, check if these ops are in the list. If you are a Meta employee, please see fburl.com/missing_ops for a fix. Or post it in https://discuss.pytorch.org/c/mobile/");
93 }
94 code_.initialized = all_ops_supported;
95 return all_ops_supported;
96 }
97
append_constant(const c10::IValue & constant)98 void Function::append_constant(const c10::IValue& constant) {
99 code_.constants_.push_back(constant);
100 }
101
append_type(const at::TypePtr & type)102 void Function::append_type(const at::TypePtr& type) {
103 code_.types_.push_back(type);
104 }
105
append_function(mobile::Function & function)106 void Function::append_function(mobile::Function& function) {
107 code_.functions_.push_back(&function);
108 }
109
set_register_size(size_t size)110 void Function::set_register_size(size_t size) {
111 code_.register_size_ = size;
112 }
113
get_debug_handle(size_t pc) const114 int64_t Function::get_debug_handle(size_t pc) const {
115 TORCH_CHECK(
116 pc < code_.debug_handles_.size(),
117 "Module debug info index out of boundary.");
118 return code_.debug_handles_[pc];
119 }
120
setSchema(c10::FunctionSchema schema)121 torch::jit::Function& Function::setSchema(c10::FunctionSchema schema) {
122 schema_ = std::move(schema);
123 return *this;
124 }
125
hasSchema() const126 bool Function::hasSchema() const {
127 return schema_.has_value();
128 }
129
getSchema() const130 const c10::FunctionSchema& Function::getSchema() const {
131 return *schema_;
132 }
133
run(Stack & stack)134 void Function::run(Stack& stack) {
135 initialize_operators(/* should_check_operators */ true);
136 if (hasSchema()) { // if we have a schema then resolve optional args if any
137 getSchema().checkAndNormalizeInputs<c10::DynamicType>(
138 stack, std::unordered_map<std::string, IValue>{} /*kwargs*/);
139 }
140 InterpreterState interp_state(code_);
141 interp_state.run(stack);
142 }
143
operator ()(Stack & stack)144 at::IValue Function::operator()(Stack& stack) {
145 run(stack);
146 return stack.front();
147 }
148
num_inputs() const149 size_t Function::num_inputs() const {
150 return schema_->arguments().size();
151 }
152
call(Stack &,c10::function_ref<void (const mobile::Code &)> f)153 bool Function::call(Stack&, c10::function_ref<void(const mobile::Code&)> f) {
154 initialize_operators(true);
155 f(code_);
156 return true;
157 }
158
get_code() const159 const Code& Function::get_code() const {
160 return code_;
161 }
162
get_code()163 Code& Function::get_code() {
164 return code_;
165 }
166
getExceptionDebugHandles() const167 const std::vector<int64_t>& Function::getExceptionDebugHandles() const {
168 return getInterpretersExceptionDebugHandles();
169 }
170
makeOperatorFunction(const c10::OperatorName & opname,std::optional<int> num_specified_args)171 std::optional<std::function<void(Stack&)>> makeOperatorFunction(
172 const c10::OperatorName& opname,
173 std::optional<int> num_specified_args) {
174 std::function<void(Stack&)> fn;
175 const auto full_name = c10::toString(opname);
176 const std::vector<c10::Argument>* pArgs = nullptr;
177 bool promoted_op = mobile::hasPrimOpsFn(full_name);
178 if (promoted_op) {
179 fn = mobile::getPrimOpsFn(full_name);
180 } else {
181 std::shared_ptr<Operator> jit_op = findOperatorFor(opname);
182 if (jit_op) {
183 fn = [jit_op](Stack& stack) { jit_op->getOperation()(stack); };
184 pArgs = &jit_op->schema().arguments();
185 } else {
186 auto op = c10::Dispatcher::singleton().findSchema(opname);
187 if (op.has_value()) {
188 fn = [op](Stack& stack) { op->callBoxed(&stack); };
189 if (op->hasSchema()) {
190 pArgs = &op->schema().arguments();
191 } else {
192 TORCH_CHECK(false, "arguments are missing for operator ", opname);
193 }
194 } else {
195 return std::nullopt;
196 }
197 }
198 }
199
200 if (!promoted_op) {
201 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(pArgs);
202 const auto& args = *pArgs;
203 // num_specified_args >= 0 indicates number of arguments are available
204 // from model. We can use it to handle backward compatibility.
205 if (num_specified_args &&
206 num_specified_args.value() < static_cast<int64_t>(args.size())) {
207 fn = [fn, num_specified_args, &args](Stack& stack) {
208 std::vector<IValue> out_args;
209 // The following logic pops and temporarily stores all out arguments
210 // from the stack (which can be 0 or more, and always appended to the
211 // schema), in order to push the necessary default values. Finally,
212 // the out arguments are pushed back into the stack.
213 for (size_t i = args.size() - 1; i > 0 && args.at(i).is_out(); i--) {
214 out_args.push_back(stack.back());
215 stack.pop_back();
216 }
217 TORCH_CHECK(
218 static_cast<size_t>(num_specified_args.value()) >= out_args.size(),
219 "The number of output arguments is: ",
220 out_args.size(),
221 ", which is more then the number of specified arguments: ",
222 num_specified_args.value());
223 size_t start_index = num_specified_args.value() - out_args.size();
224 for (size_t i = start_index; i < (args.size() - out_args.size()); ++i) {
225 TORCH_CHECK(
226 args[i].default_value().has_value(),
227 "Error happened at preparing for default values for the argument. The ",
228 i,
229 "th argument ",
230 args[i].name(),
231 " does not have a specified value or default value. ");
232
233 stack.emplace_back(args[i].default_value());
234 }
235 stack.insert(stack.end(), out_args.rbegin(), out_args.rend());
236 fn(stack);
237 };
238 }
239 }
240 return fn;
241 }
242
registerFunc(const std::string & qualified_name,const std::vector<Instruction> & instructions,const std::vector<c10::IValue> & constants,const std::vector<c10::TypePtr> & types,const size_t register_size)243 Function& Function::registerFunc(
244 const std::string& qualified_name,
245 const std::vector<Instruction>& instructions,
246 const std::vector<c10::IValue>& constants,
247 const std::vector<c10::TypePtr>& types,
248 const size_t register_size) {
249 static std::unordered_map<c10::QualifiedName, Function>
250 upgrader_function_holder;
251 c10::QualifiedName name = c10::QualifiedName(qualified_name);
252 auto found = upgrader_function_holder.find(name);
253 // Register the function if it's not found in the map.
254 if (found == upgrader_function_holder.end()) {
255 auto name_function_pair =
256 upgrader_function_holder.emplace(name, Function(name));
257 auto& func = name_function_pair.first->second;
258 for (auto const& inst : instructions) {
259 func.append_instruction(inst.op, inst.X, inst.N);
260 }
261 for (auto const& constant : constants) {
262 func.append_constant(constant);
263 }
264 for (auto const& type : types) {
265 func.append_type(type);
266 }
267 func.set_register_size(register_size);
268 return func;
269 }
270 auto& upgrader_function_in_holder = found->second;
271 return upgrader_function_in_holder;
272 }
273
274 } // namespace mobile
275 } // namespace torch::jit
276