1This folder contains a convenience library called *tf-shim* over TF and TFLite 2op kernel APIs. 3 4## Summary 5 6This library creates a shim over the custom op APIs of TF and TFLite so the 7developer can write the custom op once with minimal binary or runtime overhead. 8 9An example usage is an input preprocessing op kernel that can be used in 10both TF and TFLite. 11 12## Background 13 14When there is a need to implement a logic that is not supported by the TF 15builtin ops the alternative is to build a custom op. If that op needs to 16run on-device then it needs to be written in C++ against the client API for 17custom ops. 18 19For example, feature processing especially for textual input in an ML model 20can involve operations that don't lend themselves well to vectorization and the 21code, if written as a C++ function, would be much shorter and more readable. 22 23However, Tensorflow and TFLite APIs for creating op kernels are, at the moment, 24not identical. This library offers a convenient way to write the kernel once and 25adapt it to both TF and TFLite with minimal binary and runtime overhead. 26 27## Implementation 28 29This folder contains two pieces: 30 311. `TensorView` as a shim over `::tensorflow::Tensor` and `TfLiteTensor` 32 332. `OpKernelShim` class which abstracts the TF and TFLite op kernel APIs. 34 35### TensorView 36 37This class is a *view* over an already allocated tensor in TF or TFLite without 38taking any ownership. In that sense it is similar to `std::string_view` but with 39the difference that the underlying buffer can be mutable. 40 41Example Usage: 42 43``` 44::tensorflow::Tensor tf_tensor; 45auto t = TensorView::New(&tf_tensor); 46 47auto t_str_mat = t.As<::tensorflow::tstring, /*RANK=*/ 2>(); 48t(0, 0) = "ab"; 49t(0, 1) = "cde" 50 51 52auto t_buffer = t.Data<::tensorflow::tstring>(); 53t[0] = "ab"; 54t[1] = "cde" 55``` 56 57``` 58TfLiteTensor tflite_tensor; 59auto t = TensorView::New(&tflite_tensor); 60 61auto t_int_vec = t.As<int32, /*RANK=*/ 1>(); 62t(0) = 123; 63t(1) = 456 64 65auto t_buffer = t.Data<int32>(); 66t[0] = 123; 67t[1] = 456 68``` 69 70The `New` is the factory function which based on the type of the input returns 71either a `TfTensorView` or a `TfLiteTensorView`. 72 73See the unit tests `tf_tensor_view_test.cc` and `tflite_tensor_view_test.cc` for 74more usage. 75 76The string tensor in `TfLiteTensorView` is a bit of special case. Since string 77tensors in TfLite are serialized in a specific format, while writing to those 78tensors an intermediate buffer is needed to hold intermediate values before all 79the strings get serialized. The intermediate string buffers are serialized back 80to the TfLite string format once the last remaining `TfLiteTensorView` goes out 81of scope. Only then the user can see the string values in the underlying 82`TfLiteTensor`. That said, when implementing an op kernel, there is rarely a 83need to read back the contents of a mutable output `TfLiteTensor` within the 84same code block. 85 86### OpKernelShim 87 88*WARNING: Experimental interface, subject to change* 89 90This class defines the interface which when implemented allows for convenient 91adaptation to TF and TFLite op kernels. 92 93Here is an example op kernel implementing this interface: 94 95``` 96template<TfRuntime R> 97class MyOp : public OpKernelShim<MyOp, R> { 98 99 // Attributes declaration (syntax: https://www.tensorflow.org/guide/create_op) 100 static std::vector<std::string> Attrs(); 101 102 // Input tensors declaration (syntax: https://www.tensorflow.org/guide/create_op) 103 static std::vector<std::string> Inputs(); 104 105 // Output tensors declaration (syntax: https://www.tensorflow.org/guide/create_op) 106 static std::vector<std::string> Outputs(); 107 108 // Initializes the op 109 absl::Status Init(InitContext* ctx); 110 111 // Runs the operation 112 absl::Status Invoke(InvokeContext* ctx); 113 114 // Shape inference 115 static absl::Status ShapeInference(ShapeInferenceContext* ctx); 116}; 117``` 118 119The class `MyOp` is passing itself to `OpKernelShim` as a template parameter. 120This is because `OpKernelShim` is a static interface using the CRTP pattern. 121Similarly, the context classes: `InitContext`, `InvokeContext` and 122`ShapeInferenceContext` are all static interfaces in the same way. 123 124### Context Interfaces 125 126An op kernel written using this library has access to a number of *context* 127objects at various stages of its lifecycle. These context objects are 128effectively shims over the existing context objects in TF and TFLite. 129 130#### InitContext 131An instance of this class is passed to the op kernel during its initialization. 132 133``` 134template <typename SubType> 135class InitContext { 136 public: 137 // Read the given attribute and populate the given value. 138 template <typename AttrType> 139 absl::Status GetAttr(const std::string& attr_name, AttrType* value) const; 140}; 141``` 142 143#### InvokeContext 144An instance of this class is passed to the op kernel during its invocation. 145 146``` 147template <typename SubType> 148class InvokeContext { 149 public: 150 // Read an input tensor 151 ConstTensorViewOr GetInput(const int idx) const; 152 // Get a mutable output tensor 153 TensorViewOr GetOutput(const int idx, const Shape& shape) const; 154}; 155``` 156 157#### ShapeInferenceContext 158An instance of this class is passed to the op kernel during its shape inference. 159 160``` 161template <typename SubType> 162class ShapeInferenceContext { 163 public: 164 // Read an input tensor shape 165 ShapeOr GetInputShape(const int idx) const; 166 // Set an output tensor shape 167 absl::Status SetOutputShape(const int idx, const Shape& shape); 168 // Read an input tensor during shape inference 169 ConstTensorViewOr GetInputTensor(const int idx) const; 170}; 171``` 172