1 /*
2 * Copyright (C) 2019 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "Operations"
18
19 #include "UnidirectionalSequenceLSTM.h"
20
21 #include <vector>
22
23 #include "IndexedShapeWrapper.h"
24 #include "OperationResolver.h"
25 #include "OperationsExecutionUtils.h"
26
27 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
28 #include <tensorflow/lite/kernels/internal/tensor_utils.h>
29
30 #include "LSTM.h"
31 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
32
33 namespace android {
34 namespace nn {
35 namespace unidirectional_sequence_lstm {
36
37 #ifdef NN_INCLUDE_CPU_IMPLEMENTATION
38 namespace {
39
hasTensor(IOperationExecutionContext * context,const uint32_t tensor)40 inline bool hasTensor(IOperationExecutionContext* context, const uint32_t tensor) {
41 return context->getInputBuffer(tensor) != nullptr;
42 }
43
isTimeMajor(IOperationExecutionContext * context)44 inline bool isTimeMajor(IOperationExecutionContext* context) {
45 return context->getInputValue<bool>(kTimeMajorParam);
46 }
47
48 template <typename T>
getLSTMParams(IOperationExecutionContext * context)49 inline LSTMParams getLSTMParams(IOperationExecutionContext* context) {
50 LSTMParams params;
51 params.activation =
52 static_cast<ActivationFn>(context->getInputValue<int32_t>(kActivationParam));
53 params.cell_clip = static_cast<float>(context->getInputValue<T>(kCellClipParam));
54 params.proj_clip = static_cast<float>(context->getInputValue<T>(kProjClipParam));
55 params.use_cifg = !hasTensor(context, kInputToInputWeightsTensor);
56 params.use_peephole = hasTensor(context, kCellToOutputWeightsTensor);
57 params.use_layer_norm = hasTensor(context, kOutputLayerNormWeightsTensor);
58 params.use_projection_weight = hasTensor(context, kProjectionWeightsTensor);
59 params.use_projection_bias = hasTensor(context, kProjectionBiasTensor);
60 return params;
61 }
62
63 } // namespace
64
prepare(IOperationExecutionContext * context)65 bool prepare(IOperationExecutionContext* context) {
66 // Check that none of the required inputs are omitted
67 const std::vector<int> requiredInputs = {
68 kInputTensor,
69 kInputToForgetWeightsTensor,
70 kInputToCellWeightsTensor,
71 kInputToOutputWeightsTensor,
72 kRecurrentToForgetWeightsTensor,
73 kRecurrentToCellWeightsTensor,
74 kRecurrentToOutputWeightsTensor,
75 kForgetGateBiasTensor,
76 kCellGateBiasTensor,
77 kOutputGateBiasTensor,
78 kOutputStateInTensor,
79 kCellStateInTensor,
80 kActivationParam,
81 kCellClipParam,
82 kProjClipParam,
83 kTimeMajorParam,
84 };
85 for (const int requiredInput : requiredInputs) {
86 NN_RET_CHECK(!context->isOmittedInput(requiredInput))
87 << "required input " << requiredInput << " is omitted";
88 }
89
90 const Shape inputShape = context->getInputShape(kInputTensor);
91 const uint32_t inputRank = getNumberOfDimensions(inputShape);
92 NN_RET_CHECK_EQ(inputRank, 3u) << "Invalid input tensor rank: " << inputRank;
93
94 [[maybe_unused]] const uint32_t maxTime =
95 getSizeOfDimension(inputShape, isTimeMajor(context) ? 0 : 1);
96 const uint32_t batchSize = getSizeOfDimension(inputShape, isTimeMajor(context) ? 1 : 0);
97 const uint32_t inputSize = getSizeOfDimension(inputShape, inputRank - 1);
98
99 const Shape inputToOutputShape = context->getInputShape(kInputToOutputWeightsTensor);
100 NN_RET_CHECK_EQ(getNumberOfDimensions(inputToOutputShape), 2u);
101 NN_RET_CHECK_EQ(getSizeOfDimension(inputToOutputShape, 1), inputSize);
102 const uint32_t numCells = getSizeOfDimension(inputToOutputShape, 0);
103
104 const Shape recurrentToOutputShape = context->getInputShape(kRecurrentToOutputWeightsTensor);
105 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToOutputShape), 2u);
106 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToOutputShape, 0), numCells);
107 const uint32_t outputSize = getSizeOfDimension(recurrentToOutputShape, 1);
108
109 if (hasTensor(context, kInputToInputWeightsTensor)) {
110 const Shape inputToInputShape = context->getInputShape(kInputToInputWeightsTensor);
111 NN_RET_CHECK_EQ(getNumberOfDimensions(inputToInputShape), 2u);
112 NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 0), numCells);
113 NN_RET_CHECK_EQ(getSizeOfDimension(inputToInputShape, 1), inputSize);
114 }
115
116 const Shape inputToForgetShape = context->getInputShape(kInputToForgetWeightsTensor);
117 NN_RET_CHECK_EQ(getNumberOfDimensions(inputToForgetShape), 2u);
118 NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 0), numCells);
119 NN_RET_CHECK_EQ(getSizeOfDimension(inputToForgetShape, 1), inputSize);
120 const Shape inputToCellShape = context->getInputShape(kInputToCellWeightsTensor);
121 NN_RET_CHECK_EQ(getNumberOfDimensions(inputToCellShape), 2u);
122 NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 0), numCells);
123 NN_RET_CHECK_EQ(getSizeOfDimension(inputToCellShape, 1), inputSize);
124
125 if (hasTensor(context, kRecurrentToInputWeightsTensor)) {
126 const Shape recurrentToInputShape = context->getInputShape(kRecurrentToInputWeightsTensor);
127 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToInputShape), 2u);
128 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 0), numCells);
129 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToInputShape, 1), outputSize);
130 }
131
132 const Shape recurrentToForgetShape = context->getInputShape(kRecurrentToForgetWeightsTensor);
133 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToForgetShape), 2u);
134 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 0), numCells);
135 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToForgetShape, 1), outputSize);
136 const Shape recurrentToCellShape = context->getInputShape(kRecurrentToCellWeightsTensor);
137 NN_RET_CHECK_EQ(getNumberOfDimensions(recurrentToCellShape), 2u);
138 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 0), numCells);
139 NN_RET_CHECK_EQ(getSizeOfDimension(recurrentToCellShape, 1), outputSize);
140
141 // We make sure the input-gate's parameters are either both present (regular
142 // LSTM) or not at all (CIFG-LSTM).
143 const bool cifgWeightsAllOrNone = (hasTensor(context, kInputToInputWeightsTensor) &&
144 hasTensor(context, kRecurrentToInputWeightsTensor)) ||
145 (!hasTensor(context, kInputToInputWeightsTensor) &&
146 !hasTensor(context, kRecurrentToInputWeightsTensor));
147 NN_RET_CHECK(cifgWeightsAllOrNone);
148
149 if (hasTensor(context, kCellToInputWeightsTensor)) {
150 const Shape cellToInputShape = context->getInputShape(kCellToInputWeightsTensor);
151 NN_RET_CHECK_EQ(getNumberOfDimensions(cellToInputShape), 1u);
152 NN_RET_CHECK_EQ(getSizeOfDimension(cellToInputShape, 0), numCells);
153 }
154
155 if (hasTensor(context, kCellToForgetWeightsTensor)) {
156 const Shape cellToForgetShape = context->getInputShape(kCellToForgetWeightsTensor);
157 NN_RET_CHECK_EQ(getNumberOfDimensions(cellToForgetShape), 1u);
158 NN_RET_CHECK_EQ(getSizeOfDimension(cellToForgetShape, 0), numCells);
159 }
160
161 if (hasTensor(context, kCellToOutputWeightsTensor)) {
162 const Shape cellToOutputShape = context->getInputShape(kCellToOutputWeightsTensor);
163 NN_RET_CHECK_EQ(getNumberOfDimensions(cellToOutputShape), 1u);
164 NN_RET_CHECK_EQ(getSizeOfDimension(cellToOutputShape, 0), numCells);
165 }
166
167 // Making sure the peephole weights are there all or none.
168 const bool cifgUsed = !hasTensor(context, kInputToInputWeightsTensor);
169 const bool peepholeWeightsAllOrNone =
170 ((hasTensor(context, kCellToInputWeightsTensor) || cifgUsed) &&
171 hasTensor(context, kCellToForgetWeightsTensor) &&
172 hasTensor(context, kCellToOutputWeightsTensor)) ||
173 (!hasTensor(context, kCellToInputWeightsTensor) &&
174 !hasTensor(context, kCellToForgetWeightsTensor) &&
175 !hasTensor(context, kCellToOutputWeightsTensor));
176 NN_RET_CHECK(peepholeWeightsAllOrNone);
177
178 if (!cifgUsed) {
179 NN_RET_CHECK(hasTensor(context, kInputGateBiasTensor));
180 const Shape inputGateBiasShape = context->getInputShape(kInputGateBiasTensor);
181 NN_RET_CHECK_EQ(getNumberOfDimensions(inputGateBiasShape), 1u);
182 NN_RET_CHECK_EQ(getSizeOfDimension(inputGateBiasShape, 0), numCells);
183 } else {
184 NN_RET_CHECK(!hasTensor(context, kInputGateBiasTensor))
185 << "Input gate bias tensor is present when CIFG is used";
186 }
187
188 const Shape forgetGateBiasShape = context->getInputShape(kForgetGateBiasTensor);
189 NN_RET_CHECK_EQ(getNumberOfDimensions(forgetGateBiasShape), 1u);
190 NN_RET_CHECK_EQ(getSizeOfDimension(forgetGateBiasShape, 0), numCells);
191 const Shape cellGateBiasShape = context->getInputShape(kCellGateBiasTensor);
192 NN_RET_CHECK_EQ(getNumberOfDimensions(cellGateBiasShape), 1u);
193 NN_RET_CHECK_EQ(getSizeOfDimension(cellGateBiasShape, 0), numCells);
194 const Shape outputGateBiasShape = context->getInputShape(kOutputGateBiasTensor);
195 NN_RET_CHECK_EQ(getNumberOfDimensions(outputGateBiasShape), 1u);
196 NN_RET_CHECK_EQ(getSizeOfDimension(outputGateBiasShape, 0), numCells);
197
198 if (hasTensor(context, kProjectionWeightsTensor)) {
199 const Shape projectionShape = context->getInputShape(kProjectionWeightsTensor);
200 NN_RET_CHECK_EQ(getNumberOfDimensions(projectionShape), 2u);
201 NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 0), outputSize);
202 NN_RET_CHECK_EQ(getSizeOfDimension(projectionShape, 1), numCells);
203 }
204
205 if (hasTensor(context, kProjectionBiasTensor)) {
206 const Shape projectionBiasShape = context->getInputShape(kProjectionBiasTensor);
207 NN_RET_CHECK_EQ(getNumberOfDimensions(projectionBiasShape), 1u);
208 NN_RET_CHECK_EQ(getSizeOfDimension(projectionBiasShape, 0), outputSize);
209 }
210
211 const Shape outputStateShape = context->getInputShape(kOutputStateInTensor);
212 NN_RET_CHECK_EQ(getNumberOfDimensions(outputStateShape), 2u);
213 NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 0), batchSize);
214 NN_RET_CHECK_EQ(getSizeOfDimension(outputStateShape, 1), outputSize);
215 const Shape cellStateShape = context->getInputShape(kCellStateInTensor);
216 NN_RET_CHECK_EQ(getNumberOfDimensions(cellStateShape), 2u);
217 NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 0), batchSize);
218 NN_RET_CHECK_EQ(getSizeOfDimension(cellStateShape, 1), numCells);
219
220 if (hasTensor(context, kInputLayerNormWeightsTensor)) {
221 const Shape inputLayerNormShape = context->getInputShape(kInputLayerNormWeightsTensor);
222 NN_RET_CHECK_EQ(getNumberOfDimensions(inputLayerNormShape), 1u);
223 NN_RET_CHECK_EQ(getSizeOfDimension(inputLayerNormShape, 0), numCells);
224 }
225
226 if (hasTensor(context, kForgetLayerNormWeightsTensor)) {
227 const Shape forgetLayerNormShape = context->getInputShape(kForgetLayerNormWeightsTensor);
228 NN_RET_CHECK_EQ(getNumberOfDimensions(forgetLayerNormShape), 1u);
229 NN_RET_CHECK_EQ(getSizeOfDimension(forgetLayerNormShape, 0), numCells);
230 }
231
232 if (hasTensor(context, kCellLayerNormWeightsTensor)) {
233 const Shape cellLayerNormShape = context->getInputShape(kCellLayerNormWeightsTensor);
234 NN_RET_CHECK_EQ(getNumberOfDimensions(cellLayerNormShape), 1u);
235 NN_RET_CHECK_EQ(getSizeOfDimension(cellLayerNormShape, 0), numCells);
236 }
237
238 if (hasTensor(context, kOutputLayerNormWeightsTensor)) {
239 const Shape outputLayerNormShape = context->getInputShape(kOutputLayerNormWeightsTensor);
240 NN_RET_CHECK_EQ(getNumberOfDimensions(outputLayerNormShape), 1u);
241 NN_RET_CHECK_EQ(getSizeOfDimension(outputLayerNormShape, 0), numCells);
242 }
243
244 if (cifgUsed) {
245 NN_RET_CHECK(!hasTensor(context, kInputLayerNormWeightsTensor))
246 << "Input layer norm weights tensor is present when CIFG is used";
247 const bool layerNormWeightsAllOrNoneCifg =
248 (hasTensor(context, kForgetLayerNormWeightsTensor) &&
249 hasTensor(context, kCellLayerNormWeightsTensor) &&
250 hasTensor(context, kOutputLayerNormWeightsTensor)) ||
251 (!hasTensor(context, kForgetLayerNormWeightsTensor) &&
252 !hasTensor(context, kCellLayerNormWeightsTensor) &&
253 !hasTensor(context, kOutputLayerNormWeightsTensor));
254 NN_RET_CHECK(layerNormWeightsAllOrNoneCifg);
255 } else {
256 const bool layerNormWeightsAllOrNone =
257 (hasTensor(context, kInputLayerNormWeightsTensor) &&
258 hasTensor(context, kForgetLayerNormWeightsTensor) &&
259 hasTensor(context, kCellLayerNormWeightsTensor) &&
260 hasTensor(context, kOutputLayerNormWeightsTensor)) ||
261 (!hasTensor(context, kInputLayerNormWeightsTensor) &&
262 !hasTensor(context, kForgetLayerNormWeightsTensor) &&
263 !hasTensor(context, kCellLayerNormWeightsTensor) &&
264 !hasTensor(context, kOutputLayerNormWeightsTensor));
265 NN_RET_CHECK(layerNormWeightsAllOrNone);
266 }
267
268 Shape outputShape = context->getInputShape(kInputTensor);
269 outputShape.dimensions[2] = outputSize;
270
271 if (context->getNumOutputs() == kNumOutputsWithState) {
272 NN_RET_CHECK(!context->isOmittedOutput(kOutputStateOutTensor));
273 NN_RET_CHECK(!context->isOmittedOutput(kCellStateOutTensor));
274
275 Shape outputStateOutTensor = context->getInputShape(kOutputStateInTensor);
276 outputStateOutTensor.dimensions.resize(2);
277 outputStateOutTensor.dimensions[0] = batchSize;
278 outputStateOutTensor.dimensions[1] = outputSize;
279 NN_RET_CHECK(context->setOutputShape(kOutputStateOutTensor, outputStateOutTensor));
280
281 Shape cellStateOutTensor = context->getInputShape(kCellStateInTensor);
282 cellStateOutTensor.dimensions.resize(2);
283 cellStateOutTensor.dimensions[0] = batchSize;
284 cellStateOutTensor.dimensions[1] = numCells;
285 NN_RET_CHECK(context->setOutputShape(kCellStateOutTensor, cellStateOutTensor));
286 }
287
288 return context->setOutputShape(kOutputTensor, outputShape);
289 }
290
execute(IOperationExecutionContext * context)291 bool execute(IOperationExecutionContext* context) {
292 const auto outputStateSize = getNumberOfElements(context->getInputShape(kOutputStateInTensor));
293 const auto cellStateSize = getNumberOfElements(context->getInputShape(kCellStateInTensor));
294 const bool use_cifg = !hasTensor(context, kInputToInputWeightsTensor);
295 const auto scratchSize = use_cifg ? 3 * cellStateSize : 4 * cellStateSize;
296 const bool useStateOutTensors = (context->getNumOutputs() == kNumOutputsWithState);
297
298 const OperandType inputType = context->getInputType(kInputTensor);
299 switch (inputType) {
300 case OperandType::TENSOR_FLOAT32: {
301 // Initialize empty vectors and resize below only if needed
302 std::vector<float> outputStateOutBuffer;
303 std::vector<float> cellStateOutBuffer;
304 float* outputStateOut;
305 float* cellStateOut;
306 if (useStateOutTensors) {
307 outputStateOut = context->getOutputBuffer<float>(kOutputStateOutTensor);
308 cellStateOut = context->getOutputBuffer<float>(kCellStateOutTensor);
309 } else {
310 outputStateOutBuffer.resize(outputStateSize);
311 cellStateOutBuffer.resize(cellStateSize);
312 outputStateOut = outputStateOutBuffer.data();
313 cellStateOut = cellStateOutBuffer.data();
314 }
315 std::vector<float> scratchBuffer(scratchSize);
316 LSTMCell::LSTMEvalFloat32(
317 getLSTMParams<float>(context), context->getInputBuffer<float>(kInputTensor),
318 context->getInputShape(kInputTensor),
319 context->getInputBuffer<float>(kInputToInputWeightsTensor),
320 context->getInputBuffer<float>(kInputToForgetWeightsTensor),
321 context->getInputBuffer<float>(kInputToCellWeightsTensor),
322 context->getInputBuffer<float>(kInputToOutputWeightsTensor),
323 context->getInputShape(kInputToOutputWeightsTensor),
324 context->getInputBuffer<float>(kRecurrentToInputWeightsTensor),
325 context->getInputBuffer<float>(kRecurrentToForgetWeightsTensor),
326 context->getInputBuffer<float>(kRecurrentToCellWeightsTensor),
327 context->getInputBuffer<float>(kRecurrentToOutputWeightsTensor),
328 context->getInputShape(kRecurrentToOutputWeightsTensor),
329 context->getInputBuffer<float>(kCellToInputWeightsTensor),
330 context->getInputBuffer<float>(kCellToForgetWeightsTensor),
331 context->getInputBuffer<float>(kCellToOutputWeightsTensor),
332 /*aux_input_buffer=*/nullptr,
333 /*aux_input_to_input_weights_buffer=*/nullptr,
334 /*aux_input_to_forget_weights_buffer=*/nullptr,
335 /*aux_input_to_cell_weights_buffer=*/nullptr,
336 /*aux_input_to_output_weights_buffer=*/nullptr,
337 context->getInputBuffer<float>(kInputGateBiasTensor),
338 context->getInputBuffer<float>(kForgetGateBiasTensor),
339 context->getInputBuffer<float>(kCellGateBiasTensor),
340 context->getInputBuffer<float>(kOutputGateBiasTensor),
341 context->getInputBuffer<float>(kProjectionWeightsTensor),
342 context->getInputBuffer<float>(kProjectionBiasTensor),
343 context->getInputBuffer<float>(kOutputStateInTensor),
344 context->getInputBuffer<float>(kCellStateInTensor),
345 context->getInputBuffer<float>(kInputLayerNormWeightsTensor),
346 context->getInputBuffer<float>(kForgetLayerNormWeightsTensor),
347 context->getInputBuffer<float>(kCellLayerNormWeightsTensor),
348 context->getInputBuffer<float>(kOutputLayerNormWeightsTensor), outputStateOut,
349 cellStateOut, context->getOutputBuffer<float>(kOutputTensor),
350 scratchBuffer.data(), isTimeMajor(context));
351 } break;
352 case OperandType::TENSOR_FLOAT16: {
353 // Initialize empty vectors and resize below only if needed
354 std::vector<_Float16> outputStateOutBuffer;
355 std::vector<_Float16> cellStateOutBuffer;
356 _Float16* outputStateOut;
357 _Float16* cellStateOut;
358 if (useStateOutTensors) {
359 outputStateOut = context->getOutputBuffer<_Float16>(kOutputStateOutTensor);
360 cellStateOut = context->getOutputBuffer<_Float16>(kCellStateOutTensor);
361 } else {
362 outputStateOutBuffer.resize(outputStateSize);
363 cellStateOutBuffer.resize(cellStateSize);
364 outputStateOut = outputStateOutBuffer.data();
365 cellStateOut = cellStateOutBuffer.data();
366 }
367 std::vector<_Float16> scratchBuffer(scratchSize);
368 LSTMCell::LSTMEvalFloat16(
369 getLSTMParams<_Float16>(context),
370 context->getInputBuffer<_Float16>(kInputTensor),
371 context->getInputShape(kInputTensor),
372 context->getInputBuffer<_Float16>(kInputToInputWeightsTensor),
373 context->getInputBuffer<_Float16>(kInputToForgetWeightsTensor),
374 context->getInputBuffer<_Float16>(kInputToCellWeightsTensor),
375 context->getInputBuffer<_Float16>(kInputToOutputWeightsTensor),
376 context->getInputShape(kInputToOutputWeightsTensor),
377 context->getInputBuffer<_Float16>(kRecurrentToInputWeightsTensor),
378 context->getInputBuffer<_Float16>(kRecurrentToForgetWeightsTensor),
379 context->getInputBuffer<_Float16>(kRecurrentToCellWeightsTensor),
380 context->getInputBuffer<_Float16>(kRecurrentToOutputWeightsTensor),
381 context->getInputShape(kRecurrentToOutputWeightsTensor),
382 context->getInputBuffer<_Float16>(kCellToInputWeightsTensor),
383 context->getInputBuffer<_Float16>(kCellToForgetWeightsTensor),
384 context->getInputBuffer<_Float16>(kCellToOutputWeightsTensor),
385 /*aux_input_buffer=*/nullptr,
386 /*aux_input_to_input_weights_buffer=*/nullptr,
387 /*aux_input_to_forget_weights_buffer=*/nullptr,
388 /*aux_input_to_cell_weights_buffer=*/nullptr,
389 /*aux_input_to_output_weights_buffer=*/nullptr,
390 context->getInputBuffer<_Float16>(kInputGateBiasTensor),
391 context->getInputBuffer<_Float16>(kForgetGateBiasTensor),
392 context->getInputBuffer<_Float16>(kCellGateBiasTensor),
393 context->getInputBuffer<_Float16>(kOutputGateBiasTensor),
394 context->getInputBuffer<_Float16>(kProjectionWeightsTensor),
395 context->getInputBuffer<_Float16>(kProjectionBiasTensor),
396 context->getInputBuffer<_Float16>(kOutputStateInTensor),
397 context->getInputBuffer<_Float16>(kCellStateInTensor),
398 context->getInputBuffer<_Float16>(kInputLayerNormWeightsTensor),
399 context->getInputBuffer<_Float16>(kForgetLayerNormWeightsTensor),
400 context->getInputBuffer<_Float16>(kCellLayerNormWeightsTensor),
401 context->getInputBuffer<_Float16>(kOutputLayerNormWeightsTensor),
402 outputStateOut, cellStateOut, context->getOutputBuffer<_Float16>(kOutputTensor),
403 scratchBuffer.data(), isTimeMajor(context));
404 } break;
405 default: {
406 LOG(ERROR) << "Unsupported data type: " << static_cast<int>(inputType);
407 return false;
408 }
409 }
410 return true;
411 }
412 #endif // NN_INCLUDE_CPU_IMPLEMENTATION
413
414 } // namespace unidirectional_sequence_lstm
415
416 NN_REGISTER_OPERATION_DEFAULT_VALIDATION(UNIDIRECTIONAL_SEQUENCE_LSTM,
417 unidirectional_sequence_lstm::prepare,
418 unidirectional_sequence_lstm::execute,
419 .allowOmittedOperand = true);
420
421 } // namespace nn
422 } // namespace android
423