xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/common/lstm_parser.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/lstm_parser.h"
17 
18 #include <optional>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/types/any.h"
25 #include "tensorflow/lite/c/builtin_op_data.h"
26 #include "tensorflow/lite/c/common.h"
27 #include "tensorflow/lite/delegates/gpu/common/data_type.h"
28 #include "tensorflow/lite/delegates/gpu/common/model.h"
29 #include "tensorflow/lite/delegates/gpu/common/model_builder_helper.h"
30 #include "tensorflow/lite/delegates/gpu/common/object_reader.h"
31 #include "tensorflow/lite/delegates/gpu/common/operations.h"
32 #include "tensorflow/lite/delegates/gpu/common/shape.h"
33 #include "tensorflow/lite/delegates/gpu/common/status.h"
34 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
35 #include "tensorflow/lite/kernels/internal/quantization_util.h"
36 #include "tensorflow/lite/kernels/internal/tensor.h"
37 #include "tensorflow/lite/kernels/internal/tensor_utils.h"
38 #include "tensorflow/lite/kernels/internal/types.h"
39 #include "tensorflow/lite/kernels/lstm_shared.h"
40 #include "tensorflow/lite/string_type.h"
41 
42 namespace tflite {
43 namespace gpu {
44 namespace {
45 
CreateNewSimilarValue(GraphFloat32 * graph,const Value * old_value)46 Value* CreateNewSimilarValue(GraphFloat32* graph, const Value* old_value) {
47   Value* new_value = graph->NewValue();
48   new_value->quant_params = old_value->quant_params;
49   new_value->tensor.shape = old_value->tensor.shape;
50   new_value->tensor.type = old_value->tensor.type;
51   new_value->tensor.ref = -1;
52   return new_value;
53 }
54 
GetFullyConnectedNode(int weights_tensor_id,int bias_tensor_id,ObjectReader * reader,Node * node)55 absl::Status GetFullyConnectedNode(int weights_tensor_id, int bias_tensor_id,
56                                    ObjectReader* reader, Node* node) {
57   const TfLiteTensor* weights_tensor =
58       reader->GetInputTensor(weights_tensor_id);
59   TfLiteAffineQuantization* quant_params =
60       static_cast<TfLiteAffineQuantization*>(
61           weights_tensor->quantization.params);
62   if (weights_tensor->type == kTfLiteInt8 && quant_params->scale->size == 1) {
63     // uniform int8 quantization
64     node->operation.type = ToString(OperationType::FULLY_CONNECTED_INT8);
65     FullyConnectedInt8Attributes fc_attr;
66     fc_attr.scale = weights_tensor->params.scale;
67     fc_attr.zero_point = weights_tensor->params.zero_point;
68     fc_attr.weights.data.resize(weights_tensor->bytes);
69     std::memcpy(fc_attr.weights.data.data(), weights_tensor->data.int8,
70                 weights_tensor->bytes);
71     int tensor_id;
72     RETURN_IF_ERROR(reader->GetTensorId(weights_tensor_id, &tensor_id));
73     fc_attr.weights.id = tensor_id;
74     fc_attr.weights.shape.o = weights_tensor->dims->data[0];
75     fc_attr.weights.shape.h = 1;
76     fc_attr.weights.shape.w = 1;
77     fc_attr.weights.shape.i = weights_tensor->dims->data[1];
78     if (bias_tensor_id != -1) {
79       reader->ReadTensor(bias_tensor_id, &(fc_attr.bias)).IgnoreError();
80     }
81     node->operation.attributes = std::move(fc_attr);
82   } else {
83     node->operation.type = ToString(OperationType::FULLY_CONNECTED);
84     FullyConnectedAttributes fc_attr;
85     Tensor<HW, DataType::FLOAT32> weights;
86     RETURN_IF_ERROR(reader->ReadTensor(weights_tensor_id, &weights));
87     fc_attr.weights.data = std::move(weights.data);
88     fc_attr.weights.id = weights.id;
89     fc_attr.weights.shape.o = weights.shape.h;
90     fc_attr.weights.shape.h = 1;
91     fc_attr.weights.shape.w = 1;
92     fc_attr.weights.shape.i = weights.shape.w;
93     if (bias_tensor_id != -1) {
94       reader->ReadTensor(bias_tensor_id, &(fc_attr.bias)).IgnoreError();
95     }
96     node->operation.attributes = std::move(fc_attr);
97   }
98   return absl::OkStatus();
99 }
100 
HasTensor(const TfLiteNode * node,const int index)101 bool HasTensor(const TfLiteNode* node, const int index) {
102   return (index < node->inputs->size) &&
103          (node->inputs->data[index] != kTfLiteOptionalTensor);
104 }
105 
HasCifg(const TfLiteNode * node)106 bool HasCifg(const TfLiteNode* node) {
107   return !HasTensor(
108       node, tflite::ops::builtin::lstm::full::kInputToInputWeightsTensor);
109 }
110 
HasPeephole(const TfLiteNode * node)111 bool HasPeephole(const TfLiteNode* node) {
112   // Use forget weights to detect peephole instead of input weights as input
113   // weights may be missing for cifg.
114   return HasTensor(
115       node, tflite::ops::builtin::lstm::full::kCellToForgetWeightsTensor);
116 }
117 
HasNormalization(const TfLiteNode * node)118 bool HasNormalization(const TfLiteNode* node) {
119   return HasTensor(
120       node,
121       tflite::ops::builtin::lstm::full::kForgetLayerNormCoefficientsTensor);
122 }
123 
HasProjection(const TfLiteNode * node)124 bool HasProjection(const TfLiteNode* node) {
125   return HasTensor(node,
126                    tflite::ops::builtin::lstm::full::kProjectionWeightsTensor);
127 }
128 
129 // Builds subgraph for a single LSTM gate.
130 // Returns a Value representing the gate's output.
131 // High-level parameters:
132 //   - Has normalization (if true: provide normalization weights).
133 //   - Has peephole connection (if true: provide peephole weights).
134 //   - Which activation function to use.
135 // Note: no support for aux input.
136 //
137 // Implements the following:
138 // (*: matrix multiply, .*: elementwise multiply, +: elementwise add):
139 //   temp = input_weights * input_tensor + recurrent_weights * output_state;
140 //   if (peephole):
141 //     temp += peephole_weights .* cell_state;
142 //   if (layer normalization):
143 //     gate = activate(normalization_weights .* mean_stddev_norm(temp) + bias);
144 //   else:
145 //     gate = activate(temp + bias);
146 //
BuildLstmGate(GraphFloat32 * graph,ObjectReader * reader,Value * output_state,Value * cell_state,int input_weight_id,int recurrent_weight_id,int cell_weight_id,int bias_id,int normalization_weight_id,const TfLiteFusedActivation activation,bool has_peephole,bool has_normalization,Value ** gate_out)147 absl::Status BuildLstmGate(GraphFloat32* graph, ObjectReader* reader,
148                            Value* output_state, Value* cell_state,
149                            int input_weight_id, int recurrent_weight_id,
150                            int cell_weight_id, int bias_id,
151                            int normalization_weight_id,
152                            const TfLiteFusedActivation activation,
153                            bool has_peephole, bool has_normalization,
154                            Value** gate_out) {
155   Value* input_times_weights = CreateNewSimilarValue(graph, cell_state);
156   {
157     // #1 matrix multiplication: input_weights * input_tensor
158     // If has no normalization, also adds bias.
159     Node* node = graph->NewNode();
160     int input_bias_id = !has_normalization ? bias_id : -1;
161     RETURN_IF_ERROR(
162         GetFullyConnectedNode(input_weight_id, input_bias_id, reader, node));
163     RETURN_IF_ERROR(
164         reader->AddInput(node, tflite::ops::builtin::lstm::full::kInputTensor));
165     RETURN_IF_ERROR(graph->SetProducer(node->id, input_times_weights->id));
166   }
167 
168   Value* output_state_times_weights = CreateNewSimilarValue(graph, cell_state);
169   {
170     // #2 matrix multiplication: recurrent_weights * output_state
171     Node* node = graph->NewNode();
172     RETURN_IF_ERROR(
173         GetFullyConnectedNode(recurrent_weight_id, -1, reader, node));
174     RETURN_IF_ERROR(graph->AddConsumer(node->id, output_state->id));
175     RETURN_IF_ERROR(
176         graph->SetProducer(node->id, output_state_times_weights->id));
177   }
178 
179   Value* cell_state_times_weights;
180   if (has_peephole) {
181     // #3 elementwise multiplication: cell_weight .* cell_state
182     cell_state_times_weights = CreateNewSimilarValue(graph, cell_state);
183     Node* node = graph->NewNode();
184     node->operation.type = ToString(OperationType::MUL);
185     ElementwiseAttributes attr;
186     Tensor<Linear, DataType::FLOAT32> weights;
187     RETURN_IF_ERROR(reader->ReadTensor(cell_weight_id, &weights));
188     attr.param = std::move(weights);
189     node->operation.attributes = std::move(attr);
190     RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_state->id));
191     RETURN_IF_ERROR(graph->SetProducer(node->id, cell_state_times_weights->id));
192   }
193 
194   Value* gate_before_normalization = CreateNewSimilarValue(graph, cell_state);
195   Node* add_node = graph->NewNode();
196   {
197     // #4 elementwise addition: #1 + #2 + #3
198     add_node->operation.type = ToString(OperationType::ADD);
199     RETURN_IF_ERROR(graph->AddConsumer(add_node->id, input_times_weights->id));
200     RETURN_IF_ERROR(
201         graph->AddConsumer(add_node->id, output_state_times_weights->id));
202     if (has_peephole) {
203       RETURN_IF_ERROR(
204           graph->AddConsumer(add_node->id, cell_state_times_weights->id));
205     }
206     RETURN_IF_ERROR(
207         graph->SetProducer(add_node->id, gate_before_normalization->id));
208   }
209 
210   if (!has_normalization) {
211     // #5 Activation function: activate(temp + bias)
212     // Bias is added in node #1.
213     RETURN_IF_ERROR(MaybeFuseActivation(activation, graph, add_node));
214     *gate_out = gate_before_normalization;
215     return absl::OkStatus();
216   }
217 
218   Value* normalized_gate =
219       CreateNewSimilarValue(graph, gate_before_normalization);
220   {
221     // #6 Normalization: normalize(temp)
222     Node* node = graph->NewNode();
223     node->operation.type = ToString(OperationType::MEAN_STDDEV_NORMALIZATION);
224     RETURN_IF_ERROR(
225         graph->AddConsumer(node->id, gate_before_normalization->id));
226     RETURN_IF_ERROR(graph->SetProducer(node->id, normalized_gate->id));
227   }
228   Value* reweighted_normalized_gate =
229       CreateNewSimilarValue(graph, normalized_gate);
230   {
231     // #7 Elementwise multiplication: norm_weights .* #6
232     Node* node = graph->NewNode();
233     node->operation.type = ToString(OperationType::MUL);
234     ElementwiseAttributes attr;
235     Tensor<Linear, DataType::FLOAT32> norm_weights;
236     RETURN_IF_ERROR(reader->ReadTensor(normalization_weight_id, &norm_weights));
237     attr.param = std::move(norm_weights);
238     node->operation.attributes = std::move(attr);
239     RETURN_IF_ERROR(graph->AddConsumer(node->id, normalized_gate->id));
240     RETURN_IF_ERROR(
241         graph->SetProducer(node->id, reweighted_normalized_gate->id));
242   }
243   Value* gate = CreateNewSimilarValue(graph, reweighted_normalized_gate);
244   {
245     // #8 Elementwise add: #7 + bias
246     Node* node = graph->NewNode();
247     node->operation.type = ToString(OperationType::ADD);
248     ElementwiseAttributes attr;
249     Tensor<Linear, DataType::FLOAT32> bias;
250     RETURN_IF_ERROR(reader->ReadTensor(bias_id, &bias));
251     attr.param = std::move(bias);
252     node->operation.attributes = std::move(attr);
253     RETURN_IF_ERROR(
254         graph->AddConsumer(node->id, reweighted_normalized_gate->id));
255     RETURN_IF_ERROR(graph->SetProducer(node->id, gate->id));
256 
257     // #9: Activation function
258     RETURN_IF_ERROR(MaybeFuseActivation(activation, graph, node));
259   }
260   *gate_out = gate;
261   return absl::OkStatus();
262 }
263 
264 // Builds subgraph for LSTM cell state update.
265 // Returns a Value representing the updated cell state.
266 // High-level parameters:
267 //  - clip: if > 0, clamp the resulting cell state to [-clip, +clip].
268 //
269 // Implements the following:
270 // (*: matrix multiply, .*: elementwise multiply, +: elementwise add):
271 //
272 //   cell_state_new = clip(forget_gate .* cell_state + input_gate .* cell_gate);
273 //
BuildCellStateUpdate(GraphFloat32 * graph,ObjectReader * reader,Value * forget_gate,Value * input_gate,Value * cell_gate,float cell_clip,Value ** cell_state_new)274 absl::Status BuildCellStateUpdate(GraphFloat32* graph, ObjectReader* reader,
275                                   Value* forget_gate, Value* input_gate,
276                                   Value* cell_gate, float cell_clip,
277                                   Value** cell_state_new) {
278   Value* cell_state;
279   RETURN_IF_ERROR(reader->ReadValue(
280       tflite::ops::builtin::lstm::full::kCellStateTensor, &cell_state));
281   Value* cell_state_contrib = CreateNewSimilarValue(graph, cell_gate);
282   {
283     // #1 elementwise multiplication: forget_gate .* cell_state
284     Node* node = graph->NewNode();
285     node->operation.type = ToString(OperationType::MUL);
286     RETURN_IF_ERROR(graph->AddConsumer(node->id, forget_gate->id));
287     RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_state->id));
288     RETURN_IF_ERROR(graph->SetProducer(node->id, cell_state_contrib->id));
289   }
290   Value* cell_gate_contrib = CreateNewSimilarValue(graph, cell_gate);
291   {
292     // #2 elementwise multiplication: input_gate .* cell_gate
293     // Note, with CIFG input_gate is equal to 1-forget_gate.
294     Node* node = graph->NewNode();
295     node->operation.type = ToString(OperationType::MUL);
296     RETURN_IF_ERROR(graph->AddConsumer(node->id, input_gate->id));
297     RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_gate->id));
298     RETURN_IF_ERROR(graph->SetProducer(node->id, cell_gate_contrib->id));
299   }
300   Value* new_cell_state = CreateNewSimilarValue(graph, cell_gate);
301   {
302     // #3 elementwise add: #1 + #2
303     Node* node = graph->NewNode();
304     node->operation.type = ToString(OperationType::ADD);
305     RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_state_contrib->id));
306     RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_gate_contrib->id));
307     RETURN_IF_ERROR(graph->SetProducer(node->id, new_cell_state->id));
308   }
309 
310   if (cell_clip <= 0.0f) {
311     *cell_state_new = new_cell_state;
312     return absl::OkStatus();
313   }
314 
315   Value* max_clipped_state = CreateNewSimilarValue(graph, new_cell_state);
316   {
317     // #4 elementwise minimum: min(#3, clip)
318     Node* node = graph->NewNode();
319     node->operation.type = ToString(OperationType::MINIMUM);
320     ElementwiseAttributes attr;
321     attr.param = cell_clip;
322     node->operation.attributes = std::move(attr);
323     RETURN_IF_ERROR(graph->AddConsumer(node->id, new_cell_state->id));
324     RETURN_IF_ERROR(graph->SetProducer(node->id, max_clipped_state->id));
325   }
326   Value* clipped_cell_state = CreateNewSimilarValue(graph, max_clipped_state);
327   {
328     // #5 elementwise maximum: max(#4, -clip)
329     Node* node = graph->NewNode();
330     node->operation.type = ToString(OperationType::MAXIMUM);
331     ElementwiseAttributes attr;
332     attr.param = -cell_clip;
333     node->operation.attributes = std::move(attr);
334     RETURN_IF_ERROR(graph->AddConsumer(node->id, max_clipped_state->id));
335     RETURN_IF_ERROR(graph->SetProducer(node->id, clipped_cell_state->id));
336   }
337   *cell_state_new = clipped_cell_state;
338   return absl::OkStatus();
339 }
340 
341 // Build subgraph for LSTM output state update.
342 // Returns value representing the updated output state.
343 // High-level parameters:
344 //   - Has projection (if true, provide projection_weights).
345 //   - Has projection bias (only with projection).
346 //   - clip: clamp the projection output to [-clip, clip].
347 //   - Which activation function to use.
348 // Note the updated output state does not depend on the old output state
349 // directly, only through the output gate.
350 //
351 // Implements the following:
352 // (*: matrix multiply, .*: elementwise multiply, +: elementwise add):
353 //
354 //   temp = output_gate .* activate(cell_state);
355 //   if (projection):
356 //     output_state_new = clip(projection_weights * temp + projection_bias);
357 //   else:
358 //     output_state_new = temp;
359 //
BuildOutputStateUpdate(GraphFloat32 * graph,ObjectReader * reader,Value * output_state,Value * output_gate,Value * cell_state,TfLiteFusedActivation activation,bool has_projection,float proj_clip,Value ** output_state_new)360 absl::Status BuildOutputStateUpdate(GraphFloat32* graph, ObjectReader* reader,
361                                     Value* output_state, Value* output_gate,
362                                     Value* cell_state,
363                                     TfLiteFusedActivation activation,
364                                     bool has_projection, float proj_clip,
365                                     Value** output_state_new) {
366   Value* activated_state = CreateNewSimilarValue(graph, cell_state);
367   {
368     // #1 activation: activate(cell_state)
369     Node* node = graph->NewNode();
370     switch (activation) {
371       case kTfLiteActTanh:
372         node->operation.type = ToString(OperationType::TANH);
373         break;
374       case kTfLiteActSigmoid:
375         node->operation.type = ToString(OperationType::SIGMOID);
376         break;
377       default:
378         return absl::InvalidArgumentError(
379             absl::StrCat("Unsupported activation: ", activation));
380     }
381     RETURN_IF_ERROR(graph->AddConsumer(node->id, cell_state->id));
382     RETURN_IF_ERROR(graph->SetProducer(node->id, activated_state->id));
383   }
384 
385   Value* new_output_state = CreateNewSimilarValue(graph, cell_state);
386   {
387     // #2 elementwise multiplication: output_gate .* #1
388     Node* node = graph->NewNode();
389     node->operation.type = ToString(OperationType::MUL);
390     RETURN_IF_ERROR(graph->AddConsumer(node->id, activated_state->id));
391     RETURN_IF_ERROR(graph->AddConsumer(node->id, output_gate->id));
392     RETURN_IF_ERROR(graph->SetProducer(node->id, new_output_state->id));
393   }
394 
395   if (!has_projection) {
396     *output_state_new = new_output_state;
397     return absl::OkStatus();
398   }
399 
400   Value* projected_output_state = CreateNewSimilarValue(graph, output_state);
401   {
402     // #3 matrix multiplication: projection_weights * #2 + projection_bias
403     Node* node = graph->NewNode();
404 
405     RETURN_IF_ERROR(GetFullyConnectedNode(
406         tflite::ops::builtin::lstm::full::kProjectionWeightsTensor,
407         tflite::ops::builtin::lstm::full::kProjectionBiasTensor, reader, node));
408 
409     RETURN_IF_ERROR(graph->AddConsumer(node->id, new_output_state->id));
410     RETURN_IF_ERROR(graph->SetProducer(node->id, projected_output_state->id));
411   }
412 
413   if (proj_clip <= 0.0f) {
414     *output_state_new = projected_output_state;
415     return absl::OkStatus();
416   }
417 
418   Value* max_clipped_state =
419       CreateNewSimilarValue(graph, projected_output_state);
420   {
421     // #4 elementwise minimum: min(#3, clip)
422     Node* node = graph->NewNode();
423     node->operation.type = ToString(OperationType::MINIMUM);
424     ElementwiseAttributes attr;
425     attr.param = proj_clip;
426     node->operation.attributes = std::move(attr);
427     RETURN_IF_ERROR(graph->AddConsumer(node->id, projected_output_state->id));
428     RETURN_IF_ERROR(graph->SetProducer(node->id, max_clipped_state->id));
429   }
430   Value* clipped_output_state = CreateNewSimilarValue(graph, max_clipped_state);
431   {
432     // #5 elementwise maximum: max(#4, -clip)
433     Node* node = graph->NewNode();
434     node->operation.type = ToString(OperationType::MAXIMUM);
435     ElementwiseAttributes attr;
436     attr.param = -proj_clip;
437     node->operation.attributes = std::move(attr);
438     RETURN_IF_ERROR(graph->AddConsumer(node->id, max_clipped_state->id));
439     RETURN_IF_ERROR(graph->SetProducer(node->id, clipped_output_state->id));
440   }
441   *output_state_new = clipped_output_state;
442   return absl::OkStatus();
443 }
444 
445 }  // namespace
446 
447 // Build subgraph for a single LSTM OP.
448 // Returns a mapping for the used variable tensors' updated Values.
449 //
450 // High-level parameters:
451 //   - Has CIFG:
452 //       If false, calculate input_gate regularly.
453 //       If true, calculate input_gate to 1-forget_gate.
454 //   - Has peephole: see BuildLstmGate. Applies to all gates.
455 //   - Has normalization: see BuildLstmGate. Applies to all gates.
456 //   - Has projection, projection_bias, proj_clip: see BuildOutputStateUpdate
457 //   - Which activation to use:
458 //       Applies to only cell gate and output state update.
459 //       Other gates always use Sigmoid.
460 //
ParseLSTMAttributes(const TfLiteNode * tflite_node,const TfLiteRegistration * registration,GraphFloat32 * graph,ObjectReader * reader,const TfLiteLSTMParams * params,absl::flat_hash_map<int,ValueId> * new_variable_input_values)461 absl::Status ParseLSTMAttributes(
462     const TfLiteNode* tflite_node, const TfLiteRegistration* registration,
463     GraphFloat32* graph, ObjectReader* reader, const TfLiteLSTMParams* params,
464     absl::flat_hash_map<int, ValueId>* new_variable_input_values) {
465   const bool has_cifg = HasCifg(tflite_node);
466   const bool has_peephole = HasPeephole(tflite_node);
467   const bool has_normalization = HasNormalization(tflite_node);
468   const bool has_projection = HasProjection(tflite_node);
469 
470   Value* old_cell_state;
471   RETURN_IF_ERROR(reader->ReadValue(
472       tflite::ops::builtin::lstm::full::kCellStateTensor, &old_cell_state));
473 
474   if (old_cell_state->tensor.shape.b != 1) {
475     return absl::InvalidArgumentError(
476         "Batched execution is not supported for LSTM");
477   }
478 
479   Value* old_output_state;
480   RETURN_IF_ERROR(reader->ReadValue(
481       tflite::ops::builtin::lstm::full::kOutputStateTensor, &old_output_state));
482 
483   Value* forget_gate;
484   RETURN_IF_ERROR(BuildLstmGate(
485       graph, reader, old_output_state, old_cell_state,
486       tflite::ops::builtin::lstm::full::kInputToForgetWeightsTensor,
487       tflite::ops::builtin::lstm::full::kRecurrentToForgetWeightsTensor,
488       tflite::ops::builtin::lstm::full::kCellToForgetWeightsTensor,
489       tflite::ops::builtin::lstm::full::kForgetGateBiasTensor,
490       tflite::ops::builtin::lstm::full::kForgetLayerNormCoefficientsTensor,
491       kTfLiteActSigmoid, has_peephole, has_normalization, &forget_gate));
492 
493   Value* input_gate;
494   if (has_cifg) {
495     // When using cifg, input_gate is computed as (1 - forget_gate).
496     Node* node = graph->NewNode();
497     input_gate = CreateNewSimilarValue(graph, forget_gate);
498 
499     node->operation.type = ToString(OperationType::SUB);
500     ElementwiseAttributes attr;
501     attr.param = 1.0f;
502     attr.runtime_tensor_is_second = true;
503     node->operation.attributes = std::move(attr);
504     RETURN_IF_ERROR(graph->AddConsumer(node->id, forget_gate->id));
505     RETURN_IF_ERROR(graph->SetProducer(node->id, input_gate->id));
506   } else {
507     RETURN_IF_ERROR(BuildLstmGate(
508         graph, reader, old_output_state, old_cell_state,
509         tflite::ops::builtin::lstm::full::kInputToInputWeightsTensor,
510         tflite::ops::builtin::lstm::full::kRecurrentToInputWeightsTensor,
511         tflite::ops::builtin::lstm::full::kCellToInputWeightsTensor,
512         tflite::ops::builtin::lstm::full::kInputGateBiasTensor,
513         tflite::ops::builtin::lstm::full::kInputLayerNormCoefficientsTensor,
514         kTfLiteActSigmoid, has_peephole, has_normalization, &input_gate));
515   }
516 
517   // Cell state will not have peephole connections to itself
518   Value* cell_gate;
519   RETURN_IF_ERROR(BuildLstmGate(
520       graph, reader, old_output_state, old_cell_state,
521       tflite::ops::builtin::lstm::full::kInputToCellWeightsTensor,
522       tflite::ops::builtin::lstm::full::kRecurrentToCellWeightsTensor,
523       /*cell_weight_id=*/-1,
524       tflite::ops::builtin::lstm::full::kCellGateBiasTensor,
525       tflite::ops::builtin::lstm::full::kCellLayerNormCoefficientsTensor,
526       params->activation, /*has_peephole=*/false, has_normalization,
527       &cell_gate));
528 
529   Value* new_cell_state;
530   RETURN_IF_ERROR(BuildCellStateUpdate(graph, reader, forget_gate, input_gate,
531                                        cell_gate, params->cell_clip,
532                                        &new_cell_state));
533 
534   Value* output_gate;
535   RETURN_IF_ERROR(BuildLstmGate(
536       graph, reader, old_output_state, new_cell_state,
537       tflite::ops::builtin::lstm::full::kInputToOutputWeightsTensor,
538       tflite::ops::builtin::lstm::full::kRecurrentToOutputWeightsTensor,
539       tflite::ops::builtin::lstm::full::kCellToOutputWeightsTensor,
540       tflite::ops::builtin::lstm::full::kOutputGateBiasTensor,
541       tflite::ops::builtin::lstm::full::kOutputLayerNormCoefficientsTensor,
542       kTfLiteActSigmoid, has_peephole, has_normalization, &output_gate));
543 
544   Value* new_output_state;
545   RETURN_IF_ERROR(BuildOutputStateUpdate(graph, reader, old_output_state,
546                                          output_gate, new_cell_state,
547                                          params->activation, has_projection,
548                                          params->proj_clip, &new_output_state));
549 
550   {
551     // Copy updated output state to output.
552     Node* node = graph->NewNode();
553     node->operation.type = ToString(OperationType::COPY);
554     RETURN_IF_ERROR(graph->AddConsumer(node->id, new_output_state->id));
555     RETURN_IF_ERROR(reader->AddOutput(
556         node, tflite::ops::builtin::lstm::full::kOutputTensor));
557   }
558 
559   new_variable_input_values->clear();
560   new_variable_input_values->emplace(
561       tflite::ops::builtin::lstm::full::kCellStateTensor, new_cell_state->id);
562   new_variable_input_values->emplace(
563       tflite::ops::builtin::lstm::full::kOutputStateTensor,
564       new_output_state->id);
565   return absl::OkStatus();
566 }
567 
568 }  // namespace gpu
569 }  // namespace tflite
570