xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/shim/README.md (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1This folder contains a convenience library called *tf-shim* over TF and TFLite
2op kernel APIs.
3
4## Summary
5
6This library creates a shim over the custom op APIs of TF and TFLite so the
7developer can write the custom op once with minimal binary or runtime overhead.
8
9An example usage is an input preprocessing op kernel that can be used in
10both TF and TFLite.
11
12## Background
13
14When there is a need to implement a logic that is not supported by the TF
15builtin ops the alternative is to build a custom op. If that op needs to
16run on-device then it needs to be written in C++ against the client API for
17custom ops.
18
19For example, feature processing especially for textual input in an ML model
20can involve operations that don't lend themselves well to vectorization and the
21code, if written as a C++ function, would be much shorter and more readable.
22
23However, Tensorflow and TFLite APIs for creating op kernels are, at the moment,
24not identical. This library offers a convenient way to write the kernel once and
25adapt it to both TF and TFLite with minimal binary and runtime overhead.
26
27## Implementation
28
29This folder contains two pieces:
30
311.  `TensorView` as a shim over `::tensorflow::Tensor` and `TfLiteTensor`
32
332.  `OpKernelShim` class which abstracts the TF and TFLite op kernel APIs.
34
35### TensorView
36
37This class is a *view* over an already allocated tensor in TF or TFLite without
38taking any ownership. In that sense it is similar to `std::string_view` but with
39the difference that the underlying buffer can be mutable.
40
41Example Usage:
42
43```
44::tensorflow::Tensor tf_tensor;
45auto t = TensorView::New(&tf_tensor);
46
47auto t_str_mat = t.As<::tensorflow::tstring, /*RANK=*/ 2>();
48t(0, 0) = "ab";
49t(0, 1) = "cde"
50
51
52auto t_buffer = t.Data<::tensorflow::tstring>();
53t[0] = "ab";
54t[1] = "cde"
55```
56
57```
58TfLiteTensor tflite_tensor;
59auto t = TensorView::New(&tflite_tensor);
60
61auto t_int_vec = t.As<int32, /*RANK=*/ 1>();
62t(0) = 123;
63t(1) = 456
64
65auto t_buffer = t.Data<int32>();
66t[0] = 123;
67t[1] = 456
68```
69
70The `New` is the factory function which based on the type of the input returns
71either a `TfTensorView` or a `TfLiteTensorView`.
72
73See the unit tests `tf_tensor_view_test.cc` and `tflite_tensor_view_test.cc` for
74more usage.
75
76The string tensor in `TfLiteTensorView` is a bit of special case. Since string
77tensors in TfLite are serialized in a specific format, while writing to those
78tensors an intermediate buffer is needed to hold intermediate values before all
79the strings get serialized. The intermediate string buffers are serialized back
80to the TfLite string format once the last remaining `TfLiteTensorView` goes out
81of scope. Only then the user can see the string values in the underlying
82`TfLiteTensor`. That said, when implementing an op kernel, there is rarely a
83need to read back the contents of a mutable output `TfLiteTensor` within the
84same code block.
85
86### OpKernelShim
87
88*WARNING: Experimental interface, subject to change*
89
90This class defines the interface which when implemented allows for convenient
91adaptation to TF and TFLite op kernels.
92
93Here is an example op kernel implementing this interface:
94
95```
96template<TfRuntime R>
97class MyOp : public OpKernelShim<MyOp, R> {
98
99  // Attributes declaration (syntax: https://www.tensorflow.org/guide/create_op)
100  static std::vector<std::string> Attrs();
101
102  // Input tensors declaration (syntax: https://www.tensorflow.org/guide/create_op)
103  static std::vector<std::string> Inputs();
104
105  // Output tensors declaration (syntax: https://www.tensorflow.org/guide/create_op)
106  static std::vector<std::string> Outputs();
107
108  // Initializes the op
109  absl::Status Init(InitContext* ctx);
110
111  // Runs the operation
112  absl::Status Invoke(InvokeContext* ctx);
113
114  // Shape inference
115  static absl::Status ShapeInference(ShapeInferenceContext* ctx);
116};
117```
118
119The class `MyOp` is passing itself to `OpKernelShim` as a template parameter.
120This is because `OpKernelShim` is a static interface using the CRTP pattern.
121Similarly, the context classes: `InitContext`, `InvokeContext` and
122`ShapeInferenceContext` are all static interfaces in the same way.
123
124### Context Interfaces
125
126An op kernel written using this library has access to a number of *context*
127objects at various stages of its lifecycle. These context objects are
128effectively shims over the existing context objects in TF and TFLite.
129
130#### InitContext
131An instance of this class is passed to the op kernel during its initialization.
132
133```
134template <typename SubType>
135class InitContext {
136 public:
137  // Read the given attribute and populate the given value.
138  template <typename AttrType>
139  absl::Status GetAttr(const std::string& attr_name, AttrType* value) const;
140};
141```
142
143#### InvokeContext
144An instance of this class is passed to the op kernel during its invocation.
145
146```
147template <typename SubType>
148class InvokeContext {
149 public:
150  // Read an input tensor
151  ConstTensorViewOr GetInput(const int idx) const;
152  // Get a mutable output tensor
153  TensorViewOr GetOutput(const int idx, const Shape& shape) const;
154};
155```
156
157#### ShapeInferenceContext
158An instance of this class is passed to the op kernel during its shape inference.
159
160```
161template <typename SubType>
162class ShapeInferenceContext {
163 public:
164  // Read an input tensor shape
165  ShapeOr GetInputShape(const int idx) const;
166  // Set an output tensor shape
167  absl::Status SetOutputShape(const int idx, const Shape& shape);
168  // Read an input tensor during shape inference
169  ConstTensorViewOr GetInputTensor(const int idx) const;
170};
171```
172