#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else #include #include #include #include #include #include #endif namespace at::meta { TORCH_META_FUNC(tril)(const Tensor& self, int64_t k) { TORCH_CHECK(self.dim() >= 2, "tril: input tensor must have at least 2 dimensions") set_output_raw_strided(0, self.sizes(), {}, self.options()); } TORCH_META_FUNC(triu)(const Tensor& self, int64_t k) { TORCH_CHECK(self.dim() >= 2, "triu: input tensor must have at least 2 dimensions") set_output_raw_strided(0, self.sizes(), {}, self.options()); } } // namespace at::meta namespace at::native { namespace { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triu/tril ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template void apply_triu_tril_single( scalar_t* result, const scalar_t* self, bool inplace, int64_t k, int64_t n, int64_t m, int64_t res_row_stride, int64_t res_col_stride, int64_t self_row_stride, int64_t self_col_stride, bool upper) { constexpr int64_t zero = 0; if (upper) { parallel_for(0, n, 0, [&](int64_t start, int64_t end) { for (int64_t i : c10::irange(start, end)) { for (int64_t j = 0; j < std::min(m, i + k); j++) { result[i * res_row_stride + j * res_col_stride] = static_cast(0); } if (!inplace) { // copy the rest of the self if not inplace for (int64_t j = std::max(zero, i + k); j < m; j++) { result[i * res_row_stride + j * res_col_stride] = self[i * self_row_stride + j * self_col_stride]; } } } }); } else { parallel_for(0, n, 0, [&](int64_t start, int64_t end) { for (int64_t i : c10::irange(start, end)) { for (int64_t j = std::max(zero, i + k + 1); j < m; j++) { result[i * res_row_stride + j * res_col_stride] = static_cast(0); } if (!inplace) { // copy the rest of the self if not inplace for (int64_t j = zero; j < std::min(m, i + k + 1); j++) { result[i * res_row_stride + j * res_col_stride] = self[i * self_row_stride + j * self_col_stride]; } } } }); } } template void apply_triu_tril(const Tensor& result, const Tensor& self, bool inplace, int64_t k, bool upper) { auto n = self.size(-2); auto m = self.size(-1); auto self_data = self.const_data_ptr(); auto self_stride = (self.dim() > 2 && self.stride(-3) > 0) ? self.stride(-3) : 1; auto batchsize = batchCountTrilTriu(result); auto self_row_stride = self.stride(-2); auto self_col_stride = self.stride(-1); auto result_data = result.data_ptr(); // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int64_t result_stride, result_row_stride, result_col_stride; if (result_data != self_data) { result_stride = (result.dim() > 2 && result.stride(-3) > 0) ? result.stride(-3) : 1; result_row_stride = result.stride(-2); result_col_stride = result.stride(-1); } else { result_stride = self_stride; result_row_stride = self_row_stride; result_col_stride = self_col_stride; } parallel_for(0, batchsize, 0, [&](int64_t start, int64_t end) { for (const auto b : c10::irange(start, end)) { const scalar_t* self_batch = &self_data[b * self_stride]; scalar_t* result_batch = &result_data[b * result_stride]; apply_triu_tril_single( result_batch, self_batch, inplace, k, n, m, result_row_stride, result_col_stride, self_row_stride, self_col_stride, upper); } }); } struct UpperTriangle { static constexpr const char* op_name = "triu"; static constexpr bool upper = true; }; struct LowerTriangle { static constexpr const char *op_name = "tril"; static constexpr bool upper = false; }; template void compute_triu_tril(const Tensor& self, int64_t k, const Tensor &result) { if (self.numel() == 0) { return; } bool inplace_op = self.is_same(result); bool inplace_update = false; Tensor self_c; std::tie(inplace_update, self_c) = checkTrilTriuBatchContiguous(self, inplace_op); Tensor result_c; if (inplace_op && !inplace_update) { result_c = at::empty_like(result, LEGACY_CONTIGUOUS_MEMORY_FORMAT); } else { result_c = result; } AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( ScalarType::ComplexHalf, ScalarType::BFloat16, ScalarType::Half, ScalarType::Bool, self.scalar_type(), Triangle::op_name, [&]{ apply_triu_tril(result_c, self_c, inplace_op && inplace_update, k, Triangle::upper); }); if (inplace_op && !inplace_update) { result.copy_(result_c); } } } // namespace TORCH_IMPL_FUNC(tril_cpu)(const Tensor& self, int64_t k, const Tensor &result) { compute_triu_tril(self, k, result); } TORCH_IMPL_FUNC(triu_cpu)(const Tensor& self, int64_t k, const Tensor &result) { compute_triu_tril(self, k, result); } Tensor trace_backward_symint(const Tensor& grad, c10::SymIntArrayRef sizes) { if (sizes.size() != 2) { throw std::runtime_error("expected matrix input"); } auto grad_input = at::zeros_symint(sizes[0] * sizes[1], grad.options()); auto indices = at::arange(0, grad_input.numel(), sizes[1] + 1, grad.options().dtype(at::kLong)); // for composite compliance, use out-of-place variant of // `index_fill` if grad tensor is a Tensor Subclass. if (isTensorSubclassLike(grad)) { grad_input = grad_input.index_fill(0, indices, grad); } else { grad_input.index_fill_(0, indices, grad); } return grad_input.view_symint(sizes); } } // namespace at::native