xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/sparse/softmax_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Implements the kernel for the CSRSoftmax op, which performs softmax
17 // along the innermost (col) dimension of a CSRSparseMatrix object
18 // stored in a DT_VARIANT.
19 
20 #define EIGEN_USE_THREADS
21 
22 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
23 #include "tensorflow/core/util/cuda_sparse.h"
24 #define EIGEN_USE_GPU
25 #endif
26 
27 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
28 #include "tensorflow/core/framework/op.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/framework/tensor_types.h"
31 #include "tensorflow/core/framework/variant_op_registry.h"
32 #include "tensorflow/core/kernels/dense_update_functor.h"
33 #include "tensorflow/core/kernels/fill_functor.h"
34 #include "tensorflow/core/kernels/slice_op.h"
35 #include "tensorflow/core/kernels/sparse/kernels.h"
36 #include "tensorflow/core/kernels/sparse/sparse_matrix.h"
37 
38 namespace tensorflow {
39 
40 typedef Eigen::ThreadPoolDevice CPUDevice;
41 typedef Eigen::GpuDevice GPUDevice;
42 
43 template <typename Device, typename T>
44 class CSRSoftmaxOp : public OpKernel {
45  public:
CSRSoftmaxOp(OpKernelConstruction * ctx)46   explicit CSRSoftmaxOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
47 
Compute(OpKernelContext * ctx)48   void Compute(OpKernelContext* ctx) override {
49     const CSRSparseMatrix* logits_matrix;
50     OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &logits_matrix));
51     OP_REQUIRES(
52         ctx, logits_matrix->dtype() == DataTypeToEnum<T>::value,
53         errors::InvalidArgument("dtype of logits is not equal to 'type': ",
54                                 DataTypeString(logits_matrix->dtype()), " vs. ",
55                                 DataTypeString(DataTypeToEnum<T>::value)));
56 
57     // Allocate output shapes
58     const int total_nnz = logits_matrix->total_nnz();
59     Tensor output_values_t;
60     OP_REQUIRES_OK(
61         ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
62                                 TensorShape({total_nnz}), &output_values_t));
63 
64     CSRSparseMatrix output_matrix;
65 
66     Tensor dense_shape_t = logits_matrix->dense_shape();
67 
68     OP_REQUIRES_OK(
69         ctx,
70         CSRSparseMatrix::CreateCSRSparseMatrix(
71             DataTypeToEnum<T>::value, dense_shape_t,
72             logits_matrix->batch_pointers(), logits_matrix->row_pointers(),
73             logits_matrix->col_indices(), output_values_t, &output_matrix));
74 
75     if (total_nnz > 0) {
76       functor::CSRSparseMatrixSoftmax<Device, T> softmax;
77       OP_REQUIRES_OK(
78           ctx, softmax(ctx, *logits_matrix, output_matrix.values().vec<T>()));
79     }
80 
81     Tensor output_t(cpu_allocator(), DT_VARIANT, TensorShape({}));
82     output_t.scalar<Variant>()() = std::move(output_matrix);
83     ctx->set_output(0, output_t);
84   }
85 };
86 
87 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
88 #define REGISTER(DEV, T)                                  \
89   REGISTER_KERNEL_BUILDER(Name("SparseMatrixSoftmax")     \
90                               .Device(DEVICE_##DEV)       \
91                               .TypeConstraint<T>("type"), \
92                           CSRSoftmaxOp<DEV##Device, T>);
93 
94 REGISTER(GPU, float)
95 REGISTER(GPU, double)
96 
97 #undef REGISTER
98 
99 namespace functor {
100 #define DECLARE_GPU_SPEC(T)                                \
101   template <>                                              \
102   Status CSRSparseMatrixSoftmax<GPUDevice, T>::operator()( \
103       OpKernelContext* ctx, const CSRSparseMatrix& logits, \
104       typename TTypes<T>::Vec softmax_values);             \
105   extern template struct CSRSparseMatrixSoftmax<GPUDevice, T>;
106 
107 DECLARE_GPU_SPEC(float);
108 DECLARE_GPU_SPEC(double);
109 
110 #undef DECLARE_GPU_SPEC
111 }  // namespace functor
112 
113 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
114 
115 template <typename Device, typename T>
116 class CSRSoftmaxGradOp : public OpKernel {
117  public:
CSRSoftmaxGradOp(OpKernelConstruction * ctx)118   explicit CSRSoftmaxGradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
119 
Compute(OpKernelContext * ctx)120   void Compute(OpKernelContext* ctx) override {
121     const CSRSparseMatrix* softmax_matrix;
122     OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 0, &softmax_matrix));
123     OP_REQUIRES(ctx, softmax_matrix->dtype() == DataTypeToEnum<T>::value,
124                 errors::InvalidArgument(
125                     "dtype of softmax is not equal to 'type': ",
126                     DataTypeString(softmax_matrix->dtype()), " vs. ",
127                     DataTypeString(DataTypeToEnum<T>::value)));
128 
129     const CSRSparseMatrix* grad_softmax_matrix;
130     OP_REQUIRES_OK(ctx, ExtractVariantFromInput(ctx, 1, &grad_softmax_matrix));
131     OP_REQUIRES(ctx, grad_softmax_matrix->dtype() == DataTypeToEnum<T>::value,
132                 errors::InvalidArgument(
133                     "dtype of grad_softmax is not equal to 'type': ",
134                     DataTypeString(grad_softmax_matrix->dtype()), " vs. ",
135                     DataTypeString(DataTypeToEnum<T>::value)));
136 
137     OP_REQUIRES(
138         ctx, softmax_matrix->dims() == grad_softmax_matrix->dims(),
139         errors::InvalidArgument(
140             "Ranks of softmax and grad_softmax matrices differ: ",
141             softmax_matrix->dims(), " vs. ", grad_softmax_matrix->dims()));
142 
143     OP_REQUIRES(
144         ctx, softmax_matrix->dims() == grad_softmax_matrix->dims(),
145         errors::InvalidArgument(
146             "Ranks of softmax and grad_softmax matrices differ: ",
147             softmax_matrix->dims(), " vs. ", grad_softmax_matrix->dims()));
148 
149     Tensor dense_shape_t = softmax_matrix->dense_shape();
150     auto host_dense_shape =
151         static_cast<const Tensor>(dense_shape_t).vec<int64_t>();
152 
153     auto host_grad_dense_shape =
154         grad_softmax_matrix->dense_shape().vec<int64_t>();
155 
156     for (int i = 0; i < host_dense_shape.size(); ++i) {
157       OP_REQUIRES(ctx, host_dense_shape(i) == host_grad_dense_shape(i),
158                   errors::InvalidArgument(
159                       "Shapes of softmax and grad_softmax matrices differ: ",
160                       dense_shape_t.SummarizeValue(3), " vs. ",
161                       grad_softmax_matrix->dense_shape().SummarizeValue(3)));
162     }
163 
164     // Allocate output shapes.  Note that since the Softmax Gradient
165     // tensor is the elementwise product of some function with the
166     // softmax value, it will keep the sparsity structure of the softmax.
167     const int total_nnz = softmax_matrix->total_nnz();
168     Tensor gradient_values;
169     OP_REQUIRES_OK(
170         ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
171                                 TensorShape({total_nnz}), &gradient_values));
172 
173     CSRSparseMatrix gradient_matrix;
174 
175     OP_REQUIRES_OK(
176         ctx,
177         CSRSparseMatrix::CreateCSRSparseMatrix(
178             DataTypeToEnum<T>::value, dense_shape_t,
179             softmax_matrix->batch_pointers(), softmax_matrix->row_pointers(),
180             softmax_matrix->col_indices(), gradient_values, &gradient_matrix));
181 
182     if (total_nnz > 0) {
183       functor::CSRSparseMatrixSoftmaxGrad<Device, T> softmax_grad;
184       OP_REQUIRES_OK(ctx,
185                      softmax_grad(ctx, *softmax_matrix, *grad_softmax_matrix,
186                                   gradient_matrix.values().vec<T>()));
187     }
188 
189     Tensor gradient_t(cpu_allocator(), DT_VARIANT, TensorShape({}));
190     gradient_t.scalar<Variant>()() = std::move(gradient_matrix);
191     ctx->set_output(0, gradient_t);
192   }
193 };
194 
195 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
196 #define REGISTER(DEV, T)                                  \
197   REGISTER_KERNEL_BUILDER(Name("SparseMatrixSoftmaxGrad") \
198                               .Device(DEVICE_##DEV)       \
199                               .TypeConstraint<T>("type"), \
200                           CSRSoftmaxGradOp<DEV##Device, T>);
201 
202 REGISTER(GPU, float)
203 REGISTER(GPU, double)
204 
205 #undef REGISTER
206 
207 namespace functor {
208 #define DECLARE_GPU_SPEC(T)                                    \
209   template <>                                                  \
210   Status CSRSparseMatrixSoftmaxGrad<GPUDevice, T>::operator()( \
211       OpKernelContext* ctx, const CSRSparseMatrix& softmax,    \
212       const CSRSparseMatrix& grad_softmax,                     \
213       typename TTypes<T>::Vec gradient_values);                \
214   extern template struct CSRSparseMatrixSoftmaxGrad<GPUDevice, T>;
215 
216 DECLARE_GPU_SPEC(float);
217 DECLARE_GPU_SPEC(double);
218 
219 #undef DECLARE_GPU_SPEC
220 }  // namespace functor
221 
222 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
223 
224 }  // namespace tensorflow
225