xref: /aosp_15_r20/external/tensorflow/tensorflow/core/ops/cudnn_rnn_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 #include "tensorflow/core/lib/strings/strcat.h"
20 
21 namespace tensorflow {
22 namespace {
23 
24 constexpr auto kRNNModeAttrs =
25     "rnn_mode: {'rnn_relu', 'rnn_tanh', 'lstm', 'gru'} = 'lstm'";
26 
27 constexpr auto kRNNInputModeAttrs =
28     "input_mode: {'linear_input', 'skip_input', 'auto_select'} = "
29     "'linear_input'";
30 
31 constexpr auto kRNNDirectionAttrs =
32     "direction: {'unidirectional', 'bidirectional'} = 'unidirectional'";
33 
34 }  // namespace
35 
36 using shape_inference::DimensionHandle;
37 using shape_inference::InferenceContext;
38 using shape_inference::ShapeHandle;
39 
40 REGISTER_OP("CudnnRNNParamsSize")
41     .Input("num_layers: int32")
42     .Input("num_units: int32")
43     .Input("input_size: int32")
44     .Attr("T: {float16, float32, float64}")
45     .Attr("S: {int32, int64}")
46     .Attr(kRNNModeAttrs)
47     .Attr(kRNNInputModeAttrs)
48     .Attr(kRNNDirectionAttrs)
49     .Attr("dropout: float = 0.0")
50     .Attr("seed: int = 0")
51     .Attr("seed2: int = 0")
52     .Attr("num_proj: int = 0")
53     .Output("params_size: S")
__anonb0821bee0202(InferenceContext* c) 54     .SetShapeFn([](InferenceContext* c) {
55       ShapeHandle unused;
56       // num_layers, num_units, and input_size should be scalars.
57       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
58       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
59       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
60 
61       c->set_output(0, c->Vector(1));
62       return OkStatus();
63     });
64 
65 REGISTER_OP("CudnnRNN")
66     .Input("input: T")
67     .Input("input_h: T")
68     .Input("input_c: T")
69     .Input("params: T")
70     .SetIsStateful()
71     .Output("output: T")
72     .Output("output_h: T")
73     .Output("output_c: T")
74     .Output("reserve_space: T")
75     .Attr("T: {float16, float32, float64}")
76     .Attr(kRNNModeAttrs)
77     .Attr(kRNNInputModeAttrs)
78     .Attr(kRNNDirectionAttrs)
79     .Attr("dropout: float = 0.0")
80     .Attr("seed: int = 0")
81     .Attr("seed2: int = 0")
82     .Attr("is_training: bool = true")
__anonb0821bee0302(InferenceContext* c) 83     .SetShapeFn([](InferenceContext* c) {
84       ShapeHandle unused;
85       auto input_shape = c->input(0);
86       auto input_h_shape = c->input(1);
87       TF_RETURN_IF_ERROR(c->WithRank(input_shape, 3, &unused));
88       TF_RETURN_IF_ERROR(c->WithRank(input_h_shape, 3, &unused));
89       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused));
90 
91       auto seq_length = c->Dim(input_shape, 0);
92       auto batch_size = c->Dim(input_shape, 1);
93       auto num_units = c->Dim(input_h_shape, 2);
94 
95       string direction;
96       TF_RETURN_IF_ERROR(c->GetAttr("direction", &direction));
97       string rnn_mode;
98       TF_RETURN_IF_ERROR(c->GetAttr("rnn_mode", &rnn_mode));
99       int dir_count = (direction == "bidirectional") ? 2 : 1;
100       DimensionHandle output_size;
101       TF_RETURN_IF_ERROR(c->Multiply(num_units, dir_count, &output_size));
102       auto output_shape = c->MakeShape({seq_length, batch_size, output_size});
103       auto output_h_shape = input_h_shape;
104       auto output_c_shape TF_ATTRIBUTE_UNUSED =
105           (rnn_mode == "lstm") ? output_h_shape : c->MakeShape({});
106       c->set_output(0, output_shape);
107       c->set_output(1, output_h_shape);
108       c->set_output(2, output_c_shape);
109       c->set_output(3, c->UnknownShape());
110       return OkStatus();
111     });
112 
113 REGISTER_OP("CudnnRNNV2")
114     .Input("input: T")
115     .Input("input_h: T")
116     .Input("input_c: T")
117     .Input("params: T")
118     .SetIsStateful()
119     .Output("output: T")
120     .Output("output_h: T")
121     .Output("output_c: T")
122     .Output("reserve_space: T")
123     .Output("host_reserved: int8")
124     .Attr("T: {float16, float32, float64}")
125     .Attr(kRNNModeAttrs)
126     .Attr(kRNNInputModeAttrs)
127     .Attr(kRNNDirectionAttrs)
128     .Attr("dropout: float = 0.0")
129     .Attr("seed: int = 0")
130     .Attr("seed2: int = 0")
131     .Attr("is_training: bool = true")
__anonb0821bee0402(InferenceContext* c) 132     .SetShapeFn([](InferenceContext* c) {
133       ShapeHandle unused;
134       auto input_shape = c->input(0);
135       auto input_h_shape = c->input(1);
136       TF_RETURN_IF_ERROR(c->WithRank(input_shape, 3, &unused));
137       TF_RETURN_IF_ERROR(c->WithRank(input_h_shape, 3, &unused));
138       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused));
139 
140       auto seq_length = c->Dim(input_shape, 0);
141       auto batch_size = c->Dim(input_shape, 1);
142       auto num_units = c->Dim(input_h_shape, 2);
143       string direction;
144       TF_RETURN_IF_ERROR(c->GetAttr("direction", &direction));
145       string rnn_mode;
146       TF_RETURN_IF_ERROR(c->GetAttr("rnn_mode", &rnn_mode));
147       int dir_count = (direction == "bidirectional") ? 2 : 1;
148       DimensionHandle output_size;
149       TF_RETURN_IF_ERROR(c->Multiply(num_units, dir_count, &output_size));
150       auto output_shape = c->MakeShape({seq_length, batch_size, output_size});
151       auto output_h_shape = input_h_shape;
152       auto output_c_shape TF_ATTRIBUTE_UNUSED =
153           (rnn_mode == "lstm") ? output_h_shape : c->MakeShape({});
154       c->set_output(0, output_shape);
155       c->set_output(1, output_h_shape);
156       c->set_output(2, output_c_shape);
157       c->set_output(3, c->UnknownShape());
158       c->set_output(4, c->UnknownShape());
159       return OkStatus();
160     });
161 
162 REGISTER_OP("CudnnRNNV3")
163     .Input("input: T")
164     .Input("input_h: T")
165     .Input("input_c: T")
166     .Input("params: T")
167     .Input("sequence_lengths: int32")
168     .SetIsStateful()
169     .Output("output: T")
170     .Output("output_h: T")
171     .Output("output_c: T")
172     .Output("reserve_space: T")
173     .Output("host_reserved: int8")
174     .Attr("T: {float16, float32, float64}")
175     .Attr(kRNNModeAttrs)
176     .Attr(kRNNInputModeAttrs)
177     .Attr(kRNNDirectionAttrs)
178     .Attr("dropout: float = 0.0")
179     .Attr("seed: int = 0")
180     .Attr("seed2: int = 0")
181     .Attr("num_proj: int = 0")
182     .Attr("is_training: bool = true")
183     .Attr("time_major: bool = true")
__anonb0821bee0502(InferenceContext* c) 184     .SetShapeFn([](InferenceContext* c) {
185       ShapeHandle unused;
186       auto input_shape = c->input(0);
187       auto input_h_shape = c->input(1);
188       auto input_c_shape = c->input(2);
189       TF_RETURN_IF_ERROR(c->WithRank(input_shape, 3, &unused));
190       TF_RETURN_IF_ERROR(c->WithRank(input_h_shape, 3, &unused));
191       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused));
192       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &unused));
193 
194       auto max_seq_length = c->Dim(input_shape, 0);
195       auto batch_size = c->Dim(input_shape, 1);
196       auto num_units = c->Dim(input_h_shape, 2);
197 
198       string direction;
199       TF_RETURN_IF_ERROR(c->GetAttr("direction", &direction));
200       string rnn_mode;
201       TF_RETURN_IF_ERROR(c->GetAttr("rnn_mode", &rnn_mode));
202       if (rnn_mode == "lstm") {
203         TF_RETURN_IF_ERROR(c->WithRank(input_c_shape, 3, &unused));
204       }
205       int dir_count = (direction == "bidirectional") ? 2 : 1;
206       DimensionHandle output_size;
207       TF_RETURN_IF_ERROR(c->Multiply(num_units, dir_count, &output_size));
208       auto output_shape =
209           c->MakeShape({max_seq_length, batch_size, output_size});
210       auto output_h_shape = input_h_shape;
211       auto output_c_shape TF_ATTRIBUTE_UNUSED =
212           (rnn_mode == "lstm") ? input_c_shape : c->MakeShape({});
213       c->set_output(0, output_shape);
214       c->set_output(1, output_h_shape);
215       c->set_output(2, output_c_shape);
216       c->set_output(3, c->UnknownShape());
217       c->set_output(4, c->UnknownShape());
218       return OkStatus();
219     });
220 
221 REGISTER_OP("CudnnRNNBackprop")
222     .Input("input: T")
223     .Input("input_h: T")
224     .Input("input_c: T")
225     .Input("params: T")
226     .Input("output: T")
227     .Input("output_h: T")
228     .Input("output_c: T")
229     .Input("output_backprop: T")
230     .Input("output_h_backprop: T")
231     .Input("output_c_backprop: T")
232     .Input("reserve_space: T")
233     .SetIsStateful()
234     .Output("input_backprop: T")
235     .Output("input_h_backprop: T")
236     .Output("input_c_backprop: T")
237     .Output("params_backprop: T")
238     .Attr("T: {float16, float32, float64}")
239     .Attr(kRNNModeAttrs)
240     .Attr(kRNNInputModeAttrs)
241     .Attr(kRNNDirectionAttrs)
242     .Attr("dropout: float = 0.0")
243     .Attr("seed: int = 0")
244     .Attr("seed2: int = 0")
__anonb0821bee0602(InferenceContext* c) 245     .SetShapeFn([](InferenceContext* c) {
246       auto input_shape = c->input(0);
247       auto input_h_shape = c->input(1);
248       auto input_c_shape = c->input(2);
249       auto params_shape = c->input(3);
250       c->set_output(0, input_shape);
251       c->set_output(1, input_h_shape);
252       c->set_output(2, input_c_shape);
253       c->set_output(3, params_shape);
254       return OkStatus();
255     });
256 
257 REGISTER_OP("CudnnRNNBackpropV2")
258     .Input("input: T")
259     .Input("input_h: T")
260     .Input("input_c: T")
261     .Input("params: T")
262     .Input("output: T")
263     .Input("output_h: T")
264     .Input("output_c: T")
265     .Input("output_backprop: T")
266     .Input("output_h_backprop: T")
267     .Input("output_c_backprop: T")
268     .Input("reserve_space: T")
269     .Input("host_reserved: int8")
270     .SetIsStateful()
271     .Output("input_backprop: T")
272     .Output("input_h_backprop: T")
273     .Output("input_c_backprop: T")
274     .Output("params_backprop: T")
275     .Attr("T: {float16, float32, float64}")
276     .Attr(kRNNModeAttrs)
277     .Attr(kRNNInputModeAttrs)
278     .Attr(kRNNDirectionAttrs)
279     .Attr("dropout: float = 0.0")
280     .Attr("seed: int = 0")
281     .Attr("seed2: int = 0")
__anonb0821bee0702(InferenceContext* c) 282     .SetShapeFn([](InferenceContext* c) {
283       auto input_shape = c->input(0);
284       auto input_h_shape = c->input(1);
285       auto input_c_shape = c->input(2);
286       auto params_shape = c->input(3);
287       c->set_output(0, input_shape);
288       c->set_output(1, input_h_shape);
289       c->set_output(2, input_c_shape);
290       c->set_output(3, params_shape);
291       return OkStatus();
292     });
293 
294 REGISTER_OP("CudnnRNNBackpropV3")
295     .Input("input: T")
296     .Input("input_h: T")
297     .Input("input_c: T")
298     .Input("params: T")
299     .Input("sequence_lengths: int32")
300     .Input("output: T")
301     .Input("output_h: T")
302     .Input("output_c: T")
303     .Input("output_backprop: T")
304     .Input("output_h_backprop: T")
305     .Input("output_c_backprop: T")
306     .Input("reserve_space: T")
307     .Input("host_reserved: int8")
308     .SetIsStateful()
309     .Output("input_backprop: T")
310     .Output("input_h_backprop: T")
311     .Output("input_c_backprop: T")
312     .Output("params_backprop: T")
313     .Attr("T: {float16, float32, float64}")
314     .Attr(kRNNModeAttrs)
315     .Attr(kRNNInputModeAttrs)
316     .Attr(kRNNDirectionAttrs)
317     .Attr("dropout: float = 0.0")
318     .Attr("seed: int = 0")
319     .Attr("seed2: int = 0")
320     .Attr("num_proj: int = 0")
321     .Attr("time_major: bool = true")
__anonb0821bee0802(InferenceContext* c) 322     .SetShapeFn([](InferenceContext* c) {
323       auto input_shape = c->input(0);
324       auto input_h_shape = c->input(1);
325       auto input_c_shape = c->input(2);
326       auto params_shape = c->input(3);
327       c->set_output(0, input_shape);
328       c->set_output(1, input_h_shape);
329       c->set_output(2, input_c_shape);
330       c->set_output(3, params_shape);
331       return OkStatus();
332     });
333 
334 REGISTER_OP("CudnnRNNParamsToCanonical")
335     .Input("num_layers: int32")
336     .Input("num_units: int32")
337     .Input("input_size: int32")
338     .Input("params: T")
339     .Output("weights: num_params * T")
340     .Output("biases: num_params * T")
341     .Attr("T: {float16, float32, float64}")
342     .Attr("num_params: int")
343     .Attr(kRNNModeAttrs)
344     .Attr(kRNNInputModeAttrs)
345     .Attr(kRNNDirectionAttrs)
346     .Attr("dropout: float = 0.0")
347     .Attr("seed: int = 0")
348     .Attr("seed2: int = 0")
__anonb0821bee0902(InferenceContext* c) 349     .SetShapeFn([](InferenceContext* c) {
350       ShapeHandle unused;
351       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused));
352       int num_params;
353       TF_RETURN_IF_ERROR(c->GetAttr("num_params", &num_params));
354       // Set shape for weight matrices
355       for (int i = 0; i < num_params; i++) {
356         c->set_output(i, c->Matrix(InferenceContext::kUnknownDim,
357                                    InferenceContext::kUnknownDim));
358       }
359       // Set shape for bias vectors
360       for (int i = 0; i < num_params; i++) {
361         c->set_output(num_params + i, c->Vector(InferenceContext::kUnknownDim));
362       }
363       return OkStatus();
364     });
365 
366 REGISTER_OP("CudnnRNNParamsToCanonicalV2")
367     .Input("num_layers: int32")
368     .Input("num_units: int32")
369     .Input("input_size: int32")
370     .Input("params: T")
371     .Output("weights: num_params_weights * T")
372     .Output("biases: num_params_biases * T")
373     .Attr("T: {float16, float32, float64}")
374     .Attr("num_params_weights: int")
375     .Attr("num_params_biases: int")
376     .Attr(kRNNModeAttrs)
377     .Attr(kRNNInputModeAttrs)
378     .Attr(kRNNDirectionAttrs)
379     .Attr("dropout: float = 0.0")
380     .Attr("seed: int = 0")
381     .Attr("seed2: int = 0")
382     .Attr("num_proj: int = 0")
__anonb0821bee0a02(InferenceContext* c) 383     .SetShapeFn([](InferenceContext* c) {
384       ShapeHandle unused;
385       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &unused));
386       int num_params_weights;
387       int num_params_biases;
388       TF_RETURN_IF_ERROR(c->GetAttr("num_params_weights", &num_params_weights));
389       TF_RETURN_IF_ERROR(c->GetAttr("num_params_biases", &num_params_biases));
390       // Set shape for weight matrices
391       for (int i = 0; i < num_params_weights; i++) {
392         c->set_output(i, c->Matrix(InferenceContext::kUnknownDim,
393                                    InferenceContext::kUnknownDim));
394       }
395       // Set shape for bias vectors
396       for (int i = 0; i < num_params_biases; i++) {
397         c->set_output(num_params_weights + i,
398                       c->Vector(InferenceContext::kUnknownDim));
399       }
400       return OkStatus();
401     });
402 
403 REGISTER_OP("CudnnRNNCanonicalToParams")
404     .Input("num_layers: int32")
405     .Input("num_units: int32")
406     .Input("input_size: int32")
407     .Input("weights: num_params * T")
408     .Input("biases: num_params * T")
409     .Output("params: T")
410     .Attr("T: {float16, float32, float64}")
411     .Attr("num_params: int")
412     .Attr(kRNNModeAttrs)
413     .Attr(kRNNInputModeAttrs)
414     .Attr(kRNNDirectionAttrs)
415     .Attr("dropout: float = 0.0")
416     .Attr("seed: int = 0")
417     .Attr("seed2: int = 0")
__anonb0821bee0b02(InferenceContext* c) 418     .SetShapeFn([](InferenceContext* c) {
419       c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
420       return OkStatus();
421     });
422 
423 REGISTER_OP("CudnnRNNCanonicalToParamsV2")
424     .Input("num_layers: int32")
425     .Input("num_units: int32")
426     .Input("input_size: int32")
427     .Input("weights: num_params_weights * T")
428     .Input("biases: num_params_biases * T")
429     .Output("params: T")
430     .Attr("T: {float16, float32, float64}")
431     .Attr("num_params_weights: int")
432     .Attr("num_params_biases: int")
433     .Attr(kRNNModeAttrs)
434     .Attr(kRNNInputModeAttrs)
435     .Attr(kRNNDirectionAttrs)
436     .Attr("dropout: float = 0.0")
437     .Attr("seed: int = 0")
438     .Attr("seed2: int = 0")
439     .Attr("num_proj: int = 0")
__anonb0821bee0c02(InferenceContext* c) 440     .SetShapeFn([](InferenceContext* c) {
441       c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
442       return OkStatus();
443     });
444 
445 }  // namespace tensorflow
446