1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // Implements the kernel for the CSRSoftmax op, which performs softmax 17 // along the innermost (col) dimension of a CSRSparseMatrix object 18 // stored in a DT_VARIANT. 19 20 #define EIGEN_USE_THREADS 21 22 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 23 #include "tensorflow/core/util/cuda_sparse.h" 24 #define EIGEN_USE_GPU 25 #endif 26 27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 28 #include "tensorflow/core/framework/op.h" 29 #include "tensorflow/core/framework/op_kernel.h" 30 #include "tensorflow/core/framework/tensor_types.h" 31 #include "tensorflow/core/framework/variant_op_registry.h" 32 #include "tensorflow/core/kernels/dense_update_functor.h" 33 #include "tensorflow/core/kernels/fill_functor.h" 34 #include "tensorflow/core/kernels/slice_op.h" 35 #include "tensorflow/core/kernels/sparse/kernels.h" 36 #include "tensorflow/core/kernels/sparse/sparse_matrix.h" 37 38 namespace tensorflow { 39 40 typedef Eigen::ThreadPoolDevice CPUDevice; 41 typedef Eigen::GpuDevice GPUDevice; 42 43 template <typename Device, typename T> 44 class CSRSoftmaxOp : public OpKernel { 45 public: CSRSoftmaxOp(OpKernelConstruction * ctx)46 explicit CSRSoftmaxOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 47 Compute(OpKernelContext * ctx)48 void Compute(OpKernelContext* ctx) override { 49 const CSRSparseMatrix* logits_matrix; 50 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &logits_matrix)); 51 OP_REQUIRES( 52 ctx, logits_matrix->dtype() == DataTypeToEnum<T>::value, 53 errors::InvalidArgument("dtype of logits is not equal to 'type': ", 54 DataTypeString(logits_matrix->dtype()), " vs. ", 55 DataTypeString(DataTypeToEnum<T>::value))); 56 57 // Allocate output shapes 58 const int total_nnz = logits_matrix->total_nnz(); 59 Tensor output_values_t; 60 OP_REQUIRES_OK( 61 ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 62 TensorShape({total_nnz}), &output_values_t)); 63 64 CSRSparseMatrix output_matrix; 65 66 Tensor dense_shape_t = logits_matrix->dense_shape(); 67 68 OP_REQUIRES_OK( 69 ctx, 70 CSRSparseMatrix::CreateCSRSparseMatrix( 71 DataTypeToEnum<T>::value, dense_shape_t, 72 logits_matrix->batch_pointers(), logits_matrix->row_pointers(), 73 logits_matrix->col_indices(), output_values_t, &output_matrix)); 74 75 if (total_nnz > 0) { 76 functor::CSRSparseMatrixSoftmax<Device, T> softmax; 77 OP_REQUIRES_OK( 78 ctx, softmax(ctx, *logits_matrix, output_matrix.values().vec<T>())); 79 } 80 81 Tensor output_t(cpu_allocator(), DT_VARIANT, TensorShape({})); 82 output_t.scalar<Variant>()() = std::move(output_matrix); 83 ctx->set_output(0, output_t); 84 } 85 }; 86 87 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 88 #define REGISTER(DEV, T) \ 89 REGISTER_KERNEL_BUILDER(Name("SparseMatrixSoftmax") \ 90 .Device(DEVICE_##DEV) \ 91 .TypeConstraint<T>("type"), \ 92 CSRSoftmaxOp<DEV##Device, T>); 93 94 REGISTER(GPU, float) 95 REGISTER(GPU, double) 96 97 #undef REGISTER 98 99 namespace functor { 100 #define DECLARE_GPU_SPEC(T) \ 101 template <> \ 102 Status CSRSparseMatrixSoftmax<GPUDevice, T>::operator()( \ 103 OpKernelContext* ctx, const CSRSparseMatrix& logits, \ 104 typename TTypes<T>::Vec softmax_values); \ 105 extern template struct CSRSparseMatrixSoftmax<GPUDevice, T>; 106 107 DECLARE_GPU_SPEC(float); 108 DECLARE_GPU_SPEC(double); 109 110 #undef DECLARE_GPU_SPEC 111 } // namespace functor 112 113 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 114 115 template <typename Device, typename T> 116 class CSRSoftmaxGradOp : public OpKernel { 117 public: CSRSoftmaxGradOp(OpKernelConstruction * ctx)118 explicit CSRSoftmaxGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} 119 Compute(OpKernelContext * ctx)120 void Compute(OpKernelContext* ctx) override { 121 const CSRSparseMatrix* softmax_matrix; 122 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &softmax_matrix)); 123 OP_REQUIRES(ctx, softmax_matrix->dtype() == DataTypeToEnum<T>::value, 124 errors::InvalidArgument( 125 "dtype of softmax is not equal to 'type': ", 126 DataTypeString(softmax_matrix->dtype()), " vs. ", 127 DataTypeString(DataTypeToEnum<T>::value))); 128 129 const CSRSparseMatrix* grad_softmax_matrix; 130 OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 1, &grad_softmax_matrix)); 131 OP_REQUIRES(ctx, grad_softmax_matrix->dtype() == DataTypeToEnum<T>::value, 132 errors::InvalidArgument( 133 "dtype of grad_softmax is not equal to 'type': ", 134 DataTypeString(grad_softmax_matrix->dtype()), " vs. ", 135 DataTypeString(DataTypeToEnum<T>::value))); 136 137 OP_REQUIRES( 138 ctx, softmax_matrix->dims() == grad_softmax_matrix->dims(), 139 errors::InvalidArgument( 140 "Ranks of softmax and grad_softmax matrices differ: ", 141 softmax_matrix->dims(), " vs. ", grad_softmax_matrix->dims())); 142 143 OP_REQUIRES( 144 ctx, softmax_matrix->dims() == grad_softmax_matrix->dims(), 145 errors::InvalidArgument( 146 "Ranks of softmax and grad_softmax matrices differ: ", 147 softmax_matrix->dims(), " vs. ", grad_softmax_matrix->dims())); 148 149 Tensor dense_shape_t = softmax_matrix->dense_shape(); 150 auto host_dense_shape = 151 static_cast<const Tensor>(dense_shape_t).vec<int64_t>(); 152 153 auto host_grad_dense_shape = 154 grad_softmax_matrix->dense_shape().vec<int64_t>(); 155 156 for (int i = 0; i < host_dense_shape.size(); ++i) { 157 OP_REQUIRES(ctx, host_dense_shape(i) == host_grad_dense_shape(i), 158 errors::InvalidArgument( 159 "Shapes of softmax and grad_softmax matrices differ: ", 160 dense_shape_t.SummarizeValue(3), " vs. ", 161 grad_softmax_matrix->dense_shape().SummarizeValue(3))); 162 } 163 164 // Allocate output shapes. Note that since the Softmax Gradient 165 // tensor is the elementwise product of some function with the 166 // softmax value, it will keep the sparsity structure of the softmax. 167 const int total_nnz = softmax_matrix->total_nnz(); 168 Tensor gradient_values; 169 OP_REQUIRES_OK( 170 ctx, ctx->allocate_temp(DataTypeToEnum<T>::value, 171 TensorShape({total_nnz}), &gradient_values)); 172 173 CSRSparseMatrix gradient_matrix; 174 175 OP_REQUIRES_OK( 176 ctx, 177 CSRSparseMatrix::CreateCSRSparseMatrix( 178 DataTypeToEnum<T>::value, dense_shape_t, 179 softmax_matrix->batch_pointers(), softmax_matrix->row_pointers(), 180 softmax_matrix->col_indices(), gradient_values, &gradient_matrix)); 181 182 if (total_nnz > 0) { 183 functor::CSRSparseMatrixSoftmaxGrad<Device, T> softmax_grad; 184 OP_REQUIRES_OK(ctx, 185 softmax_grad(ctx, *softmax_matrix, *grad_softmax_matrix, 186 gradient_matrix.values().vec<T>())); 187 } 188 189 Tensor gradient_t(cpu_allocator(), DT_VARIANT, TensorShape({})); 190 gradient_t.scalar<Variant>()() = std::move(gradient_matrix); 191 ctx->set_output(0, gradient_t); 192 } 193 }; 194 195 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM 196 #define REGISTER(DEV, T) \ 197 REGISTER_KERNEL_BUILDER(Name("SparseMatrixSoftmaxGrad") \ 198 .Device(DEVICE_##DEV) \ 199 .TypeConstraint<T>("type"), \ 200 CSRSoftmaxGradOp<DEV##Device, T>); 201 202 REGISTER(GPU, float) 203 REGISTER(GPU, double) 204 205 #undef REGISTER 206 207 namespace functor { 208 #define DECLARE_GPU_SPEC(T) \ 209 template <> \ 210 Status CSRSparseMatrixSoftmaxGrad<GPUDevice, T>::operator()( \ 211 OpKernelContext* ctx, const CSRSparseMatrix& softmax, \ 212 const CSRSparseMatrix& grad_softmax, \ 213 typename TTypes<T>::Vec gradient_values); \ 214 extern template struct CSRSparseMatrixSoftmaxGrad<GPUDevice, T>; 215 216 DECLARE_GPU_SPEC(float); 217 DECLARE_GPU_SPEC(double); 218 219 #undef DECLARE_GPU_SPEC 220 } // namespace functor 221 222 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM 223 224 } // namespace tensorflow 225