xref: /aosp_15_r20/external/executorch/runtime/kernel/operator_registry.h (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <cstring>
12 
13 #include <executorch/runtime/core/array_ref.h>
14 #include <executorch/runtime/core/error.h>
15 #include <executorch/runtime/core/evalue.h>
16 #include <executorch/runtime/core/exec_aten/exec_aten.h>
17 #include <executorch/runtime/core/result.h>
18 #include <executorch/runtime/core/span.h>
19 #include <executorch/runtime/platform/compiler.h>
20 #include <executorch/runtime/platform/platform.h>
21 
22 // Debug switch for operator registry
23 #if defined(ET_OP_REGISTRY_DEBUG)
24 #include <ostream>
25 #endif
26 
27 #define ET_LOG_KERNEL_KEY(k)      \
28   ET_LOG(                         \
29       Error,                      \
30       "key: %s, is_fallback: %s", \
31       k.data(),                   \
32       k.is_fallback() ? "true" : "false");
33 #define ET_LOG_TENSOR_META(meta_list)                                 \
34   for (const auto& meta : meta_list) {                                \
35     ET_LOG(Error, "dtype: %d | dim order: [", int(meta.dtype_));      \
36     for (int i = 0; i < meta.dim_order_.size(); i++) {                \
37       ET_LOG(Error, "%d,", static_cast<int32_t>(meta.dim_order_[i])); \
38     }                                                                 \
39     ET_LOG(Error, "]");                                               \
40   }
41 
42 namespace executorch {
43 namespace runtime {
44 
45 class KernelRuntimeContext; // Forward declaration
46 using OpFunction = void (*)(KernelRuntimeContext&, EValue**);
47 
48 /**
49  * Dtype and dim order metadata for a Tensor argument to an operator.
50  * Used by the Executor to hold the tensor metadata info and retrieve kernel.
51  */
52 struct TensorMeta {
53   executorch::aten::ScalarType dtype_;
54   Span<executorch::aten::DimOrderType> dim_order_;
55 
56   TensorMeta() = default;
TensorMetaTensorMeta57   TensorMeta(
58       executorch::aten::ScalarType dtype,
59       Span<executorch::aten::DimOrderType> order)
60       : dtype_(dtype), dim_order_(order) {}
61 
62   bool operator==(const TensorMeta& other) const {
63     return this->equals(other);
64   }
65 
66   bool operator!=(const TensorMeta& other) const {
67     return !this->equals(other);
68   }
69 
equalsTensorMeta70   bool equals(const TensorMeta& other) const {
71     if (dtype_ != other.dtype_) {
72       return false;
73     }
74     if (dim_order_.size() != other.dim_order_.size()) {
75       return false;
76     }
77     for (int i = 0; i < dim_order_.size(); i++) {
78       if (dim_order_[i] != other.dim_order_[i]) {
79         return false;
80       }
81     }
82     return true;
83   }
84 
85 #if defined(ET_OP_REGISTRY_DEBUG)
86   friend std::ostream& operator<<(std::ostream& os, const TensorMeta& meta) {
87     os << "dtype: " << int(meta.dtype_) << " | dim order: [";
88     for (int i = 0; i < meta.dim_order_.size(); i++) {
89       os << static_cast<int32_t>(meta.dim_order_[i]) << ", ";
90     }
91     os << "]";
92     return os;
93   }
94 #endif
95 };
96 
97 /**
98  * Describes which dtype & dim order specialized kernel to be bound to an
99  * operator. If `is_fallback_` is true, it means this kernel can be used as a
100  * fallback, if false, it means this kernel can only be used if all the
101  * `TensorMeta` are matched. Fallback means this kernel will be used for
102  * all input tensor dtypes and dim orders, if the specialized kernel is not
103  * registered.
104  *
105  * The format of a kernel key data is a string:
106  *                              "v<version>/<tensor_meta>|<tensor_meta>..."
107  * Size: Up to 691               1    1    1     (42     +1) * 16
108  *           Assuming max number of tensors is 16               ^
109  * Kernel key version is v1 for now. If the kernel key format changes,
110  * update the version to avoid breaking pre-existing kernel keys.
111  * Example: v1/7;0,1,2,3
112  * The kernel key has only one tensor: a double tensor with dimension 0, 1, 2, 3
113  *
114  * Each tensor_meta has the following format: "<dtype>;<dim_order,...>"
115  * Size: Up to 42                               1-2   1    24 (1 byte for 0-9; 2
116  * for 10-15) + 15 commas Assuming that the max number of dims is 16 ^ Example:
117  * 7;0,1,2,3 for [double; 0, 1, 2, 3]
118  *
119  * IMPORTANT:
120  * Users should not construct a kernel key manually. Instead, it should be
121  * generated from kernel yaml.
122  */
123 struct KernelKey {
124  public:
KernelKeyKernelKey125   KernelKey() : is_fallback_(true) {}
126 
KernelKeyKernelKey127   /* implicit */ KernelKey(const char* kernel_key_data)
128       : kernel_key_data_(kernel_key_data), is_fallback_(false) {}
129 
130   constexpr static int MAX_SIZE = 691;
131 
132   bool operator==(const KernelKey& other) const {
133     return this->equals(other);
134   }
135 
136   bool operator!=(const KernelKey& other) const {
137     return !this->equals(other);
138   }
139 
equalsKernelKey140   bool equals(const KernelKey& other) const {
141     if (is_fallback_ != other.is_fallback_) {
142       return false;
143     }
144     if (is_fallback_) {
145       return true;
146     }
147     return strncmp(kernel_key_data_, other.kernel_key_data_, MAX_SIZE) == 0;
148   }
149 
is_fallbackKernelKey150   bool is_fallback() const {
151     return is_fallback_;
152   }
153 
dataKernelKey154   const char* data() const {
155     return kernel_key_data_;
156   }
157 
158 #if defined(ET_OP_REGISTRY_DEBUG)
159   friend std::ostream& operator<<(std::ostream& os, const KernelKey& key) {
160     os << key.kernel_key_data_ << std::endl;
161     return os;
162   }
163 #endif
164 
165  private:
166   const char* kernel_key_data_ = nullptr;
167   bool is_fallback_;
168 };
169 
170 /**
171  * Struct that bundles a kernel key, a function and an op name together. An
172  * `Operator` may have more than one `Kernel` (maximum kMaxNumOfKernelPerOp) and
173  * they should have the same op name and different kernel key. A "fallback"
174  * kernel may or may not live in an `Operator`.
175  */
176 struct Kernel {
177   const char* name_;
178   // String representation of kernel key, with the same format as
179   // KernelKey.to_string_representation()
180   // Data is not owned by the Kernel struct.
181   KernelKey kernel_key_;
182   OpFunction op_;
183   /**
184    * We are doing a copy of the string pointer instead of duplicating the string
185    * itself, we require the lifetime of the operator name to be at least as long
186    * as the operator registry.
187    */
KernelKernel188   explicit Kernel(const char* name, OpFunction func) : name_(name), op_(func) {}
189 
KernelKernel190   explicit Kernel(const char* name, KernelKey key, OpFunction func)
191       : name_(name), kernel_key_(key), op_(func) {}
192 
KernelKernel193   Kernel() {}
194 };
195 
196 namespace internal {
197 void make_kernel_key_string(Span<const TensorMeta> key, char* buf);
198 } // namespace internal
199 
200 /**
201  * Checks whether an operator exists with a given name and TensorMeta list. When
202  * TensorMeta is empty, it means this op does not have specialized kernels, so
203  * it checks whether it has any fallback kernels.
204  */
205 bool registry_has_op_function(
206     const char* name,
207     Span<const TensorMeta> meta_list = {});
208 
209 /**
210  * Returns the operator with a given name and TensorMeta list, if present.
211  */
212 ::executorch::runtime::Result<OpFunction> get_op_function_from_registry(
213     const char* name,
214     Span<const TensorMeta> meta_list = {});
215 
216 /**
217  * Returns all registered kernels.
218  */
219 Span<const Kernel> get_registered_kernels();
220 
221 /**
222  * Registers the provided kernels.
223  *
224  * @param[in] kernels Kernel objects to register.
225  * @retval Error::Ok always. Panics on error. This function needs to return a
226  *     non-void type to run at static initialization time.
227  */
228 ET_NODISCARD Error register_kernels(const Span<const Kernel>);
229 
230 /**
231  * Registers a single kernel.
232  *
233  * @param[in] kernel Kernel object to register.
234  * @retval Error::Ok always. Panics on error. This function needs to return a
235  *     non-void type to run at static initialization time.
236  */
register_kernel(const Kernel & kernel)237 ET_NODISCARD inline Error register_kernel(const Kernel& kernel) {
238   return register_kernels({&kernel, 1});
239 };
240 
241 } // namespace runtime
242 } // namespace executorch
243 
244 namespace torch {
245 namespace executor {
246 // TODO(T197294990): Remove these deprecated aliases once all users have moved
247 // to the new `::executorch` namespaces.
248 using ::executorch::runtime::Kernel;
249 using ::executorch::runtime::KernelKey;
250 using ::executorch::runtime::KernelRuntimeContext;
251 using ::executorch::runtime::OpFunction;
252 using ::executorch::runtime::TensorMeta;
253 using KernelRuntimeContext = ::executorch::runtime::KernelRuntimeContext;
254 
register_kernels(ArrayRef<Kernel> kernels)255 inline ::executorch::runtime::Error register_kernels(ArrayRef<Kernel> kernels) {
256   return ::executorch::runtime::register_kernels(
257       {kernels.data(), kernels.size()});
258 }
259 inline OpFunction getOpsFn(
260     const char* name,
261     ArrayRef<TensorMeta> meta_list = {}) {
262   auto result = ::executorch::runtime::get_op_function_from_registry(
263       name, {meta_list.data(), meta_list.size()});
264   ET_CHECK(result.ok()); // get_op_function_from_registry() logs details.
265   return *result;
266 }
267 inline bool hasOpsFn(const char* name, ArrayRef<TensorMeta> meta_list = {}) {
268   return ::executorch::runtime::registry_has_op_function(
269       name, {meta_list.data(), meta_list.size()});
270 }
get_kernels()271 inline ArrayRef<Kernel> get_kernels() {
272   Span<const Kernel> kernels = ::executorch::runtime::get_registered_kernels();
273   return ArrayRef<Kernel>(kernels.data(), kernels.size());
274 }
275 } // namespace executor
276 } // namespace torch
277