Home
last modified time | relevance | path

Searched full:dropout_state (Results 1 – 10 of 10) sorted by relevance

/aosp_15_r20/external/pytorch/torch/backends/cudnn/
H A Drnn.py45 def init_dropout_state(dropout, train, dropout_seed, dropout_state): argument
48 if (dropout_desc_name not in dropout_state) or (
49 dropout_state[dropout_desc_name].get() is None
52 dropout_state[dropout_desc_name] = Unserializable(None)
54 dropout_state[dropout_desc_name] = Unserializable(
63 dropout_ts = dropout_state[dropout_desc_name].get()
/aosp_15_r20/external/pytorch/aten/src/ATen/native/miopen/
H A DRNN_miopen.cpp772 const Tensor& dropout_state = c10::value_or_else(dropout_state_opt, [] {return Tensor();}); in miopen_rnn_backward() local
781 … num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, {out… in miopen_rnn_backward()
784 …, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, ws); in miopen_rnn_backward()
832 Tensor dropout_state = at::empty({0}, input.options()); in _miopen_impl() local
837 dropout_p, train, bidirectional, batch_sizes, dropout_state); in _miopen_impl()
851 Tensor dropout_state = at::empty({0}, input.options()); in _miopen_impl() local
856 train, bidirectional, /*batch_sizes=*/{}, dropout_state); in _miopen_impl()
/aosp_15_r20/external/pytorch/aten/src/ATen/native/cudnn/
H A DRNN.cpp127 Tensor dropout_state; member
132 dropout_state = dropout_state_; in set()
140 dropout_desc.set(handle, dropout_p, dropout_state); in descriptor()
2121 const Tensor& dropout_state = in _cudnn_rnn_backward() local
2161 dropout_state, in _cudnn_rnn_backward()
2183 dropout_state, in _cudnn_rnn_backward()
2537 auto& dropout_state = get_dropout_state(dropout_p, train, input.options()); in _cudnn_impl() local
2538 std::unique_lock<DropoutState> lock{dropout_state}; in _cudnn_impl()
2563 dropout_state.buffer); in _cudnn_impl()
2601 auto& dropout_state = get_dropout_state(dropout_p, train, input.options()); in _cudnn_impl() local
[all …]
/aosp_15_r20/external/pytorch/aten/src/ATen/cudnn/
H A DAutocastRNN.cpp37 const std::optional<Tensor>& dropout_state) { in _cudnn_rnn_cast_reflatten() argument
114 dropout_state); in _cudnn_rnn_cast_reflatten()
/aosp_15_r20/external/pytorch/tools/autograd/
H A Dderivatives.yaml2707 …ropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state) -> (Tensor, T…
2708 dropout_state: non_differentiable
2710 … num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, retain_variab…
2712 …ropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserv…
2713 dropout_state: non_differentiable
2743 …t dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state) -> (Tensor, T…
2744 dropout_state: non_differentiable
2746 … num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, retain_variab…
2748 …t dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserv…
2749 dropout_state: non_differentiable
/aosp_15_r20/external/pytorch/torch/csrc/inductor/aoti_torch/generated/
H A Dc_shim_cuda.h21 …nst int64_t* batch_sizes, int64_t batch_sizes_len_, AtenTensorHandle* dropout_state, AtenTensorHan…
/aosp_15_r20/external/pytorch/aten/src/ATen/native/
H A DRNN.cpp72 bool use_miopen(const at::Tensor& input, const double dropout_state) { in use_miopen() argument
H A Dnative_functions.yaml247 …ropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state) -> (Tensor, T…
256 …ropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserv…
4081 …t dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state) -> (Tensor, T…
4088 …t dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserv…
/aosp_15_r20/external/pytorch/torch/
H A Doverrides.py799 … num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state: -1 # noqa: B…
H A D_meta_registrations.py5933 dropout_state, argument