1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/lstm_parser.h"
17
18 #include <optional>
19 #include <string>
20 #include <utility>
21
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/types/any.h"
25 #include "tensorflow/lite/c/builtin_op_data.h"
26 #include "tensorflow/lite/c/common.h"
27 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
28 #include "tensorflow/lite/delegates/gpu/common/model.h"
29 #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
30 #include "tensorflow/lite/delegates/gpu/common/object_reader.h"
31 #include "tensorflow/lite/delegates/gpu/common/operations.h"
32 #include "tensorflow/lite/delegates/gpu/common/shape.h"
33 #include "tensorflow/lite/delegates/gpu/common/status.h"
34 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
35 #include "tensorflow/lite/kernels/internal/quantization_util.h"
36 #include "tensorflow/lite/kernels/internal/tensor.h"
37 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
38 #include "tensorflow/lite/kernels/internal/types.h"
39 #include "tensorflow/lite/kernels/lstm_shared.h"
40 #include "tensorflow/lite/string_type.h"
41
42 namespace tflite {
43 namespace gpu {
44 namespace {
45
CreateNewSimilarValue(GraphFloat32 * graph,const Value * old_value)46 Value* CreateNewSimilarValue(GraphFloat32* graph, const Value* old_value) {
47 Value* new_value = graph->NewValue();
48 new_value->quant_params = old_value->quant_params;
49 new_value->tensor.shape = old_value->tensor.shape;
50 new_value->tensor.type = old_value->tensor.type;
51 new_value->tensor.ref = -1;
52 return new_value;
53 }
54
GetFullyConnectedNode(int weights_tensor_id,int bias_tensor_id,ObjectReader * reader,Node * node)55 absl::Status GetFullyConnectedNode(int weights_tensor_id, int bias_tensor_id,
56 ObjectReader* reader, Node* node) {
57 const TfLiteTensor* weights_tensor =
58 reader->GetInputTensor(weights_tensor_id);
59 TfLiteAffineQuantization* quant_params =
60 static_cast<TfLiteAffineQuantization*>(
61 weights_tensor->quantization.params);
62 if (weights_tensor->type == kTfLiteInt8 && quant_params->scale->size == 1) {
63 // uniform int8 quantization
64 node->operation.type = ToString(OperationType::FULLY_CONNECTED_INT8);
65 FullyConnectedInt8Attributes fc_attr;
66 fc_attr.scale = weights_tensor->params.scale;
67 fc_attr.zero_point = weights_tensor->params.zero_point;
68 fc_attr.weights.data.resize(weights_tensor->bytes);
69 std::memcpy(fc_attr.weights.data.data(), weights_tensor->data.int8,
70 weights_tensor->bytes);
71 int tensor_id;
72 RETURN_IF_ERROR(reader->GetTensorId(weights_tensor_id, &tensor_id));
73 fc_attr.weights.id = tensor_id;
74 fc_attr.weights.shape.o = weights_tensor->dims->data[0];
75 fc_attr.weights.shape.h = 1;
76 fc_attr.weights.shape.w = 1;
77 fc_attr.weights.shape.i = weights_tensor->dims->data[1];
78 if (bias_tensor_id != -1) {
79 reader->ReadTensor(bias_tensor_id, &(fc_attr.bias)).IgnoreError();
80 }
81 node->operation.attributes = std::move(fc_attr);
82 } else {
83 node->operation.type = ToString(OperationType::FULLY_CONNECTED);
84 FullyConnectedAttributes fc_attr;
85 Tensor<HW, DataType::FLOAT32> weights;
86 RETURN_IF_ERROR(reader->ReadTensor(weights_tensor_id, &weights));
87 fc_attr.weights.data = std::move(weights.data);
88 fc_attr.weights.id = weights.id;
89 fc_attr.weights.shape.o = weights.shape.h;
90 fc_attr.weights.shape.h = 1;
91 fc_attr.weights.shape.w = 1;
92 fc_attr.weights.shape.i = weights.shape.w;
93 if (bias_tensor_id != -1) {
94 reader->ReadTensor(bias_tensor_id, &(fc_attr.bias)).IgnoreError();
95 }
96 node->operation.attributes = std::move(fc_attr);
97 }
98 return absl::OkStatus();
99 }
100
HasTensor(const TfLiteNode * node,const int index)101 bool HasTensor(const TfLiteNode* node, const int index) {
102 return (index < node->inputs->size) &&
103 (node->inputs->data[index] != kTfLiteOptionalTensor);
104 }
105
HasCifg(const TfLiteNode * node)106 bool HasCifg(const TfLiteNode* node) {
107 return !HasTensor(
108 node, tflite::ops::builtin::lstm::full::kInputToInputWeightsTensor);
109 }
110
HasPeephole(const TfLiteNode * node)111 bool HasPeephole(const TfLiteNode* node) {
112 // Use forget weights to detect peephole instead of input weights as input
113 // weights may be missing for cifg.
114 return HasTensor(
115 node, tflite::ops::builtin::lstm::full::kCellToForgetWeightsTensor);
116 }
117
HasNormalization(const TfLiteNode * node)118 bool HasNormalization(const TfLiteNode* node) {
119 return HasTensor(
120 node,
121 tflite::ops::builtin::lstm::full::kForgetLayerNormCoefficientsTensor);
122 }
123
HasProjection(const TfLiteNode * node)124 bool HasProjection(const TfLiteNode* node) {
125 return HasTensor(node,
126 tflite::ops::builtin::lstm::full::kProjectionWeightsTensor);
127 }
128
129 // Builds subgraph for a single LSTM gate.
130 // Returns a Value representing the gate's output.
131 // High-level parameters:
132 // - Has normalization (if true: provide normalization weights).
133 // - Has peephole connection (if true: provide peephole weights).
134 // - Which activation function to use.
135 // Note: no support for aux input.
136 //
137 // Implements the following:
138 // (*: matrix multiply, .*: elementwise multiply, +: elementwise add):
139 // temp = input_weights * input_tensor + recurrent_weights * output_state;
140 // if (peephole):
141 // temp += peephole_weights .* cell_state;
142 // if (layer normalization):
143 // gate = activate(normalization_weights .* mean_stddev_norm(temp) + bias);
144 // else:
145 // gate = activate(temp + bias);
146 //
BuildLstmGate(GraphFloat32 * graph,ObjectReader * reader,Value * output_state,Value * cell_state,int input_weight_id,int recurrent_weight_id,int cell_weight_id,int bias_id,int normalization_weight_id,const TfLiteFusedActivation activation,bool has_peephole,bool has_normalization,Value ** gate_out)147 absl::Status BuildLstmGate(GraphFloat32* graph, ObjectReader* reader,
148 Value* output_state, Value* cell_state,
149 int input_weight_id, int recurrent_weight_id,
150 int cell_weight_id, int bias_id,
151 int normalization_weight_id,
152 const TfLiteFusedActivation activation,
153 bool has_peephole, bool has_normalization,
154 Value** gate_out) {
155 Value* input_times_weights = CreateNewSimilarValue(graph, cell_state);
156 {
157 // #1 matrix multiplication: input_weights * input_tensor
158 // If has no normalization, also adds bias.
159 Node* node = graph->NewNode();
160 int input_bias_id = !has_normalization ? bias_id : -1;
161 RETURN_IF_ERROR(
162 GetFullyConnectedNode(input_weight_id, input_bias_id, reader, node));
163 RETURN_IF_ERROR(
164 reader->AddInput(node, tflite::ops::builtin::lstm::full::kInputTensor));
165 RETURN_IF_ERROR(graph->SetProducer(node->id, input_times_weights->id));
166 }
167
168 Value* output_state_times_weights = CreateNewSimilarValue(graph, cell_state);
169 {
170 // #2 matrix multiplication: recurrent_weights * output_state
171 Node* node = graph->NewNode();
172 RETURN_IF_ERROR(
173 GetFullyConnectedNode(recurrent_weight_id, -1, reader, node));
174 RETURN_IF_ERROR(graph->AddConsumer(node->id, output_state->id));
175 RETURN_IF_ERROR(
176 graph->SetProducer(node->id, output_state_times_weights->id));
177 }
178
179 Value* cell_state_times_weights;
180 if (has_peephole) {
181 // #3 elementwise multiplication: cell_weight .* cell_state
182 cell_state_times_weights = CreateNewSimilarValue(graph, cell_state);
183 Node* node = graph->NewNode();
184 node->operation.type = ToString(OperationType::MUL);
185 ElementwiseAttributes attr;
186 Tensor<Linear, DataType::FLOAT32> weights;
187 RETURN_IF_ERROR(reader->ReadTensor(cell_weight_id, &weights));
188 attr.param = std::move(weights);
189 node->operation.attributes = std::move(attr);
190 RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_state->id));
191 RETURN_IF_ERROR(graph->SetProducer(node->id, cell_state_times_weights->id));
192 }
193
194 Value* gate_before_normalization = CreateNewSimilarValue(graph, cell_state);
195 Node* add_node = graph->NewNode();
196 {
197 // #4 elementwise addition: #1 + #2 + #3
198 add_node->operation.type = ToString(OperationType::ADD);
199 RETURN_IF_ERROR(graph->AddConsumer(add_node->id, input_times_weights->id));
200 RETURN_IF_ERROR(
201 graph->AddConsumer(add_node->id, output_state_times_weights->id));
202 if (has_peephole) {
203 RETURN_IF_ERROR(
204 graph->AddConsumer(add_node->id, cell_state_times_weights->id));
205 }
206 RETURN_IF_ERROR(
207 graph->SetProducer(add_node->id, gate_before_normalization->id));
208 }
209
210 if (!has_normalization) {
211 // #5 Activation function: activate(temp + bias)
212 // Bias is added in node #1.
213 RETURN_IF_ERROR(MaybeFuseActivation(activation, graph, add_node));
214 *gate_out = gate_before_normalization;
215 return absl::OkStatus();
216 }
217
218 Value* normalized_gate =
219 CreateNewSimilarValue(graph, gate_before_normalization);
220 {
221 // #6 Normalization: normalize(temp)
222 Node* node = graph->NewNode();
223 node->operation.type = ToString(OperationType::MEAN_STDDEV_NORMALIZATION);
224 RETURN_IF_ERROR(
225 graph->AddConsumer(node->id, gate_before_normalization->id));
226 RETURN_IF_ERROR(graph->SetProducer(node->id, normalized_gate->id));
227 }
228 Value* reweighted_normalized_gate =
229 CreateNewSimilarValue(graph, normalized_gate);
230 {
231 // #7 Elementwise multiplication: norm_weights .* #6
232 Node* node = graph->NewNode();
233 node->operation.type = ToString(OperationType::MUL);
234 ElementwiseAttributes attr;
235 Tensor<Linear, DataType::FLOAT32> norm_weights;
236 RETURN_IF_ERROR(reader->ReadTensor(normalization_weight_id, &norm_weights));
237 attr.param = std::move(norm_weights);
238 node->operation.attributes = std::move(attr);
239 RETURN_IF_ERROR(graph->AddConsumer(node->id, normalized_gate->id));
240 RETURN_IF_ERROR(
241 graph->SetProducer(node->id, reweighted_normalized_gate->id));
242 }
243 Value* gate = CreateNewSimilarValue(graph, reweighted_normalized_gate);
244 {
245 // #8 Elementwise add: #7 + bias
246 Node* node = graph->NewNode();
247 node->operation.type = ToString(OperationType::ADD);
248 ElementwiseAttributes attr;
249 Tensor<Linear, DataType::FLOAT32> bias;
250 RETURN_IF_ERROR(reader->ReadTensor(bias_id, &bias));
251 attr.param = std::move(bias);
252 node->operation.attributes = std::move(attr);
253 RETURN_IF_ERROR(
254 graph->AddConsumer(node->id, reweighted_normalized_gate->id));
255 RETURN_IF_ERROR(graph->SetProducer(node->id, gate->id));
256
257 // #9: Activation function
258 RETURN_IF_ERROR(MaybeFuseActivation(activation, graph, node));
259 }
260 *gate_out = gate;
261 return absl::OkStatus();
262 }
263
264 // Builds subgraph for LSTM cell state update.
265 // Returns a Value representing the updated cell state.
266 // High-level parameters:
267 // - clip: if > 0, clamp the resulting cell state to [-clip, +clip].
268 //
269 // Implements the following:
270 // (*: matrix multiply, .*: elementwise multiply, +: elementwise add):
271 //
272 // cell_state_new = clip(forget_gate .* cell_state + input_gate .* cell_gate);
273 //
BuildCellStateUpdate(GraphFloat32 * graph,ObjectReader * reader,Value * forget_gate,Value * input_gate,Value * cell_gate,float cell_clip,Value ** cell_state_new)274 absl::Status BuildCellStateUpdate(GraphFloat32* graph, ObjectReader* reader,
275 Value* forget_gate, Value* input_gate,
276 Value* cell_gate, float cell_clip,
277 Value** cell_state_new) {
278 Value* cell_state;
279 RETURN_IF_ERROR(reader->ReadValue(
280 tflite::ops::builtin::lstm::full::kCellStateTensor, &cell_state));
281 Value* cell_state_contrib = CreateNewSimilarValue(graph, cell_gate);
282 {
283 // #1 elementwise multiplication: forget_gate .* cell_state
284 Node* node = graph->NewNode();
285 node->operation.type = ToString(OperationType::MUL);
286 RETURN_IF_ERROR(graph->AddConsumer(node->id, forget_gate->id));
287 RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_state->id));
288 RETURN_IF_ERROR(graph->SetProducer(node->id, cell_state_contrib->id));
289 }
290 Value* cell_gate_contrib = CreateNewSimilarValue(graph, cell_gate);
291 {
292 // #2 elementwise multiplication: input_gate .* cell_gate
293 // Note, with CIFG input_gate is equal to 1-forget_gate.
294 Node* node = graph->NewNode();
295 node->operation.type = ToString(OperationType::MUL);
296 RETURN_IF_ERROR(graph->AddConsumer(node->id, input_gate->id));
297 RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_gate->id));
298 RETURN_IF_ERROR(graph->SetProducer(node->id, cell_gate_contrib->id));
299 }
300 Value* new_cell_state = CreateNewSimilarValue(graph, cell_gate);
301 {
302 // #3 elementwise add: #1 + #2
303 Node* node = graph->NewNode();
304 node->operation.type = ToString(OperationType::ADD);
305 RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_state_contrib->id));
306 RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_gate_contrib->id));
307 RETURN_IF_ERROR(graph->SetProducer(node->id, new_cell_state->id));
308 }
309
310 if (cell_clip <= 0.0f) {
311 *cell_state_new = new_cell_state;
312 return absl::OkStatus();
313 }
314
315 Value* max_clipped_state = CreateNewSimilarValue(graph, new_cell_state);
316 {
317 // #4 elementwise minimum: min(#3, clip)
318 Node* node = graph->NewNode();
319 node->operation.type = ToString(OperationType::MINIMUM);
320 ElementwiseAttributes attr;
321 attr.param = cell_clip;
322 node->operation.attributes = std::move(attr);
323 RETURN_IF_ERROR(graph->AddConsumer(node->id, new_cell_state->id));
324 RETURN_IF_ERROR(graph->SetProducer(node->id, max_clipped_state->id));
325 }
326 Value* clipped_cell_state = CreateNewSimilarValue(graph, max_clipped_state);
327 {
328 // #5 elementwise maximum: max(#4, -clip)
329 Node* node = graph->NewNode();
330 node->operation.type = ToString(OperationType::MAXIMUM);
331 ElementwiseAttributes attr;
332 attr.param = -cell_clip;
333 node->operation.attributes = std::move(attr);
334 RETURN_IF_ERROR(graph->AddConsumer(node->id, max_clipped_state->id));
335 RETURN_IF_ERROR(graph->SetProducer(node->id, clipped_cell_state->id));
336 }
337 *cell_state_new = clipped_cell_state;
338 return absl::OkStatus();
339 }
340
341 // Build subgraph for LSTM output state update.
342 // Returns value representing the updated output state.
343 // High-level parameters:
344 // - Has projection (if true, provide projection_weights).
345 // - Has projection bias (only with projection).
346 // - clip: clamp the projection output to [-clip, clip].
347 // - Which activation function to use.
348 // Note the updated output state does not depend on the old output state
349 // directly, only through the output gate.
350 //
351 // Implements the following:
352 // (*: matrix multiply, .*: elementwise multiply, +: elementwise add):
353 //
354 // temp = output_gate .* activate(cell_state);
355 // if (projection):
356 // output_state_new = clip(projection_weights * temp + projection_bias);
357 // else:
358 // output_state_new = temp;
359 //
BuildOutputStateUpdate(GraphFloat32 * graph,ObjectReader * reader,Value * output_state,Value * output_gate,Value * cell_state,TfLiteFusedActivation activation,bool has_projection,float proj_clip,Value ** output_state_new)360 absl::Status BuildOutputStateUpdate(GraphFloat32* graph, ObjectReader* reader,
361 Value* output_state, Value* output_gate,
362 Value* cell_state,
363 TfLiteFusedActivation activation,
364 bool has_projection, float proj_clip,
365 Value** output_state_new) {
366 Value* activated_state = CreateNewSimilarValue(graph, cell_state);
367 {
368 // #1 activation: activate(cell_state)
369 Node* node = graph->NewNode();
370 switch (activation) {
371 case kTfLiteActTanh:
372 node->operation.type = ToString(OperationType::TANH);
373 break;
374 case kTfLiteActSigmoid:
375 node->operation.type = ToString(OperationType::SIGMOID);
376 break;
377 default:
378 return absl::InvalidArgumentError(
379 absl::StrCat("Unsupported activation: ", activation));
380 }
381 RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_state->id));
382 RETURN_IF_ERROR(graph->SetProducer(node->id, activated_state->id));
383 }
384
385 Value* new_output_state = CreateNewSimilarValue(graph, cell_state);
386 {
387 // #2 elementwise multiplication: output_gate .* #1
388 Node* node = graph->NewNode();
389 node->operation.type = ToString(OperationType::MUL);
390 RETURN_IF_ERROR(graph->AddConsumer(node->id, activated_state->id));
391 RETURN_IF_ERROR(graph->AddConsumer(node->id, output_gate->id));
392 RETURN_IF_ERROR(graph->SetProducer(node->id, new_output_state->id));
393 }
394
395 if (!has_projection) {
396 *output_state_new = new_output_state;
397 return absl::OkStatus();
398 }
399
400 Value* projected_output_state = CreateNewSimilarValue(graph, output_state);
401 {
402 // #3 matrix multiplication: projection_weights * #2 + projection_bias
403 Node* node = graph->NewNode();
404
405 RETURN_IF_ERROR(GetFullyConnectedNode(
406 tflite::ops::builtin::lstm::full::kProjectionWeightsTensor,
407 tflite::ops::builtin::lstm::full::kProjectionBiasTensor, reader, node));
408
409 RETURN_IF_ERROR(graph->AddConsumer(node->id, new_output_state->id));
410 RETURN_IF_ERROR(graph->SetProducer(node->id, projected_output_state->id));
411 }
412
413 if (proj_clip <= 0.0f) {
414 *output_state_new = projected_output_state;
415 return absl::OkStatus();
416 }
417
418 Value* max_clipped_state =
419 CreateNewSimilarValue(graph, projected_output_state);
420 {
421 // #4 elementwise minimum: min(#3, clip)
422 Node* node = graph->NewNode();
423 node->operation.type = ToString(OperationType::MINIMUM);
424 ElementwiseAttributes attr;
425 attr.param = proj_clip;
426 node->operation.attributes = std::move(attr);
427 RETURN_IF_ERROR(graph->AddConsumer(node->id, projected_output_state->id));
428 RETURN_IF_ERROR(graph->SetProducer(node->id, max_clipped_state->id));
429 }
430 Value* clipped_output_state = CreateNewSimilarValue(graph, max_clipped_state);
431 {
432 // #5 elementwise maximum: max(#4, -clip)
433 Node* node = graph->NewNode();
434 node->operation.type = ToString(OperationType::MAXIMUM);
435 ElementwiseAttributes attr;
436 attr.param = -proj_clip;
437 node->operation.attributes = std::move(attr);
438 RETURN_IF_ERROR(graph->AddConsumer(node->id, max_clipped_state->id));
439 RETURN_IF_ERROR(graph->SetProducer(node->id, clipped_output_state->id));
440 }
441 *output_state_new = clipped_output_state;
442 return absl::OkStatus();
443 }
444
445 } // namespace
446
447 // Build subgraph for a single LSTM OP.
448 // Returns a mapping for the used variable tensors' updated Values.
449 //
450 // High-level parameters:
451 // - Has CIFG:
452 // If false, calculate input_gate regularly.
453 // If true, calculate input_gate to 1-forget_gate.
454 // - Has peephole: see BuildLstmGate. Applies to all gates.
455 // - Has normalization: see BuildLstmGate. Applies to all gates.
456 // - Has projection, projection_bias, proj_clip: see BuildOutputStateUpdate
457 // - Which activation to use:
458 // Applies to only cell gate and output state update.
459 // Other gates always use Sigmoid.
460 //
ParseLSTMAttributes(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader,const TfLiteLSTMParams * params,absl::flat_hash_map<int,ValueId> * new_variable_input_values)461 absl::Status ParseLSTMAttributes(
462 const TfLiteNode* tflite_node, const TfLiteRegistration* registration,
463 GraphFloat32* graph, ObjectReader* reader, const TfLiteLSTMParams* params,
464 absl::flat_hash_map<int, ValueId>* new_variable_input_values) {
465 const bool has_cifg = HasCifg(tflite_node);
466 const bool has_peephole = HasPeephole(tflite_node);
467 const bool has_normalization = HasNormalization(tflite_node);
468 const bool has_projection = HasProjection(tflite_node);
469
470 Value* old_cell_state;
471 RETURN_IF_ERROR(reader->ReadValue(
472 tflite::ops::builtin::lstm::full::kCellStateTensor, &old_cell_state));
473
474 if (old_cell_state->tensor.shape.b != 1) {
475 return absl::InvalidArgumentError(
476 "Batched execution is not supported for LSTM");
477 }
478
479 Value* old_output_state;
480 RETURN_IF_ERROR(reader->ReadValue(
481 tflite::ops::builtin::lstm::full::kOutputStateTensor, &old_output_state));
482
483 Value* forget_gate;
484 RETURN_IF_ERROR(BuildLstmGate(
485 graph, reader, old_output_state, old_cell_state,
486 tflite::ops::builtin::lstm::full::kInputToForgetWeightsTensor,
487 tflite::ops::builtin::lstm::full::kRecurrentToForgetWeightsTensor,
488 tflite::ops::builtin::lstm::full::kCellToForgetWeightsTensor,
489 tflite::ops::builtin::lstm::full::kForgetGateBiasTensor,
490 tflite::ops::builtin::lstm::full::kForgetLayerNormCoefficientsTensor,
491 kTfLiteActSigmoid, has_peephole, has_normalization, &forget_gate));
492
493 Value* input_gate;
494 if (has_cifg) {
495 // When using cifg, input_gate is computed as (1 - forget_gate).
496 Node* node = graph->NewNode();
497 input_gate = CreateNewSimilarValue(graph, forget_gate);
498
499 node->operation.type = ToString(OperationType::SUB);
500 ElementwiseAttributes attr;
501 attr.param = 1.0f;
502 attr.runtime_tensor_is_second = true;
503 node->operation.attributes = std::move(attr);
504 RETURN_IF_ERROR(graph->AddConsumer(node->id, forget_gate->id));
505 RETURN_IF_ERROR(graph->SetProducer(node->id, input_gate->id));
506 } else {
507 RETURN_IF_ERROR(BuildLstmGate(
508 graph, reader, old_output_state, old_cell_state,
509 tflite::ops::builtin::lstm::full::kInputToInputWeightsTensor,
510 tflite::ops::builtin::lstm::full::kRecurrentToInputWeightsTensor,
511 tflite::ops::builtin::lstm::full::kCellToInputWeightsTensor,
512 tflite::ops::builtin::lstm::full::kInputGateBiasTensor,
513 tflite::ops::builtin::lstm::full::kInputLayerNormCoefficientsTensor,
514 kTfLiteActSigmoid, has_peephole, has_normalization, &input_gate));
515 }
516
517 // Cell state will not have peephole connections to itself
518 Value* cell_gate;
519 RETURN_IF_ERROR(BuildLstmGate(
520 graph, reader, old_output_state, old_cell_state,
521 tflite::ops::builtin::lstm::full::kInputToCellWeightsTensor,
522 tflite::ops::builtin::lstm::full::kRecurrentToCellWeightsTensor,
523 /*cell_weight_id=*/-1,
524 tflite::ops::builtin::lstm::full::kCellGateBiasTensor,
525 tflite::ops::builtin::lstm::full::kCellLayerNormCoefficientsTensor,
526 params->activation, /*has_peephole=*/false, has_normalization,
527 &cell_gate));
528
529 Value* new_cell_state;
530 RETURN_IF_ERROR(BuildCellStateUpdate(graph, reader, forget_gate, input_gate,
531 cell_gate, params->cell_clip,
532 &new_cell_state));
533
534 Value* output_gate;
535 RETURN_IF_ERROR(BuildLstmGate(
536 graph, reader, old_output_state, new_cell_state,
537 tflite::ops::builtin::lstm::full::kInputToOutputWeightsTensor,
538 tflite::ops::builtin::lstm::full::kRecurrentToOutputWeightsTensor,
539 tflite::ops::builtin::lstm::full::kCellToOutputWeightsTensor,
540 tflite::ops::builtin::lstm::full::kOutputGateBiasTensor,
541 tflite::ops::builtin::lstm::full::kOutputLayerNormCoefficientsTensor,
542 kTfLiteActSigmoid, has_peephole, has_normalization, &output_gate));
543
544 Value* new_output_state;
545 RETURN_IF_ERROR(BuildOutputStateUpdate(graph, reader, old_output_state,
546 output_gate, new_cell_state,
547 params->activation, has_projection,
548 params->proj_clip, &new_output_state));
549
550 {
551 // Copy updated output state to output.
552 Node* node = graph->NewNode();
553 node->operation.type = ToString(OperationType::COPY);
554 RETURN_IF_ERROR(graph->AddConsumer(node->id, new_output_state->id));
555 RETURN_IF_ERROR(reader->AddOutput(
556 node, tflite::ops::builtin::lstm::full::kOutputTensor));
557 }
558
559 new_variable_input_values->clear();
560 new_variable_input_values->emplace(
561 tflite::ops::builtin::lstm::full::kCellStateTensor, new_cell_state->id);
562 new_variable_input_values->emplace(
563 tflite::ops::builtin::lstm::full::kOutputStateTensor,
564 new_output_state->id);
565 return absl::OkStatus();
566 }
567
568 } // namespace gpu
569 } // namespace tflite
570