1 /*
2  * Copyright (C) 2019 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "Operations"
18 
19 #include "UnidirectionalSequenceLSTM.h"
20 
21 #include <vector>
22 
23 #include "IndexedShapeWrapper.h"
24 #include "OperationResolver.h"
25 #include "OperationsExecutionUtils.h"
26 
27 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
28 #include <tensorflow/lite/kernels/internal/tensor_utils.h>
29 
30 #include "LSTM.h"
31 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
32 
33 namespace android {
34 namespace nn {
35 namespace unidirectional_sequence_lstm {
36 
37 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
38 namespace {
39 
hasTensor(IOperationExecutionContext * context,const uint32_t tensor)40 inline bool hasTensor(IOperationExecutionContext* context, const uint32_t tensor) {
41     return context->getInputBuffer(tensor) != nullptr;
42 }
43 
isTimeMajor(IOperationExecutionContext * context)44 inline bool isTimeMajor(IOperationExecutionContext* context) {
45     return context->getInputValue<bool>(kTimeMajorParam);
46 }
47 
48 template <typename T>
getLSTMParams(IOperationExecutionContext * context)49 inline LSTMParams getLSTMParams(IOperationExecutionContext* context) {
50     LSTMParams params;
51     params.activation =
52             static_cast<ActivationFn>(context->getInputValue<int32_t>(kActivationParam));
53     params.cell_clip = static_cast<float>(context->getInputValue<T>(kCellClipParam));
54     params.proj_clip = static_cast<float>(context->getInputValue<T>(kProjClipParam));
55     params.use_cifg = !hasTensor(context, kInputToInputWeightsTensor);
56     params.use_peephole = hasTensor(context, kCellToOutputWeightsTensor);
57     params.use_layer_norm = hasTensor(context, kOutputLayerNormWeightsTensor);
58     params.use_projection_weight = hasTensor(context, kProjectionWeightsTensor);
59     params.use_projection_bias = hasTensor(context, kProjectionBiasTensor);
60     return params;
61 }
62 
63 }  // namespace
64 
prepare(IOperationExecutionContext * context)65 bool prepare(IOperationExecutionContext* context) {
66     // Check that none of the required inputs are omitted
67     const std::vector<int> requiredInputs = {
68             kInputTensor,
69             kInputToForgetWeightsTensor,
70             kInputToCellWeightsTensor,
71             kInputToOutputWeightsTensor,
72             kRecurrentToForgetWeightsTensor,
73             kRecurrentToCellWeightsTensor,
74             kRecurrentToOutputWeightsTensor,
75             kForgetGateBiasTensor,
76             kCellGateBiasTensor,
77             kOutputGateBiasTensor,
78             kOutputStateInTensor,
79             kCellStateInTensor,
80             kActivationParam,
81             kCellClipParam,
82             kProjClipParam,
83             kTimeMajorParam,
84     };
85     for (const int requiredInput : requiredInputs) {
86         NN_RET_CHECK(!context->isOmittedInput(requiredInput))
87                 << "required input " << requiredInput << " is omitted";
88     }
89 
90     const Shape inputShape = context->getInputShape(kInputTensor);
91     const uint32_t inputRank = getNumberOfDimensions(inputShape);
92     NN_RET_CHECK_EQ(inputRank, 3u) << "Invalid input tensor rank: " << inputRank;
93 
94     [[maybe_unused]] const uint32_t maxTime =
95             getSizeOfDimension(inputShape, isTimeMajor(context) ? 0 : 1);
96     const uint32_t batchSize = getSizeOfDimension(inputShape, isTimeMajor(context) ? 1 : 0);
97     const uint32_t inputSize = getSizeOfDimension(inputShape, inputRank - 1);
98 
99     const Shape inputToOutputShape = context->getInputShape(kInputToOutputWeightsTensor);
100     NN_RET_CHECK_EQ(getNumberOfDimensions(inputToOutputShape), 2u);
101     NN_RET_CHECK_EQ(getSizeOfDimension(inputToOutputShape, 1), inputSize);
102     const uint32_t numCells = getSizeOfDimension(inputToOutputShape, 0);
103 
104     const Shape recurrentToOutputShape = context->getInputShape(kRecurrentToOutputWeightsTensor);
105     NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToOutputShape), 2u);
106     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToOutputShape, 0), numCells);
107     const uint32_t outputSize = getSizeOfDimension(recurrentToOutputShape, 1);
108 
109     if (hasTensor(context, kInputToInputWeightsTensor)) {
110         const Shape inputToInputShape = context->getInputShape(kInputToInputWeightsTensor);
111         NN_RET_CHECK_EQ(getNumberOfDimensions(inputToInputShape), 2u);
112         NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 0), numCells);
113         NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 1), inputSize);
114     }
115 
116     const Shape inputToForgetShape = context->getInputShape(kInputToForgetWeightsTensor);
117     NN_RET_CHECK_EQ(getNumberOfDimensions(inputToForgetShape), 2u);
118     NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 0), numCells);
119     NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 1), inputSize);
120     const Shape inputToCellShape = context->getInputShape(kInputToCellWeightsTensor);
121     NN_RET_CHECK_EQ(getNumberOfDimensions(inputToCellShape), 2u);
122     NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 0), numCells);
123     NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 1), inputSize);
124 
125     if (hasTensor(context, kRecurrentToInputWeightsTensor)) {
126         const Shape recurrentToInputShape = context->getInputShape(kRecurrentToInputWeightsTensor);
127         NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToInputShape), 2u);
128         NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 0), numCells);
129         NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 1), outputSize);
130     }
131 
132     const Shape recurrentToForgetShape = context->getInputShape(kRecurrentToForgetWeightsTensor);
133     NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToForgetShape), 2u);
134     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 0), numCells);
135     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 1), outputSize);
136     const Shape recurrentToCellShape = context->getInputShape(kRecurrentToCellWeightsTensor);
137     NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToCellShape), 2u);
138     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 0), numCells);
139     NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 1), outputSize);
140 
141     // We make sure the input-gate's parameters are either both present (regular
142     // LSTM) or not at all (CIFG-LSTM).
143     const bool cifgWeightsAllOrNone = (hasTensor(context, kInputToInputWeightsTensor) &&
144                                        hasTensor(context, kRecurrentToInputWeightsTensor)) ||
145                                       (!hasTensor(context, kInputToInputWeightsTensor) &&
146                                        !hasTensor(context, kRecurrentToInputWeightsTensor));
147     NN_RET_CHECK(cifgWeightsAllOrNone);
148 
149     if (hasTensor(context, kCellToInputWeightsTensor)) {
150         const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor);
151         NN_RET_CHECK_EQ(getNumberOfDimensions(cellToInputShape), 1u);
152         NN_RET_CHECK_EQ(getSizeOfDimension(cellToInputShape, 0), numCells);
153     }
154 
155     if (hasTensor(context, kCellToForgetWeightsTensor)) {
156         const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor);
157         NN_RET_CHECK_EQ(getNumberOfDimensions(cellToForgetShape), 1u);
158         NN_RET_CHECK_EQ(getSizeOfDimension(cellToForgetShape, 0), numCells);
159     }
160 
161     if (hasTensor(context, kCellToOutputWeightsTensor)) {
162         const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor);
163         NN_RET_CHECK_EQ(getNumberOfDimensions(cellToOutputShape), 1u);
164         NN_RET_CHECK_EQ(getSizeOfDimension(cellToOutputShape, 0), numCells);
165     }
166 
167     // Making sure the peephole weights are there all or none.
168     const bool cifgUsed = !hasTensor(context, kInputToInputWeightsTensor);
169     const bool peepholeWeightsAllOrNone =
170             ((hasTensor(context, kCellToInputWeightsTensor) || cifgUsed) &&
171              hasTensor(context, kCellToForgetWeightsTensor) &&
172              hasTensor(context, kCellToOutputWeightsTensor)) ||
173             (!hasTensor(context, kCellToInputWeightsTensor) &&
174              !hasTensor(context, kCellToForgetWeightsTensor) &&
175              !hasTensor(context, kCellToOutputWeightsTensor));
176     NN_RET_CHECK(peepholeWeightsAllOrNone);
177 
178     if (!cifgUsed) {
179         NN_RET_CHECK(hasTensor(context, kInputGateBiasTensor));
180         const Shape inputGateBiasShape = context->getInputShape(kInputGateBiasTensor);
181         NN_RET_CHECK_EQ(getNumberOfDimensions(inputGateBiasShape), 1u);
182         NN_RET_CHECK_EQ(getSizeOfDimension(inputGateBiasShape, 0), numCells);
183     } else {
184         NN_RET_CHECK(!hasTensor(context, kInputGateBiasTensor))
185                 << "Input gate bias tensor is present when CIFG is used";
186     }
187 
188     const Shape forgetGateBiasShape = context->getInputShape(kForgetGateBiasTensor);
189     NN_RET_CHECK_EQ(getNumberOfDimensions(forgetGateBiasShape), 1u);
190     NN_RET_CHECK_EQ(getSizeOfDimension(forgetGateBiasShape, 0), numCells);
191     const Shape cellGateBiasShape = context->getInputShape(kCellGateBiasTensor);
192     NN_RET_CHECK_EQ(getNumberOfDimensions(cellGateBiasShape), 1u);
193     NN_RET_CHECK_EQ(getSizeOfDimension(cellGateBiasShape, 0), numCells);
194     const Shape outputGateBiasShape = context->getInputShape(kOutputGateBiasTensor);
195     NN_RET_CHECK_EQ(getNumberOfDimensions(outputGateBiasShape), 1u);
196     NN_RET_CHECK_EQ(getSizeOfDimension(outputGateBiasShape, 0), numCells);
197 
198     if (hasTensor(context, kProjectionWeightsTensor)) {
199         const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor);
200         NN_RET_CHECK_EQ(getNumberOfDimensions(projectionShape), 2u);
201         NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 0), outputSize);
202         NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 1), numCells);
203     }
204 
205     if (hasTensor(context, kProjectionBiasTensor)) {
206         const Shape projectionBiasShape = context->getInputShape(kProjectionBiasTensor);
207         NN_RET_CHECK_EQ(getNumberOfDimensions(projectionBiasShape), 1u);
208         NN_RET_CHECK_EQ(getSizeOfDimension(projectionBiasShape, 0), outputSize);
209     }
210 
211     const Shape outputStateShape = context->getInputShape(kOutputStateInTensor);
212     NN_RET_CHECK_EQ(getNumberOfDimensions(outputStateShape), 2u);
213     NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 0), batchSize);
214     NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 1), outputSize);
215     const Shape cellStateShape = context->getInputShape(kCellStateInTensor);
216     NN_RET_CHECK_EQ(getNumberOfDimensions(cellStateShape), 2u);
217     NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 0), batchSize);
218     NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 1), numCells);
219 
220     if (hasTensor(context, kInputLayerNormWeightsTensor)) {
221         const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormWeightsTensor);
222         NN_RET_CHECK_EQ(getNumberOfDimensions(inputLayerNormShape), 1u);
223         NN_RET_CHECK_EQ(getSizeOfDimension(inputLayerNormShape, 0), numCells);
224     }
225 
226     if (hasTensor(context, kForgetLayerNormWeightsTensor)) {
227         const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormWeightsTensor);
228         NN_RET_CHECK_EQ(getNumberOfDimensions(forgetLayerNormShape), 1u);
229         NN_RET_CHECK_EQ(getSizeOfDimension(forgetLayerNormShape, 0), numCells);
230     }
231 
232     if (hasTensor(context, kCellLayerNormWeightsTensor)) {
233         const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormWeightsTensor);
234         NN_RET_CHECK_EQ(getNumberOfDimensions(cellLayerNormShape), 1u);
235         NN_RET_CHECK_EQ(getSizeOfDimension(cellLayerNormShape, 0), numCells);
236     }
237 
238     if (hasTensor(context, kOutputLayerNormWeightsTensor)) {
239         const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormWeightsTensor);
240         NN_RET_CHECK_EQ(getNumberOfDimensions(outputLayerNormShape), 1u);
241         NN_RET_CHECK_EQ(getSizeOfDimension(outputLayerNormShape, 0), numCells);
242     }
243 
244     if (cifgUsed) {
245         NN_RET_CHECK(!hasTensor(context, kInputLayerNormWeightsTensor))
246                 << "Input layer norm weights tensor is present when CIFG is used";
247         const bool layerNormWeightsAllOrNoneCifg =
248                 (hasTensor(context, kForgetLayerNormWeightsTensor) &&
249                  hasTensor(context, kCellLayerNormWeightsTensor) &&
250                  hasTensor(context, kOutputLayerNormWeightsTensor)) ||
251                 (!hasTensor(context, kForgetLayerNormWeightsTensor) &&
252                  !hasTensor(context, kCellLayerNormWeightsTensor) &&
253                  !hasTensor(context, kOutputLayerNormWeightsTensor));
254         NN_RET_CHECK(layerNormWeightsAllOrNoneCifg);
255     } else {
256         const bool layerNormWeightsAllOrNone =
257                 (hasTensor(context, kInputLayerNormWeightsTensor) &&
258                  hasTensor(context, kForgetLayerNormWeightsTensor) &&
259                  hasTensor(context, kCellLayerNormWeightsTensor) &&
260                  hasTensor(context, kOutputLayerNormWeightsTensor)) ||
261                 (!hasTensor(context, kInputLayerNormWeightsTensor) &&
262                  !hasTensor(context, kForgetLayerNormWeightsTensor) &&
263                  !hasTensor(context, kCellLayerNormWeightsTensor) &&
264                  !hasTensor(context, kOutputLayerNormWeightsTensor));
265         NN_RET_CHECK(layerNormWeightsAllOrNone);
266     }
267 
268     Shape outputShape = context->getInputShape(kInputTensor);
269     outputShape.dimensions[2] = outputSize;
270 
271     if (context->getNumOutputs() == kNumOutputsWithState) {
272         NN_RET_CHECK(!context->isOmittedOutput(kOutputStateOutTensor));
273         NN_RET_CHECK(!context->isOmittedOutput(kCellStateOutTensor));
274 
275         Shape outputStateOutTensor = context->getInputShape(kOutputStateInTensor);
276         outputStateOutTensor.dimensions.resize(2);
277         outputStateOutTensor.dimensions[0] = batchSize;
278         outputStateOutTensor.dimensions[1] = outputSize;
279         NN_RET_CHECK(context->setOutputShape(kOutputStateOutTensor, outputStateOutTensor));
280 
281         Shape cellStateOutTensor = context->getInputShape(kCellStateInTensor);
282         cellStateOutTensor.dimensions.resize(2);
283         cellStateOutTensor.dimensions[0] = batchSize;
284         cellStateOutTensor.dimensions[1] = numCells;
285         NN_RET_CHECK(context->setOutputShape(kCellStateOutTensor, cellStateOutTensor));
286     }
287 
288     return context->setOutputShape(kOutputTensor, outputShape);
289 }
290 
execute(IOperationExecutionContext * context)291 bool execute(IOperationExecutionContext* context) {
292     const auto outputStateSize = getNumberOfElements(context->getInputShape(kOutputStateInTensor));
293     const auto cellStateSize = getNumberOfElements(context->getInputShape(kCellStateInTensor));
294     const bool use_cifg = !hasTensor(context, kInputToInputWeightsTensor);
295     const auto scratchSize = use_cifg ? 3 * cellStateSize : 4 * cellStateSize;
296     const bool useStateOutTensors = (context->getNumOutputs() == kNumOutputsWithState);
297 
298     const OperandType inputType = context->getInputType(kInputTensor);
299     switch (inputType) {
300         case OperandType::TENSOR_FLOAT32: {
301             // Initialize empty vectors and resize below only if needed
302             std::vector<float> outputStateOutBuffer;
303             std::vector<float> cellStateOutBuffer;
304             float* outputStateOut;
305             float* cellStateOut;
306             if (useStateOutTensors) {
307                 outputStateOut = context->getOutputBuffer<float>(kOutputStateOutTensor);
308                 cellStateOut = context->getOutputBuffer<float>(kCellStateOutTensor);
309             } else {
310                 outputStateOutBuffer.resize(outputStateSize);
311                 cellStateOutBuffer.resize(cellStateSize);
312                 outputStateOut = outputStateOutBuffer.data();
313                 cellStateOut = cellStateOutBuffer.data();
314             }
315             std::vector<float> scratchBuffer(scratchSize);
316             LSTMCell::LSTMEvalFloat32(
317                     getLSTMParams<float>(context), context->getInputBuffer<float>(kInputTensor),
318                     context->getInputShape(kInputTensor),
319                     context->getInputBuffer<float>(kInputToInputWeightsTensor),
320                     context->getInputBuffer<float>(kInputToForgetWeightsTensor),
321                     context->getInputBuffer<float>(kInputToCellWeightsTensor),
322                     context->getInputBuffer<float>(kInputToOutputWeightsTensor),
323                     context->getInputShape(kInputToOutputWeightsTensor),
324                     context->getInputBuffer<float>(kRecurrentToInputWeightsTensor),
325                     context->getInputBuffer<float>(kRecurrentToForgetWeightsTensor),
326                     context->getInputBuffer<float>(kRecurrentToCellWeightsTensor),
327                     context->getInputBuffer<float>(kRecurrentToOutputWeightsTensor),
328                     context->getInputShape(kRecurrentToOutputWeightsTensor),
329                     context->getInputBuffer<float>(kCellToInputWeightsTensor),
330                     context->getInputBuffer<float>(kCellToForgetWeightsTensor),
331                     context->getInputBuffer<float>(kCellToOutputWeightsTensor),
332                     /*aux_input_buffer=*/nullptr,
333                     /*aux_input_to_input_weights_buffer=*/nullptr,
334                     /*aux_input_to_forget_weights_buffer=*/nullptr,
335                     /*aux_input_to_cell_weights_buffer=*/nullptr,
336                     /*aux_input_to_output_weights_buffer=*/nullptr,
337                     context->getInputBuffer<float>(kInputGateBiasTensor),
338                     context->getInputBuffer<float>(kForgetGateBiasTensor),
339                     context->getInputBuffer<float>(kCellGateBiasTensor),
340                     context->getInputBuffer<float>(kOutputGateBiasTensor),
341                     context->getInputBuffer<float>(kProjectionWeightsTensor),
342                     context->getInputBuffer<float>(kProjectionBiasTensor),
343                     context->getInputBuffer<float>(kOutputStateInTensor),
344                     context->getInputBuffer<float>(kCellStateInTensor),
345                     context->getInputBuffer<float>(kInputLayerNormWeightsTensor),
346                     context->getInputBuffer<float>(kForgetLayerNormWeightsTensor),
347                     context->getInputBuffer<float>(kCellLayerNormWeightsTensor),
348                     context->getInputBuffer<float>(kOutputLayerNormWeightsTensor), outputStateOut,
349                     cellStateOut, context->getOutputBuffer<float>(kOutputTensor),
350                     scratchBuffer.data(), isTimeMajor(context));
351         } break;
352         case OperandType::TENSOR_FLOAT16: {
353             // Initialize empty vectors and resize below only if needed
354             std::vector<_Float16> outputStateOutBuffer;
355             std::vector<_Float16> cellStateOutBuffer;
356             _Float16* outputStateOut;
357             _Float16* cellStateOut;
358             if (useStateOutTensors) {
359                 outputStateOut = context->getOutputBuffer<_Float16>(kOutputStateOutTensor);
360                 cellStateOut = context->getOutputBuffer<_Float16>(kCellStateOutTensor);
361             } else {
362                 outputStateOutBuffer.resize(outputStateSize);
363                 cellStateOutBuffer.resize(cellStateSize);
364                 outputStateOut = outputStateOutBuffer.data();
365                 cellStateOut = cellStateOutBuffer.data();
366             }
367             std::vector<_Float16> scratchBuffer(scratchSize);
368             LSTMCell::LSTMEvalFloat16(
369                     getLSTMParams<_Float16>(context),
370                     context->getInputBuffer<_Float16>(kInputTensor),
371                     context->getInputShape(kInputTensor),
372                     context->getInputBuffer<_Float16>(kInputToInputWeightsTensor),
373                     context->getInputBuffer<_Float16>(kInputToForgetWeightsTensor),
374                     context->getInputBuffer<_Float16>(kInputToCellWeightsTensor),
375                     context->getInputBuffer<_Float16>(kInputToOutputWeightsTensor),
376                     context->getInputShape(kInputToOutputWeightsTensor),
377                     context->getInputBuffer<_Float16>(kRecurrentToInputWeightsTensor),
378                     context->getInputBuffer<_Float16>(kRecurrentToForgetWeightsTensor),
379                     context->getInputBuffer<_Float16>(kRecurrentToCellWeightsTensor),
380                     context->getInputBuffer<_Float16>(kRecurrentToOutputWeightsTensor),
381                     context->getInputShape(kRecurrentToOutputWeightsTensor),
382                     context->getInputBuffer<_Float16>(kCellToInputWeightsTensor),
383                     context->getInputBuffer<_Float16>(kCellToForgetWeightsTensor),
384                     context->getInputBuffer<_Float16>(kCellToOutputWeightsTensor),
385                     /*aux_input_buffer=*/nullptr,
386                     /*aux_input_to_input_weights_buffer=*/nullptr,
387                     /*aux_input_to_forget_weights_buffer=*/nullptr,
388                     /*aux_input_to_cell_weights_buffer=*/nullptr,
389                     /*aux_input_to_output_weights_buffer=*/nullptr,
390                     context->getInputBuffer<_Float16>(kInputGateBiasTensor),
391                     context->getInputBuffer<_Float16>(kForgetGateBiasTensor),
392                     context->getInputBuffer<_Float16>(kCellGateBiasTensor),
393                     context->getInputBuffer<_Float16>(kOutputGateBiasTensor),
394                     context->getInputBuffer<_Float16>(kProjectionWeightsTensor),
395                     context->getInputBuffer<_Float16>(kProjectionBiasTensor),
396                     context->getInputBuffer<_Float16>(kOutputStateInTensor),
397                     context->getInputBuffer<_Float16>(kCellStateInTensor),
398                     context->getInputBuffer<_Float16>(kInputLayerNormWeightsTensor),
399                     context->getInputBuffer<_Float16>(kForgetLayerNormWeightsTensor),
400                     context->getInputBuffer<_Float16>(kCellLayerNormWeightsTensor),
401                     context->getInputBuffer<_Float16>(kOutputLayerNormWeightsTensor),
402                     outputStateOut, cellStateOut, context->getOutputBuffer<_Float16>(kOutputTensor),
403                     scratchBuffer.data(), isTimeMajor(context));
404         } break;
405         default: {
406             LOG(ERROR) << "Unsupported data type: " << static_cast<int>(inputType);
407             return false;
408         }
409     }
410     return true;
411 }
412 #endif  // NN_INCLUDE_CPU_IMPLEMENTATION
413 
414 }  // namespace unidirectional_sequence_lstm
415 
416 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(UNIDIRECTIONAL_SEQUENCE_LSTM,
417                                          unidirectional_sequence_lstm::prepare,
418                                          unidirectional_sequence_lstm::execute,
419                                          .allowOmittedOperand = true);
420 
421 }  // namespace nn
422 }  // namespace android
423