1 #pragma once 2 3 #include <ATen/TensorMeta.h> 4 #include <ATen/core/Dimname.h> 5 #include <ATen/core/Range.h> 6 #include <ATen/core/TensorBase.h> 7 #include <c10/core/DynamicCast.h> 8 #include <c10/util/FunctionRef.h> 9 #include <c10/util/MaybeOwned.h> 10 #include <c10/util/SmallVector.h> 11 #include <c10/util/TypeCast.h> 12 #include <c10/util/irange.h> 13 14 #include <array> 15 #include <bitset> 16 17 namespace at { 18 class Tensor; 19 class OptionalTensorRef; 20 using NameVector = SmallVector<Dimname, kDimVectorStaticSize>; 21 } // namespace at 22 23 // TensorIterator is a helper class for element-wise operations, such as 24 // arithmetic, comparisons, and trigonometric functions. It handles 25 // broadcasting and type conversions of operands. 26 // 27 // This is inspired by NumPy's Array Iterator API (NpyIter). 28 // 29 // The files Loops.h and Loops.cuh provide functions to build kernels that 30 // use TensorIterator. 31 // 32 // Example: 33 // 34 // auto iter = TensorIteratorConfig() 35 // .add_output(output) 36 // .add_input(input) 37 // .build() 38 // 39 // [MyKernel.cpp / MyKernel.cu] 40 // cpu_kernel(iter, [](float a, float b) { 41 // return a + b; 42 // }); 43 // 44 // gpu_kernel(iter, []GPU_LAMBDA(float a, float b) -> float { 45 // return a + b; 46 // }); 47 // 48 // Note [Order of Construction] 49 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 50 // When setting up the tensor iterator configuration, the output Tensors 51 // have to be added first via 52 // TensorIteratorConfig::add_owned_output(at::Tensor). After adding all outputs, 53 // the inputs can be added via 54 // TensorIteratorConfig::add_owned_input(at::Tensor). 55 // Adding another output after inputs have been added will rise an exception. 56 // 57 // Note [Common Dtype Computation] 58 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 59 // Some operations have a natural notion of a "common dtype" or 60 // "computation dtype" where all inputs are cast to one dtype, the 61 // operation is performed, and then the results are cast to all outputs. 62 // 63 // TensorIterator infers a common dtype if all inputs have the same dtype, 64 // and it computes one using type promotion rules on its inputs if 65 // promote_inputs_to_common_dtype_ is true. Attempting to query 66 // a common dtype otherwise will throw an exception. 67 // 68 // Note that the outputs are not considered when computing a common dtype. 69 70 namespace at { 71 72 namespace internal { 73 // This parameter is heuristically chosen to determine the minimum number of 74 // work that warrants parallelism. For example, when summing an array, it is 75 // deemed inefficient to parallelise over arrays shorter than 32768. Further, 76 // no parallel algorithm (such as parallel_reduce) should split work into 77 // smaller than GRAIN_SIZE chunks. 78 constexpr int64_t GRAIN_SIZE = 32768; 79 80 // Storage for a non-owning Tensor, without needing to include Tensor.h 81 class TORCH_API OpaqueOptionalTensorRef { 82 alignas(alignof(TensorBase)) std::array<char, sizeof(TensorBase)> data_{}; 83 84 public: 85 OpaqueOptionalTensorRef(); 86 OpaqueOptionalTensorRef(const OpaqueOptionalTensorRef&) = default; 87 OpaqueOptionalTensorRef& operator=(const OpaqueOptionalTensorRef&) = default; 88 OpaqueOptionalTensorRef(OpaqueOptionalTensorRef&&) noexcept = default; 89 OpaqueOptionalTensorRef& operator=(OpaqueOptionalTensorRef&&) noexcept = 90 default; 91 ~OpaqueOptionalTensorRef(); 92 get()93 OptionalTensorRef* get() { 94 return reinterpret_cast<OptionalTensorRef*>(data_.data()); 95 } get()96 const OptionalTensorRef* get() const { 97 return reinterpret_cast<const OptionalTensorRef*>(data_.data()); 98 } 99 100 OptionalTensorRef& operator*() { 101 return *get(); 102 } 103 const OptionalTensorRef& operator*() const { 104 return *get(); 105 } 106 OptionalTensorRef* operator->() { 107 return get(); 108 } 109 const OptionalTensorRef* operator->() const { 110 return get(); 111 } 112 113 const Tensor& getTensor() const; 114 }; 115 } // namespace internal 116 117 struct TORCH_API OperandInfo { 118 using StrideVector = SmallVector<int64_t, 6>; 119 OperandInfo() = default; OperandInfoOperandInfo120 C10_ALWAYS_INLINE explicit OperandInfo(c10::MaybeOwned<TensorBase>&& t) { 121 if (t->defined()) { 122 device = t->device(); 123 target_dtype = t->scalar_type(); 124 current_dtype = target_dtype; 125 } 126 tensor(std::move(t)); 127 validate(); 128 } 129 130 C10_ALWAYS_INLINE OperandInfo(const OperandInfo&) = default; 131 C10_ALWAYS_INLINE OperandInfo& operator=(const OperandInfo&) = default; 132 C10_ALWAYS_INLINE OperandInfo(OperandInfo&&) noexcept = default; 133 C10_ALWAYS_INLINE OperandInfo& operator=(OperandInfo&&) noexcept = default; 134 C10_ALWAYS_INLINE ~OperandInfo() = default; 135 136 /// The data pointer. This may be different from tensor->data_ptr() if the 137 /// iterator is split. 138 void* data = nullptr; 139 140 /// Stride after broadcasting. The stride is in bytes, not number of elements. 141 StrideVector stride_bytes; 142 143 /// The desired device and type for the operand. For inputs, this specifies 144 /// that the input should be converted to this type if necessary. For outputs, 145 /// this specifies which type to allocate. target_dtype and device are 146 /// initialized with the dtype and device of the tensor but during type 147 /// promotion target_dtype value can become different from tensor's dtype 148 /// also, during type promotion target_dtype and device can be set for an 149 /// undefined tensor so that tensor can be properly constructed later. 150 std::optional<Device> device = std::nullopt; 151 ScalarType target_dtype = ScalarType::Undefined; 152 // Caches dtype of the tensor, because scalar_type is an expensive operation 153 // If dtype of the tensor is changed (e.g. as a result of type promotion or in 154 // allocate_outputs), this 155 // value should be changed too. 156 ScalarType current_dtype = ScalarType::Undefined; 157 is_device_definedOperandInfo158 bool is_device_defined() const { 159 return device.has_value(); 160 } is_type_definedOperandInfo161 bool is_type_defined() const { 162 return target_dtype != ScalarType::Undefined; 163 } optionsOperandInfo164 TensorOptions options() const { 165 return TensorOptions(target_dtype).device(device); 166 } 167 168 bool is_output = false; 169 170 // will_resize is only for output tensor. 171 // 1) Functional call(like torch.add(self, other)): output tensor is 172 // undefined, and pytorch creates a new tensor by using common shape 173 // and computed stride in TensorIterator; 174 // 2) Inplace call(like torch.add_(self, other)): output tensor is same 175 // with input tensor, and can't to modify tensor's size and stride; 176 // 3) Op call with output(like torch.add(self, other, out = output)): 177 // output tensor is defined, but tensor shape maybe different with common 178 // shape. If tensor shape is not same with common shape, this output 179 // tensor will be resized by using common shape and computed stride in 180 // TensorIterator. Otherwise can't modify tensor's size and stride. 181 bool will_resize = false; 182 183 bool is_read_write = false; 184 185 bool is_const = false; 186 validateOperandInfo187 void validate() { 188 TORCH_CHECK( 189 !tensor_base_->defined() || tensor_base_->layout() == kStrided, 190 "unsupported tensor layout: ", 191 tensor_base_->layout()); 192 } 193 194 /// The tensor operand. Note that the strides, data pointer, and 195 /// other attributes may differ due to dimension reordering and 196 /// coalescing. tensorOperandInfo197 const Tensor& tensor() const { 198 return tensor_storage_.getTensor(); 199 } tensor_baseOperandInfo200 const TensorBase& tensor_base() const { 201 return *tensor_base_; 202 } 203 void tensor(c10::MaybeOwned<TensorBase>&& tensor); 204 205 // Save the original tensor operand in cases when an output is modified 206 // (e.g. if dtype is changed) original_tensorOperandInfo207 const Tensor& original_tensor() const { 208 return original_tensor_storage_.getTensor(); 209 } original_tensor_baseOperandInfo210 const TensorBase& original_tensor_base() const { 211 return *original_tensor_base_; 212 } 213 214 // Set tensor to a new value, and store the old tensor value in 215 // original_tensor Should only ever be called once for the lifetime of an 216 // operand 217 void exchange_tensor(c10::MaybeOwned<TensorBase>&& new_tensor); 218 219 // Move original_tensor back into tensor, exchange_tensor must have been 220 // called before 221 void restore_original_tensor(); 222 223 private: 224 c10::MaybeOwned<TensorBase> tensor_base_; 225 c10::MaybeOwned<TensorBase> original_tensor_base_ = 226 c10::MaybeOwned<TensorBase>::owned(std::in_place); 227 228 // We store TensorBase visibly in the header to allow inline access. 229 // However, we sometimes need a genuine `const Tensor &` for the 230 // TensorIterator API. So, we also store a non-owning `Tensor` 231 // object in these `_storage_` variables. 232 internal::OpaqueOptionalTensorRef tensor_storage_; 233 internal::OpaqueOptionalTensorRef original_tensor_storage_; 234 }; 235 236 struct SplitUntil32Bit; 237 238 enum class FastSetupType : uint8_t { 239 NONE, 240 CONTIGUOUS, 241 CHANNELS_LAST, 242 NON_OVERLAPPING_DENSE 243 }; 244 245 class TensorIteratorConfig; 246 struct TensorIterator; 247 248 struct TORCH_API TensorIteratorBase : public impl::MetaBase { 249 using DimMask = std::bitset<64>; 250 using PtrVector = SmallVector<char*, 4>; 251 using StrideVector = SmallVector<int64_t, 6>; 252 253 TensorIteratorBase(); 254 void build(TensorIteratorConfig&); 255 256 // The inner-loop function operates on the fastest moving dimension. It 257 // implements element-wise operations in terms of 1-d strided tensors. 258 // 259 // Arguments: 260 // data: data pointers for each operand (length `ntensors`) 261 // strides: stride for each operand (length `ntensors`) 262 // size: size of inner loop 263 // 264 // The `size` often matches shape[0], but may be smaller due to 265 // parallelization of the inner loop. 266 using loop2d_t = c10::function_ref< 267 void(char** data, const int64_t* strides, int64_t size0, int64_t size1)>; 268 269 using loop_subiter_t = c10::function_ref<void(TensorIteratorBase& subiter)>; 270 271 void foreach_reduced_elt(loop_subiter_t loop, bool parallelize = true); 272 ndimTensorIteratorBase273 int ndim() const { 274 return static_cast<int>(shape_.size()); 275 } shapeTensorIteratorBase276 IntArrayRef shape() const { 277 return shape_; 278 } 279 int64_t numel() const; ntensorsTensorIteratorBase280 int ntensors() const { 281 return static_cast<int>(operands_.size()); 282 } noutputsTensorIteratorBase283 int noutputs() const { 284 return num_outputs_; 285 } ninputsTensorIteratorBase286 int ninputs() const { 287 return ntensors() - noutputs(); 288 } view_offsetsTensorIteratorBase289 IntArrayRef view_offsets() const { 290 return view_offsets_; 291 } 292 293 /// number of elements in the output operand. this is the same as numel() for 294 /// operations that are not reductions. 295 int64_t num_output_elements() const; 296 297 /// number of reduced dimensions in a reduction operation 298 int num_reduce_dims() const; 299 300 /// 1-dimensional iteration and no buffering or type conversion 301 bool is_trivial_1d() const; 302 /// Reducible to 1-dimensional and all operands are contiguous 303 bool is_contiguous() const; 304 bool is_dim_reduced(int dim) const; 305 306 /// Accessors for each operand stridesTensorIteratorBase307 IntArrayRef strides(int64_t arg) const { 308 return operands_[arg].stride_bytes; 309 } 310 void* data_ptr(int64_t arg) const; 311 ScalarType dtype(int64_t arg = 0) const { 312 return operands_[arg].current_dtype; 313 } common_dtypeTensorIteratorBase314 ScalarType common_dtype() const { 315 TORCH_INTERNAL_ASSERT( 316 common_dtype_ != ScalarType::Undefined, 317 "Queried for invalid common dtype!"); 318 return common_dtype_; 319 } 320 ScalarType input_dtype(int64_t arg = 0) const { 321 return operands_[num_outputs_ + arg].current_dtype; 322 } 323 Device device(int64_t arg = 0) const { 324 return operands_[arg].device.value(); 325 } 326 c10::DeviceType device_type(int64_t arg = 0) const { 327 return device(arg).type(); 328 } element_sizeTensorIteratorBase329 int64_t element_size(int64_t arg) const { 330 return static_cast<int64_t>(elementSize(dtype(arg))); 331 } 332 bool is_scalar(int64_t arg) const; 333 bool is_cpu_scalar(int64_t arg) const; 334 tensor_baseTensorIteratorBase335 const TensorBase& tensor_base(int64_t arg) const { 336 return operands_[arg].tensor_base(); 337 } tensorTensorIteratorBase338 const Tensor& tensor(int64_t arg) const { 339 return operands_[arg].tensor(); 340 } 341 342 const TensorBase& output_base(int64_t arg = 0) const { 343 AT_ASSERT(arg < num_outputs_); 344 return tensor_base(arg); 345 } 346 347 const Tensor& output(int64_t arg = 0) const { 348 AT_ASSERT(arg < num_outputs_); 349 return tensor(arg); 350 } 351 352 const TensorBase& input_base(int64_t arg = 0) const { 353 AT_ASSERT(arg >= 0 && arg < ntensors() - num_outputs_); 354 return tensor_base(num_outputs_ + arg); 355 } 356 const Tensor& input(int64_t arg = 0) const { 357 AT_ASSERT(arg >= 0 && arg < ntensors() - num_outputs_); 358 return tensor(num_outputs_ + arg); 359 } 360 361 // Copies from temporary outputs back to the original outputs 362 // NOTE: only used on CPU 363 void cast_outputs(); 364 365 /// Removes an operand from this iterator 366 void remove_operand(int64_t arg); 367 /// Shrinks an iterated dimension 368 void narrow(int dim, int64_t start, int64_t size); 369 /// Narrows every dim after and including `start_dim` to size one. 370 void select_all_keeping_dim(int start_dim, IntArrayRef starts); 371 /// Replaces the data pointer for the operand at index `arg`. 372 /// The new pointer should have the same sizes, strides and dtype as the 373 /// original 374 void unsafe_replace_operand(int64_t arg, void* data); 375 376 /// Splits this TensorIterator into two iterators. Together they iterate over 377 /// the entire operation. Used by `with_32bit_indexing()`. 378 std::unique_ptr<TensorIterator> split(int dim); 379 380 /// Returns the dimension with the largest extent: (size[dim]-1) * stride[dim] 381 int get_dim_to_split() const; 382 383 template <typename T> scalar_valueTensorIteratorBase384 T scalar_value(int64_t arg) { 385 auto& op = operands_[arg]; 386 return c10::fetch_and_cast<T>(op.tensor_base().scalar_type(), op.data); 387 } 388 389 /// Return scalar value from original_tensor_base if it is defined. When 390 /// common_dtype is Half, casting scalar input to common_dtype might overflow. 391 /// If the scalar is aleady given in the type of Half, then return scalar 392 /// value from tensor_base. 393 template <typename T> original_scalar_valueTensorIteratorBase394 T original_scalar_value(int64_t arg) { 395 auto& original_tensor_base = operands_[arg].original_tensor_base(); 396 if (original_tensor_base.defined()) { 397 TORCH_INTERNAL_ASSERT( 398 original_tensor_base.scalar_type() != common_dtype()); 399 return c10::fetch_and_cast<T>( 400 original_tensor_base.scalar_type(), 401 original_tensor_base.const_data_ptr()); 402 } else { 403 return scalar_value<T>(arg); 404 } 405 } 406 407 private: 408 template <typename loop1d_t> loop_2d_from_1dTensorIteratorBase409 auto loop_2d_from_1d(const loop1d_t& loop) { 410 return 411 [loop, ntensor = ntensors()]( 412 char** base, const int64_t* strides, int64_t size0, int64_t size1) { 413 PtrVector data(base, base + ntensor); 414 const int64_t* outer_strides = &strides[ntensor]; 415 for (const auto i : c10::irange(size1)) { 416 if (i > 0) { 417 for (const auto arg : c10::irange(ntensor)) { 418 data[arg] += outer_strides[arg]; 419 } 420 } 421 loop(data.data(), strides, size0); 422 } 423 }; 424 } 425 426 public: 427 template < 428 typename loop1d_t, 429 std::enable_if_t< 430 std::is_convertible_v< 431 loop1d_t, 432 c10::function_ref< 433 void(char**, const int64_t* strides, int64_t size)>>, 434 int> = 0> 435 void for_each(loop1d_t loop, int64_t grain_size = at::internal::GRAIN_SIZE) { 436 for_each(loop_2d_from_1d(loop), grain_size); 437 } 438 439 void for_each(loop2d_t loop, int64_t grain_size = at::internal::GRAIN_SIZE); 440 441 void parallel_reduce(loop2d_t loop); 442 443 template < 444 typename loop1d_t, 445 std::enable_if_t< 446 std::is_convertible_v< 447 loop1d_t, 448 c10::function_ref< 449 void(char**, const int64_t* strides, int64_t size)>>, 450 int> = 0> serial_for_eachTensorIteratorBase451 void serial_for_each(loop1d_t loop, Range range) { 452 serial_for_each(loop_2d_from_1d(loop), range); 453 } 454 455 void serial_for_each(loop2d_t loop, Range range) const; 456 457 /// Create a strides array for a Tensor with shape of this iterator. The 458 /// parameter `element_size` specifies the size of Tensor's data type in 459 /// bytes (e.g. `4` for `float`) 460 StrideVector compatible_stride(int64_t element_size) const; 461 462 /// Inverts the re-ordering done by reorder_dimensions. This can only be 463 /// called *before* coalesce_dimensions() is called. 464 DimVector invert_perm(IntArrayRef input) const; 465 466 /// Reapply same re-ordering as it is done by reorder_dimensions. This can 467 /// only be called *before* coalesce_dimensions() is called. 468 DimVector apply_perm_and_mul(IntArrayRef input, int mul) const; 469 470 /// Helper functions for CPU iteration 471 StrideVector get_dim_strides(int dim) const; 472 StrideVector get_strides() const; get_inner_stridesTensorIteratorBase473 StrideVector get_inner_strides() const { 474 return get_dim_strides(0); 475 } 476 PtrVector get_base_ptrs() const; 477 478 // Helper functions for advanced stride manipulations (e.g. torch.flip) _unsafe_set_arg_stridesTensorIteratorBase479 void _unsafe_set_arg_strides(const int64_t arg, IntArrayRef strides) { 480 operands_[arg].stride_bytes = strides; 481 } _unsafe_set_arg_dataTensorIteratorBase482 void _unsafe_set_arg_data(const int64_t arg, void* data) { 483 operands_[arg].data = data; 484 } 485 486 // Helper functions for custom device, custom device can get OperandInfo and 487 // NameVector in their side. 488 const OperandInfo& operand(int arg = 0) const { 489 return operands_[arg]; 490 } 491 OperandInfo& operand(int arg = 0) { 492 return operands_[arg]; 493 } get_dim_namesTensorIteratorBase494 NameVector& get_dim_names() { 495 return names_; 496 } get_dim_namesTensorIteratorBase497 const NameVector& get_dim_names() const { 498 return names_; 499 } 500 501 /// true if the stride computation can use 32-bit arithmetic. Used by GPU 502 /// kernels 503 bool can_use_32bit_indexing() const; 504 505 /// An "iteratable" object that recursively splits this iterator into 506 /// sub-iterators that can use 32-bit indexing. 507 SplitUntil32Bit with_32bit_indexing() const; 508 509 /// If the kernel should accumulate into the output. Only relevant for CUDA 510 /// reductions. should_accumulateTensorIteratorBase511 bool should_accumulate() const { 512 return accumulate_; 513 } 514 515 /// Whether this iterator produces the actual output, 516 /// as opposed to something that will be accumulated further. Only relevant 517 /// for CUDA reductions. is_final_outputTensorIteratorBase518 bool is_final_output() const { 519 return final_output_; 520 } 521 has_contiguous_first_dimTensorIteratorBase522 bool has_contiguous_first_dim() const { 523 if (ndim() == 0) { 524 return true; 525 } 526 527 int num_tensors = ntensors(); 528 for (const auto i : c10::irange(num_tensors)) { 529 if (strides(i)[0] != element_size(i)) { 530 return false; 531 } 532 } 533 return true; 534 } 535 536 void set_output_raw_strided( 537 int64_t output_idx, 538 IntArrayRef sizes, 539 IntArrayRef strides, 540 TensorOptions options, 541 DimnameList names) override; 542 543 #define TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, maybestatic) \ 544 maybestatic void methodname( \ 545 TensorBase&& out, const TensorBase& a, const TensorBase& b) = delete; \ 546 maybestatic void methodname( \ 547 const TensorBase& out, TensorBase&& a, const TensorBase& b) = delete; \ 548 maybestatic void methodname( \ 549 const TensorBase& out, const TensorBase& a, TensorBase&& b) = delete; \ 550 maybestatic void methodname( \ 551 TensorBase&& out, TensorBase&& a, const TensorBase& b) = delete; \ 552 maybestatic void methodname( \ 553 TensorBase&& out, const TensorBase& a, TensorBase&& b) = delete; \ 554 maybestatic void methodname( \ 555 const TensorBase& out, TensorBase&& a, TensorBase&& b) = delete; \ 556 maybestatic void methodname( \ 557 TensorBase&& out, TensorBase&& a, TensorBase&& b) = delete; 558 559 #define TORCH_DISALLOW_TEMPORARIES(methodname) \ 560 TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, ) 561 562 void build_binary_float_op( 563 const TensorBase& out, 564 const TensorBase& a, 565 const TensorBase& b); 566 void build_borrowing_binary_float_op( 567 const TensorBase& out, 568 const TensorBase& a, 569 const TensorBase& b); 570 TORCH_DISALLOW_TEMPORARIES(build_borrowing_binary_float_op) 571 void build_binary_op( 572 const TensorBase& out, 573 const TensorBase& a, 574 const TensorBase& b); 575 void build_borrowing_binary_op( 576 const TensorBase& out, 577 const TensorBase& a, 578 const TensorBase& b); 579 TORCH_DISALLOW_TEMPORARIES(build_borrowing_binary_op) 580 void build_unary_float_op(const TensorBase& out, const TensorBase& a); 581 void build_borrowing_unary_float_op( 582 const TensorBase& out, 583 const TensorBase& a); 584 TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_float_op) 585 void build_unary_op(const TensorBase& out, const TensorBase& a); 586 // Odd special case needed for pow. Has to borrow the output because 587 // it's a structured kernel, but the argument is potentially a copy. 588 void build_output_borrowing_argument_owning_unary_op( 589 const TensorBase& out, 590 const TensorBase& a); 591 void build_borrowing_unary_op(const TensorBase& out, const TensorBase& a); 592 TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_op) 593 void build_borrowing_unary_force_boolean_op( 594 const TensorBase& out, 595 const TensorBase& a); 596 TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_force_boolean_op) 597 void build_comparison_op( 598 const TensorBase& out, 599 const TensorBase& a, 600 const TensorBase& b); 601 void build_borrowing_comparison_op( 602 const TensorBase& out, 603 const TensorBase& a, 604 const TensorBase& b); 605 TORCH_DISALLOW_TEMPORARIES(build_borrowing_comparison_op) 606 // Another special case: we need to own the second argument for comparison 607 // ops. 608 void build_borrowing_except_last_argument_comparison_op( 609 const TensorBase& out, 610 const TensorBase& a, 611 const TensorBase& b); 612 void build_ternary_op( 613 const TensorBase& out, 614 const TensorBase& a, 615 const TensorBase& b, 616 const TensorBase& c); 617 618 #undef TORCH_DISALLOW_TEMPORARIES 619 protected: 620 // Mutable reference as it moves tensors out of TensorIteratorConfig 621 void populate_operands(TensorIteratorConfig&); 622 void mark_outputs(); 623 void mark_resize_outputs(const TensorIteratorConfig&); 624 void compute_mem_overlaps(const TensorIteratorConfig&); 625 void compute_shape(const TensorIteratorConfig&); 626 void compute_strides(const TensorIteratorConfig&); 627 void reorder_dimensions(); 628 void permute_dimensions(IntArrayRef perm); 629 void compute_types(const TensorIteratorConfig&); 630 ScalarType compute_common_dtype(); 631 void allocate_or_resize_outputs(); 632 bool fast_set_up(const TensorIteratorConfig&); 633 FastSetupType compute_fast_setup_type(const TensorIteratorConfig&); 634 void compute_names(const TensorIteratorConfig&); 635 void propagate_names_to_outputs(); 636 void coalesce_dimensions(); 637 638 protected: 639 /// Records the "computation" shape of the output tensor. The computation 640 /// shape is different from the regular shape in a few ways: 641 /// 642 /// - The shape may be permuted (via permute_dimensions) so that we 643 /// process the dimensions in the most computationally efficient order 644 /// (rather than the logical order given to us by the users.) 645 /// - The shape may have adjacent dimensions collapsed (via 646 /// coalesce_dimensions) so that we minimize the number of 647 /// dimensions we have to explicitly iterate over. For example, 648 /// a pointwise operation on a contiguous tensor "computationally" 649 /// consists of only a single dimension. 650 /// 651 /// In other words, the computation shape is the output shape as it 652 /// actually matters for implementing the kernel, but not necessarily the 653 /// output shape that the user will see in the end. 654 /// 655 /// The lifecycle of mutations to shape_ in TensorIterator: 656 /// - declare_static_shape() sets an initial shape explicitly 657 /// provided by user, otherwise 658 /// - compute_shape() computes the true (non-computational) shape 659 /// specified by the user. 660 /// - reorder_dimensions() reorders dimensions to improve coalescing. 661 /// - coalesce_dimensions() then coalesces adjacent dimensions when 662 /// possible. 663 /// 664 /// The shape may also be further modified if we create sub-TensorIterators, 665 /// e.g., via narrow or select_all_keeping_dim. 666 DimVector shape_; 667 668 /// Temporarily records the permutation computed by reorder_dimensions. 669 /// This permutation maps the computation output dimension (dim) to 670 /// the original true output dimension (perm_[dim]). It is used by 671 /// invert_perm to undo the permutation. After coalesce_dimensions is 672 /// called, the permutation is no longer valid (as, in general, there 673 /// is no permutation that will make computation dimensions to 674 /// output dimensions); methods that manipulate perm_ are obligated 675 /// to test that !has_coalesced_dimensions 676 DimVector perm_; 677 678 /// Has coalesce_dimensions() (or any moral equivalent, e.g., fast_build()) 679 /// been called? This is SOLELY used to check validity of perm_. 680 bool has_coalesced_dimensions_ = false; 681 682 /// Whether iteration must be fixed. This disables dimension permuting and 683 /// also changes how for_each divides work among threads. 684 bool enforce_linear_iteration_ = false; 685 686 /// The index offsets into the original tensors for each dimension. 687 /// This is only non-zero when you narrow() a TensorIterator (e.g., 688 /// when you make sub-TensorIterators). 689 DimVector view_offsets_; 690 691 /// The computed names of the output tensor. Computed by compute_names() 692 NameVector names_; 693 694 /// The operands of the TensorIterator: both the inputs and outputs. The 695 /// outputs MUST come first in the operands_ list. There is always an 696 /// operand for each output of the TensorIterator, even if TensorIterator 697 /// will ultimately be responsible for allocating the output; in those 698 /// cases, tensor is simply undefined (and will be populated later 699 /// during build()). 700 /// 701 /// This list is initially populated prior to build(), but build() mutates 702 /// OperandInfo to populate more information. 703 SmallVector<OperandInfo, 4> operands_; 704 705 /// Number of outputs in operands_ (the length of the outputs prefix 706 /// in operands_). 707 int num_outputs_ = 0; 708 709 /// Whether or not all operands have the same shape and are 1d+. Having all 710 /// the same shape affects whether or not the iterator is eligible for fast 711 /// setup. 712 bool all_ops_same_shape_ = false; 713 /// Whether or not all operands are 0d, this affects type promotion 714 bool all_ops_are_scalars_ = false; 715 716 /// The "computation" dtype of TensorIterator, specifying what the dtype 717 /// we will do the internal computation in TensorIterator. Typically, 718 /// this matches the dtype of the output tensors, but not always! 719 ScalarType common_dtype_ = ScalarType::Undefined; 720 721 /// This is currently defined as kCPU, or the device of the first non-CPU 722 /// tensor argument. See TensorIteratorBase::compute_types for details. 723 Device common_device_ = kCPU; 724 725 /// Set by split(), see should_accumulate() and is_final_output() 726 bool accumulate_ = false; 727 bool final_output_ = true; 728 729 // From TensorIteratorConfig 730 bool is_reduction_ = false; 731 732 /// Set by populate_operands(), says if we're handling meta tensors 733 bool is_meta_ = false; 734 }; 735 736 struct TORCH_API TensorIterator final : public TensorIteratorBase { TensorIteratorfinal737 TensorIterator() : TensorIteratorBase() {} 738 // Slicing is OK, TensorIterator guaranteed NOT to have any fields TensorIteratorfinal739 TensorIterator(const TensorIteratorBase& iter) : TensorIteratorBase(iter) {} 740 741 #define TORCH_DISALLOW_TEMPORARIES(methodname) \ 742 TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, static) 743 744 static TensorIterator binary_float_op( 745 TensorBase& out, 746 const TensorBase& a, 747 const TensorBase& b); 748 static TensorIterator binary_op( 749 TensorBase& out, 750 const TensorBase& a, 751 const TensorBase& b); 752 static TensorIterator borrowing_binary_op( 753 const TensorBase& out, 754 const TensorBase& a, 755 const TensorBase& b); 756 TORCH_DISALLOW_TEMPORARIES(borrowing_binary_op) 757 static TensorIterator comparison_op( 758 TensorBase& out, 759 const TensorBase& a, 760 const TensorBase& b); 761 static TensorIterator unary_op(TensorBase& out, const TensorBase& a); 762 static TensorIterator unary_float_op(TensorBase& out, const TensorBase& a); 763 static TensorIterator nullary_op(TensorBase& out); 764 static TensorIterator borrowing_nullary_op(const TensorBase& out); 765 static TensorIterator borrowing_nullary_op(TensorBase&& out) = delete; 766 static TensorIterator reduce_op(TensorBase& out, const TensorBase& a); 767 static TensorIterator reduce_op( 768 TensorBase& out1, 769 TensorBase& out2, 770 const TensorBase& a); 771 #undef TORCH_DISALLOW_TEMPORARIES 772 #undef TORCH_DISALLOW_TEMPORARIES_IMPL 773 774 const Tensor& maybe_get_output(int64_t output_idx) override; 775 void set_output_raw_strided( 776 int64_t output_idx, 777 IntArrayRef sizes, 778 IntArrayRef strides, 779 TensorOptions options, 780 DimnameList names) override; 781 }; 782 783 class TORCH_API TensorIteratorConfig final { 784 public: 785 friend struct TensorIteratorBase; 786 friend struct TensorIterator; 787 788 TensorIteratorConfig() = default; 789 790 C10_DISABLE_COPY_AND_ASSIGN(TensorIteratorConfig); 791 792 /// Construction 793 // Stores input/output Tensors without incrementing the reference count. 794 // Important: the outputs have to be added before the inputs. add_output(const TensorBase & output)795 TensorIteratorConfig& add_output(const TensorBase& output) { 796 return add_borrowed_output(output); 797 } add_input(const TensorBase & input)798 TensorIteratorConfig& add_input(const TensorBase& input) { 799 return add_borrowed_input(input); 800 } add_const_input(const TensorBase & input)801 TensorIteratorConfig& add_const_input(const TensorBase& input) { 802 return add_borrowed_const_input(input); 803 } 804 805 // Borrowing from temporaries is unlikely to go well. 806 TensorIteratorConfig& add_output(TensorBase&& output) = delete; 807 TensorIteratorConfig& add_input(TensorBase&& input) = delete; 808 TensorIteratorConfig& add_const_input(TensorBase&& input) = delete; 809 810 // Stores input/output Tensors while incrementing the reference count. 811 // Note that add_{in,out}put are nearly always what you 812 // want, and the exception (adding an unnamed temporary) won't 813 // compile. 814 TensorIteratorConfig& add_owned_output(const TensorBase& output); 815 TensorIteratorConfig& add_owned_input(const TensorBase& input); 816 TensorIteratorConfig& add_owned_const_input(const TensorBase& input); 817 818 // Advanced API: stores input/output Tensors without incrementing 819 // the reference count. The caller must ensure that these Tensors 820 // live at least as long as this TensorIteratorConfig and any 821 // TensorIteratorBase built from this TensorIteratorConfig. 822 // Important: the outputs have to be added before the inputs. 823 TensorIteratorConfig& add_borrowed_output(const TensorBase& output); 824 TensorIteratorConfig& add_borrowed_input(const TensorBase& input); 825 TensorIteratorConfig& add_borrowed_const_input(const TensorBase& input); 826 827 // Borrowing from temporaries is unlikely to go well. 828 TensorIteratorConfig& add_borrowed_output(TensorBase&& output) = delete; 829 TensorIteratorConfig& add_borrowed_input(TensorBase&& input) = delete; 830 TensorIteratorConfig& add_borrowed_const_input(TensorBase&& input) = delete; 831 832 // Sets the check_mem_overlap_ flag, which is true by default. 833 // If true, inputs are checked for partial overlap with the outputs and 834 // outputs are checked for internal overlap (e.g. broadcasted views). An error 835 // is raised if unacceptable overlap is detected. 836 // If you're migrating an existing operator to using TensorIterator, please 837 // consider if the previous implementation checked memory overlap. If it did 838 // not, and if the operator is idempotent (for example, Tensor.fill_(0)), then 839 // checking memory overlap is BC-breaking. Please don't check memory overlap 840 // in that case. set_check_mem_overlap(bool check_mem_overlap)841 TensorIteratorConfig& set_check_mem_overlap(bool check_mem_overlap) { 842 check_mem_overlap_ = check_mem_overlap; 843 return *this; 844 } 845 846 // Sets the check_all_same_dtype_ flag, which is true by default 847 // If true, checks that all inputs and defined outputs have the same dtype 848 // Setting either of promote_inputs_to_common_dtype_ 849 // or cast_common_dtype_to_outputs_ to true will set 850 // check_all_same_dtype_ to false. check_all_same_dtype(const bool _check_all_same_dtype)851 TensorIteratorConfig& check_all_same_dtype(const bool _check_all_same_dtype) { 852 check_all_same_dtype_ = _check_all_same_dtype; 853 return *this; 854 } 855 856 // Sets the check_all_same_device_ flag, which is true by default 857 // If true, all operands must be on the same device, with the possible 858 // exception of CPU scalars, which can be passed to some CUDA kernels 859 // as kernel arguments. check_all_same_device(const bool _check_all_same_device)860 TensorIteratorConfig& check_all_same_device( 861 const bool _check_all_same_device) { 862 check_all_same_device_ = _check_all_same_device; 863 return *this; 864 } 865 866 // Sets the enforce_safe_casting_to_output_ flag, which is false by default 867 // If true, the iterator's "common dtype" must be computable 868 // (see the [Common Dtype Computation] note) and 869 // canCast(common dtype, output dtype) must be true for all outputs. enforce_safe_casting_to_output(const bool _enforce_safe_casting_to_output)870 TensorIteratorConfig& enforce_safe_casting_to_output( 871 const bool _enforce_safe_casting_to_output) { 872 enforce_safe_casting_to_output_ = _enforce_safe_casting_to_output; 873 return *this; 874 } 875 876 // Sets the enforce_linear_iteration_ flag, which is false by default. 877 // If true, iteration goes in the same order as a C-contiguous tensor 878 // is layed out in memory. i.e. last dimension iterates fastest. 879 // 880 // This iteration order can be less efficient and may even prevent 881 // vectorization. So only use if the correctness of your kernel depends on it. 882 TensorIteratorConfig& enforce_linear_iteration( 883 const bool _enforce_linear_iteration = true) { 884 enforce_linear_iteration_ = _enforce_linear_iteration; 885 return *this; 886 } 887 888 // Sets the promote_inputs_to_common_dtype_ flag, which is false by default 889 // If true, the iterator's "common dtype" is always computed (see the 890 // [Common Dtype Computation] note) and, on the CPU, temporary copies of 891 // the inputs in the common dtype are passed as the actual inputs to 892 // the operation. 893 // Setting this flag to true sets check_all_same_dtype_ to false. promote_inputs_to_common_dtype(const bool _promote_inputs_to_common_dtype)894 TensorIteratorConfig& promote_inputs_to_common_dtype( 895 const bool _promote_inputs_to_common_dtype) { 896 promote_inputs_to_common_dtype_ = _promote_inputs_to_common_dtype; 897 if (_promote_inputs_to_common_dtype) { 898 check_all_same_dtype_ = false; 899 } 900 return *this; 901 } 902 903 // Sets the promote_integer_inputs_to_float_ flag, which is false by default 904 // NOTE: If set to true, the promote_inputs_to_common_dtype_ must also be 905 // true. If true, if the iterator's "common dtype" is an integral type 906 // (including bool) 907 // then it is changed to the default float scalar type. promote_integer_inputs_to_float(const bool _promote_integer_inputs_to_float)908 TensorIteratorConfig& promote_integer_inputs_to_float( 909 const bool _promote_integer_inputs_to_float) { 910 promote_integer_inputs_to_float_ = _promote_integer_inputs_to_float; 911 TORCH_INTERNAL_ASSERT( 912 !promote_integer_inputs_to_float_ || promote_inputs_to_common_dtype_); 913 return *this; 914 } 915 is_reduction(const bool _is_reduction)916 TensorIteratorConfig& is_reduction(const bool _is_reduction) { 917 is_reduction_ = _is_reduction; 918 return *this; 919 } 920 allow_cpu_scalars(const bool _allow_cpu_scalars)921 TensorIteratorConfig& allow_cpu_scalars(const bool _allow_cpu_scalars) { 922 allow_cpu_scalars_ = _allow_cpu_scalars; 923 return *this; 924 } 925 926 // Sets the cast_common_dtype_to_outputs_ flag, which is false by default 927 // If true, the iterator's "common dtype" must be computatable 928 // (see the [Common Dtype Computation] note) and, on the CPU, temporary 929 // copies of the outputs are passed as the actual output to the operation. 930 // These temporaries are then copied to the original outputs after 931 // the operation is performed (see cast_outputs()). 932 // Setting this flag to true sets check_all_same_dtype_ to false. cast_common_dtype_to_outputs(const bool _cast_common_dtype_to_outputs)933 TensorIteratorConfig& cast_common_dtype_to_outputs( 934 const bool _cast_common_dtype_to_outputs) { 935 cast_common_dtype_to_outputs_ = _cast_common_dtype_to_outputs; 936 if (_cast_common_dtype_to_outputs) { 937 check_all_same_dtype_ = false; 938 } 939 return *this; 940 } 941 resize_outputs(bool resize_outputs)942 TensorIteratorConfig& resize_outputs(bool resize_outputs) { 943 resize_outputs_ = resize_outputs; 944 return *this; 945 } 946 947 // Bypass output dtype/device computation and fix the dtype/device as 948 // specified here. 949 TensorIteratorConfig& declare_static_dtype_and_device( 950 ScalarType dtype, 951 Device device); 952 TensorIteratorConfig& declare_static_dtype(ScalarType dtype); 953 TensorIteratorConfig& declare_static_device(Device device); 954 TensorIteratorConfig& declare_static_shape(IntArrayRef shape); 955 TensorIteratorConfig& declare_static_shape( 956 IntArrayRef shape, 957 IntArrayRef squash_dims); 958 959 // It would be better if this was && qualified, but this would be at the cost 960 // of a lot of boilerplate above build()961 TensorIterator build() { 962 TensorIterator iter; 963 iter.build(*this); 964 return iter; 965 } 966 967 private: 968 bool is_tensor_const(size_t idx); 969 970 SmallVector<c10::MaybeOwned<TensorBase>, 4> tensors_; 971 int num_outputs_ = 0; 972 int num_inputs_ = 0; 973 974 std::optional<DimVector> static_shape_ = std::nullopt; 975 std::optional<ScalarType> static_dtype_ = std::nullopt; 976 std::optional<Device> static_device_ = std::nullopt; 977 bool check_mem_overlap_ = true; 978 bool allow_cpu_scalars_ = false; 979 bool is_reduction_ = false; 980 bool resize_outputs_ = true; 981 bool check_all_same_dtype_ = true; 982 bool check_all_same_device_ = true; 983 bool enforce_safe_casting_to_output_ = false; 984 bool enforce_linear_iteration_ = false; 985 bool promote_inputs_to_common_dtype_ = false; 986 bool promote_integer_inputs_to_float_ = false; 987 bool cast_common_dtype_to_outputs_ = false; 988 989 SmallVector<size_t, 4> const_tensor_indices_; 990 }; 991 992 /// A container-like struct that acts as if it contains splits of a 993 /// TensorIterator that can use 32-bit indexing. Taken together the splits cover 994 /// the original TensorIterator. 995 struct TORCH_API SplitUntil32Bit { 996 struct TORCH_API iterator { 997 iterator() = default; 998 iterator(const TensorIteratorBase& iter); 999 iterator(iterator&&) = default; 1000 1001 // Guaranteed to be a TensorIterator proper! 1002 TensorIterator& operator*() const; 1003 iterator& operator++(); 1004 bool operator==(const iterator& other) const { 1005 // two iterators are equal if they are the same object or they're both 1006 // empty 1007 return this == &other || (vec.empty() && other.vec.empty()); 1008 } 1009 // needed for C++11 range-based for loop 1010 bool operator!=(const iterator& other) const { 1011 return !(*this == other); 1012 } 1013 1014 /// stack of TensorIterators to be split 1015 std::vector<std::unique_ptr<TensorIterator>> vec; 1016 }; 1017 SplitUntil32BitSplitUntil32Bit1018 SplitUntil32Bit(const TensorIteratorBase& iter) : iter(iter) {} 1019 1020 iterator begin() const; 1021 iterator end() const; 1022 1023 private: 1024 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) 1025 const TensorIteratorBase& iter; 1026 }; 1027 1028 } // namespace at 1029