#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else #include #include #include #include #include #include #endif namespace at::native { namespace { template inline scalar_t multilabel_margin_loss_forward_inner_sum_cpu( const scalar_t* input_data, const int64_t* target_data, scalar_t* is_target_data, int64_t dim) { using accscalar_t = at::acc_type; accscalar_t sum = 0; for (const auto ddt : c10::irange(dim)) { int64_t target_idx = target_data[ddt]; if (target_idx < 0) { break; } is_target_data[target_idx] = 1; } for (const auto dt : c10::irange(dim)) { int64_t target_idx = target_data[dt]; if (target_idx < 0) { break; } scalar_t input_target = input_data[target_idx]; for (const auto d : c10::irange(dim)) { if (!is_target_data[d]) { scalar_t z = 1 - input_target + input_data[d]; if (z > 0) { sum += z; } } } } return sum; } template static void multilabel_margin_loss_forward_out_frame( const Tensor& input_contiguous, const Tensor& target_contiguous, Tensor& output, Tensor& is_target, int64_t reduction, int64_t nframe, int64_t dim) { using accscalar_t = at::acc_type; const scalar_t* input_data = input_contiguous.const_data_ptr(); const int64_t* target_data = target_contiguous.const_data_ptr(); scalar_t* is_target_data = is_target.data_ptr(); if (reduction != Reduction::None || output.dim() == 0) { scalar_t* output_data = output.data_ptr(); accscalar_t sum = 0; for (C10_UNUSED const auto t : c10::irange(nframe)) { sum += multilabel_margin_loss_forward_inner_sum_cpu( input_data, target_data, is_target_data, dim); input_data += dim; target_data += dim; is_target_data += dim; } sum /= dim; if (reduction == Reduction::Mean) { sum /= nframe; } *output_data = sum; // write scalar output value } else { auto output_acc = output.accessor(); for (const auto t : c10::irange(nframe)) { scalar_t sum = multilabel_margin_loss_forward_inner_sum_cpu( input_data, target_data, is_target_data, dim); sum /= dim; output_acc[t] = sum; input_data += dim; target_data += dim; is_target_data += dim; } } } static void multilabel_margin_loss_forward_out_cpu_template( const Tensor& input, const Tensor& target, Tensor& output, Tensor& is_target, int64_t reduction) { #ifndef STRIP_ERROR_MESSAGES auto target_arg = TensorArg(target, "target", 2); #endif // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int64_t nframe, dim; const int64_t ndims = input.dim(); multilabel_margin_loss_shape_check(nframe, dim, ndims, input, target); // special case target.dim() <= 1: produce scalar output for scalar inputs // even if reduction == Reduction::None if (reduction != Reduction::None || target.dim() <= 1) { output.resize_({}); } else { output.resize_({nframe}); } is_target.resize_as_(target); TORCH_CHECK(is_target.is_contiguous(), "is_target must be contiguous"); is_target.zero_(); if (input.numel() == 0) { return; } TORCH_CHECK( target.min().item() >= -1, target_arg, " is out of range"); TORCH_CHECK( target.max().item() < dim, target_arg, " is out of range"); auto input_contiguous = input.contiguous(); auto target_contiguous = target.contiguous(); AT_DISPATCH_FLOATING_TYPES( input.scalar_type(), "multilabel_margin_loss_forward_out_frame", [&] { multilabel_margin_loss_forward_out_frame( input_contiguous, target_contiguous, output, is_target, reduction, nframe, dim); }); } template static void multilabel_margin_loss_backward_out_frame( Tensor& grad_input, const Tensor& grad_output, const Tensor& input_contiguous, const Tensor& target_contiguous, int64_t reduction, const Tensor& is_target_contiguous, int64_t nframe, int64_t dim) { #ifndef STRIP_ERROR_MESSAGES auto is_target_arg = TensorArg(is_target_contiguous, "is_target", 5); #endif TORCH_CHECK( is_target_contiguous.min().item() >= 0, is_target_arg, " is out of range"); TORCH_CHECK( is_target_contiguous.max().item() <= 1, is_target_arg, " is out of range"); const scalar_t* input_data = input_contiguous.const_data_ptr(); const int64_t* target_data = target_contiguous.const_data_ptr(); const scalar_t* is_target_data = is_target_contiguous.const_data_ptr(); scalar_t g = static_cast( // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) reduction == Reduction::Mean ? 1. / (nframe * dim) : 1. / dim); scalar_t* grad_input_row_data = grad_input.mutable_data_ptr(); for (C10_UNUSED const auto t : c10::irange(nframe)) { for (const auto dt : c10::irange(dim)) { int64_t target_idx = target_data[dt]; if (target_idx < 0) { break; } scalar_t input_target = input_data[target_idx]; for (const auto d : c10::irange(dim)) { if (!is_target_data[d]) { scalar_t z = 1 - input_target + input_data[d]; if (z > 0) { grad_input_row_data[target_idx] -= g; grad_input_row_data[d] += g; } } } } input_data += dim; target_data += dim; is_target_data += dim; grad_input_row_data += dim; } scalar_t* grad_input_data = grad_input.mutable_data_ptr(); if (reduction != Reduction::None || grad_output.dim() == 0) { assert( reduction != Reduction::None || grad_output.dim() > 0 || nframe == 1); const auto d = *grad_output.const_data_ptr(); for (int64_t t = 0; t < nframe * dim; t++) { grad_input_data[t] *= d; } } else { check_dim_size(grad_output, 1, 0, nframe); auto grad_output_acc = grad_output.accessor(); for (const auto t : c10::irange(nframe)) { for (const auto d : c10::irange(dim)) { grad_input_data[t * dim + d] *= grad_output_acc[t]; } } } } static void multilabel_margin_loss_backward_out_cpu_template( Tensor& grad_input, const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, const Tensor& is_target) { // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int64_t nframe, dim; CheckedFrom c = "multilabel_margin_loss_backward_cpu_template"; auto target_arg = TensorArg(target, "target", 3); auto is_target_arg = TensorArg(is_target, "is_target", 5); const int64_t ndims = input.dim(); multilabel_margin_loss_shape_check(nframe, dim, ndims, input, target); checkSameSize(c, target_arg, is_target_arg); grad_input.resize_as_(input); if (grad_input.numel() == 0) { return; } TORCH_CHECK(grad_input.is_contiguous(), "grad_input must be contiguous"); grad_input.zero_(); TORCH_CHECK( target.min().item() >= -1, target_arg, " is out of range"); TORCH_CHECK( target.max().item() < dim, target_arg, " is out of range"); auto input_contiguous = input.contiguous(); auto target_contiguous = target.contiguous(); auto is_target_contiguous = is_target.contiguous(); AT_DISPATCH_FLOATING_TYPES( input.scalar_type(), "multilabel_margin_loss_backward_out_frame", [&] { multilabel_margin_loss_backward_out_frame( grad_input, grad_output, input_contiguous, target_contiguous, reduction, is_target_contiguous, nframe, dim); }); } } // namespace std::tuple multilabel_margin_loss_forward_out_cpu(const Tensor& self, const Tensor& target, int64_t reduction, Tensor& output, Tensor& is_target) { multilabel_margin_loss_forward_out_cpu_template( self, target, output, is_target, reduction); return std::tuple(output, is_target); } std::tuple multilabel_margin_loss_forward_cpu( const Tensor& self, const Tensor& target, int64_t reduction) { auto output = at::empty({0}, self.options()); auto is_target = at::empty({0}, self.options()); at::native::multilabel_margin_loss_forward_out_cpu( self, target, reduction, output, is_target); return std::make_tuple(output, is_target); } Tensor& multilabel_margin_loss_backward_cpu_out(const Tensor& grad_output, const Tensor& self, const Tensor& target, int64_t reduction, const Tensor& is_target, Tensor& grad_input) { multilabel_margin_loss_backward_out_cpu_template( grad_input, grad_output, self, target, reduction, is_target); return grad_input; } Tensor multilabel_margin_loss_backward_cpu( const Tensor& grad_output, const Tensor& self, const Tensor& target, int64_t reduction, const Tensor& is_target) { auto grad_input = at::zeros_like(self, LEGACY_CONTIGUOUS_MEMORY_FORMAT); at::native::multilabel_margin_loss_backward_cpu_out( grad_output, self, target, reduction, is_target, grad_input); return grad_input; } Tensor & multilabel_margin_loss_out(const Tensor & self, const Tensor & target, int64_t reduction, Tensor & output) { Tensor is_target = at::empty({0}, self.options()); return std::get<0>(at::multilabel_margin_loss_forward_out(output, is_target, self, target, reduction)); } Tensor multilabel_margin_loss(const Tensor & self, const Tensor & target, int64_t reduction) { return std::get<0>(at::multilabel_margin_loss_forward(self, target, reduction)); } } // namespace at::native