xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/op_registration/op_registration.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/macros/Macros.h>
2 
3 #include <ATen/core/dispatch/Dispatcher.h>
4 #include <ATen/core/op_registration/op_allowlist.h>
5 #include <ATen/core/op_registration/op_registration.h>
6 #if !defined(CAFFE2_IS_XPLAT_BUILD)
7 #include <torch/csrc/jit/frontend/function_schema_parser.h>
8 #endif
9 
10 namespace c10 {
11 namespace impl {
build_feature_required_feature_not_available(const char * feature)12 void build_feature_required_feature_not_available(const char* feature) {
13   TORCH_CHECK(
14       false,
15       "Required feature '" + std::string(feature) + "' is not available");
16 }
17 } // namespace impl
18 
19 static_assert(std::is_nothrow_move_constructible<
20               std::optional<RegistrationHandleRAII>>::value);
21 static_assert(std::is_nothrow_move_assignable<
22               std::optional<RegistrationHandleRAII>>::value);
23 
checkSchemaAndRegisterOp_(Options && options)24 void RegisterOperators::checkSchemaAndRegisterOp_(Options&& options) {
25   TORCH_CHECK(
26       options.schemaOrName_.has_value(),
27       "In operator registration: Tried to register an operator without specifying a schema or operator name.");
28   if (options.schemaOrName_->index() == 1) {
29     // schema was explicitly specified.
30 
31     checkNoDuplicateKernels_(options);
32 
33     registerOp_(std::move(options));
34   } else {
35     // schema wasn't explicitly specified. Take the inferred schema for
36     // registering the op.
37 
38     OperatorName name =
39         std::get<OperatorName>(std::move(*options.schemaOrName_));
40     FunctionSchema inferred_schema = inferSchemaFromKernels_(name, options);
41 
42     options.schemaOrName_ = FunctionSchema(
43         std::move(name.name),
44         std::move(name.overload_name),
45         inferred_schema.arguments(),
46         inferred_schema.returns(),
47         inferred_schema.is_vararg(),
48         inferred_schema.is_varret());
49 
50     checkNoDuplicateKernels_(options);
51 
52     // This would have unexpected behavior since an inferred schema will not
53     // have aliasing annotations.
54     TORCH_CHECK(
55         options.aliasAnalysisKind_ != AliasAnalysisKind::FROM_SCHEMA,
56         "In operator registration: Tried to register operator ",
57         std::get<FunctionSchema>(options.schemaOrName_.value()),
58         " with AliasAnalysisKind::FROM_SCHEMA, but the schema is inferred.");
59 
60     // Register all kernels with the schema we inferred
61     registerOp_(std::move(options));
62   }
63 }
64 
inferSchemaFromKernels_(const OperatorName & opName,const RegisterOperators::Options & options)65 c10::FunctionSchema RegisterOperators::inferSchemaFromKernels_(
66     const OperatorName& opName,
67     const RegisterOperators::Options& options) {
68   TORCH_CHECK(
69       !options.kernels.empty(),
70       "Cannot infer operator schema in registration of operator ",
71       opName,
72       " because there is no kernel specified.");
73 
74   std::optional<FunctionSchema> inferred_schema = std::nullopt;
75   for (const auto& kernel : options.kernels) {
76     if (nullptr != kernel.inferred_function_schema.get()) {
77       if (!inferred_schema.has_value()) {
78         inferred_schema = *kernel.inferred_function_schema;
79         break;
80       }
81     }
82   }
83   TORCH_CHECK(
84       inferred_schema.has_value(),
85       "Cannot infer operator schema for this kind of kernel in registration of operator ",
86       opName,
87       ". Please explicitly specify the operator schema or specify at least one kernel for which we can infer the schema.");
88 
89   return *inferred_schema;
90 }
91 
checkNoDuplicateKernels_(const Options & options)92 void RegisterOperators::checkNoDuplicateKernels_(const Options& options) {
93   std::unordered_set<DispatchKey> dispatch_keys;
94   bool has_catchall_kernel = false;
95 
96   for (const auto& kernel : options.kernels) {
97     if (kernel.dispatch_key.has_value()) {
98       TORCH_CHECK(
99           0 == dispatch_keys.count(*kernel.dispatch_key),
100           "In operator registration: Tried to register multiple kernels with same dispatch key ",
101           *kernel.dispatch_key,
102           " for operator schema ",
103           toString(std::get<FunctionSchema>(options.schemaOrName_.value())));
104       dispatch_keys.insert(*kernel.dispatch_key);
105     } else {
106       TORCH_CHECK(
107           !has_catchall_kernel,
108           "In operator registration: Tried to register multiple catch-all kernels for operator schema ",
109           toString(std::get<FunctionSchema>(options.schemaOrName_.value())));
110       has_catchall_kernel = true;
111     }
112   }
113 }
114 
registerOp_(Options && options)115 void RegisterOperators::registerOp_(Options&& options) {
116   FunctionSchema schema =
117       std::get<FunctionSchema>(std::move(options.schemaOrName_.value()));
118 
119   // HACK: bong in the alias analysis kind from the legacy API directly
120   // into schema
121   if (options.aliasAnalysisKind_.has_value()) {
122     schema.setAliasAnalysis(*options.aliasAnalysisKind_);
123   }
124 
125   OperatorName op_name = schema.operator_name();
126 
127   registrars_.emplace_back(Dispatcher::singleton().registerDef(
128       std::move(schema), "registered by RegisterOperators"));
129 
130   for (auto& kernel : options.kernels) {
131     registrars_.emplace_back(Dispatcher::singleton().registerImpl(
132         op_name,
133         kernel.dispatch_key,
134         std::move(kernel.func),
135         kernel.cpp_signature,
136         std::move(kernel.inferred_function_schema),
137         "registered by RegisterOperators"));
138   }
139 }
140 
141 } // namespace c10
142