1 #include <torch/csrc/jit/runtime/operator.h>
2
3 #include <ATen/ATen.h>
4 #include <ATen/core/interned_strings.h>
5 #include <c10/util/irange.h>
6 #include <torch/csrc/jit/frontend/edit_distance.h>
7
8 #include <queue>
9 #include <utility>
10 #include <vector>
11
12 namespace torch::jit {
13
14 namespace {
15 using OperatorMap =
16 std::unordered_map<Symbol, std::vector<std::shared_ptr<Operator>>>;
17 struct OperatorRegistry {
18 private:
19 std::mutex lock;
20 OperatorMap operators;
21 // list of operators whose schema have not yet been parsed, and must
22 // be registered before any call to lookup an operator
23 std::vector<std::shared_ptr<Operator>> to_register;
24 // Those two maps are used to implement lookupByLiteral, which is needed for
25 // the n->match(...) calls. Basically, every function schema is assigned a
26 // unique string you can use to match it. However, parsing those strings or
27 // comparing and hashing them character by character would be very slow, so we
28 // use a trick here! Every string literal in your program is guaranteed to
29 // have static storage duration and so its address won't change at runtime.
30 // This allows us to memoize answers for every pointer, which is done by the
31 // operators_by_sig_literal map. Still, this map is initially empty, and so we
32 // still need to do the complete string matching at the first time, which is
33 // implemented by performing a lookup in the operators_by_sig map.
34 std::unordered_map<std::string, std::shared_ptr<Operator>> operators_by_sig;
35 std::unordered_map<const char*, std::shared_ptr<Operator>>
36 operators_by_sig_literal;
37
38 // Remember all registered operator names to check that they aren't
39 // registered a second time. Registering an op multiple times is
40 // fragile because it might depend on static initialization order
41 // which one is picked at runtime.
42 #ifdef C10_MOBILE
43 std::unordered_set<c10::OperatorName> registered_operator_names;
44 #endif
45
46 // XXX - caller must be holding lock
registerPendingOperatorstorch::jit::__anond11ff3ee0111::OperatorRegistry47 void registerPendingOperators() {
48 for (const auto& op : to_register) {
49 Symbol sym = Symbol::fromQualString(op->schema().name());
50 operators[sym].push_back(op);
51 operators_by_sig[canonicalSchemaString(op->schema())] = op;
52 }
53 to_register.clear();
54 }
55
56 public:
registerOperatortorch::jit::__anond11ff3ee0111::OperatorRegistry57 void registerOperator(Operator&& op) {
58 std::lock_guard<std::mutex> guard(lock);
59 #ifdef C10_MOBILE
60 TORCH_INTERNAL_ASSERT(
61 0 == registered_operator_names.count(op.schema().operator_name()),
62 "Tried to register operator \"",
63 op.schema(),
64 "\" to JIT but the operator name was already registered before. Please add or change the overload name.");
65 registered_operator_names.insert(op.schema().operator_name());
66 #endif
67 to_register.push_back(std::make_shared<Operator>(std::move(op)));
68 }
69
deregisterOperatortorch::jit::__anond11ff3ee0111::OperatorRegistry70 void deregisterOperator(const FunctionSchema& schema) {
71 Symbol sym = Symbol::fromQualString(schema.name());
72 auto sig = canonicalSchemaString(schema);
73
74 std::lock_guard<std::mutex> guard(lock);
75 #ifdef C10_MOBILE
76 TORCH_INTERNAL_ASSERT(
77 1 == registered_operator_names.count(schema.operator_name()),
78 "Tried to remove operator ",
79 schema,
80 " from JIT but it wasn't found.");
81 registered_operator_names.erase(schema.operator_name());
82 #endif
83 // Try removing from pending operators list first
84 auto pending_it = to_register.begin();
85 while (pending_it != to_register.end() && (*pending_it)->schema() != schema)
86 ++pending_it;
87
88 if (pending_it != to_register.end()) {
89 to_register.erase(pending_it);
90 return;
91 }
92
93 // Remove operator from signature map
94 auto sig_it = operators_by_sig.find(sig);
95 if (sig_it == operators_by_sig.end()) {
96 return;
97 }
98
99 operators_by_sig.erase(sig_it);
100
101 // Remove operator from symbol map
102 auto op_it = operators.find(sym);
103 TORCH_CHECK(
104 op_it != operators.end(),
105 "operator with signature ",
106 sig,
107 " is missing from symbol registry");
108
109 auto& op_vec = op_it->second;
110 auto it = op_vec.begin();
111 while (it != op_vec.end() && (*it)->schema() != schema)
112 ++it;
113 if (it != op_vec.end()) {
114 op_vec.erase(it);
115 }
116 if (op_vec.empty()) {
117 operators.erase(op_it);
118 }
119 }
120
lookupByLiteraltorch::jit::__anond11ff3ee0111::OperatorRegistry121 const std::shared_ptr<Operator>& lookupByLiteral(const char* name) {
122 std::lock_guard<std::mutex> guard(lock);
123 registerPendingOperators();
124 auto it = operators_by_sig_literal.find(name);
125 if (it == operators_by_sig_literal.end()) {
126 auto op_ptr_it =
127 operators_by_sig.find(canonicalSchemaString(parseSchema(name)));
128 // Handy debugging code that dumps all operators we know about on mismatch
129 #if 0
130 if (op_ptr_it == operators_by_sig.end()) {
131 for (auto & entry : operators_by_sig) {
132 std::cout << entry.first << std::endl;
133 }
134 }
135 #endif
136 TORCH_CHECK(
137 op_ptr_it != operators_by_sig.end(),
138 "Couldn't find an operator for ",
139 name,
140 ". Do you have to update a set of hardcoded JIT ops?");
141 it = operators_by_sig_literal.emplace_hint(it, name, op_ptr_it->second);
142 }
143 return it->second;
144 }
145
getOperatorstorch::jit::__anond11ff3ee0111::OperatorRegistry146 const std::vector<std::shared_ptr<Operator>>& getOperators(Symbol name) {
147 std::lock_guard<std::mutex> guard(lock);
148 registerPendingOperators();
149 static std::vector<std::shared_ptr<Operator>> empty;
150 auto it = operators.find(name);
151 if (it != operators.end())
152 return it->second;
153 return empty;
154 }
155
findSimilarOperatorstorch::jit::__anond11ff3ee0111::OperatorRegistry156 std::vector<Symbol> findSimilarOperators(Symbol input_op) {
157 std::lock_guard<std::mutex> guard(lock);
158 registerPendingOperators();
159
160 using EntryPair = std::pair<int64_t, Symbol>;
161 auto cmp = [](const EntryPair& lhs, const EntryPair& rhs) {
162 return lhs.first > rhs.first;
163 };
164
165 std::priority_queue<EntryPair, std::vector<EntryPair>, decltype(cmp)>
166 rankings(cmp);
167 static constexpr size_t MAX_EDIT_DIST = 2u;
168 for (const auto& op : operators) {
169 auto edit_dist = ComputeEditDistance(
170 input_op.toQualString(), op.first.toQualString(), MAX_EDIT_DIST);
171 if (edit_dist <= MAX_EDIT_DIST) {
172 rankings.emplace(edit_dist, op.first);
173 }
174 }
175 std::vector<Symbol> ret;
176 while (!rankings.empty()) {
177 ret.push_back(rankings.top().second);
178 rankings.pop();
179 }
180 return ret;
181 }
182
getAllOperatorstorch::jit::__anond11ff3ee0111::OperatorRegistry183 const std::vector<std::shared_ptr<Operator>> getAllOperators() {
184 std::lock_guard<std::mutex> guard(lock);
185 registerPendingOperators();
186 std::vector<std::shared_ptr<Operator>> values;
187 values.clear();
188 for (auto& kv : operators) {
189 values.insert(values.end(), kv.second.begin(), kv.second.end());
190 }
191 return values;
192 }
193 };
194
getRegistry()195 OperatorRegistry& getRegistry() {
196 static OperatorRegistry r;
197 return r;
198 }
199
printerHasSpecialCaseFor(Symbol sym)200 bool printerHasSpecialCaseFor(Symbol sym) {
201 using namespace at;
202 // WARNING: by adding a value to this set, you are asserting
203 // that you have also added special handling of this symbol to
204 // the python_print.cpp. Not adding handling will cause import and export
205 // of modules with this new operator to fail. This is only required
206 // for operators without schema. Prefer registering your operator with
207 // schema to editing this list here. These cases should only be things
208 // that require special handling because they do not fit normal schema
209 const static std::unordered_set<Symbol> handled = {
210 prim::Constant, prim::Uninitialized, prim::fork,
211 prim::awaitable, prim::ListConstruct, prim::DictConstruct,
212 prim::ListUnpack, prim::Print, prim::PythonOp,
213 prim::TupleConstruct, prim::TupleIndex, prim::TupleSlice,
214 prim::TupleUnpack, prim::CreateObject, prim::GetAttr,
215 prim::SetAttr, prim::CallFunction, prim::isinstance,
216 prim::unchecked_cast, prim::tolist, prim::rpc_async,
217 prim::rpc_sync, prim::rpc_remote};
218
219 // WARNING: by adding a value to this set, you are asserting that your
220 // primitive is only ever added during optimization and does not need
221 // to be correctly printed for export (a process that happens before
222 // optimization passes run)
223 const static std::unordered_set<Symbol> unneeded = {
224 c10::onnx::Reshape, // only used in onnx
225 c10::onnx::Shape, // only used in onnx
226 prim::AutogradZero, // temporarily inserted by autograd
227 prim::AutogradAnyNonZero, // temporarily inserted by autograd
228 prim::AutogradAllNonZero, // temporarily inserted by autograd
229 prim::AutogradAllZero, // temporarily inserted by autograd
230 prim::AutogradAdd, // temporarily inserted by autograd
231 prim::ConstantChunk, // optimization pass adds it
232 prim::DifferentiableGraph, // optimization pass adds it,
233 prim::FunctionalGraph, // optimization pass adds it,
234 prim::ReductionSizes, // optimization pass (fuser) adds it
235 prim::BroadcastSizes, // optimization pass (fuser) adds it
236 prim::ChunkSizes, // optimization pass (fuser) adds it
237 prim::Drop, // used in interpreter only
238 prim::FusedConcat, // optimization pass adds it
239 prim::FusionGroup, // optimization pass adds it
240 prim::CudaFusionGroup, // optimization pass adds it
241 prim::CudaFusionGuard, // optimization pass adds it
242 prim::TensorExprGroup, // optimization pass adds it
243 prim::TensorExprDynamicGroup, // optimization pass adds it
244 prim::StaticSubgraph, // optimization pass adds it
245 prim::ConstantMKLDNNTensor, // optimization pass adds it
246 prim::BroadcastMKLDNNTensors, // optimization pass adds it
247 prim::oneDNNFusionGroup, // optimization pass adds it
248 prim::oneDNNFusionGuard, // optimization pass adds it
249 prim::StaticRuntimeCopyOuts, // used in SR only
250 prim::Load, // used in interpreter only
251 prim::MMTreeReduce, // used as an optimization
252 prim::MMBatchSide, // used as an optimization
253 prim::Store, // used in interpreter only
254 prim::profile, // used in interpreter only
255 prim::profile_ivalue, // used in interpreter only
256 prim::TypeCheck, // used in interpreter only
257 prim::RequiresGradCheck, // used in interpreter only
258 prim::FallbackGraph, // converted into prim::CallFunction
259
260 };
261
262 // These namespaces are required to have Python printers unless
263 // otherwise noted in unneeded.
264 const static std::unordered_set<Symbol> required_namespaces = {
265 c10::namespaces::prim,
266 c10::namespaces::aten,
267 c10::namespaces::onnx,
268 };
269
270 return handled.count(sym) || unneeded.count(sym) ||
271 !required_namespaces.count(sym.ns());
272 }
273
274 } // anonymous namespace
275
aliasAnalysisHasSpecialCaseFor(Symbol symbol)276 bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
277 using namespace at;
278 // WARNING: by adding a case to this list, you are asserting that you have
279 // added a case for the unschematized node in AliasDb::analyze
280 const static std::unordered_set<Symbol> handled = {
281 prim::If,
282 prim::Loop,
283 prim::FusionGroup,
284 prim::CudaFusionGroup,
285 prim::oneDNNFusionGroup,
286 prim::DifferentiableGraph,
287 prim::TensorExprGroup,
288 prim::TensorExprDynamicGroup,
289 prim::StaticSubgraph,
290 prim::FunctionalGraph,
291 prim::Constant,
292 prim::Uninitialized,
293 prim::DictConstruct,
294 prim::ListConstruct,
295 prim::TupleConstruct,
296 prim::AutogradZero,
297 prim::FusedConcat,
298 prim::GradOf,
299 prim::MMTreeReduce,
300 prim::MMBatchSide,
301 prim::BroadcastSizes,
302 prim::ChunkSizes,
303 prim::Closure,
304 prim::TupleUnpack,
305 prim::TupleIndex,
306 prim::TupleSlice,
307 prim::ListUnpack,
308 prim::PythonOp,
309 prim::ConstantChunk,
310 prim::BroadcastingChunk,
311 prim::MKLDNNGroup,
312 prim::ConstantMKLDNNTensor,
313 prim::BroadcastMKLDNNTensors,
314 prim::fork,
315 prim::awaitable,
316 prim::awaitable_nowait,
317 prim::awaitable_wait,
318 prim::CreateObject,
319 prim::AutogradAdd,
320 prim::GetAttr,
321 prim::SetAttr,
322 prim::profile,
323 prim::profile_ivalue,
324 prim::TypeCheck,
325 prim::RequiresGradCheck,
326 prim::Print,
327 prim::CallFunction,
328 prim::CallMethod,
329 aten::wait,
330 prim::isinstance,
331 prim::unchecked_cast,
332 prim::tolist,
333 prim::rpc_async,
334 prim::rpc_sync,
335 prim::rpc_remote,
336 prim::Enter,
337 prim::Exit,
338 prim::FallbackGraph,
339 };
340
341 // Operators that should not be used by alias analysis
342 const static std::unordered_set<Symbol> purposefully_not_handled = {
343 prim::Load,
344 prim::Store,
345 prim::Drop,
346 at::onnx::Reshape,
347 at::onnx::Shape,
348 prim::AutogradAdd,
349 };
350
351 return handled.count(symbol) || purposefully_not_handled.count(symbol);
352 }
353
registerOperator(Operator && op)354 void registerOperator(Operator&& op) {
355 if (op.schema().is_varret()) {
356 Symbol s = Symbol::fromQualString(op.schema().name());
357 if (!printerHasSpecialCaseFor(s)) {
358 AT_ERROR(
359 "Missing special case in python printer for non-schematized"
360 " operator ",
361 op.schema().name(),
362 ". File a bug to add a case for this operator.\n");
363 }
364 if (aliasAnalysisHasSpecialCaseFor(s) &&
365 op.aliasAnalysisKind() == AliasAnalysisKind::CONSERVATIVE) {
366 AT_ERROR(
367 "Conflict in special casing in alias analysis for non-schematized"
368 " operator ",
369 op.schema().name(),
370 ". File a bug to add a case for this operator.\n");
371 }
372 if (aliasAnalysisHasSpecialCaseFor(s) &&
373 op.aliasAnalysisKind() == AliasAnalysisKind::FROM_SCHEMA) {
374 AT_ERROR(
375 "The operator ",
376 op.schema().name(),
377 " is special cased and cannot use explicit alias analysis.");
378 }
379 }
380 getRegistry().registerOperator(std::move(op));
381 }
382
deregisterOperator(const FunctionSchema & schema)383 void deregisterOperator(const FunctionSchema& schema) {
384 getRegistry().deregisterOperator(schema);
385 }
386
getAllOperators()387 const std::vector<std::shared_ptr<Operator>> getAllOperators() {
388 return getRegistry().getAllOperators();
389 }
390
getAllOperatorsFor(Symbol name)391 const std::vector<std::shared_ptr<Operator>>& getAllOperatorsFor(Symbol name) {
392 return getRegistry().getOperators(name);
393 }
394
getAllSortedOperatorsFor(Symbol name)395 std::vector<std::shared_ptr<Operator>> getAllSortedOperatorsFor(Symbol name) {
396 const auto& unsortedOps = getAllOperatorsFor(name);
397 // Depending on the order of registration, aten or jit ops may be
398 // registered first. This sorting is helpful in cases where
399 // deterministic (i.e. not dependent on build config) behavior is
400 // desired; e.g. torch.ops.aten.* uses this function, and tries to
401 // find the "first" op that matches input args. Without the sorting,
402 // the "first" op may change depending on registration order.
403 std::vector<std::shared_ptr<Operator>> sortedOps;
404 sortedOps.reserve(unsortedOps.size());
405 std::copy_if(
406 unsortedOps.begin(),
407 unsortedOps.end(),
408 std::back_inserter(sortedOps),
409 [](const std::shared_ptr<Operator>& op) { return op->isC10Op(); });
410 std::copy_if(
411 unsortedOps.begin(),
412 unsortedOps.end(),
413 std::back_inserter(sortedOps),
414 [](const std::shared_ptr<Operator>& op) { return !op->isC10Op(); });
415 return sortedOps;
416 }
417
findOperatorFor(const c10::OperatorName & full_name)418 std::shared_ptr<Operator> findOperatorFor(const c10::OperatorName& full_name) {
419 for (const auto& op :
420 getRegistry().getOperators(Symbol::fromQualString(full_name.name))) {
421 if (op->schema().overload_name() == full_name.overload_name) {
422 return op;
423 }
424 }
425 return nullptr;
426 }
427
findSimilarOperators(Symbol input_op)428 std::vector<Symbol> findSimilarOperators(Symbol input_op) {
429 return getRegistry().findSimilarOperators(input_op);
430 }
431
getOperatorForLiteral(const char * signature)432 std::shared_ptr<Operator> getOperatorForLiteral(const char* signature) {
433 return getRegistry().lookupByLiteral(signature);
434 }
435
canonicalSchemaString(const FunctionSchema & schema)436 std::string canonicalSchemaString(const FunctionSchema& schema) {
437 std::string out = schema.name();
438 out.push_back('(');
439
440 bool seen_kwarg_only = false;
441 for (const auto i : c10::irange(schema.arguments().size())) {
442 if (i > 0) {
443 out += ", ";
444 }
445 if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) {
446 out += "*, ";
447 seen_kwarg_only = true;
448 }
449 const auto& arg = schema.arguments()[i];
450 out += arg.type()->str();
451 out.push_back(' ');
452 out += arg.name();
453 }
454
455 out += ") -> ";
456 if (schema.returns().size() == 1) {
457 out += schema.returns().at(0).type()->str();
458 } else if (schema.returns().size() > 1) {
459 out.push_back('(');
460 for (const auto i : c10::irange(schema.returns().size())) {
461 if (i > 0) {
462 out += ", ";
463 }
464 out += schema.returns()[i].type()->str();
465 }
466 out.push_back(')');
467 }
468 return out;
469 }
470
471 } // namespace torch::jit
472