xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/operator.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/runtime/operator.h>
2 
3 #include <ATen/ATen.h>
4 #include <ATen/core/interned_strings.h>
5 #include <c10/util/irange.h>
6 #include <torch/csrc/jit/frontend/edit_distance.h>
7 
8 #include <queue>
9 #include <utility>
10 #include <vector>
11 
12 namespace torch::jit {
13 
14 namespace {
15 using OperatorMap =
16     std::unordered_map<Symbol, std::vector<std::shared_ptr<Operator>>>;
17 struct OperatorRegistry {
18  private:
19   std::mutex lock;
20   OperatorMap operators;
21   // list of operators whose schema have not yet been parsed, and must
22   // be registered before any call to lookup an operator
23   std::vector<std::shared_ptr<Operator>> to_register;
24   // Those two maps are used to implement lookupByLiteral, which is needed for
25   // the n->match(...) calls. Basically, every function schema is assigned a
26   // unique string you can use to match it. However, parsing those strings or
27   // comparing and hashing them character by character would be very slow, so we
28   // use a trick here! Every string literal in your program is guaranteed to
29   // have static storage duration and so its address won't change at runtime.
30   // This allows us to memoize answers for every pointer, which is done by the
31   // operators_by_sig_literal map. Still, this map is initially empty, and so we
32   // still need to do the complete string matching at the first time, which is
33   // implemented by performing a lookup in the operators_by_sig map.
34   std::unordered_map<std::string, std::shared_ptr<Operator>> operators_by_sig;
35   std::unordered_map<const char*, std::shared_ptr<Operator>>
36       operators_by_sig_literal;
37 
38   // Remember all registered operator names to check that they aren't
39   // registered a second time. Registering an op multiple times is
40   // fragile because it might depend on static initialization order
41   // which one is picked at runtime.
42 #ifdef C10_MOBILE
43   std::unordered_set<c10::OperatorName> registered_operator_names;
44 #endif
45 
46   // XXX - caller must be holding lock
registerPendingOperatorstorch::jit::__anond11ff3ee0111::OperatorRegistry47   void registerPendingOperators() {
48     for (const auto& op : to_register) {
49       Symbol sym = Symbol::fromQualString(op->schema().name());
50       operators[sym].push_back(op);
51       operators_by_sig[canonicalSchemaString(op->schema())] = op;
52     }
53     to_register.clear();
54   }
55 
56  public:
registerOperatortorch::jit::__anond11ff3ee0111::OperatorRegistry57   void registerOperator(Operator&& op) {
58     std::lock_guard<std::mutex> guard(lock);
59 #ifdef C10_MOBILE
60     TORCH_INTERNAL_ASSERT(
61         0 == registered_operator_names.count(op.schema().operator_name()),
62         "Tried to register operator \"",
63         op.schema(),
64         "\" to JIT but the operator name was already registered before. Please add or change the overload name.");
65     registered_operator_names.insert(op.schema().operator_name());
66 #endif
67     to_register.push_back(std::make_shared<Operator>(std::move(op)));
68   }
69 
deregisterOperatortorch::jit::__anond11ff3ee0111::OperatorRegistry70   void deregisterOperator(const FunctionSchema& schema) {
71     Symbol sym = Symbol::fromQualString(schema.name());
72     auto sig = canonicalSchemaString(schema);
73 
74     std::lock_guard<std::mutex> guard(lock);
75 #ifdef C10_MOBILE
76     TORCH_INTERNAL_ASSERT(
77         1 == registered_operator_names.count(schema.operator_name()),
78         "Tried to remove operator ",
79         schema,
80         " from JIT but it wasn't found.");
81     registered_operator_names.erase(schema.operator_name());
82 #endif
83     // Try removing from pending operators list first
84     auto pending_it = to_register.begin();
85     while (pending_it != to_register.end() && (*pending_it)->schema() != schema)
86       ++pending_it;
87 
88     if (pending_it != to_register.end()) {
89       to_register.erase(pending_it);
90       return;
91     }
92 
93     // Remove operator from signature map
94     auto sig_it = operators_by_sig.find(sig);
95     if (sig_it == operators_by_sig.end()) {
96       return;
97     }
98 
99     operators_by_sig.erase(sig_it);
100 
101     // Remove operator from symbol map
102     auto op_it = operators.find(sym);
103     TORCH_CHECK(
104         op_it != operators.end(),
105         "operator with signature ",
106         sig,
107         " is missing from symbol registry");
108 
109     auto& op_vec = op_it->second;
110     auto it = op_vec.begin();
111     while (it != op_vec.end() && (*it)->schema() != schema)
112       ++it;
113     if (it != op_vec.end()) {
114       op_vec.erase(it);
115     }
116     if (op_vec.empty()) {
117       operators.erase(op_it);
118     }
119   }
120 
lookupByLiteraltorch::jit::__anond11ff3ee0111::OperatorRegistry121   const std::shared_ptr<Operator>& lookupByLiteral(const char* name) {
122     std::lock_guard<std::mutex> guard(lock);
123     registerPendingOperators();
124     auto it = operators_by_sig_literal.find(name);
125     if (it == operators_by_sig_literal.end()) {
126       auto op_ptr_it =
127           operators_by_sig.find(canonicalSchemaString(parseSchema(name)));
128       // Handy debugging code that dumps all operators we know about on mismatch
129 #if 0
130       if (op_ptr_it == operators_by_sig.end()) {
131         for (auto & entry : operators_by_sig) {
132           std::cout << entry.first << std::endl;
133         }
134       }
135 #endif
136       TORCH_CHECK(
137           op_ptr_it != operators_by_sig.end(),
138           "Couldn't find an operator for ",
139           name,
140           ". Do you have to update a set of hardcoded JIT ops?");
141       it = operators_by_sig_literal.emplace_hint(it, name, op_ptr_it->second);
142     }
143     return it->second;
144   }
145 
getOperatorstorch::jit::__anond11ff3ee0111::OperatorRegistry146   const std::vector<std::shared_ptr<Operator>>& getOperators(Symbol name) {
147     std::lock_guard<std::mutex> guard(lock);
148     registerPendingOperators();
149     static std::vector<std::shared_ptr<Operator>> empty;
150     auto it = operators.find(name);
151     if (it != operators.end())
152       return it->second;
153     return empty;
154   }
155 
findSimilarOperatorstorch::jit::__anond11ff3ee0111::OperatorRegistry156   std::vector<Symbol> findSimilarOperators(Symbol input_op) {
157     std::lock_guard<std::mutex> guard(lock);
158     registerPendingOperators();
159 
160     using EntryPair = std::pair<int64_t, Symbol>;
161     auto cmp = [](const EntryPair& lhs, const EntryPair& rhs) {
162       return lhs.first > rhs.first;
163     };
164 
165     std::priority_queue<EntryPair, std::vector<EntryPair>, decltype(cmp)>
166         rankings(cmp);
167     static constexpr size_t MAX_EDIT_DIST = 2u;
168     for (const auto& op : operators) {
169       auto edit_dist = ComputeEditDistance(
170           input_op.toQualString(), op.first.toQualString(), MAX_EDIT_DIST);
171       if (edit_dist <= MAX_EDIT_DIST) {
172         rankings.emplace(edit_dist, op.first);
173       }
174     }
175     std::vector<Symbol> ret;
176     while (!rankings.empty()) {
177       ret.push_back(rankings.top().second);
178       rankings.pop();
179     }
180     return ret;
181   }
182 
getAllOperatorstorch::jit::__anond11ff3ee0111::OperatorRegistry183   const std::vector<std::shared_ptr<Operator>> getAllOperators() {
184     std::lock_guard<std::mutex> guard(lock);
185     registerPendingOperators();
186     std::vector<std::shared_ptr<Operator>> values;
187     values.clear();
188     for (auto& kv : operators) {
189       values.insert(values.end(), kv.second.begin(), kv.second.end());
190     }
191     return values;
192   }
193 };
194 
getRegistry()195 OperatorRegistry& getRegistry() {
196   static OperatorRegistry r;
197   return r;
198 }
199 
printerHasSpecialCaseFor(Symbol sym)200 bool printerHasSpecialCaseFor(Symbol sym) {
201   using namespace at;
202   // WARNING: by adding a value to this set, you are asserting
203   // that you have also added special handling of this symbol to
204   // the python_print.cpp. Not adding handling will cause import and export
205   // of modules with this new operator to fail. This is only required
206   // for operators without schema. Prefer registering your operator with
207   // schema to editing this list here. These cases should only be things
208   // that require special handling because they do not fit normal schema
209   const static std::unordered_set<Symbol> handled = {
210       prim::Constant,       prim::Uninitialized, prim::fork,
211       prim::awaitable,      prim::ListConstruct, prim::DictConstruct,
212       prim::ListUnpack,     prim::Print,         prim::PythonOp,
213       prim::TupleConstruct, prim::TupleIndex,    prim::TupleSlice,
214       prim::TupleUnpack,    prim::CreateObject,  prim::GetAttr,
215       prim::SetAttr,        prim::CallFunction,  prim::isinstance,
216       prim::unchecked_cast, prim::tolist,        prim::rpc_async,
217       prim::rpc_sync,       prim::rpc_remote};
218 
219   // WARNING: by adding a value to this set, you are asserting that your
220   // primitive is only ever added during optimization and does not need
221   // to be correctly printed for export (a process that happens before
222   // optimization passes run)
223   const static std::unordered_set<Symbol> unneeded = {
224       c10::onnx::Reshape, // only used in onnx
225       c10::onnx::Shape, // only used in onnx
226       prim::AutogradZero, // temporarily inserted by autograd
227       prim::AutogradAnyNonZero, // temporarily inserted by autograd
228       prim::AutogradAllNonZero, // temporarily inserted by autograd
229       prim::AutogradAllZero, // temporarily inserted by autograd
230       prim::AutogradAdd, // temporarily inserted by autograd
231       prim::ConstantChunk, // optimization pass adds it
232       prim::DifferentiableGraph, // optimization pass adds it,
233       prim::FunctionalGraph, // optimization pass adds it,
234       prim::ReductionSizes, // optimization pass (fuser) adds it
235       prim::BroadcastSizes, // optimization pass (fuser) adds it
236       prim::ChunkSizes, // optimization pass (fuser) adds it
237       prim::Drop, // used in interpreter only
238       prim::FusedConcat, // optimization pass adds it
239       prim::FusionGroup, // optimization pass adds it
240       prim::CudaFusionGroup, // optimization pass adds it
241       prim::CudaFusionGuard, // optimization pass adds it
242       prim::TensorExprGroup, // optimization pass adds it
243       prim::TensorExprDynamicGroup, // optimization pass adds it
244       prim::StaticSubgraph, // optimization pass adds it
245       prim::ConstantMKLDNNTensor, // optimization pass adds it
246       prim::BroadcastMKLDNNTensors, // optimization pass adds it
247       prim::oneDNNFusionGroup, // optimization pass adds it
248       prim::oneDNNFusionGuard, // optimization pass adds it
249       prim::StaticRuntimeCopyOuts, // used in SR only
250       prim::Load, // used in interpreter only
251       prim::MMTreeReduce, // used as an optimization
252       prim::MMBatchSide, // used as an optimization
253       prim::Store, // used in interpreter only
254       prim::profile, // used in interpreter only
255       prim::profile_ivalue, // used in interpreter only
256       prim::TypeCheck, // used in interpreter only
257       prim::RequiresGradCheck, // used in interpreter only
258       prim::FallbackGraph, // converted into prim::CallFunction
259 
260   };
261 
262   // These namespaces are required to have Python printers unless
263   // otherwise noted in unneeded.
264   const static std::unordered_set<Symbol> required_namespaces = {
265       c10::namespaces::prim,
266       c10::namespaces::aten,
267       c10::namespaces::onnx,
268   };
269 
270   return handled.count(sym) || unneeded.count(sym) ||
271       !required_namespaces.count(sym.ns());
272 }
273 
274 } // anonymous namespace
275 
aliasAnalysisHasSpecialCaseFor(Symbol symbol)276 bool aliasAnalysisHasSpecialCaseFor(Symbol symbol) {
277   using namespace at;
278   // WARNING: by adding a case to this list, you are asserting that you have
279   // added a case for the unschematized node in AliasDb::analyze
280   const static std::unordered_set<Symbol> handled = {
281       prim::If,
282       prim::Loop,
283       prim::FusionGroup,
284       prim::CudaFusionGroup,
285       prim::oneDNNFusionGroup,
286       prim::DifferentiableGraph,
287       prim::TensorExprGroup,
288       prim::TensorExprDynamicGroup,
289       prim::StaticSubgraph,
290       prim::FunctionalGraph,
291       prim::Constant,
292       prim::Uninitialized,
293       prim::DictConstruct,
294       prim::ListConstruct,
295       prim::TupleConstruct,
296       prim::AutogradZero,
297       prim::FusedConcat,
298       prim::GradOf,
299       prim::MMTreeReduce,
300       prim::MMBatchSide,
301       prim::BroadcastSizes,
302       prim::ChunkSizes,
303       prim::Closure,
304       prim::TupleUnpack,
305       prim::TupleIndex,
306       prim::TupleSlice,
307       prim::ListUnpack,
308       prim::PythonOp,
309       prim::ConstantChunk,
310       prim::BroadcastingChunk,
311       prim::MKLDNNGroup,
312       prim::ConstantMKLDNNTensor,
313       prim::BroadcastMKLDNNTensors,
314       prim::fork,
315       prim::awaitable,
316       prim::awaitable_nowait,
317       prim::awaitable_wait,
318       prim::CreateObject,
319       prim::AutogradAdd,
320       prim::GetAttr,
321       prim::SetAttr,
322       prim::profile,
323       prim::profile_ivalue,
324       prim::TypeCheck,
325       prim::RequiresGradCheck,
326       prim::Print,
327       prim::CallFunction,
328       prim::CallMethod,
329       aten::wait,
330       prim::isinstance,
331       prim::unchecked_cast,
332       prim::tolist,
333       prim::rpc_async,
334       prim::rpc_sync,
335       prim::rpc_remote,
336       prim::Enter,
337       prim::Exit,
338       prim::FallbackGraph,
339   };
340 
341   // Operators that should not be used by alias analysis
342   const static std::unordered_set<Symbol> purposefully_not_handled = {
343       prim::Load,
344       prim::Store,
345       prim::Drop,
346       at::onnx::Reshape,
347       at::onnx::Shape,
348       prim::AutogradAdd,
349   };
350 
351   return handled.count(symbol) || purposefully_not_handled.count(symbol);
352 }
353 
registerOperator(Operator && op)354 void registerOperator(Operator&& op) {
355   if (op.schema().is_varret()) {
356     Symbol s = Symbol::fromQualString(op.schema().name());
357     if (!printerHasSpecialCaseFor(s)) {
358       AT_ERROR(
359           "Missing special case in python printer for non-schematized"
360           " operator ",
361           op.schema().name(),
362           ". File a bug to add a case for this operator.\n");
363     }
364     if (aliasAnalysisHasSpecialCaseFor(s) &&
365         op.aliasAnalysisKind() == AliasAnalysisKind::CONSERVATIVE) {
366       AT_ERROR(
367           "Conflict in special casing in alias analysis for non-schematized"
368           " operator ",
369           op.schema().name(),
370           ". File a bug to add a case for this operator.\n");
371     }
372     if (aliasAnalysisHasSpecialCaseFor(s) &&
373         op.aliasAnalysisKind() == AliasAnalysisKind::FROM_SCHEMA) {
374       AT_ERROR(
375           "The operator ",
376           op.schema().name(),
377           " is special cased and cannot use explicit alias analysis.");
378     }
379   }
380   getRegistry().registerOperator(std::move(op));
381 }
382 
deregisterOperator(const FunctionSchema & schema)383 void deregisterOperator(const FunctionSchema& schema) {
384   getRegistry().deregisterOperator(schema);
385 }
386 
getAllOperators()387 const std::vector<std::shared_ptr<Operator>> getAllOperators() {
388   return getRegistry().getAllOperators();
389 }
390 
getAllOperatorsFor(Symbol name)391 const std::vector<std::shared_ptr<Operator>>& getAllOperatorsFor(Symbol name) {
392   return getRegistry().getOperators(name);
393 }
394 
getAllSortedOperatorsFor(Symbol name)395 std::vector<std::shared_ptr<Operator>> getAllSortedOperatorsFor(Symbol name) {
396   const auto& unsortedOps = getAllOperatorsFor(name);
397   // Depending on the order of registration, aten or jit ops may be
398   // registered first. This sorting is helpful in cases where
399   // deterministic (i.e. not dependent on build config) behavior is
400   // desired; e.g. torch.ops.aten.* uses this function, and tries to
401   // find the "first" op that matches input args. Without the sorting,
402   // the "first" op may change depending on registration order.
403   std::vector<std::shared_ptr<Operator>> sortedOps;
404   sortedOps.reserve(unsortedOps.size());
405   std::copy_if(
406       unsortedOps.begin(),
407       unsortedOps.end(),
408       std::back_inserter(sortedOps),
409       [](const std::shared_ptr<Operator>& op) { return op->isC10Op(); });
410   std::copy_if(
411       unsortedOps.begin(),
412       unsortedOps.end(),
413       std::back_inserter(sortedOps),
414       [](const std::shared_ptr<Operator>& op) { return !op->isC10Op(); });
415   return sortedOps;
416 }
417 
findOperatorFor(const c10::OperatorName & full_name)418 std::shared_ptr<Operator> findOperatorFor(const c10::OperatorName& full_name) {
419   for (const auto& op :
420        getRegistry().getOperators(Symbol::fromQualString(full_name.name))) {
421     if (op->schema().overload_name() == full_name.overload_name) {
422       return op;
423     }
424   }
425   return nullptr;
426 }
427 
findSimilarOperators(Symbol input_op)428 std::vector<Symbol> findSimilarOperators(Symbol input_op) {
429   return getRegistry().findSimilarOperators(input_op);
430 }
431 
getOperatorForLiteral(const char * signature)432 std::shared_ptr<Operator> getOperatorForLiteral(const char* signature) {
433   return getRegistry().lookupByLiteral(signature);
434 }
435 
canonicalSchemaString(const FunctionSchema & schema)436 std::string canonicalSchemaString(const FunctionSchema& schema) {
437   std::string out = schema.name();
438   out.push_back('(');
439 
440   bool seen_kwarg_only = false;
441   for (const auto i : c10::irange(schema.arguments().size())) {
442     if (i > 0) {
443       out += ", ";
444     }
445     if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) {
446       out += "*, ";
447       seen_kwarg_only = true;
448     }
449     const auto& arg = schema.arguments()[i];
450     out += arg.type()->str();
451     out.push_back(' ');
452     out += arg.name();
453   }
454 
455   out += ") -> ";
456   if (schema.returns().size() == 1) {
457     out += schema.returns().at(0).type()->str();
458   } else if (schema.returns().size() > 1) {
459     out.push_back('(');
460     for (const auto i : c10::irange(schema.returns().size())) {
461       if (i > 0) {
462         out += ", ";
463       }
464       out += schema.returns()[i].type()->str();
465     }
466     out.push_back(')');
467   }
468   return out;
469 }
470 
471 } // namespace torch::jit
472