xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/RNN.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/native/DispatchStub.h>
5 
6 namespace at::native {
7 
8 using lstm_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&, TensorList, TensorList, bool, int64_t, double, bool, bool, bool);
9 using rnn_fn = void(*)(Tensor&, Tensor&, const Tensor&, const Tensor&, TensorList, bool, int64_t, double, bool, bool, bool);
10 using lstm_packed_fn = void(*)(Tensor&, Tensor&, Tensor&, const Tensor&, const Tensor&, TensorList, TensorList, bool, int64_t, double, bool, bool);
11 using rnn_packed_fn = void(*)(Tensor&, Tensor&, const Tensor&, const Tensor&, const Tensor&, TensorList, bool, int64_t, double, bool, bool);
12 
13 DECLARE_DISPATCH(lstm_fn, lstm_cudnn_stub);
14 DECLARE_DISPATCH(lstm_fn, lstm_miopen_stub);
15 DECLARE_DISPATCH(lstm_fn, lstm_mkldnn_stub);
16 DECLARE_DISPATCH(rnn_fn, gru_cudnn_stub);
17 DECLARE_DISPATCH(rnn_fn, gru_miopen_stub);
18 DECLARE_DISPATCH(rnn_fn, rnn_tanh_cudnn_stub);
19 DECLARE_DISPATCH(rnn_fn, rnn_tanh_miopen_stub);
20 DECLARE_DISPATCH(rnn_fn, rnn_relu_cudnn_stub);
21 DECLARE_DISPATCH(rnn_fn, rnn_relu_miopen_stub);
22 DECLARE_DISPATCH(lstm_packed_fn, lstm_packed_cudnn_stub);
23 DECLARE_DISPATCH(lstm_packed_fn, lstm_packed_miopen_stub);
24 DECLARE_DISPATCH(rnn_packed_fn, gru_packed_cudnn_stub);
25 DECLARE_DISPATCH(rnn_packed_fn, gru_packed_miopen_stub);
26 DECLARE_DISPATCH(rnn_packed_fn, rnn_tanh_packed_cudnn_stub);
27 DECLARE_DISPATCH(rnn_packed_fn, rnn_tanh_packed_miopen_stub);
28 DECLARE_DISPATCH(rnn_packed_fn, rnn_relu_packed_cudnn_stub);
29 DECLARE_DISPATCH(rnn_packed_fn, rnn_relu_packed_miopen_stub);
30 
31 inline void check_attributes(const Tensor& input, const TensorList& params, const TensorList& hiddens, bool check_dtype=false) {
32   auto input_device = input.device();
33   auto input_dtype = input.scalar_type();
34 
35   auto check_tensors = [&](const std::string& name, const Tensor& t) {
36     if (!t.defined()) return;
37     auto t_device = t.device();
38     TORCH_CHECK(input_device == t_device,
39              "Input and ", name, " tensors are not at the same device, found input tensor at ",
40              input_device, " and ", name, " tensor at ", t_device);
41     if (check_dtype) {
42       auto t_dtype = t.scalar_type();
43       TORCH_CHECK(input_dtype == t_dtype,
44                "Input and ", name, " tensors are not the same dtype, found input tensor with ",
45                input_dtype, " and ", name, " tensor with ", t_dtype);
46     }
47   };
48 
49   for (const auto& h : hiddens) check_tensors("hidden", h);
50   for (const auto& p : params) check_tensors("parameter", p);
51 }
52 
53 } // namespace at::native
54