xref: /aosp_15_r20/external/pytorch/aten/src/ATen/TensorIterator.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/TensorMeta.h>
4 #include <ATen/core/Dimname.h>
5 #include <ATen/core/Range.h>
6 #include <ATen/core/TensorBase.h>
7 #include <c10/core/DynamicCast.h>
8 #include <c10/util/FunctionRef.h>
9 #include <c10/util/MaybeOwned.h>
10 #include <c10/util/SmallVector.h>
11 #include <c10/util/TypeCast.h>
12 #include <c10/util/irange.h>
13 
14 #include <array>
15 #include <bitset>
16 
17 namespace at {
18 class Tensor;
19 class OptionalTensorRef;
20 using NameVector = SmallVector<Dimname, kDimVectorStaticSize>;
21 } // namespace at
22 
23 // TensorIterator is a helper class for element-wise operations, such as
24 // arithmetic, comparisons, and trigonometric functions. It handles
25 // broadcasting and type conversions of operands.
26 //
27 // This is inspired by NumPy's Array Iterator API (NpyIter).
28 //
29 // The files Loops.h and Loops.cuh provide functions to build kernels that
30 // use TensorIterator.
31 //
32 // Example:
33 //
34 //   auto iter = TensorIteratorConfig()
35 //     .add_output(output)
36 //     .add_input(input)
37 //     .build()
38 //
39 // [MyKernel.cpp / MyKernel.cu]
40 //   cpu_kernel(iter, [](float a, float b) {
41 //     return a + b;
42 //   });
43 //
44 //   gpu_kernel(iter, []GPU_LAMBDA(float a, float b) -> float {
45 //     return a + b;
46 //   });
47 //
48 // Note [Order of Construction]
49 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
50 // When setting up the tensor iterator configuration, the output Tensors
51 // have to be added first via
52 // TensorIteratorConfig::add_owned_output(at::Tensor). After adding all outputs,
53 // the inputs can be added via
54 // TensorIteratorConfig::add_owned_input(at::Tensor).
55 // Adding another output after inputs have been added will rise an exception.
56 //
57 // Note [Common Dtype Computation]
58 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
59 // Some operations have a natural notion of a "common dtype" or
60 //   "computation dtype" where all inputs are cast to one dtype, the
61 //   operation is performed, and then the results are cast to all outputs.
62 //
63 // TensorIterator infers a common dtype if all inputs have the same dtype,
64 //   and it computes one using type promotion rules on its inputs if
65 //   promote_inputs_to_common_dtype_ is true. Attempting to query
66 //   a common dtype otherwise will throw an exception.
67 //
68 // Note that the outputs are not considered when computing a common dtype.
69 
70 namespace at {
71 
72 namespace internal {
73 // This parameter is heuristically chosen to determine the minimum number of
74 // work that warrants parallelism. For example, when summing an array, it is
75 // deemed inefficient to parallelise over arrays shorter than 32768. Further,
76 // no parallel algorithm (such as parallel_reduce) should split work into
77 // smaller than GRAIN_SIZE chunks.
78 constexpr int64_t GRAIN_SIZE = 32768;
79 
80 // Storage for a non-owning Tensor, without needing to include Tensor.h
81 class TORCH_API OpaqueOptionalTensorRef {
82   alignas(alignof(TensorBase)) std::array<char, sizeof(TensorBase)> data_{};
83 
84  public:
85   OpaqueOptionalTensorRef();
86   OpaqueOptionalTensorRef(const OpaqueOptionalTensorRef&) = default;
87   OpaqueOptionalTensorRef& operator=(const OpaqueOptionalTensorRef&) = default;
88   OpaqueOptionalTensorRef(OpaqueOptionalTensorRef&&) noexcept = default;
89   OpaqueOptionalTensorRef& operator=(OpaqueOptionalTensorRef&&) noexcept =
90       default;
91   ~OpaqueOptionalTensorRef();
92 
get()93   OptionalTensorRef* get() {
94     return reinterpret_cast<OptionalTensorRef*>(data_.data());
95   }
get()96   const OptionalTensorRef* get() const {
97     return reinterpret_cast<const OptionalTensorRef*>(data_.data());
98   }
99 
100   OptionalTensorRef& operator*() {
101     return *get();
102   }
103   const OptionalTensorRef& operator*() const {
104     return *get();
105   }
106   OptionalTensorRef* operator->() {
107     return get();
108   }
109   const OptionalTensorRef* operator->() const {
110     return get();
111   }
112 
113   const Tensor& getTensor() const;
114 };
115 } // namespace internal
116 
117 struct TORCH_API OperandInfo {
118   using StrideVector = SmallVector<int64_t, 6>;
119   OperandInfo() = default;
OperandInfoOperandInfo120   C10_ALWAYS_INLINE explicit OperandInfo(c10::MaybeOwned<TensorBase>&& t) {
121     if (t->defined()) {
122       device = t->device();
123       target_dtype = t->scalar_type();
124       current_dtype = target_dtype;
125     }
126     tensor(std::move(t));
127     validate();
128   }
129 
130   C10_ALWAYS_INLINE OperandInfo(const OperandInfo&) = default;
131   C10_ALWAYS_INLINE OperandInfo& operator=(const OperandInfo&) = default;
132   C10_ALWAYS_INLINE OperandInfo(OperandInfo&&) noexcept = default;
133   C10_ALWAYS_INLINE OperandInfo& operator=(OperandInfo&&) noexcept = default;
134   C10_ALWAYS_INLINE ~OperandInfo() = default;
135 
136   /// The data pointer. This may be different from tensor->data_ptr() if the
137   /// iterator is split.
138   void* data = nullptr;
139 
140   /// Stride after broadcasting. The stride is in bytes, not number of elements.
141   StrideVector stride_bytes;
142 
143   /// The desired device and type for the operand. For inputs, this specifies
144   /// that the input should be converted to this type if necessary. For outputs,
145   /// this specifies which type to allocate. target_dtype and device are
146   /// initialized with the dtype and device of the tensor but during type
147   /// promotion target_dtype value can become different from tensor's dtype
148   /// also, during type promotion target_dtype and device can be set for an
149   /// undefined tensor so that tensor can be properly constructed later.
150   std::optional<Device> device = std::nullopt;
151   ScalarType target_dtype = ScalarType::Undefined;
152   // Caches dtype of the tensor, because scalar_type is an expensive operation
153   // If dtype of the tensor is changed (e.g. as a result of type promotion or in
154   // allocate_outputs), this
155   // value should be changed too.
156   ScalarType current_dtype = ScalarType::Undefined;
157 
is_device_definedOperandInfo158   bool is_device_defined() const {
159     return device.has_value();
160   }
is_type_definedOperandInfo161   bool is_type_defined() const {
162     return target_dtype != ScalarType::Undefined;
163   }
optionsOperandInfo164   TensorOptions options() const {
165     return TensorOptions(target_dtype).device(device);
166   }
167 
168   bool is_output = false;
169 
170   // will_resize is only for output tensor.
171   // 1) Functional call(like torch.add(self, other)): output tensor is
172   //    undefined, and pytorch creates a new tensor by using common shape
173   //    and computed stride in TensorIterator;
174   // 2) Inplace call(like torch.add_(self, other)): output tensor is same
175   //    with input tensor, and can't to modify tensor's size and stride;
176   // 3) Op call with output(like torch.add(self, other, out = output)):
177   //    output tensor is defined, but tensor shape maybe different with common
178   //    shape. If tensor shape is not same with common shape, this output
179   //    tensor will be resized by using common shape and computed stride in
180   //    TensorIterator. Otherwise can't modify tensor's size and stride.
181   bool will_resize = false;
182 
183   bool is_read_write = false;
184 
185   bool is_const = false;
186 
validateOperandInfo187   void validate() {
188     TORCH_CHECK(
189         !tensor_base_->defined() || tensor_base_->layout() == kStrided,
190         "unsupported tensor layout: ",
191         tensor_base_->layout());
192   }
193 
194   /// The tensor operand. Note that the strides, data pointer, and
195   /// other attributes may differ due to dimension reordering and
196   /// coalescing.
tensorOperandInfo197   const Tensor& tensor() const {
198     return tensor_storage_.getTensor();
199   }
tensor_baseOperandInfo200   const TensorBase& tensor_base() const {
201     return *tensor_base_;
202   }
203   void tensor(c10::MaybeOwned<TensorBase>&& tensor);
204 
205   // Save the original tensor operand in cases when an output is modified
206   // (e.g. if dtype is changed)
original_tensorOperandInfo207   const Tensor& original_tensor() const {
208     return original_tensor_storage_.getTensor();
209   }
original_tensor_baseOperandInfo210   const TensorBase& original_tensor_base() const {
211     return *original_tensor_base_;
212   }
213 
214   // Set tensor to a new value, and store the old tensor value in
215   // original_tensor Should only ever be called once for the lifetime of an
216   // operand
217   void exchange_tensor(c10::MaybeOwned<TensorBase>&& new_tensor);
218 
219   // Move original_tensor back into tensor, exchange_tensor must have been
220   // called before
221   void restore_original_tensor();
222 
223  private:
224   c10::MaybeOwned<TensorBase> tensor_base_;
225   c10::MaybeOwned<TensorBase> original_tensor_base_ =
226       c10::MaybeOwned<TensorBase>::owned(std::in_place);
227 
228   // We store TensorBase visibly in the header to allow inline access.
229   // However, we sometimes need a genuine `const Tensor &` for the
230   // TensorIterator API. So, we also store a non-owning `Tensor`
231   // object in these `_storage_` variables.
232   internal::OpaqueOptionalTensorRef tensor_storage_;
233   internal::OpaqueOptionalTensorRef original_tensor_storage_;
234 };
235 
236 struct SplitUntil32Bit;
237 
238 enum class FastSetupType : uint8_t {
239   NONE,
240   CONTIGUOUS,
241   CHANNELS_LAST,
242   NON_OVERLAPPING_DENSE
243 };
244 
245 class TensorIteratorConfig;
246 struct TensorIterator;
247 
248 struct TORCH_API TensorIteratorBase : public impl::MetaBase {
249   using DimMask = std::bitset<64>;
250   using PtrVector = SmallVector<char*, 4>;
251   using StrideVector = SmallVector<int64_t, 6>;
252 
253   TensorIteratorBase();
254   void build(TensorIteratorConfig&);
255 
256   // The inner-loop function operates on the fastest moving dimension. It
257   // implements element-wise operations in terms of 1-d strided tensors.
258   //
259   // Arguments:
260   //  data: data pointers for each operand (length `ntensors`)
261   //  strides: stride for each operand (length `ntensors`)
262   //  size: size of inner loop
263   //
264   // The `size` often matches shape[0], but may be smaller due to
265   // parallelization of the inner loop.
266   using loop2d_t = c10::function_ref<
267       void(char** data, const int64_t* strides, int64_t size0, int64_t size1)>;
268 
269   using loop_subiter_t = c10::function_ref<void(TensorIteratorBase& subiter)>;
270 
271   void foreach_reduced_elt(loop_subiter_t loop, bool parallelize = true);
272 
ndimTensorIteratorBase273   int ndim() const {
274     return static_cast<int>(shape_.size());
275   }
shapeTensorIteratorBase276   IntArrayRef shape() const {
277     return shape_;
278   }
279   int64_t numel() const;
ntensorsTensorIteratorBase280   int ntensors() const {
281     return static_cast<int>(operands_.size());
282   }
noutputsTensorIteratorBase283   int noutputs() const {
284     return num_outputs_;
285   }
ninputsTensorIteratorBase286   int ninputs() const {
287     return ntensors() - noutputs();
288   }
view_offsetsTensorIteratorBase289   IntArrayRef view_offsets() const {
290     return view_offsets_;
291   }
292 
293   /// number of elements in the output operand. this is the same as numel() for
294   /// operations that are not reductions.
295   int64_t num_output_elements() const;
296 
297   /// number of reduced dimensions in a reduction operation
298   int num_reduce_dims() const;
299 
300   /// 1-dimensional iteration and no buffering or type conversion
301   bool is_trivial_1d() const;
302   /// Reducible to 1-dimensional and all operands are contiguous
303   bool is_contiguous() const;
304   bool is_dim_reduced(int dim) const;
305 
306   /// Accessors for each operand
stridesTensorIteratorBase307   IntArrayRef strides(int64_t arg) const {
308     return operands_[arg].stride_bytes;
309   }
310   void* data_ptr(int64_t arg) const;
311   ScalarType dtype(int64_t arg = 0) const {
312     return operands_[arg].current_dtype;
313   }
common_dtypeTensorIteratorBase314   ScalarType common_dtype() const {
315     TORCH_INTERNAL_ASSERT(
316         common_dtype_ != ScalarType::Undefined,
317         "Queried for invalid common dtype!");
318     return common_dtype_;
319   }
320   ScalarType input_dtype(int64_t arg = 0) const {
321     return operands_[num_outputs_ + arg].current_dtype;
322   }
323   Device device(int64_t arg = 0) const {
324     return operands_[arg].device.value();
325   }
326   c10::DeviceType device_type(int64_t arg = 0) const {
327     return device(arg).type();
328   }
element_sizeTensorIteratorBase329   int64_t element_size(int64_t arg) const {
330     return static_cast<int64_t>(elementSize(dtype(arg)));
331   }
332   bool is_scalar(int64_t arg) const;
333   bool is_cpu_scalar(int64_t arg) const;
334 
tensor_baseTensorIteratorBase335   const TensorBase& tensor_base(int64_t arg) const {
336     return operands_[arg].tensor_base();
337   }
tensorTensorIteratorBase338   const Tensor& tensor(int64_t arg) const {
339     return operands_[arg].tensor();
340   }
341 
342   const TensorBase& output_base(int64_t arg = 0) const {
343     AT_ASSERT(arg < num_outputs_);
344     return tensor_base(arg);
345   }
346 
347   const Tensor& output(int64_t arg = 0) const {
348     AT_ASSERT(arg < num_outputs_);
349     return tensor(arg);
350   }
351 
352   const TensorBase& input_base(int64_t arg = 0) const {
353     AT_ASSERT(arg >= 0 && arg < ntensors() - num_outputs_);
354     return tensor_base(num_outputs_ + arg);
355   }
356   const Tensor& input(int64_t arg = 0) const {
357     AT_ASSERT(arg >= 0 && arg < ntensors() - num_outputs_);
358     return tensor(num_outputs_ + arg);
359   }
360 
361   // Copies from temporary outputs back to the original outputs
362   // NOTE: only used on CPU
363   void cast_outputs();
364 
365   /// Removes an operand from this iterator
366   void remove_operand(int64_t arg);
367   /// Shrinks an iterated dimension
368   void narrow(int dim, int64_t start, int64_t size);
369   /// Narrows every dim after and including `start_dim` to size one.
370   void select_all_keeping_dim(int start_dim, IntArrayRef starts);
371   /// Replaces the data pointer for the operand at index `arg`.
372   /// The new pointer should have the same sizes, strides and dtype as the
373   /// original
374   void unsafe_replace_operand(int64_t arg, void* data);
375 
376   /// Splits this TensorIterator into two iterators. Together they iterate over
377   /// the entire operation. Used by `with_32bit_indexing()`.
378   std::unique_ptr<TensorIterator> split(int dim);
379 
380   /// Returns the dimension with the largest extent: (size[dim]-1) * stride[dim]
381   int get_dim_to_split() const;
382 
383   template <typename T>
scalar_valueTensorIteratorBase384   T scalar_value(int64_t arg) {
385     auto& op = operands_[arg];
386     return c10::fetch_and_cast<T>(op.tensor_base().scalar_type(), op.data);
387   }
388 
389   /// Return scalar value from original_tensor_base if it is defined. When
390   /// common_dtype is Half, casting scalar input to common_dtype might overflow.
391   /// If the scalar is aleady given in the type of Half, then return scalar
392   /// value from tensor_base.
393   template <typename T>
original_scalar_valueTensorIteratorBase394   T original_scalar_value(int64_t arg) {
395     auto& original_tensor_base = operands_[arg].original_tensor_base();
396     if (original_tensor_base.defined()) {
397       TORCH_INTERNAL_ASSERT(
398           original_tensor_base.scalar_type() != common_dtype());
399       return c10::fetch_and_cast<T>(
400           original_tensor_base.scalar_type(),
401           original_tensor_base.const_data_ptr());
402     } else {
403       return scalar_value<T>(arg);
404     }
405   }
406 
407  private:
408   template <typename loop1d_t>
loop_2d_from_1dTensorIteratorBase409   auto loop_2d_from_1d(const loop1d_t& loop) {
410     return
411         [loop, ntensor = ntensors()](
412             char** base, const int64_t* strides, int64_t size0, int64_t size1) {
413           PtrVector data(base, base + ntensor);
414           const int64_t* outer_strides = &strides[ntensor];
415           for (const auto i : c10::irange(size1)) {
416             if (i > 0) {
417               for (const auto arg : c10::irange(ntensor)) {
418                 data[arg] += outer_strides[arg];
419               }
420             }
421             loop(data.data(), strides, size0);
422           }
423         };
424   }
425 
426  public:
427   template <
428       typename loop1d_t,
429       std::enable_if_t<
430           std::is_convertible_v<
431               loop1d_t,
432               c10::function_ref<
433                   void(char**, const int64_t* strides, int64_t size)>>,
434           int> = 0>
435   void for_each(loop1d_t loop, int64_t grain_size = at::internal::GRAIN_SIZE) {
436     for_each(loop_2d_from_1d(loop), grain_size);
437   }
438 
439   void for_each(loop2d_t loop, int64_t grain_size = at::internal::GRAIN_SIZE);
440 
441   void parallel_reduce(loop2d_t loop);
442 
443   template <
444       typename loop1d_t,
445       std::enable_if_t<
446           std::is_convertible_v<
447               loop1d_t,
448               c10::function_ref<
449                   void(char**, const int64_t* strides, int64_t size)>>,
450           int> = 0>
serial_for_eachTensorIteratorBase451   void serial_for_each(loop1d_t loop, Range range) {
452     serial_for_each(loop_2d_from_1d(loop), range);
453   }
454 
455   void serial_for_each(loop2d_t loop, Range range) const;
456 
457   /// Create a strides array for a Tensor with shape of this iterator. The
458   /// parameter `element_size` specifies the size of Tensor's data type in
459   /// bytes (e.g. `4` for `float`)
460   StrideVector compatible_stride(int64_t element_size) const;
461 
462   /// Inverts the re-ordering done by reorder_dimensions. This can only be
463   /// called *before* coalesce_dimensions() is called.
464   DimVector invert_perm(IntArrayRef input) const;
465 
466   /// Reapply same re-ordering as it is done by reorder_dimensions. This can
467   /// only be called *before* coalesce_dimensions() is called.
468   DimVector apply_perm_and_mul(IntArrayRef input, int mul) const;
469 
470   /// Helper functions for CPU iteration
471   StrideVector get_dim_strides(int dim) const;
472   StrideVector get_strides() const;
get_inner_stridesTensorIteratorBase473   StrideVector get_inner_strides() const {
474     return get_dim_strides(0);
475   }
476   PtrVector get_base_ptrs() const;
477 
478   // Helper functions for advanced stride manipulations (e.g. torch.flip)
_unsafe_set_arg_stridesTensorIteratorBase479   void _unsafe_set_arg_strides(const int64_t arg, IntArrayRef strides) {
480     operands_[arg].stride_bytes = strides;
481   }
_unsafe_set_arg_dataTensorIteratorBase482   void _unsafe_set_arg_data(const int64_t arg, void* data) {
483     operands_[arg].data = data;
484   }
485 
486   // Helper functions for custom device, custom device can get OperandInfo and
487   // NameVector in their side.
488   const OperandInfo& operand(int arg = 0) const {
489     return operands_[arg];
490   }
491   OperandInfo& operand(int arg = 0) {
492     return operands_[arg];
493   }
get_dim_namesTensorIteratorBase494   NameVector& get_dim_names() {
495     return names_;
496   }
get_dim_namesTensorIteratorBase497   const NameVector& get_dim_names() const {
498     return names_;
499   }
500 
501   /// true if the stride computation can use 32-bit arithmetic. Used by GPU
502   /// kernels
503   bool can_use_32bit_indexing() const;
504 
505   /// An "iteratable" object that recursively splits this iterator into
506   /// sub-iterators that can use 32-bit indexing.
507   SplitUntil32Bit with_32bit_indexing() const;
508 
509   /// If the kernel should accumulate into the output. Only relevant for CUDA
510   /// reductions.
should_accumulateTensorIteratorBase511   bool should_accumulate() const {
512     return accumulate_;
513   }
514 
515   /// Whether this iterator produces the actual output,
516   /// as opposed to something that will be accumulated further. Only relevant
517   /// for CUDA reductions.
is_final_outputTensorIteratorBase518   bool is_final_output() const {
519     return final_output_;
520   }
521 
has_contiguous_first_dimTensorIteratorBase522   bool has_contiguous_first_dim() const {
523     if (ndim() == 0) {
524       return true;
525     }
526 
527     int num_tensors = ntensors();
528     for (const auto i : c10::irange(num_tensors)) {
529       if (strides(i)[0] != element_size(i)) {
530         return false;
531       }
532     }
533     return true;
534   }
535 
536   void set_output_raw_strided(
537       int64_t output_idx,
538       IntArrayRef sizes,
539       IntArrayRef strides,
540       TensorOptions options,
541       DimnameList names) override;
542 
543 #define TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, maybestatic)            \
544   maybestatic void methodname(                                              \
545       TensorBase&& out, const TensorBase& a, const TensorBase& b) = delete; \
546   maybestatic void methodname(                                              \
547       const TensorBase& out, TensorBase&& a, const TensorBase& b) = delete; \
548   maybestatic void methodname(                                              \
549       const TensorBase& out, const TensorBase& a, TensorBase&& b) = delete; \
550   maybestatic void methodname(                                              \
551       TensorBase&& out, TensorBase&& a, const TensorBase& b) = delete;      \
552   maybestatic void methodname(                                              \
553       TensorBase&& out, const TensorBase& a, TensorBase&& b) = delete;      \
554   maybestatic void methodname(                                              \
555       const TensorBase& out, TensorBase&& a, TensorBase&& b) = delete;      \
556   maybestatic void methodname(                                              \
557       TensorBase&& out, TensorBase&& a, TensorBase&& b) = delete;
558 
559 #define TORCH_DISALLOW_TEMPORARIES(methodname) \
560   TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, )
561 
562   void build_binary_float_op(
563       const TensorBase& out,
564       const TensorBase& a,
565       const TensorBase& b);
566   void build_borrowing_binary_float_op(
567       const TensorBase& out,
568       const TensorBase& a,
569       const TensorBase& b);
570   TORCH_DISALLOW_TEMPORARIES(build_borrowing_binary_float_op)
571   void build_binary_op(
572       const TensorBase& out,
573       const TensorBase& a,
574       const TensorBase& b);
575   void build_borrowing_binary_op(
576       const TensorBase& out,
577       const TensorBase& a,
578       const TensorBase& b);
579   TORCH_DISALLOW_TEMPORARIES(build_borrowing_binary_op)
580   void build_unary_float_op(const TensorBase& out, const TensorBase& a);
581   void build_borrowing_unary_float_op(
582       const TensorBase& out,
583       const TensorBase& a);
584   TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_float_op)
585   void build_unary_op(const TensorBase& out, const TensorBase& a);
586   // Odd special case needed for pow. Has to borrow the output because
587   // it's a structured kernel, but the argument is potentially a copy.
588   void build_output_borrowing_argument_owning_unary_op(
589       const TensorBase& out,
590       const TensorBase& a);
591   void build_borrowing_unary_op(const TensorBase& out, const TensorBase& a);
592   TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_op)
593   void build_borrowing_unary_force_boolean_op(
594       const TensorBase& out,
595       const TensorBase& a);
596   TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_force_boolean_op)
597   void build_comparison_op(
598       const TensorBase& out,
599       const TensorBase& a,
600       const TensorBase& b);
601   void build_borrowing_comparison_op(
602       const TensorBase& out,
603       const TensorBase& a,
604       const TensorBase& b);
605   TORCH_DISALLOW_TEMPORARIES(build_borrowing_comparison_op)
606   // Another special case: we need to own the second argument for comparison
607   // ops.
608   void build_borrowing_except_last_argument_comparison_op(
609       const TensorBase& out,
610       const TensorBase& a,
611       const TensorBase& b);
612   void build_ternary_op(
613       const TensorBase& out,
614       const TensorBase& a,
615       const TensorBase& b,
616       const TensorBase& c);
617 
618 #undef TORCH_DISALLOW_TEMPORARIES
619  protected:
620   // Mutable reference as it moves tensors out of TensorIteratorConfig
621   void populate_operands(TensorIteratorConfig&);
622   void mark_outputs();
623   void mark_resize_outputs(const TensorIteratorConfig&);
624   void compute_mem_overlaps(const TensorIteratorConfig&);
625   void compute_shape(const TensorIteratorConfig&);
626   void compute_strides(const TensorIteratorConfig&);
627   void reorder_dimensions();
628   void permute_dimensions(IntArrayRef perm);
629   void compute_types(const TensorIteratorConfig&);
630   ScalarType compute_common_dtype();
631   void allocate_or_resize_outputs();
632   bool fast_set_up(const TensorIteratorConfig&);
633   FastSetupType compute_fast_setup_type(const TensorIteratorConfig&);
634   void compute_names(const TensorIteratorConfig&);
635   void propagate_names_to_outputs();
636   void coalesce_dimensions();
637 
638  protected:
639   /// Records the "computation" shape of the output tensor. The computation
640   /// shape is different from the regular shape in a few ways:
641   ///
642   ///   - The shape may be permuted (via permute_dimensions) so that we
643   ///     process the dimensions in the most computationally efficient order
644   ///     (rather than the logical order given to us by the users.)
645   ///   - The shape may have adjacent dimensions collapsed (via
646   ///     coalesce_dimensions) so that we minimize the number of
647   ///     dimensions we have to explicitly iterate over.  For example,
648   ///     a pointwise operation on a contiguous tensor "computationally"
649   ///     consists of only a single dimension.
650   ///
651   /// In other words, the computation shape is the output shape as it
652   /// actually matters for implementing the kernel, but not necessarily the
653   /// output shape that the user will see in the end.
654   ///
655   /// The lifecycle of mutations to shape_ in TensorIterator:
656   ///   - declare_static_shape() sets an initial shape explicitly
657   ///     provided by user, otherwise
658   ///   - compute_shape() computes the true (non-computational) shape
659   ///     specified by the user.
660   ///   - reorder_dimensions() reorders dimensions to improve coalescing.
661   ///   - coalesce_dimensions() then coalesces adjacent dimensions when
662   ///     possible.
663   ///
664   /// The shape may also be further modified if we create sub-TensorIterators,
665   /// e.g., via narrow or select_all_keeping_dim.
666   DimVector shape_;
667 
668   /// Temporarily records the permutation computed by reorder_dimensions.
669   /// This permutation maps the computation output dimension (dim) to
670   /// the original true output dimension (perm_[dim]).  It is used by
671   /// invert_perm to undo the permutation.  After coalesce_dimensions is
672   /// called, the permutation is no longer valid (as, in general, there
673   /// is no permutation that will make computation dimensions to
674   /// output dimensions); methods that manipulate perm_ are obligated
675   /// to test that !has_coalesced_dimensions
676   DimVector perm_;
677 
678   /// Has coalesce_dimensions() (or any moral equivalent, e.g., fast_build())
679   /// been called?  This is SOLELY used to check validity of perm_.
680   bool has_coalesced_dimensions_ = false;
681 
682   /// Whether iteration must be fixed. This disables dimension permuting and
683   /// also changes how for_each divides work among threads.
684   bool enforce_linear_iteration_ = false;
685 
686   /// The index offsets into the original tensors for each dimension.
687   /// This is only non-zero when you narrow() a TensorIterator (e.g.,
688   /// when you make sub-TensorIterators).
689   DimVector view_offsets_;
690 
691   /// The computed names of the output tensor.  Computed by compute_names()
692   NameVector names_;
693 
694   /// The operands of the TensorIterator: both the inputs and outputs.  The
695   /// outputs MUST come first in the operands_ list.  There is always an
696   /// operand for each output of the TensorIterator, even if TensorIterator
697   /// will ultimately be responsible for allocating the output; in those
698   /// cases, tensor is simply undefined (and will be populated later
699   /// during build()).
700   ///
701   /// This list is initially populated prior to build(), but build() mutates
702   /// OperandInfo to populate more information.
703   SmallVector<OperandInfo, 4> operands_;
704 
705   /// Number of outputs in operands_ (the length of the outputs prefix
706   /// in operands_).
707   int num_outputs_ = 0;
708 
709   /// Whether or not all operands have the same shape and are 1d+. Having all
710   /// the same shape affects whether or not the iterator is eligible for fast
711   /// setup.
712   bool all_ops_same_shape_ = false;
713   /// Whether or not all operands are 0d, this affects type promotion
714   bool all_ops_are_scalars_ = false;
715 
716   /// The "computation" dtype of TensorIterator, specifying what the dtype
717   /// we will do the internal computation in TensorIterator.  Typically,
718   /// this matches the dtype of the output tensors, but not always!
719   ScalarType common_dtype_ = ScalarType::Undefined;
720 
721   /// This is currently defined as kCPU, or the device of the first non-CPU
722   /// tensor argument. See TensorIteratorBase::compute_types for details.
723   Device common_device_ = kCPU;
724 
725   /// Set by split(), see should_accumulate() and is_final_output()
726   bool accumulate_ = false;
727   bool final_output_ = true;
728 
729   // From TensorIteratorConfig
730   bool is_reduction_ = false;
731 
732   /// Set by populate_operands(), says if we're handling meta tensors
733   bool is_meta_ = false;
734 };
735 
736 struct TORCH_API TensorIterator final : public TensorIteratorBase {
TensorIteratorfinal737   TensorIterator() : TensorIteratorBase() {}
738   // Slicing is OK, TensorIterator guaranteed NOT to have any fields
TensorIteratorfinal739   TensorIterator(const TensorIteratorBase& iter) : TensorIteratorBase(iter) {}
740 
741 #define TORCH_DISALLOW_TEMPORARIES(methodname) \
742   TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, static)
743 
744   static TensorIterator binary_float_op(
745       TensorBase& out,
746       const TensorBase& a,
747       const TensorBase& b);
748   static TensorIterator binary_op(
749       TensorBase& out,
750       const TensorBase& a,
751       const TensorBase& b);
752   static TensorIterator borrowing_binary_op(
753       const TensorBase& out,
754       const TensorBase& a,
755       const TensorBase& b);
756   TORCH_DISALLOW_TEMPORARIES(borrowing_binary_op)
757   static TensorIterator comparison_op(
758       TensorBase& out,
759       const TensorBase& a,
760       const TensorBase& b);
761   static TensorIterator unary_op(TensorBase& out, const TensorBase& a);
762   static TensorIterator unary_float_op(TensorBase& out, const TensorBase& a);
763   static TensorIterator nullary_op(TensorBase& out);
764   static TensorIterator borrowing_nullary_op(const TensorBase& out);
765   static TensorIterator borrowing_nullary_op(TensorBase&& out) = delete;
766   static TensorIterator reduce_op(TensorBase& out, const TensorBase& a);
767   static TensorIterator reduce_op(
768       TensorBase& out1,
769       TensorBase& out2,
770       const TensorBase& a);
771 #undef TORCH_DISALLOW_TEMPORARIES
772 #undef TORCH_DISALLOW_TEMPORARIES_IMPL
773 
774   const Tensor& maybe_get_output(int64_t output_idx) override;
775   void set_output_raw_strided(
776       int64_t output_idx,
777       IntArrayRef sizes,
778       IntArrayRef strides,
779       TensorOptions options,
780       DimnameList names) override;
781 };
782 
783 class TORCH_API TensorIteratorConfig final {
784  public:
785   friend struct TensorIteratorBase;
786   friend struct TensorIterator;
787 
788   TensorIteratorConfig() = default;
789 
790   C10_DISABLE_COPY_AND_ASSIGN(TensorIteratorConfig);
791 
792   /// Construction
793   // Stores input/output Tensors without incrementing the reference count.
794   // Important: the outputs have to be added before the inputs.
add_output(const TensorBase & output)795   TensorIteratorConfig& add_output(const TensorBase& output) {
796     return add_borrowed_output(output);
797   }
add_input(const TensorBase & input)798   TensorIteratorConfig& add_input(const TensorBase& input) {
799     return add_borrowed_input(input);
800   }
add_const_input(const TensorBase & input)801   TensorIteratorConfig& add_const_input(const TensorBase& input) {
802     return add_borrowed_const_input(input);
803   }
804 
805   // Borrowing from temporaries is unlikely to go well.
806   TensorIteratorConfig& add_output(TensorBase&& output) = delete;
807   TensorIteratorConfig& add_input(TensorBase&& input) = delete;
808   TensorIteratorConfig& add_const_input(TensorBase&& input) = delete;
809 
810   // Stores input/output Tensors while incrementing the reference count.
811   // Note that add_{in,out}put are nearly always what you
812   // want, and the exception (adding an unnamed temporary) won't
813   // compile.
814   TensorIteratorConfig& add_owned_output(const TensorBase& output);
815   TensorIteratorConfig& add_owned_input(const TensorBase& input);
816   TensorIteratorConfig& add_owned_const_input(const TensorBase& input);
817 
818   // Advanced API: stores input/output Tensors without incrementing
819   // the reference count. The caller must ensure that these Tensors
820   // live at least as long as this TensorIteratorConfig and any
821   // TensorIteratorBase built from this TensorIteratorConfig.
822   // Important: the outputs have to be added before the inputs.
823   TensorIteratorConfig& add_borrowed_output(const TensorBase& output);
824   TensorIteratorConfig& add_borrowed_input(const TensorBase& input);
825   TensorIteratorConfig& add_borrowed_const_input(const TensorBase& input);
826 
827   // Borrowing from temporaries is unlikely to go well.
828   TensorIteratorConfig& add_borrowed_output(TensorBase&& output) = delete;
829   TensorIteratorConfig& add_borrowed_input(TensorBase&& input) = delete;
830   TensorIteratorConfig& add_borrowed_const_input(TensorBase&& input) = delete;
831 
832   // Sets the check_mem_overlap_ flag, which is true by default.
833   // If true, inputs are checked for partial overlap with the outputs and
834   // outputs are checked for internal overlap (e.g. broadcasted views). An error
835   // is raised if unacceptable overlap is detected.
836   // If you're migrating an existing operator to using TensorIterator, please
837   // consider if the previous implementation checked memory overlap. If it did
838   // not, and if the operator is idempotent (for example, Tensor.fill_(0)), then
839   // checking memory overlap is BC-breaking. Please don't check memory overlap
840   // in that case.
set_check_mem_overlap(bool check_mem_overlap)841   TensorIteratorConfig& set_check_mem_overlap(bool check_mem_overlap) {
842     check_mem_overlap_ = check_mem_overlap;
843     return *this;
844   }
845 
846   // Sets the check_all_same_dtype_ flag, which is true by default
847   // If true, checks that all inputs and defined outputs have the same dtype
848   // Setting either of promote_inputs_to_common_dtype_
849   //   or cast_common_dtype_to_outputs_ to true will set
850   //   check_all_same_dtype_ to false.
check_all_same_dtype(const bool _check_all_same_dtype)851   TensorIteratorConfig& check_all_same_dtype(const bool _check_all_same_dtype) {
852     check_all_same_dtype_ = _check_all_same_dtype;
853     return *this;
854   }
855 
856   // Sets the check_all_same_device_ flag, which is true by default
857   // If true, all operands must be on the same device, with the possible
858   //   exception of CPU scalars, which can be passed to some CUDA kernels
859   //   as kernel arguments.
check_all_same_device(const bool _check_all_same_device)860   TensorIteratorConfig& check_all_same_device(
861       const bool _check_all_same_device) {
862     check_all_same_device_ = _check_all_same_device;
863     return *this;
864   }
865 
866   // Sets the enforce_safe_casting_to_output_ flag, which is false by default
867   // If true, the iterator's "common dtype" must be computable
868   //   (see the [Common Dtype Computation] note) and
869   //   canCast(common dtype, output dtype) must be true for all outputs.
enforce_safe_casting_to_output(const bool _enforce_safe_casting_to_output)870   TensorIteratorConfig& enforce_safe_casting_to_output(
871       const bool _enforce_safe_casting_to_output) {
872     enforce_safe_casting_to_output_ = _enforce_safe_casting_to_output;
873     return *this;
874   }
875 
876   // Sets the enforce_linear_iteration_ flag, which is false by default.
877   // If true, iteration goes in the same order as a C-contiguous tensor
878   // is layed out in memory. i.e. last dimension iterates fastest.
879   //
880   // This iteration order can be less efficient and may even prevent
881   // vectorization. So only use if the correctness of your kernel depends on it.
882   TensorIteratorConfig& enforce_linear_iteration(
883       const bool _enforce_linear_iteration = true) {
884     enforce_linear_iteration_ = _enforce_linear_iteration;
885     return *this;
886   }
887 
888   // Sets the promote_inputs_to_common_dtype_ flag, which is false by default
889   // If true, the iterator's "common dtype" is always computed (see the
890   //   [Common Dtype Computation] note) and, on the CPU, temporary copies of
891   //   the inputs in the common dtype are passed as the actual inputs to
892   //   the operation.
893   // Setting this flag to true sets check_all_same_dtype_ to false.
promote_inputs_to_common_dtype(const bool _promote_inputs_to_common_dtype)894   TensorIteratorConfig& promote_inputs_to_common_dtype(
895       const bool _promote_inputs_to_common_dtype) {
896     promote_inputs_to_common_dtype_ = _promote_inputs_to_common_dtype;
897     if (_promote_inputs_to_common_dtype) {
898       check_all_same_dtype_ = false;
899     }
900     return *this;
901   }
902 
903   // Sets the promote_integer_inputs_to_float_ flag, which is false by default
904   // NOTE: If set to true, the promote_inputs_to_common_dtype_ must also be
905   // true. If true, if the iterator's "common dtype" is an integral type
906   // (including bool)
907   //   then it is changed to the default float scalar type.
promote_integer_inputs_to_float(const bool _promote_integer_inputs_to_float)908   TensorIteratorConfig& promote_integer_inputs_to_float(
909       const bool _promote_integer_inputs_to_float) {
910     promote_integer_inputs_to_float_ = _promote_integer_inputs_to_float;
911     TORCH_INTERNAL_ASSERT(
912         !promote_integer_inputs_to_float_ || promote_inputs_to_common_dtype_);
913     return *this;
914   }
915 
is_reduction(const bool _is_reduction)916   TensorIteratorConfig& is_reduction(const bool _is_reduction) {
917     is_reduction_ = _is_reduction;
918     return *this;
919   }
920 
allow_cpu_scalars(const bool _allow_cpu_scalars)921   TensorIteratorConfig& allow_cpu_scalars(const bool _allow_cpu_scalars) {
922     allow_cpu_scalars_ = _allow_cpu_scalars;
923     return *this;
924   }
925 
926   // Sets the cast_common_dtype_to_outputs_ flag, which is false by default
927   // If true, the iterator's "common dtype" must be computatable
928   //   (see the [Common Dtype Computation] note) and, on the CPU, temporary
929   //   copies of the outputs are passed as the actual output to the operation.
930   //   These temporaries are then copied to the original outputs after
931   //   the operation is performed (see cast_outputs()).
932   // Setting this flag to true sets check_all_same_dtype_ to false.
cast_common_dtype_to_outputs(const bool _cast_common_dtype_to_outputs)933   TensorIteratorConfig& cast_common_dtype_to_outputs(
934       const bool _cast_common_dtype_to_outputs) {
935     cast_common_dtype_to_outputs_ = _cast_common_dtype_to_outputs;
936     if (_cast_common_dtype_to_outputs) {
937       check_all_same_dtype_ = false;
938     }
939     return *this;
940   }
941 
resize_outputs(bool resize_outputs)942   TensorIteratorConfig& resize_outputs(bool resize_outputs) {
943     resize_outputs_ = resize_outputs;
944     return *this;
945   }
946 
947   // Bypass output dtype/device computation and fix the dtype/device as
948   // specified here.
949   TensorIteratorConfig& declare_static_dtype_and_device(
950       ScalarType dtype,
951       Device device);
952   TensorIteratorConfig& declare_static_dtype(ScalarType dtype);
953   TensorIteratorConfig& declare_static_device(Device device);
954   TensorIteratorConfig& declare_static_shape(IntArrayRef shape);
955   TensorIteratorConfig& declare_static_shape(
956       IntArrayRef shape,
957       IntArrayRef squash_dims);
958 
959   // It would be better if this was && qualified, but this would be at the cost
960   // of a lot of boilerplate above
build()961   TensorIterator build() {
962     TensorIterator iter;
963     iter.build(*this);
964     return iter;
965   }
966 
967  private:
968   bool is_tensor_const(size_t idx);
969 
970   SmallVector<c10::MaybeOwned<TensorBase>, 4> tensors_;
971   int num_outputs_ = 0;
972   int num_inputs_ = 0;
973 
974   std::optional<DimVector> static_shape_ = std::nullopt;
975   std::optional<ScalarType> static_dtype_ = std::nullopt;
976   std::optional<Device> static_device_ = std::nullopt;
977   bool check_mem_overlap_ = true;
978   bool allow_cpu_scalars_ = false;
979   bool is_reduction_ = false;
980   bool resize_outputs_ = true;
981   bool check_all_same_dtype_ = true;
982   bool check_all_same_device_ = true;
983   bool enforce_safe_casting_to_output_ = false;
984   bool enforce_linear_iteration_ = false;
985   bool promote_inputs_to_common_dtype_ = false;
986   bool promote_integer_inputs_to_float_ = false;
987   bool cast_common_dtype_to_outputs_ = false;
988 
989   SmallVector<size_t, 4> const_tensor_indices_;
990 };
991 
992 /// A container-like struct that acts as if it contains splits of a
993 /// TensorIterator that can use 32-bit indexing. Taken together the splits cover
994 /// the original TensorIterator.
995 struct TORCH_API SplitUntil32Bit {
996   struct TORCH_API iterator {
997     iterator() = default;
998     iterator(const TensorIteratorBase& iter);
999     iterator(iterator&&) = default;
1000 
1001     // Guaranteed to be a TensorIterator proper!
1002     TensorIterator& operator*() const;
1003     iterator& operator++();
1004     bool operator==(const iterator& other) const {
1005       // two iterators are equal if they are the same object or they're both
1006       // empty
1007       return this == &other || (vec.empty() && other.vec.empty());
1008     }
1009     // needed for C++11 range-based for loop
1010     bool operator!=(const iterator& other) const {
1011       return !(*this == other);
1012     }
1013 
1014     /// stack of TensorIterators to be split
1015     std::vector<std::unique_ptr<TensorIterator>> vec;
1016   };
1017 
SplitUntil32BitSplitUntil32Bit1018   SplitUntil32Bit(const TensorIteratorBase& iter) : iter(iter) {}
1019 
1020   iterator begin() const;
1021   iterator end() const;
1022 
1023  private:
1024   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
1025   const TensorIteratorBase& iter;
1026 };
1027 
1028 } // namespace at
1029