1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_KERNELS_SHIM_TEST_OP_SIMPLE_OP_H_ 16 #define TENSORFLOW_LITE_KERNELS_SHIM_TEST_OP_SIMPLE_OP_H_ 17 18 #include <algorithm> 19 #include <cstdint> 20 #include <string> 21 22 #include "absl/status/status.h" 23 #include "absl/strings/str_cat.h" 24 #include "absl/strings/string_view.h" 25 #include "tensorflow/lite/kernels/shim/op_kernel.h" 26 #include "tensorflow/lite/kernels/shim/status_macros.h" 27 #include "tensorflow/lite/kernels/shim/tensor_view.h" 28 29 namespace tflite { 30 namespace shim { 31 32 // A simple operation for demonstration and testing purposes. 33 // See the kDoc member for documentation. 34 35 template <Runtime Rt> 36 class SimpleOp : public OpKernelShim<SimpleOp, Rt> { 37 protected: 38 enum Inputs { kInput0 = 0, kInput1 }; 39 enum Outputs { kOutput0 = 0, kOutput1, kOutput2, kOutput3 }; 40 int64_t output1_size_; 41 std::string output2_suffix_; 42 int64_t n_; 43 static constexpr int kOutput0Size = 5; 44 static const char kOutput1SizeAttr[]; 45 46 public: 47 using typename OpKernelShim<SimpleOp, Rt>::InitContext; 48 using typename OpKernelShim<SimpleOp, Rt>::InvokeContext; 49 using typename OpKernelShim<SimpleOp, Rt>::ShapeInferenceContext; 50 51 SimpleOp() = default; 52 static const char kOpName[]; 53 static const char kDoc[]; 54 55 // Attributes declaration (syntax: https://www.tensorflow.org/guide/create_op) Attrs()56 static std::vector<std::string> Attrs() { 57 return {absl::StrCat(kOutput1SizeAttr, ": int"), "output2_suffix: string", 58 "N: int >= 0"}; 59 } 60 // Input tensors declaration (syntax: 61 // https://www.tensorflow.org/guide/create_op) Inputs()62 static std::vector<std::string> Inputs() { 63 return {"in0: string", "in1: N*int64"}; 64 } 65 // Output tensors declaration (syntax: 66 // https://www.tensorflow.org/guide/create_op) Outputs()67 static std::vector<std::string> Outputs() { 68 return {"out0: int32", "out1: float", "out2: string", "out3: N*int64"}; 69 } 70 71 // Initializes the op Init(InitContext * ctx)72 absl::Status Init(InitContext* ctx) { 73 SH_RETURN_IF_ERROR(ctx->GetAttr(kOutput1SizeAttr, &output1_size_)); 74 if (output1_size_ < 1) { 75 return absl::InternalError( 76 absl::StrCat(kOutput1SizeAttr, " should be >= 1")); 77 } 78 SH_RETURN_IF_ERROR(ctx->GetAttr("N", &n_)); 79 absl::string_view output2_suffix; 80 SH_RETURN_IF_ERROR(ctx->GetAttr("output2_suffix", &output2_suffix)); 81 output2_suffix_ = std::string(output2_suffix); 82 return absl::OkStatus(); 83 } 84 85 // Runs the operation Invoke(InvokeContext * ctx)86 absl::Status Invoke(InvokeContext* ctx) { 87 using std::int32_t; 88 // read input 89 SH_ASSIGN_OR_RETURN(const auto input_t, ctx->GetInput(kInput0)); 90 const auto input_str = input_t->template AsScalar<::tensorflow::tstring>(); 91 // output0 whose size is static 92 SH_ASSIGN_OR_RETURN(auto output0_t, 93 ctx->GetOutput(kOutput0, Shape({kOutput0Size}))); 94 auto output0 = output0_t->template As<int32_t, 1>(); 95 for (int i = 0; i < output0.Dim(0); ++i) output0(i) = i; 96 // output1 whose size is based on the attr 97 SH_ASSIGN_OR_RETURN( 98 auto output1_t, 99 ctx->GetOutput(kOutput1, Shape({static_cast<int>(output1_size_)}))); 100 auto output1 = output1_t->template As<float, 1>(); 101 for (int i = 0; i < output1.Dim(0); ++i) output1(i) = 0.5 * i; 102 // output2 whose size is based on input 103 const int output2_size = input_str.length() + 1; 104 SH_ASSIGN_OR_RETURN(auto output2_t, 105 ctx->GetOutput(kOutput2, Shape({output2_size}))); 106 auto output2 = output2_t->template As<tensorflow::tstring, 1>(); 107 for (int i = 0; i < output2.Dim(0) - 1; ++i) output2(i) = std::to_string(i); 108 output2(output2.Dim(0) - 1) = output2_suffix_; 109 // output3 which is a list of length N 110 // The values in output3 are element wise equal to input2 + 1. 111 if (ctx->NumInputs() < kInput1 + n_) { 112 return absl::InternalError(absl::StrCat( 113 "out of bounds: num_inputs=", ctx->NumInputs(), " N=", n_)); 114 } 115 if (ctx->NumOutputs() < kOutput3 + n_) { 116 return absl::InternalError(absl::StrCat( 117 "out of bounds: num_outputs=", ctx->NumOutputs(), " N=", n_)); 118 } 119 for (int i = 0; i < n_; ++i) { 120 SH_ASSIGN_OR_RETURN(const auto input_t, ctx->GetInput(kInput1 + i)); 121 Shape output_shape(input_t->Shape()); 122 SH_ASSIGN_OR_RETURN(auto output_t, 123 ctx->GetOutput(kOutput3 + i, output_shape)); 124 const auto input_data = input_t->template Data<int64_t>(); 125 auto output_buffer = output_t->template Data<int64_t>().data(); 126 std::copy(input_data.begin(), input_data.end(), output_buffer); 127 // Increment the values of the output 128 for (auto& v : output_t->template Data<int64_t>()) ++v; 129 } 130 return absl::OkStatus(); 131 } 132 133 // Shape inference ShapeInference(ShapeInferenceContext * ctx)134 static absl::Status ShapeInference(ShapeInferenceContext* ctx) { 135 // outpu0 136 SH_RETURN_IF_ERROR(ctx->SetOutputShape(kOutput0, Shape({kOutput0Size}))); 137 // output1 138 SH_RETURN_IF_ERROR( 139 ctx->SetOutputShape(kOutput1, Shape({Shape::kUnknownDim}))); 140 // output2 141 const auto input_t_or = ctx->GetInputTensor(kInput0); 142 Shape output2_shape; 143 if (input_t_or.ok()) { 144 const auto& input_t = input_t_or.value(); 145 const auto input_str = 146 input_t->template AsScalar<::tensorflow::tstring>(); 147 output2_shape = Shape({static_cast<int>(input_str.length() + 1)}); 148 } else { 149 output2_shape = Shape({Shape::kUnknownDim}); 150 } 151 SH_RETURN_IF_ERROR(ctx->SetOutputShape(kOutput2, output2_shape)); 152 // output3 153 for (int i = kOutput3; i < ctx->NumOutputs(); ++i) { 154 SH_RETURN_IF_ERROR(ctx->SetOutputShape(kOutput3, Shape())); 155 } 156 int64_t n; 157 SH_RETURN_IF_ERROR(ctx->GetAttr("N", &n)); 158 if (n + 1 != ctx->NumInputs()) { 159 return absl::InternalError(absl::StrCat("n + 1 != num_inputs: ", n + 1, 160 " != ", ctx->NumInputs())); 161 } 162 if (n + 3 != ctx->NumOutputs()) { 163 return absl::InternalError(absl::StrCat("n + 1 != num_inputs: ", n + 1, 164 " != ", ctx->NumOutputs())); 165 } 166 return absl::OkStatus(); 167 } 168 }; 169 170 // Static member definitions. 171 // These can be inlined once the toolchain is bumped up to C++17 172 173 template <Runtime Rt> 174 const char SimpleOp<Rt>::kOutput1SizeAttr[] = "output1_size"; 175 176 template <Runtime Rt> 177 const char SimpleOp<Rt>::kOpName[] = "SimpleOperation"; 178 179 template <Runtime Rt> 180 const char SimpleOp<Rt>::kDoc[] = R"doc( 181 Description: 182 Simple example op for testing and demonstration purposes. 183 184 Attrs 185 output1_size: int - the size of the second output 186 output2_suffix: string - the string value to be appended to the end of out2 187 N: int - the number of tensors for the second input and last output 188 Inputs 189 in0: str, shape=[] - A scalar input 190 in1: int64, list<shape=?> - A list of tensors as input 191 Outputs 192 out0: int, shape=[5] - first output 193 out1: float, shape=[?] - second output 194 out2: string, shape=[?] - third output 195 out3: int64, list<shape=?> - fourth output that is in1 but incremented. 196 )doc"; 197 198 } // namespace shim 199 } // namespace tflite 200 201 #endif // TENSORFLOW_LITE_KERNELS_SHIM_TEST_OP_SIMPLE_OP_H_ 202