xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/shim/test_op/simple_op.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_KERNELS_SHIM_TEST_OP_SIMPLE_OP_H_
16 #define TENSORFLOW_LITE_KERNELS_SHIM_TEST_OP_SIMPLE_OP_H_
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <string>
21 
22 #include "absl/status/status.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/lite/kernels/shim/op_kernel.h"
26 #include "tensorflow/lite/kernels/shim/status_macros.h"
27 #include "tensorflow/lite/kernels/shim/tensor_view.h"
28 
29 namespace tflite {
30 namespace shim {
31 
32 // A simple operation for demonstration and testing purposes.
33 // See the kDoc member for documentation.
34 
35 template <Runtime Rt>
36 class SimpleOp : public OpKernelShim<SimpleOp, Rt> {
37  protected:
38   enum Inputs { kInput0 = 0, kInput1 };
39   enum Outputs { kOutput0 = 0, kOutput1, kOutput2, kOutput3 };
40   int64_t output1_size_;
41   std::string output2_suffix_;
42   int64_t n_;
43   static constexpr int kOutput0Size = 5;
44   static const char kOutput1SizeAttr[];
45 
46  public:
47   using typename OpKernelShim<SimpleOp, Rt>::InitContext;
48   using typename OpKernelShim<SimpleOp, Rt>::InvokeContext;
49   using typename OpKernelShim<SimpleOp, Rt>::ShapeInferenceContext;
50 
51   SimpleOp() = default;
52   static const char kOpName[];
53   static const char kDoc[];
54 
55   // Attributes declaration (syntax: https://www.tensorflow.org/guide/create_op)
Attrs()56   static std::vector<std::string> Attrs() {
57     return {absl::StrCat(kOutput1SizeAttr, ": int"), "output2_suffix: string",
58             "N: int >= 0"};
59   }
60   // Input tensors declaration (syntax:
61   // https://www.tensorflow.org/guide/create_op)
Inputs()62   static std::vector<std::string> Inputs() {
63     return {"in0: string", "in1: N*int64"};
64   }
65   // Output tensors declaration (syntax:
66   // https://www.tensorflow.org/guide/create_op)
Outputs()67   static std::vector<std::string> Outputs() {
68     return {"out0: int32", "out1: float", "out2: string", "out3: N*int64"};
69   }
70 
71   // Initializes the op
Init(InitContext * ctx)72   absl::Status Init(InitContext* ctx) {
73     SH_RETURN_IF_ERROR(ctx->GetAttr(kOutput1SizeAttr, &output1_size_));
74     if (output1_size_ < 1) {
75       return absl::InternalError(
76           absl::StrCat(kOutput1SizeAttr, " should be >= 1"));
77     }
78     SH_RETURN_IF_ERROR(ctx->GetAttr("N", &n_));
79     absl::string_view output2_suffix;
80     SH_RETURN_IF_ERROR(ctx->GetAttr("output2_suffix", &output2_suffix));
81     output2_suffix_ = std::string(output2_suffix);
82     return absl::OkStatus();
83   }
84 
85   // Runs the operation
Invoke(InvokeContext * ctx)86   absl::Status Invoke(InvokeContext* ctx) {
87     using std::int32_t;
88     // read input
89     SH_ASSIGN_OR_RETURN(const auto input_t, ctx->GetInput(kInput0));
90     const auto input_str = input_t->template AsScalar<::tensorflow::tstring>();
91     // output0 whose size is static
92     SH_ASSIGN_OR_RETURN(auto output0_t,
93                         ctx->GetOutput(kOutput0, Shape({kOutput0Size})));
94     auto output0 = output0_t->template As<int32_t, 1>();
95     for (int i = 0; i < output0.Dim(0); ++i) output0(i) = i;
96     // output1 whose size is based on the attr
97     SH_ASSIGN_OR_RETURN(
98         auto output1_t,
99         ctx->GetOutput(kOutput1, Shape({static_cast<int>(output1_size_)})));
100     auto output1 = output1_t->template As<float, 1>();
101     for (int i = 0; i < output1.Dim(0); ++i) output1(i) = 0.5 * i;
102     // output2 whose size is based on input
103     const int output2_size = input_str.length() + 1;
104     SH_ASSIGN_OR_RETURN(auto output2_t,
105                         ctx->GetOutput(kOutput2, Shape({output2_size})));
106     auto output2 = output2_t->template As<tensorflow::tstring, 1>();
107     for (int i = 0; i < output2.Dim(0) - 1; ++i) output2(i) = std::to_string(i);
108     output2(output2.Dim(0) - 1) = output2_suffix_;
109     // output3 which is a list of length N
110     // The values in output3 are element wise equal to input2 + 1.
111     if (ctx->NumInputs() < kInput1 + n_) {
112       return absl::InternalError(absl::StrCat(
113           "out of bounds: num_inputs=", ctx->NumInputs(), " N=", n_));
114     }
115     if (ctx->NumOutputs() < kOutput3 + n_) {
116       return absl::InternalError(absl::StrCat(
117           "out of bounds: num_outputs=", ctx->NumOutputs(), " N=", n_));
118     }
119     for (int i = 0; i < n_; ++i) {
120       SH_ASSIGN_OR_RETURN(const auto input_t, ctx->GetInput(kInput1 + i));
121       Shape output_shape(input_t->Shape());
122       SH_ASSIGN_OR_RETURN(auto output_t,
123                           ctx->GetOutput(kOutput3 + i, output_shape));
124       const auto input_data = input_t->template Data<int64_t>();
125       auto output_buffer = output_t->template Data<int64_t>().data();
126       std::copy(input_data.begin(), input_data.end(), output_buffer);
127       // Increment the values of the output
128       for (auto& v : output_t->template Data<int64_t>()) ++v;
129     }
130     return absl::OkStatus();
131   }
132 
133   // Shape inference
ShapeInference(ShapeInferenceContext * ctx)134   static absl::Status ShapeInference(ShapeInferenceContext* ctx) {
135     // outpu0
136     SH_RETURN_IF_ERROR(ctx->SetOutputShape(kOutput0, Shape({kOutput0Size})));
137     // output1
138     SH_RETURN_IF_ERROR(
139         ctx->SetOutputShape(kOutput1, Shape({Shape::kUnknownDim})));
140     // output2
141     const auto input_t_or = ctx->GetInputTensor(kInput0);
142     Shape output2_shape;
143     if (input_t_or.ok()) {
144       const auto& input_t = input_t_or.value();
145       const auto input_str =
146           input_t->template AsScalar<::tensorflow::tstring>();
147       output2_shape = Shape({static_cast<int>(input_str.length() + 1)});
148     } else {
149       output2_shape = Shape({Shape::kUnknownDim});
150     }
151     SH_RETURN_IF_ERROR(ctx->SetOutputShape(kOutput2, output2_shape));
152     // output3
153     for (int i = kOutput3; i < ctx->NumOutputs(); ++i) {
154       SH_RETURN_IF_ERROR(ctx->SetOutputShape(kOutput3, Shape()));
155     }
156     int64_t n;
157     SH_RETURN_IF_ERROR(ctx->GetAttr("N", &n));
158     if (n + 1 != ctx->NumInputs()) {
159       return absl::InternalError(absl::StrCat("n + 1 != num_inputs: ", n + 1,
160                                               " != ", ctx->NumInputs()));
161     }
162     if (n + 3 != ctx->NumOutputs()) {
163       return absl::InternalError(absl::StrCat("n + 1 != num_inputs: ", n + 1,
164                                               " != ", ctx->NumOutputs()));
165     }
166     return absl::OkStatus();
167   }
168 };
169 
170 // Static member definitions.
171 // These can be inlined once the toolchain is bumped up to C++17
172 
173 template <Runtime Rt>
174 const char SimpleOp<Rt>::kOutput1SizeAttr[] = "output1_size";
175 
176 template <Runtime Rt>
177 const char SimpleOp<Rt>::kOpName[] = "SimpleOperation";
178 
179 template <Runtime Rt>
180 const char SimpleOp<Rt>::kDoc[] = R"doc(
181 Description:
182   Simple example op for testing and demonstration purposes.
183 
184 Attrs
185   output1_size: int - the size of the second output
186   output2_suffix: string - the string value to be appended to the end of out2
187   N: int - the number of tensors for the second input and last output
188 Inputs
189   in0: str, shape=[] - A scalar input
190   in1: int64, list<shape=?> - A list of tensors as input
191 Outputs
192   out0: int, shape=[5] - first output
193   out1: float, shape=[?] - second output
194   out2: string, shape=[?] - third output
195   out3: int64, list<shape=?> - fourth output that is in1 but incremented.
196 )doc";
197 
198 }  // namespace shim
199 }  // namespace tflite
200 
201 #endif  // TENSORFLOW_LITE_KERNELS_SHIM_TEST_OP_SIMPLE_OP_H_
202