1 #include <c10/macros/Macros.h>
2
3 #include <ATen/core/dispatch/Dispatcher.h>
4 #include <ATen/core/op_registration/op_allowlist.h>
5 #include <ATen/core/op_registration/op_registration.h>
6 #if !defined(CAFFE2_IS_XPLAT_BUILD)
7 #include <torch/csrc/jit/frontend/function_schema_parser.h>
8 #endif
9
10 namespace c10 {
11 namespace impl {
build_feature_required_feature_not_available(const char * feature)12 void build_feature_required_feature_not_available(const char* feature) {
13 TORCH_CHECK(
14 false,
15 "Required feature '" + std::string(feature) + "' is not available");
16 }
17 } // namespace impl
18
19 static_assert(std::is_nothrow_move_constructible<
20 std::optional<RegistrationHandleRAII>>::value);
21 static_assert(std::is_nothrow_move_assignable<
22 std::optional<RegistrationHandleRAII>>::value);
23
checkSchemaAndRegisterOp_(Options && options)24 void RegisterOperators::checkSchemaAndRegisterOp_(Options&& options) {
25 TORCH_CHECK(
26 options.schemaOrName_.has_value(),
27 "In operator registration: Tried to register an operator without specifying a schema or operator name.");
28 if (options.schemaOrName_->index() == 1) {
29 // schema was explicitly specified.
30
31 checkNoDuplicateKernels_(options);
32
33 registerOp_(std::move(options));
34 } else {
35 // schema wasn't explicitly specified. Take the inferred schema for
36 // registering the op.
37
38 OperatorName name =
39 std::get<OperatorName>(std::move(*options.schemaOrName_));
40 FunctionSchema inferred_schema = inferSchemaFromKernels_(name, options);
41
42 options.schemaOrName_ = FunctionSchema(
43 std::move(name.name),
44 std::move(name.overload_name),
45 inferred_schema.arguments(),
46 inferred_schema.returns(),
47 inferred_schema.is_vararg(),
48 inferred_schema.is_varret());
49
50 checkNoDuplicateKernels_(options);
51
52 // This would have unexpected behavior since an inferred schema will not
53 // have aliasing annotations.
54 TORCH_CHECK(
55 options.aliasAnalysisKind_ != AliasAnalysisKind::FROM_SCHEMA,
56 "In operator registration: Tried to register operator ",
57 std::get<FunctionSchema>(options.schemaOrName_.value()),
58 " with AliasAnalysisKind::FROM_SCHEMA, but the schema is inferred.");
59
60 // Register all kernels with the schema we inferred
61 registerOp_(std::move(options));
62 }
63 }
64
inferSchemaFromKernels_(const OperatorName & opName,const RegisterOperators::Options & options)65 c10::FunctionSchema RegisterOperators::inferSchemaFromKernels_(
66 const OperatorName& opName,
67 const RegisterOperators::Options& options) {
68 TORCH_CHECK(
69 !options.kernels.empty(),
70 "Cannot infer operator schema in registration of operator ",
71 opName,
72 " because there is no kernel specified.");
73
74 std::optional<FunctionSchema> inferred_schema = std::nullopt;
75 for (const auto& kernel : options.kernels) {
76 if (nullptr != kernel.inferred_function_schema.get()) {
77 if (!inferred_schema.has_value()) {
78 inferred_schema = *kernel.inferred_function_schema;
79 break;
80 }
81 }
82 }
83 TORCH_CHECK(
84 inferred_schema.has_value(),
85 "Cannot infer operator schema for this kind of kernel in registration of operator ",
86 opName,
87 ". Please explicitly specify the operator schema or specify at least one kernel for which we can infer the schema.");
88
89 return *inferred_schema;
90 }
91
checkNoDuplicateKernels_(const Options & options)92 void RegisterOperators::checkNoDuplicateKernels_(const Options& options) {
93 std::unordered_set<DispatchKey> dispatch_keys;
94 bool has_catchall_kernel = false;
95
96 for (const auto& kernel : options.kernels) {
97 if (kernel.dispatch_key.has_value()) {
98 TORCH_CHECK(
99 0 == dispatch_keys.count(*kernel.dispatch_key),
100 "In operator registration: Tried to register multiple kernels with same dispatch key ",
101 *kernel.dispatch_key,
102 " for operator schema ",
103 toString(std::get<FunctionSchema>(options.schemaOrName_.value())));
104 dispatch_keys.insert(*kernel.dispatch_key);
105 } else {
106 TORCH_CHECK(
107 !has_catchall_kernel,
108 "In operator registration: Tried to register multiple catch-all kernels for operator schema ",
109 toString(std::get<FunctionSchema>(options.schemaOrName_.value())));
110 has_catchall_kernel = true;
111 }
112 }
113 }
114
registerOp_(Options && options)115 void RegisterOperators::registerOp_(Options&& options) {
116 FunctionSchema schema =
117 std::get<FunctionSchema>(std::move(options.schemaOrName_.value()));
118
119 // HACK: bong in the alias analysis kind from the legacy API directly
120 // into schema
121 if (options.aliasAnalysisKind_.has_value()) {
122 schema.setAliasAnalysis(*options.aliasAnalysisKind_);
123 }
124
125 OperatorName op_name = schema.operator_name();
126
127 registrars_.emplace_back(Dispatcher::singleton().registerDef(
128 std::move(schema), "registered by RegisterOperators"));
129
130 for (auto& kernel : options.kernels) {
131 registrars_.emplace_back(Dispatcher::singleton().registerImpl(
132 op_name,
133 kernel.dispatch_key,
134 std::move(kernel.func),
135 kernel.cpp_signature,
136 std::move(kernel.inferred_function_schema),
137 "registered by RegisterOperators"));
138 }
139 }
140
141 } // namespace c10
142