// Copyright 2004-present Facebook. All Rights Reserved. #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else #include #include #endif namespace at::native { namespace { template std::tuple _rowwise_prune_helper( const Tensor& weights, const Tensor& mask, ScalarType compressed_indices_dtype) { int num_non_masked_rows = 0; auto mask_contig = mask.contiguous(); auto mask_data = mask_contig.data_ptr(); for (const auto i : c10::irange(mask.numel())) { num_non_masked_rows += (((mask_data[i] == true)) ? 1 : 0); } int num_cols = weights.size(1); auto pruned_2d_tensor = at::empty({num_non_masked_rows, num_cols}, weights.options()); auto compressed_indices_mapping = at::empty({mask.numel()}, compressed_indices_dtype); AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, weights.scalar_type(), "rowwise_prune_helper", [&]() { auto* pruned_2d_tensor_data = pruned_2d_tensor.data_ptr(); auto compressed_indices_mapping_data = compressed_indices_mapping.data_ptr(); auto weights_data = weights.data_ptr(); int last_row_kept = 0; for (const auto i : c10::irange(mask.numel())) { if (mask_data[i]) { memcpy(pruned_2d_tensor_data + last_row_kept * num_cols, weights_data + i * num_cols, num_cols * sizeof (scalar_t)); compressed_indices_mapping_data[i] = last_row_kept; last_row_kept++; } else { compressed_indices_mapping_data[i] = -1; } } }); return std::tuple(pruned_2d_tensor, compressed_indices_mapping); } } // namespace // This operator introduces sparsity to the 'weights' matrix with the help // of the importance indicator 'mask'. // // A row is considered important and not pruned if the mask value for that // particular row is 1(True) and not important otherwise. // // This operator doesn't zero out the pruned rows in-place. Instead, it // returns a tuple that contains a pruned weights tensor as well as a map that // can be used to look up the original row in the pruned weights tensor. // We refer this map as 'compressed indices map' going forward. // The 'compressed indices map' is an 1D tensor that contains one entry per // original row in 'weights'. The array index is the index for the original // non-pruned weight tensor and the value would be the re-mapped index in the // pruned weights tensor. If the value for a index is -1, it means the // corresponding row has been pruned from the original weight tensor. // Arguments: // 'weights' - two dimensional matrix that needs to be prune. // 'mask' - 1D boolean tensor that represents whether a row is important or // not. A mask value of 1 means the row should be kept and 0 means the row // should be pruned. // // Returns: // A tuple containing two tensors, // 1. A pruned weight tensor that contains only the weights that are preserved // post pruning. // 2. An 1D tensor that contains the mapping between original weight row and // the corresponding row in the pruned weights tensor. std::tuple _rowwise_prune(const Tensor& weights, const Tensor& mask, ScalarType compressed_indices_dtype) { TORCH_CHECK(weights.ndimension() == 2, "'weights' should have 2 dimensions."); TORCH_CHECK( mask.numel() == weights.size(0), "Number of elements in 'mask' should be equivalent to the " "number of rows in 'weights'." ) TORCH_CHECK( compressed_indices_dtype == ScalarType::Int || compressed_indices_dtype == ScalarType::Long, "compressed_indices_dtype should be either int(int32) or long(int64)."); if (compressed_indices_dtype == at::ScalarType::Int) { return _rowwise_prune_helper(weights, mask, compressed_indices_dtype); } return _rowwise_prune_helper(weights, mask, compressed_indices_dtype); } } // namespace at::native