#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else #include #include #endif namespace at::native { void _fused_sgd_kernel_cpu_( at::TensorList params, at::TensorList grads, at::TensorList momentum_buffer_list, const double weight_decay, const double momentum, const double lr, const double dampening, const bool nesterov, const bool maximize, const bool is_first_step, const std::optional& grad_scale, const std::optional& found_inf) { const float* grad_scale_ptr = grad_scale.has_value() ? grad_scale->data_ptr() : nullptr; const float* found_inf_ptr = found_inf.has_value() ? found_inf->data_ptr() : nullptr; if (found_inf_ptr && *found_inf_ptr == 1.0) { return; } size_t n_tensors = params.size(); TORCH_CHECK(grads.size() == n_tensors); bool no_momentum_buffer = momentum == 0.0; if (no_momentum_buffer) { TORCH_CHECK(momentum_buffer_list.empty()); } else { TORCH_CHECK(momentum_buffer_list.size() == n_tensors); } for (size_t i = 0; i < n_tensors; i++){ fused_sgd_stub( kCPU, params[i], grads[i], no_momentum_buffer ? Tensor() : momentum_buffer_list[i], weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale_ptr); } } void _fused_sgd_kernel_cpu_( at::TensorList params, at::TensorList grads, at::TensorList momentum_buffer_list, const double weight_decay, const double momentum, const at::Tensor& lr, const double dampening, const bool nesterov, const bool maximize, const bool is_first_step, const std::optional& grad_scale, const std::optional& found_inf) { _fused_sgd_kernel_cpu_( params, grads, momentum_buffer_list, weight_decay, momentum, lr.item(), dampening, nesterov, maximize, is_first_step, grad_scale, found_inf ); } DEFINE_DISPATCH(fused_sgd_stub); }