#include #include #include #include namespace at { namespace autocast { /********************************************************************** Autocast wrapper for MIOpen RNNs **********************************************************************/ std::tuple miopen_rnn(const Tensor & input_r, TensorList weight, int64_t weight_stride0, const Tensor & hx, const std::optional& cx_opt, int64_t fn_mode, int64_t fn_hidden_size, int64_t fn_num_layers, bool batch_first, double fn_dropout, bool fn_train, bool fn_bidirectional, IntArrayRef fn_batch_sizes, const std::optional& fn_dropout_state_opt) { #if AT_ROCM_ENABLED() c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast); return at::miopen_rnn( cached_cast(at::kHalf, input_r), cached_cast(at::kHalf, weight), weight_stride0, cached_cast(at::kHalf, hx), cached_cast(at::kHalf, cx_opt), fn_mode, fn_hidden_size, fn_num_layers, batch_first, fn_dropout, fn_train, fn_bidirectional, fn_batch_sizes, fn_dropout_state_opt); #else AT_ERROR("autocast::miopen_rnn: ATen not compiled with ROCm enabled"); return {Tensor{}, Tensor{}, Tensor{}, Tensor{}, Tensor{}}; // placate the compiler #endif } // Register Autocast dispatch namespace { TORCH_LIBRARY_IMPL(aten, Autocast, m) { m.impl("miopen_rnn", TORCH_FN((&at::autocast::miopen_rnn))); } } // anonymous namespace } // namespace autocast } // namespace at