xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/scatter_functor.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_
17 #define TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_
18 
19 #include <type_traits>
20 
21 #include "third_party/eigen3/Eigen/Core"
22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23 #include "tensorflow/core/framework/bounds_check.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/variant_op_registry.h"
27 #include "tensorflow/core/kernels/dense_update_functor.h"
28 #include "tensorflow/core/platform/types.h"
29 #include "tensorflow/core/util/determinism.h"
30 #include "tensorflow/core/util/work_sharder.h"
31 
32 namespace tensorflow {
33 
34 class OpKernelContext;
35 typedef Eigen::ThreadPoolDevice CPUDevice;
36 typedef Eigen::GpuDevice GPUDevice;
37 
38 namespace scatter_op {
39 
40 enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV, MIN, MAX };
41 
42 namespace internal {
43 
44 template <scatter_op::UpdateOp Op>
45 struct Assign {};
46 template <>
47 struct Assign<scatter_op::UpdateOp::ASSIGN> {
48   template <typename Params, typename Update>
49   static void Run(Params p, Update u) {
50     p = u;
51   }
52   template <typename Params, typename Update>
53   static void RunScalar(Params p, Update u) {
54     p.setConstant(u);
55   }
56 };
57 template <>
58 struct Assign<scatter_op::UpdateOp::ADD> {
59   template <typename Params, typename Update>
60   static void Run(Params p, Update u) {
61     p += u;
62   }
63   template <typename Params, typename Update>
64   static void RunScalar(Params p, Update u) {
65     p = p + u;
66   }
67 };
68 template <>
69 struct Assign<scatter_op::UpdateOp::SUB> {
70   template <typename Params, typename Update>
71   static void Run(Params p, Update u) {
72     p -= u;
73   }
74   template <typename Params, typename Update>
75   static void RunScalar(Params p, Update u) {
76     p = p + static_cast<Update>(-u);
77   }
78 };
79 template <>
80 struct Assign<scatter_op::UpdateOp::MUL> {
81   template <typename Params, typename Update>
82   static void Run(Params p, Update u) {
83     p *= u;
84   }
85   template <typename Params, typename Update>
86   static void RunScalar(Params p, Update u) {
87     p = p * u;
88   }
89 };
90 template <>
91 struct Assign<scatter_op::UpdateOp::DIV> {
92   template <typename Params, typename Update>
93   static void Run(Params p, Update u) {
94     p /= u;
95   }
96   template <typename Params, typename Update>
97   static void RunScalar(Params p, Update u) {
98     p = p / u;
99   }
100 };
101 template <>
102 struct Assign<scatter_op::UpdateOp::MIN> {
103   // This method requires that Params and Update are tensor types.
104   template <typename Params, typename Update>
105   static void Run(Params p, Update u) {
106     p = p.cwiseMin(u);
107   }
108   // Same thing, but for Update being a scalar type.
109   template <typename Params, typename Update>
110   static void RunScalar(Params p, Update u) {
111     p = p.cwiseMin(u);
112   }
113 };
114 template <>
115 struct Assign<scatter_op::UpdateOp::MAX> {
116   template <typename Params, typename Update>
117   static void Run(Params p, Update u) {
118     p = p.cwiseMax(u);
119   }
120   template <typename Params, typename Update>
121   static void RunScalar(Params p, Update u) {
122     p = p.cwiseMax(u);
123   }
124 };
125 
126 
127 }  // namespace internal
128 }  // namespace scatter_op
129 
130 namespace functor {
131 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
132 struct ScatterFunctor {
133   Index operator()(OpKernelContext* c, const Device& d,
134                    typename TTypes<T>::Matrix params,
135                    typename TTypes<T>::ConstMatrix updates,
136                    typename TTypes<Index>::ConstFlat indices);
137 };
138 
139 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
140 struct ScatterFunctorBase {
141   Index ParallelExecute(OpKernelContext* c, const Device& d,
142                         typename TTypes<T>::Matrix params,
143                         typename TTypes<T>::ConstMatrix updates,
144                         typename TTypes<Index>::ConstFlat indices) {
145     const Index N = static_cast<Index>(indices.size());
146     const Index limit = static_cast<Index>(params.dimension(0));
147     const Index kMaxLocks = 1024;
148     const Index entries_per_lock = (limit + kMaxLocks - 1) / kMaxLocks;
149     // To reduce the number of locks and the memory usage, we divide the whole
150     // index space into kMaxLocks regions with each lock serializing access to
151     // a region.
152     mutex accessed[kMaxLocks];
153     std::atomic<Index> bad_index(-1);
154     auto ParallelScatter = [&](Index start, Index end) {
155       for (Index i = start; i < end; ++i) {
156         // Grab the index and check its validity.  Do this carefully,
157         // to avoid checking the value and grabbing it again from
158         // memory a second time (a security risk since it may change in
159         // between).
160         const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
161         if (!FastBoundsCheck(index, limit)) {
162           bad_index = i;
163           return;
164         }
165         const Index lock_id = index / entries_per_lock;
166         // Copy last Ndim-1 dimensions of updates[i] to params[index]
167         {
168           mutex_lock l(accessed[lock_id]);
169           scatter_op::internal::Assign<op>::Run(params.template chip<0>(index),
170                                                 updates.template chip<0>(i));
171         }
172       }
173     };
174     const float kMovingCost = 2.5f;
175     float shard_cost = kMovingCost * params.dimension(1);
176     const DeviceBase::CpuWorkerThreads& worker_threads =
177         *(c->device()->tensorflow_cpu_worker_threads());
178     Shard(worker_threads.num_threads, worker_threads.workers, N, shard_cost,
179           ParallelScatter);  // TODO: Come up with a good cost estimate.
180     return bad_index;
181   }
182   Index SerialExecute(OpKernelContext* c, const Device& d,
183                       typename TTypes<T>::Matrix params,
184                       typename TTypes<T>::ConstMatrix updates,
185                       typename TTypes<Index>::ConstFlat indices) {
186     const Index N = static_cast<Index>(indices.size());
187     const Index limit = static_cast<Index>(params.dimension(0));
188     for (Index i = 0; i < N; ++i) {
189       // Grab the index and check its validity.  Do this carefully,
190       // to avoid checking the value and grabbing it again from
191       // memory a second time (a security risk since it may change in
192       // between).
193       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
194       if (!FastBoundsCheck(index, limit)) return i;
195       // Copy last Ndim-1 dimensions of updates[i] to params[index]
196       scatter_op::internal::Assign<op>::Run(params.template chip<0>(index),
197                                             updates.template chip<0>(i));
198     }
199     return -1;
200   }
201 
202   Index operator()(OpKernelContext* c, const Device& d,
203                    typename TTypes<T>::Matrix params,
204                    typename TTypes<T>::ConstMatrix updates,
205                    typename TTypes<Index>::ConstFlat indices) {
206 #ifdef PLATFORM_GOOGLE
207     // The parallel version is significantly slower internally. Only call the
208     // serial version for now.
209     // TODO(penporn): Avoid locking in parallelization (sort beforehand).
210     return SerialExecute(c, d, params, updates, indices);
211 #else
212     // indices and params sizes were validated in DoCompute().
213     const Index N = static_cast<Index>(indices.size());
214     const Index limit = static_cast<Index>(params.dimension(0));
215     const Index min_n_threshold = 1024;
216     const Index ser_par_ratio = 10000;
217     // For parallelizing the updates, duplicate entries need to be handled
218     // correctly. Multiple updates to the same index has to be serialized.
219     // This can lead to lock contention which may nullify the benefits of
220     // parallelization. Assuming uniform random distribution of the indices, we
221     // come up with a rough heuristic and determine whether the updates execute
222     // serially or parallelly. Also if 'N' is small, overheads of parallel
223     // execution outweigh its benefits and hence we check the value of N.
224     const bool execute_serial = N < min_n_threshold ||
225                                 (N / limit) > ser_par_ratio ||
226                                 OpDeterminismRequired();
227     if (execute_serial)
228       return SerialExecute(c, d, params, updates, indices);
229     else
230       return ParallelExecute(c, d, params, updates, indices);
231 #endif  // PLATFORM_GOOGLE
232   }
233 };
234 
235 template <typename Device, typename Index>
236 struct ScatterFunctorVariantAssignBase {
237   Index operator()(OpKernelContext* c, const Device& d,
238                    typename TTypes<Variant>::Matrix params,
239                    typename TTypes<Variant>::ConstMatrix updates,
240                    typename TTypes<Index>::ConstFlat indices) {
241     // indices and params sizes were validated in DoCompute().
242     const Index N = static_cast<Index>(indices.size());
243     const Index limit = static_cast<Index>(params.dimension(0));
244     const Index cols = static_cast<Index>(params.dimension(1));
245     DCHECK_EQ(N, updates.dimension(0));
246     DCHECK_EQ(cols, updates.dimension(1));
247     for (Index i = 0; i < N; i++) {
248       // Grab the index and check its validity.  Do this carefully,
249       // to avoid checking the value and grabbing it again from
250       // memory a second time (a security risk since it may change in between).
251       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
252       if (!FastBoundsCheck(index, limit)) return i;
253       // Copy last Ndim-1 dimensions of updates[i] to params[index]
254       for (int j = 0; j < cols; ++j) {
255         const Variant& to_scatter = updates(i, j);
256         params(index, j) = to_scatter;
257       }
258     }
259     return -1;
260   }
261 };
262 
263 template <typename Index>
264 struct ScatterFunctor<CPUDevice, Variant, Index, scatter_op::UpdateOp::ASSIGN>
265     : ScatterFunctorVariantAssignBase<CPUDevice, Index> {};
266 
267 template <typename Index>
268 struct ScatterFunctor<GPUDevice, Variant, Index, scatter_op::UpdateOp::ASSIGN>
269     : ScatterFunctorVariantAssignBase<GPUDevice, Index> {};
270 
271 
272 template <typename T, typename Index>
273 struct ScatterFunctorBase<CPUDevice, T, Index, scatter_op::UpdateOp::ASSIGN> {
274   Index operator()(OpKernelContext* c, const CPUDevice& d,
275                    typename TTypes<T>::Matrix params,
276                    typename TTypes<T>::ConstMatrix updates,
277                    typename TTypes<Index>::ConstFlat indices) {
278     // indices and params sizes were validated in DoCompute().
279     const Index N = static_cast<Index>(indices.size());
280     const Index limit = static_cast<Index>(params.dimension(0));
281     if (!std::is_same<T, tstring>::value) {
282       for (Index i = 0; i < N; i++) {
283         // Grab the index and check its validity.  Do this carefully,
284         // to avoid checking the value and grabbing it again from
285         // memory a second time (a security risk since it may change in
286         // between).
287         const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
288         if (!FastBoundsCheck(index, limit)) return i;
289         memmove(params.data() + index * params.dimension(1),
290                 updates.data() + i * updates.dimension(1),
291                 updates.dimension(1) * sizeof(T));
292       }
293     } else {
294       for (Index i = 0; i < N; i++) {
295         // Grab the index and check its validity.  Do this carefully,
296         // to avoid checking the value and grabbing it again from
297         // memory a second time (a security risk since it may change in
298         // between).
299         const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
300         if (!FastBoundsCheck(index, limit)) return i;
301         // Copy last Ndim-1 dimensions of updates[i] to params[index]
302         scatter_op::internal::Assign<scatter_op::UpdateOp::ASSIGN>::Run(
303             params.template chip<0>(index), updates.template chip<0>(i));
304       }
305     }
306     return -1;
307   }
308 };
309 
310 template <typename T, typename Index, scatter_op::UpdateOp op>
311 struct ScatterFunctor<CPUDevice, T, Index, op>
312     : ScatterFunctorBase<CPUDevice, T, Index, op> {};
313 
314 
315 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
316 struct ScatterScalarFunctor {
317   Index operator()(OpKernelContext* c, const Device& d,
318                    typename TTypes<T>::Matrix params,
319                    const typename TTypes<T>::ConstScalar update,
320                    typename TTypes<Index>::ConstFlat indices);
321 };
322 
323 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op>
324 struct ScatterScalarFunctorBase {
325   Index operator()(OpKernelContext* c, const Device& d,
326                    typename TTypes<T>::Matrix params,
327                    const typename TTypes<T>::ConstScalar update,
328                    typename TTypes<Index>::ConstFlat indices) {
329     // indices and params sizes were validated in DoCompute().
330     const Index N = static_cast<Index>(indices.size());
331     const Index limit = static_cast<Index>(params.dimension(0));
332     for (Index i = 0; i < N; i++) {
333       // Grab the index and check its validity.  Do this carefully,
334       // to avoid checking the value and grabbing it again from
335       // memory a second time (a security risk since it may change in between).
336       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
337       if (!FastBoundsCheck(index, limit)) return i;
338       // Broadcast update to params[index]
339       scatter_op::internal::Assign<op>::RunScalar(
340           params.template chip<0>(index), update());
341     }
342     return -1;
343   }
344 };
345 
346 template <typename Device, typename Index>
347 struct ScatterScalarFunctorVariantAssignBase {
348   Index operator()(OpKernelContext* c, const Device& d,
349                    typename TTypes<Variant>::Matrix params,
350                    const typename TTypes<Variant>::ConstScalar update,
351                    typename TTypes<Index>::ConstFlat indices) {
352     // indices and params sizes were validated in DoCompute().
353     const Index N = static_cast<Index>(indices.size());
354     const Index limit = static_cast<Index>(params.dimension(0));
355     const Index cols = static_cast<Index>(params.dimension(1));
356     const Variant& to_scatter = update();
357     for (Index i = 0; i < N; i++) {
358       // Grab the index and check its validity.  Do this carefully,
359       // to avoid checking the value and grabbing it again from
360       // memory a second time (a security risk since it may change in between).
361       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
362       if (!FastBoundsCheck(index, limit)) return i;
363       // Broadcast update to params[index]
364       for (Index j = 0; j < cols; ++j) {
365         params(index, j) = to_scatter;
366       }
367     }
368     return -1;
369   }
370 };
371 
372 template <typename Index>
373 struct ScatterScalarFunctor<CPUDevice, Variant, Index,
374                             scatter_op::UpdateOp::ASSIGN>
375     : ScatterScalarFunctorVariantAssignBase<CPUDevice, Index> {};
376 template <typename Index>
377 struct ScatterScalarFunctor<GPUDevice, Variant, Index,
378                             scatter_op::UpdateOp::ASSIGN>
379     : ScatterScalarFunctorVariantAssignBase<GPUDevice, Index> {};
380 
381 
382 template <typename T, typename Index>
383 struct ScatterScalarFunctorBase<CPUDevice, T, Index,
384                                 scatter_op::UpdateOp::ASSIGN> {
385   Index operator()(OpKernelContext* c, const CPUDevice& d,
386                    typename TTypes<T>::Matrix params,
387                    const typename TTypes<T>::ConstScalar update,
388                    typename TTypes<Index>::ConstFlat indices) {
389     // indices and params sizes were validated in DoCompute().
390     const Index N = static_cast<Index>(indices.size());
391     const Index limit = static_cast<Index>(params.dimension(0));
392     for (Index i = 0; i < N; i++) {
393       // Grab the index and check its validity.  Do this carefully,
394       // to avoid checking the value and grabbing it again from
395       // memory a second time (a security risk since it may change in between).
396       const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i));
397       if (!FastBoundsCheck(index, limit)) return i;
398       // Broadcast update to params[index]
399       scatter_op::internal::Assign<scatter_op::UpdateOp::ASSIGN>::RunScalar(
400           params.template chip<0>(index), update());
401     }
402     return -1;
403   }
404 };
405 
406 template <typename T, typename Index, scatter_op::UpdateOp op>
407 struct ScatterScalarFunctor<CPUDevice, T, Index, op>
408     : ScatterScalarFunctorBase<CPUDevice, T, Index, op> {};
409 
410 
411 }  // namespace functor
412 }  // namespace tensorflow
413 
414 #endif  // TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_
415