1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_ 17 #define TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_ 18 19 #include <type_traits> 20 21 #include "third_party/eigen3/Eigen/Core" 22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 23 #include "tensorflow/core/framework/bounds_check.h" 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/tensor.h" 26 #include "tensorflow/core/framework/variant_op_registry.h" 27 #include "tensorflow/core/kernels/dense_update_functor.h" 28 #include "tensorflow/core/platform/types.h" 29 #include "tensorflow/core/util/determinism.h" 30 #include "tensorflow/core/util/work_sharder.h" 31 32 namespace tensorflow { 33 34 class OpKernelContext; 35 typedef Eigen::ThreadPoolDevice CPUDevice; 36 typedef Eigen::GpuDevice GPUDevice; 37 38 namespace scatter_op { 39 40 enum class UpdateOp { ASSIGN, ADD, SUB, MUL, DIV, MIN, MAX }; 41 42 namespace internal { 43 44 template <scatter_op::UpdateOp Op> 45 struct Assign {}; 46 template <> 47 struct Assign<scatter_op::UpdateOp::ASSIGN> { 48 template <typename Params, typename Update> 49 static void Run(Params p, Update u) { 50 p = u; 51 } 52 template <typename Params, typename Update> 53 static void RunScalar(Params p, Update u) { 54 p.setConstant(u); 55 } 56 }; 57 template <> 58 struct Assign<scatter_op::UpdateOp::ADD> { 59 template <typename Params, typename Update> 60 static void Run(Params p, Update u) { 61 p += u; 62 } 63 template <typename Params, typename Update> 64 static void RunScalar(Params p, Update u) { 65 p = p + u; 66 } 67 }; 68 template <> 69 struct Assign<scatter_op::UpdateOp::SUB> { 70 template <typename Params, typename Update> 71 static void Run(Params p, Update u) { 72 p -= u; 73 } 74 template <typename Params, typename Update> 75 static void RunScalar(Params p, Update u) { 76 p = p + static_cast<Update>(-u); 77 } 78 }; 79 template <> 80 struct Assign<scatter_op::UpdateOp::MUL> { 81 template <typename Params, typename Update> 82 static void Run(Params p, Update u) { 83 p *= u; 84 } 85 template <typename Params, typename Update> 86 static void RunScalar(Params p, Update u) { 87 p = p * u; 88 } 89 }; 90 template <> 91 struct Assign<scatter_op::UpdateOp::DIV> { 92 template <typename Params, typename Update> 93 static void Run(Params p, Update u) { 94 p /= u; 95 } 96 template <typename Params, typename Update> 97 static void RunScalar(Params p, Update u) { 98 p = p / u; 99 } 100 }; 101 template <> 102 struct Assign<scatter_op::UpdateOp::MIN> { 103 // This method requires that Params and Update are tensor types. 104 template <typename Params, typename Update> 105 static void Run(Params p, Update u) { 106 p = p.cwiseMin(u); 107 } 108 // Same thing, but for Update being a scalar type. 109 template <typename Params, typename Update> 110 static void RunScalar(Params p, Update u) { 111 p = p.cwiseMin(u); 112 } 113 }; 114 template <> 115 struct Assign<scatter_op::UpdateOp::MAX> { 116 template <typename Params, typename Update> 117 static void Run(Params p, Update u) { 118 p = p.cwiseMax(u); 119 } 120 template <typename Params, typename Update> 121 static void RunScalar(Params p, Update u) { 122 p = p.cwiseMax(u); 123 } 124 }; 125 126 127 } // namespace internal 128 } // namespace scatter_op 129 130 namespace functor { 131 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> 132 struct ScatterFunctor { 133 Index operator()(OpKernelContext* c, const Device& d, 134 typename TTypes<T>::Matrix params, 135 typename TTypes<T>::ConstMatrix updates, 136 typename TTypes<Index>::ConstFlat indices); 137 }; 138 139 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> 140 struct ScatterFunctorBase { 141 Index ParallelExecute(OpKernelContext* c, const Device& d, 142 typename TTypes<T>::Matrix params, 143 typename TTypes<T>::ConstMatrix updates, 144 typename TTypes<Index>::ConstFlat indices) { 145 const Index N = static_cast<Index>(indices.size()); 146 const Index limit = static_cast<Index>(params.dimension(0)); 147 const Index kMaxLocks = 1024; 148 const Index entries_per_lock = (limit + kMaxLocks - 1) / kMaxLocks; 149 // To reduce the number of locks and the memory usage, we divide the whole 150 // index space into kMaxLocks regions with each lock serializing access to 151 // a region. 152 mutex accessed[kMaxLocks]; 153 std::atomic<Index> bad_index(-1); 154 auto ParallelScatter = [&](Index start, Index end) { 155 for (Index i = start; i < end; ++i) { 156 // Grab the index and check its validity. Do this carefully, 157 // to avoid checking the value and grabbing it again from 158 // memory a second time (a security risk since it may change in 159 // between). 160 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); 161 if (!FastBoundsCheck(index, limit)) { 162 bad_index = i; 163 return; 164 } 165 const Index lock_id = index / entries_per_lock; 166 // Copy last Ndim-1 dimensions of updates[i] to params[index] 167 { 168 mutex_lock l(accessed[lock_id]); 169 scatter_op::internal::Assign<op>::Run(params.template chip<0>(index), 170 updates.template chip<0>(i)); 171 } 172 } 173 }; 174 const float kMovingCost = 2.5f; 175 float shard_cost = kMovingCost * params.dimension(1); 176 const DeviceBase::CpuWorkerThreads& worker_threads = 177 *(c->device()->tensorflow_cpu_worker_threads()); 178 Shard(worker_threads.num_threads, worker_threads.workers, N, shard_cost, 179 ParallelScatter); // TODO: Come up with a good cost estimate. 180 return bad_index; 181 } 182 Index SerialExecute(OpKernelContext* c, const Device& d, 183 typename TTypes<T>::Matrix params, 184 typename TTypes<T>::ConstMatrix updates, 185 typename TTypes<Index>::ConstFlat indices) { 186 const Index N = static_cast<Index>(indices.size()); 187 const Index limit = static_cast<Index>(params.dimension(0)); 188 for (Index i = 0; i < N; ++i) { 189 // Grab the index and check its validity. Do this carefully, 190 // to avoid checking the value and grabbing it again from 191 // memory a second time (a security risk since it may change in 192 // between). 193 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); 194 if (!FastBoundsCheck(index, limit)) return i; 195 // Copy last Ndim-1 dimensions of updates[i] to params[index] 196 scatter_op::internal::Assign<op>::Run(params.template chip<0>(index), 197 updates.template chip<0>(i)); 198 } 199 return -1; 200 } 201 202 Index operator()(OpKernelContext* c, const Device& d, 203 typename TTypes<T>::Matrix params, 204 typename TTypes<T>::ConstMatrix updates, 205 typename TTypes<Index>::ConstFlat indices) { 206 #ifdef PLATFORM_GOOGLE 207 // The parallel version is significantly slower internally. Only call the 208 // serial version for now. 209 // TODO(penporn): Avoid locking in parallelization (sort beforehand). 210 return SerialExecute(c, d, params, updates, indices); 211 #else 212 // indices and params sizes were validated in DoCompute(). 213 const Index N = static_cast<Index>(indices.size()); 214 const Index limit = static_cast<Index>(params.dimension(0)); 215 const Index min_n_threshold = 1024; 216 const Index ser_par_ratio = 10000; 217 // For parallelizing the updates, duplicate entries need to be handled 218 // correctly. Multiple updates to the same index has to be serialized. 219 // This can lead to lock contention which may nullify the benefits of 220 // parallelization. Assuming uniform random distribution of the indices, we 221 // come up with a rough heuristic and determine whether the updates execute 222 // serially or parallelly. Also if 'N' is small, overheads of parallel 223 // execution outweigh its benefits and hence we check the value of N. 224 const bool execute_serial = N < min_n_threshold || 225 (N / limit) > ser_par_ratio || 226 OpDeterminismRequired(); 227 if (execute_serial) 228 return SerialExecute(c, d, params, updates, indices); 229 else 230 return ParallelExecute(c, d, params, updates, indices); 231 #endif // PLATFORM_GOOGLE 232 } 233 }; 234 235 template <typename Device, typename Index> 236 struct ScatterFunctorVariantAssignBase { 237 Index operator()(OpKernelContext* c, const Device& d, 238 typename TTypes<Variant>::Matrix params, 239 typename TTypes<Variant>::ConstMatrix updates, 240 typename TTypes<Index>::ConstFlat indices) { 241 // indices and params sizes were validated in DoCompute(). 242 const Index N = static_cast<Index>(indices.size()); 243 const Index limit = static_cast<Index>(params.dimension(0)); 244 const Index cols = static_cast<Index>(params.dimension(1)); 245 DCHECK_EQ(N, updates.dimension(0)); 246 DCHECK_EQ(cols, updates.dimension(1)); 247 for (Index i = 0; i < N; i++) { 248 // Grab the index and check its validity. Do this carefully, 249 // to avoid checking the value and grabbing it again from 250 // memory a second time (a security risk since it may change in between). 251 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); 252 if (!FastBoundsCheck(index, limit)) return i; 253 // Copy last Ndim-1 dimensions of updates[i] to params[index] 254 for (int j = 0; j < cols; ++j) { 255 const Variant& to_scatter = updates(i, j); 256 params(index, j) = to_scatter; 257 } 258 } 259 return -1; 260 } 261 }; 262 263 template <typename Index> 264 struct ScatterFunctor<CPUDevice, Variant, Index, scatter_op::UpdateOp::ASSIGN> 265 : ScatterFunctorVariantAssignBase<CPUDevice, Index> {}; 266 267 template <typename Index> 268 struct ScatterFunctor<GPUDevice, Variant, Index, scatter_op::UpdateOp::ASSIGN> 269 : ScatterFunctorVariantAssignBase<GPUDevice, Index> {}; 270 271 272 template <typename T, typename Index> 273 struct ScatterFunctorBase<CPUDevice, T, Index, scatter_op::UpdateOp::ASSIGN> { 274 Index operator()(OpKernelContext* c, const CPUDevice& d, 275 typename TTypes<T>::Matrix params, 276 typename TTypes<T>::ConstMatrix updates, 277 typename TTypes<Index>::ConstFlat indices) { 278 // indices and params sizes were validated in DoCompute(). 279 const Index N = static_cast<Index>(indices.size()); 280 const Index limit = static_cast<Index>(params.dimension(0)); 281 if (!std::is_same<T, tstring>::value) { 282 for (Index i = 0; i < N; i++) { 283 // Grab the index and check its validity. Do this carefully, 284 // to avoid checking the value and grabbing it again from 285 // memory a second time (a security risk since it may change in 286 // between). 287 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); 288 if (!FastBoundsCheck(index, limit)) return i; 289 memmove(params.data() + index * params.dimension(1), 290 updates.data() + i * updates.dimension(1), 291 updates.dimension(1) * sizeof(T)); 292 } 293 } else { 294 for (Index i = 0; i < N; i++) { 295 // Grab the index and check its validity. Do this carefully, 296 // to avoid checking the value and grabbing it again from 297 // memory a second time (a security risk since it may change in 298 // between). 299 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); 300 if (!FastBoundsCheck(index, limit)) return i; 301 // Copy last Ndim-1 dimensions of updates[i] to params[index] 302 scatter_op::internal::Assign<scatter_op::UpdateOp::ASSIGN>::Run( 303 params.template chip<0>(index), updates.template chip<0>(i)); 304 } 305 } 306 return -1; 307 } 308 }; 309 310 template <typename T, typename Index, scatter_op::UpdateOp op> 311 struct ScatterFunctor<CPUDevice, T, Index, op> 312 : ScatterFunctorBase<CPUDevice, T, Index, op> {}; 313 314 315 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> 316 struct ScatterScalarFunctor { 317 Index operator()(OpKernelContext* c, const Device& d, 318 typename TTypes<T>::Matrix params, 319 const typename TTypes<T>::ConstScalar update, 320 typename TTypes<Index>::ConstFlat indices); 321 }; 322 323 template <typename Device, typename T, typename Index, scatter_op::UpdateOp op> 324 struct ScatterScalarFunctorBase { 325 Index operator()(OpKernelContext* c, const Device& d, 326 typename TTypes<T>::Matrix params, 327 const typename TTypes<T>::ConstScalar update, 328 typename TTypes<Index>::ConstFlat indices) { 329 // indices and params sizes were validated in DoCompute(). 330 const Index N = static_cast<Index>(indices.size()); 331 const Index limit = static_cast<Index>(params.dimension(0)); 332 for (Index i = 0; i < N; i++) { 333 // Grab the index and check its validity. Do this carefully, 334 // to avoid checking the value and grabbing it again from 335 // memory a second time (a security risk since it may change in between). 336 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); 337 if (!FastBoundsCheck(index, limit)) return i; 338 // Broadcast update to params[index] 339 scatter_op::internal::Assign<op>::RunScalar( 340 params.template chip<0>(index), update()); 341 } 342 return -1; 343 } 344 }; 345 346 template <typename Device, typename Index> 347 struct ScatterScalarFunctorVariantAssignBase { 348 Index operator()(OpKernelContext* c, const Device& d, 349 typename TTypes<Variant>::Matrix params, 350 const typename TTypes<Variant>::ConstScalar update, 351 typename TTypes<Index>::ConstFlat indices) { 352 // indices and params sizes were validated in DoCompute(). 353 const Index N = static_cast<Index>(indices.size()); 354 const Index limit = static_cast<Index>(params.dimension(0)); 355 const Index cols = static_cast<Index>(params.dimension(1)); 356 const Variant& to_scatter = update(); 357 for (Index i = 0; i < N; i++) { 358 // Grab the index and check its validity. Do this carefully, 359 // to avoid checking the value and grabbing it again from 360 // memory a second time (a security risk since it may change in between). 361 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); 362 if (!FastBoundsCheck(index, limit)) return i; 363 // Broadcast update to params[index] 364 for (Index j = 0; j < cols; ++j) { 365 params(index, j) = to_scatter; 366 } 367 } 368 return -1; 369 } 370 }; 371 372 template <typename Index> 373 struct ScatterScalarFunctor<CPUDevice, Variant, Index, 374 scatter_op::UpdateOp::ASSIGN> 375 : ScatterScalarFunctorVariantAssignBase<CPUDevice, Index> {}; 376 template <typename Index> 377 struct ScatterScalarFunctor<GPUDevice, Variant, Index, 378 scatter_op::UpdateOp::ASSIGN> 379 : ScatterScalarFunctorVariantAssignBase<GPUDevice, Index> {}; 380 381 382 template <typename T, typename Index> 383 struct ScatterScalarFunctorBase<CPUDevice, T, Index, 384 scatter_op::UpdateOp::ASSIGN> { 385 Index operator()(OpKernelContext* c, const CPUDevice& d, 386 typename TTypes<T>::Matrix params, 387 const typename TTypes<T>::ConstScalar update, 388 typename TTypes<Index>::ConstFlat indices) { 389 // indices and params sizes were validated in DoCompute(). 390 const Index N = static_cast<Index>(indices.size()); 391 const Index limit = static_cast<Index>(params.dimension(0)); 392 for (Index i = 0; i < N; i++) { 393 // Grab the index and check its validity. Do this carefully, 394 // to avoid checking the value and grabbing it again from 395 // memory a second time (a security risk since it may change in between). 396 const Index index = ::tensorflow::internal::SubtleMustCopy(indices(i)); 397 if (!FastBoundsCheck(index, limit)) return i; 398 // Broadcast update to params[index] 399 scatter_op::internal::Assign<scatter_op::UpdateOp::ASSIGN>::RunScalar( 400 params.template chip<0>(index), update()); 401 } 402 return -1; 403 } 404 }; 405 406 template <typename T, typename Index, scatter_op::UpdateOp op> 407 struct ScatterScalarFunctor<CPUDevice, T, Index, op> 408 : ScatterScalarFunctorBase<CPUDevice, T, Index, op> {}; 409 410 411 } // namespace functor 412 } // namespace tensorflow 413 414 #endif // TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_ 415