xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/aggregate_ops.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/aggregate_ops.h"
21 
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/op_requires.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/variant.h"
27 #include "tensorflow/core/framework/variant_op_registry.h"
28 #include "tensorflow/core/kernels/aggregate_ops_cpu.h"
29 #include "tensorflow/core/kernels/variant_ops_util.h"
30 #include "tensorflow/core/lib/gtl/inlined_vector.h"
31 
32 namespace tensorflow {
33 
34 typedef Eigen::ThreadPoolDevice CPUDevice;
35 typedef Eigen::GpuDevice GPUDevice;
36 
37 template <typename Device, typename T>
38 class AddNOp : public OpKernel {
39  public:
AddNOp(OpKernelConstruction * context)40   explicit AddNOp(OpKernelConstruction* context) : OpKernel(context) {}
41 
Compute(OpKernelContext * ctx)42   void Compute(OpKernelContext* ctx) override {
43     if (!ctx->ValidateInputsAreSameShape(this)) return;
44 
45     const Tensor& input0 = ctx->input(0);
46     const int num = ctx->num_inputs();
47 
48     if (num == 1) {
49       ctx->set_output(0, input0);
50       return;
51     }
52 
53     // Try to forward and accumulate the result in one of the input buffers.
54     int reused_input = -1;
55     gtl::InlinedVector<int, 8> input_indices(num);
56     std::iota(input_indices.begin(), input_indices.end(), 0);
57     Tensor* output = nullptr;
58     for (int input_idx = 0; input_idx < num; ++input_idx) {
59       if (ctx->forward_input_to_output_with_shape(input_idx, 0, input0.shape(),
60                                                   &output)) {
61         reused_input = input_idx;
62         break;
63       }
64     }
65     if (reused_input == -1) {
66       OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input0.shape(), &output));
67     } else if (reused_input > 0) {
68       // Move the forwarded buffer to the front so we don't double count
69       // anything if there are more than 8 inputs.
70       input_indices[0] = reused_input;
71       input_indices[reused_input] = 0;
72     }
73     auto To = output->flat<T>();
74 
75 #define I(IDX) ctx->input(input_indices[IDX]).template flat<T>()
76 
77 #if defined(__ANDROID_TYPES_SLIM__)
78     // On Android by default,we only support additions of two arguments, so we
79     // can reduce the number of template instantiations.
80     OP_REQUIRES(ctx, num == 2,
81                 errors::InvalidArgument("Only additions of two arguments "
82                                         "supported. Num inputs: ",
83                                         num));
84     functor::Add2Functor<Device, T> functor2;
85     functor2(ctx->template eigen_device<Device>(), To, I(0), I(1));
86 #else
87     static const int kWidth = 8;
88     int r = num % kWidth;
89 
90     switch (r) {
91       case 2: {
92         functor::Add2Functor<Device, T> functor2;
93         functor2(ctx->template eigen_device<Device>(), To, I(0), I(1));
94         break;
95       }
96       case 3: {
97         functor::Add3Functor<Device, T> functor3;
98         functor3(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2));
99         break;
100       }
101       case 4: {
102         functor::Add4Functor<Device, T> functor4;
103         functor4(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
104                  I(3));
105         break;
106       }
107       case 5: {
108         functor::Add5Functor<Device, T> functor5;
109         functor5(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
110                  I(3), I(4));
111         break;
112       }
113       case 6: {
114         functor::Add6Functor<Device, T> functor6;
115         functor6(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
116                  I(3), I(4), I(5));
117         break;
118       }
119       case 7: {
120         functor::Add7Functor<Device, T> functor7;
121         functor7(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
122                  I(3), I(4), I(5), I(6));
123         break;
124       }
125       case 0: {
126         functor::Add8Functor<Device, T> functor8;
127         functor8(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
128                  I(3), I(4), I(5), I(6), I(7));
129         r = 8;
130         break;
131       }
132       case 1: {
133         functor::Add9Functor<Device, T> functor9;
134         functor9(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
135                  I(3), I(4), I(5), I(6), I(7), I(8));
136         r = 9;
137         break;
138       }
139     }
140 
141     for (; r < num; r += kWidth) {
142       functor::Add8pFunctor<Device, T> functor8p;
143       functor8p(ctx->template eigen_device<Device>(), To, I(r), I(r + 1),
144                 I(r + 2), I(r + 3), I(r + 4), I(r + 5), I(r + 6), I(r + 7));
145     }
146 #endif  // defined(__ANDROID_TYPES_SLIM__)
147 
148 #undef I
149   }
150 };
151 
152 template <typename Device>
153 class AddNOp<Device, Variant> : public OpKernel {
154  public:
AddNOp(OpKernelConstruction * context)155   explicit AddNOp(OpKernelConstruction* context) : OpKernel(context) {}
156 
Compute(OpKernelContext * ctx)157   void Compute(OpKernelContext* ctx) override {
158     auto binary_add = [](OpKernelContext* cc_ctx, const Variant& a,
159                          const Variant& b, Variant* out) {
160       return BinaryOpVariants<Device>(cc_ctx, ADD_VARIANT_BINARY_OP, a, b, out);
161     };
162     AddNVariant(ctx, binary_add);
163   }
164 
165  private:
166   // AddVariantTo efficiently performs:
167   //    temp[lhs_ix] <- array(lhs_ix) + array(rhs_ix)
168   // where array(ix) := (temp_filled[ix]
169   //                     ? temp[ix]
170   //                     : ctx->input(ix).scalar<Variant>()())
171   // This reduces (possibly expensive) copying of Variants from
172   // the inputs into temp at the lowest levels of the summation tree.
AddVariantTo(OpKernelContext * ctx,const int lhs_ix,const int rhs_ix,gtl::InlinedVector<Variant,4> * temp,gtl::InlinedVector<bool,4> * temp_filled)173   static inline Status AddVariantTo(OpKernelContext* ctx, const int lhs_ix,
174                                     const int rhs_ix,
175                                     gtl::InlinedVector<Variant, 4>* temp,
176                                     gtl::InlinedVector<bool, 4>* temp_filled) {
177     Variant tmp;
178     if (temp_filled->at(lhs_ix)) tmp = std::move(temp->at(lhs_ix));
179     const Variant& a = temp_filled->at(lhs_ix)
180                            ? tmp
181                            : ctx->input(lhs_ix).template scalar<Variant>()();
182     const Variant& b = temp_filled->at(rhs_ix)
183                            ? temp->at(rhs_ix)
184                            : ctx->input(rhs_ix).template scalar<Variant>()();
185     Variant* c = &temp->at(lhs_ix);
186     TF_RETURN_IF_ERROR(
187         BinaryOpVariants<Device>(ctx, ADD_VARIANT_BINARY_OP, a, b, c));
188     temp_filled->at(lhs_ix) = true;
189     return OkStatus();
190   }
191 };
192 
193 #define REGISTER_ADDN(type, dev)                                   \
194   REGISTER_KERNEL_BUILDER(                                         \
195       Name("AddN").Device(DEVICE_##dev).TypeConstraint<type>("T"), \
196       AddNOp<dev##Device, type>)
197 
198 #define REGISTER_ADDN_CPU(type) REGISTER_ADDN(type, CPU)
199 
200 TF_CALL_NUMBER_TYPES(REGISTER_ADDN_CPU);
201 REGISTER_ADDN_CPU(Variant);
202 
203 #undef REGISTER_ADDN_CPU
204 
205 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
206     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
207 #define REGISTER_ADDN_GPU(type) REGISTER_ADDN(type, GPU)
208 TF_CALL_int64(REGISTER_ADDN_GPU);
209 TF_CALL_uint32(REGISTER_ADDN_GPU);
210 TF_CALL_variant(REGISTER_ADDN_GPU);
211 TF_CALL_GPU_NUMBER_TYPES(REGISTER_ADDN_GPU);
212 TF_CALL_COMPLEX_TYPES(REGISTER_ADDN_GPU);
213 #undef REGISTER_ADDN_GPU
214 
215 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
216 
217 // A special DEVICE_DEFAULT kernel for int32.
218 // TODO(b/25387198): Also enable int32 in device memory. This kernel
219 // registration requires all int32 inputs and outputs to be in host memory.
220 REGISTER_KERNEL_BUILDER(Name("AddN")
221                             .Device(DEVICE_DEFAULT)
222                             .TypeConstraint<int32>("T")
223                             .HostMemory("inputs")
224                             .HostMemory("sum"),
225                         AddNOp<CPUDevice, int32>);
226 
227 #undef REGISTER_ADDN
228 
229 }  // namespace tensorflow
230