1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #pragma once
10
11 #include <cstring>
12
13 #include <executorch/runtime/core/array_ref.h>
14 #include <executorch/runtime/core/error.h>
15 #include <executorch/runtime/core/evalue.h>
16 #include <executorch/runtime/core/exec_aten/exec_aten.h>
17 #include <executorch/runtime/core/result.h>
18 #include <executorch/runtime/core/span.h>
19 #include <executorch/runtime/platform/compiler.h>
20 #include <executorch/runtime/platform/platform.h>
21
22 // Debug switch for operator registry
23 #if defined(ET_OP_REGISTRY_DEBUG)
24 #include <ostream>
25 #endif
26
27 #define ET_LOG_KERNEL_KEY(k) \
28 ET_LOG( \
29 Error, \
30 "key: %s, is_fallback: %s", \
31 k.data(), \
32 k.is_fallback() ? "true" : "false");
33 #define ET_LOG_TENSOR_META(meta_list) \
34 for (const auto& meta : meta_list) { \
35 ET_LOG(Error, "dtype: %d | dim order: [", int(meta.dtype_)); \
36 for (int i = 0; i < meta.dim_order_.size(); i++) { \
37 ET_LOG(Error, "%d,", static_cast<int32_t>(meta.dim_order_[i])); \
38 } \
39 ET_LOG(Error, "]"); \
40 }
41
42 namespace executorch {
43 namespace runtime {
44
45 class KernelRuntimeContext; // Forward declaration
46 using OpFunction = void (*)(KernelRuntimeContext&, EValue**);
47
48 /**
49 * Dtype and dim order metadata for a Tensor argument to an operator.
50 * Used by the Executor to hold the tensor metadata info and retrieve kernel.
51 */
52 struct TensorMeta {
53 executorch::aten::ScalarType dtype_;
54 Span<executorch::aten::DimOrderType> dim_order_;
55
56 TensorMeta() = default;
TensorMetaTensorMeta57 TensorMeta(
58 executorch::aten::ScalarType dtype,
59 Span<executorch::aten::DimOrderType> order)
60 : dtype_(dtype), dim_order_(order) {}
61
62 bool operator==(const TensorMeta& other) const {
63 return this->equals(other);
64 }
65
66 bool operator!=(const TensorMeta& other) const {
67 return !this->equals(other);
68 }
69
equalsTensorMeta70 bool equals(const TensorMeta& other) const {
71 if (dtype_ != other.dtype_) {
72 return false;
73 }
74 if (dim_order_.size() != other.dim_order_.size()) {
75 return false;
76 }
77 for (int i = 0; i < dim_order_.size(); i++) {
78 if (dim_order_[i] != other.dim_order_[i]) {
79 return false;
80 }
81 }
82 return true;
83 }
84
85 #if defined(ET_OP_REGISTRY_DEBUG)
86 friend std::ostream& operator<<(std::ostream& os, const TensorMeta& meta) {
87 os << "dtype: " << int(meta.dtype_) << " | dim order: [";
88 for (int i = 0; i < meta.dim_order_.size(); i++) {
89 os << static_cast<int32_t>(meta.dim_order_[i]) << ", ";
90 }
91 os << "]";
92 return os;
93 }
94 #endif
95 };
96
97 /**
98 * Describes which dtype & dim order specialized kernel to be bound to an
99 * operator. If `is_fallback_` is true, it means this kernel can be used as a
100 * fallback, if false, it means this kernel can only be used if all the
101 * `TensorMeta` are matched. Fallback means this kernel will be used for
102 * all input tensor dtypes and dim orders, if the specialized kernel is not
103 * registered.
104 *
105 * The format of a kernel key data is a string:
106 * "v<version>/<tensor_meta>|<tensor_meta>..."
107 * Size: Up to 691 1 1 1 (42 +1) * 16
108 * Assuming max number of tensors is 16 ^
109 * Kernel key version is v1 for now. If the kernel key format changes,
110 * update the version to avoid breaking pre-existing kernel keys.
111 * Example: v1/7;0,1,2,3
112 * The kernel key has only one tensor: a double tensor with dimension 0, 1, 2, 3
113 *
114 * Each tensor_meta has the following format: "<dtype>;<dim_order,...>"
115 * Size: Up to 42 1-2 1 24 (1 byte for 0-9; 2
116 * for 10-15) + 15 commas Assuming that the max number of dims is 16 ^ Example:
117 * 7;0,1,2,3 for [double; 0, 1, 2, 3]
118 *
119 * IMPORTANT:
120 * Users should not construct a kernel key manually. Instead, it should be
121 * generated from kernel yaml.
122 */
123 struct KernelKey {
124 public:
KernelKeyKernelKey125 KernelKey() : is_fallback_(true) {}
126
KernelKeyKernelKey127 /* implicit */ KernelKey(const char* kernel_key_data)
128 : kernel_key_data_(kernel_key_data), is_fallback_(false) {}
129
130 constexpr static int MAX_SIZE = 691;
131
132 bool operator==(const KernelKey& other) const {
133 return this->equals(other);
134 }
135
136 bool operator!=(const KernelKey& other) const {
137 return !this->equals(other);
138 }
139
equalsKernelKey140 bool equals(const KernelKey& other) const {
141 if (is_fallback_ != other.is_fallback_) {
142 return false;
143 }
144 if (is_fallback_) {
145 return true;
146 }
147 return strncmp(kernel_key_data_, other.kernel_key_data_, MAX_SIZE) == 0;
148 }
149
is_fallbackKernelKey150 bool is_fallback() const {
151 return is_fallback_;
152 }
153
dataKernelKey154 const char* data() const {
155 return kernel_key_data_;
156 }
157
158 #if defined(ET_OP_REGISTRY_DEBUG)
159 friend std::ostream& operator<<(std::ostream& os, const KernelKey& key) {
160 os << key.kernel_key_data_ << std::endl;
161 return os;
162 }
163 #endif
164
165 private:
166 const char* kernel_key_data_ = nullptr;
167 bool is_fallback_;
168 };
169
170 /**
171 * Struct that bundles a kernel key, a function and an op name together. An
172 * `Operator` may have more than one `Kernel` (maximum kMaxNumOfKernelPerOp) and
173 * they should have the same op name and different kernel key. A "fallback"
174 * kernel may or may not live in an `Operator`.
175 */
176 struct Kernel {
177 const char* name_;
178 // String representation of kernel key, with the same format as
179 // KernelKey.to_string_representation()
180 // Data is not owned by the Kernel struct.
181 KernelKey kernel_key_;
182 OpFunction op_;
183 /**
184 * We are doing a copy of the string pointer instead of duplicating the string
185 * itself, we require the lifetime of the operator name to be at least as long
186 * as the operator registry.
187 */
KernelKernel188 explicit Kernel(const char* name, OpFunction func) : name_(name), op_(func) {}
189
KernelKernel190 explicit Kernel(const char* name, KernelKey key, OpFunction func)
191 : name_(name), kernel_key_(key), op_(func) {}
192
KernelKernel193 Kernel() {}
194 };
195
196 namespace internal {
197 void make_kernel_key_string(Span<const TensorMeta> key, char* buf);
198 } // namespace internal
199
200 /**
201 * Checks whether an operator exists with a given name and TensorMeta list. When
202 * TensorMeta is empty, it means this op does not have specialized kernels, so
203 * it checks whether it has any fallback kernels.
204 */
205 bool registry_has_op_function(
206 const char* name,
207 Span<const TensorMeta> meta_list = {});
208
209 /**
210 * Returns the operator with a given name and TensorMeta list, if present.
211 */
212 ::executorch::runtime::Result<OpFunction> get_op_function_from_registry(
213 const char* name,
214 Span<const TensorMeta> meta_list = {});
215
216 /**
217 * Returns all registered kernels.
218 */
219 Span<const Kernel> get_registered_kernels();
220
221 /**
222 * Registers the provided kernels.
223 *
224 * @param[in] kernels Kernel objects to register.
225 * @retval Error::Ok always. Panics on error. This function needs to return a
226 * non-void type to run at static initialization time.
227 */
228 ET_NODISCARD Error register_kernels(const Span<const Kernel>);
229
230 /**
231 * Registers a single kernel.
232 *
233 * @param[in] kernel Kernel object to register.
234 * @retval Error::Ok always. Panics on error. This function needs to return a
235 * non-void type to run at static initialization time.
236 */
register_kernel(const Kernel & kernel)237 ET_NODISCARD inline Error register_kernel(const Kernel& kernel) {
238 return register_kernels({&kernel, 1});
239 };
240
241 } // namespace runtime
242 } // namespace executorch
243
244 namespace torch {
245 namespace executor {
246 // TODO(T197294990): Remove these deprecated aliases once all users have moved
247 // to the new `::executorch` namespaces.
248 using ::executorch::runtime::Kernel;
249 using ::executorch::runtime::KernelKey;
250 using ::executorch::runtime::KernelRuntimeContext;
251 using ::executorch::runtime::OpFunction;
252 using ::executorch::runtime::TensorMeta;
253 using KernelRuntimeContext = ::executorch::runtime::KernelRuntimeContext;
254
register_kernels(ArrayRef<Kernel> kernels)255 inline ::executorch::runtime::Error register_kernels(ArrayRef<Kernel> kernels) {
256 return ::executorch::runtime::register_kernels(
257 {kernels.data(), kernels.size()});
258 }
259 inline OpFunction getOpsFn(
260 const char* name,
261 ArrayRef<TensorMeta> meta_list = {}) {
262 auto result = ::executorch::runtime::get_op_function_from_registry(
263 name, {meta_list.data(), meta_list.size()});
264 ET_CHECK(result.ok()); // get_op_function_from_registry() logs details.
265 return *result;
266 }
267 inline bool hasOpsFn(const char* name, ArrayRef<TensorMeta> meta_list = {}) {
268 return ::executorch::runtime::registry_has_op_function(
269 name, {meta_list.data(), meta_list.size()});
270 }
get_kernels()271 inline ArrayRef<Kernel> get_kernels() {
272 Span<const Kernel> kernels = ::executorch::runtime::get_registered_kernels();
273 return ArrayRef<Kernel>(kernels.data(), kernels.size());
274 }
275 } // namespace executor
276 } // namespace torch
277