xref: /aosp_15_r20/external/pytorch/aten/src/ATen/TensorIndexing.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ExpandUtils.h>
4 #include <ATen/ScalarOps.h>
5 #include <ATen/core/Tensor.h>
6 #include <ATen/core/TensorBody.h>
7 #include <c10/core/SymInt.h>
8 #include <c10/util/irange.h>
9 #include <optional>
10 
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #include <ATen/NativeFunctions.h>
14 #else
15 #include <ATen/ops/alias.h>
16 #include <ATen/ops/empty.h>
17 #include <ATen/ops/scalar_tensor.h>
18 #include <ATen/ops/zeros.h>
19 #endif
20 
21 #include <ATen/core/List.h>
22 
23 #include <utility>
24 
25 namespace at::indexing {
26 
27 constexpr int64_t INDEX_MIN = c10::SymInt::min_representable_int();
28 constexpr int64_t INDEX_MAX = -(INDEX_MIN + 1);
29 
30 enum class TensorIndexType { None, Ellipsis, SymInt, Boolean, Slice, Tensor };
31 
32 constexpr std::nullopt_t None = std::nullopt;
33 
34 struct TORCH_API EllipsisIndexType final {
35   EllipsisIndexType() = default;
36 };
37 TORCH_API extern const EllipsisIndexType Ellipsis;
38 
39 struct TORCH_API Slice final {
40  public:
41   Slice(
42       std::optional<c10::SymInt> start_index = std::nullopt,
43       std::optional<c10::SymInt> stop_index = std::nullopt,
44       std::optional<c10::SymInt> step_index = std::nullopt) {
45     if (!step_index.has_value()) {
46       step_ = c10::SymInt(1);
47     } else {
48       step_ = std::move(step_index).value();
49     }
50 
51     TORCH_CHECK_VALUE(
52         step_.sym_ne(0).expect_true(__FILE__, __LINE__),
53         "slice step cannot be zero");
54 
55     if (!start_index.has_value()) {
56       start_ = c10::SymInt(step_ < 0 ? INDEX_MAX : 0);
57     } else {
58       start_ = std::move(start_index).value();
59     }
60 
61     if (!stop_index.has_value()) {
62       stop_ = c10::SymInt(step_ < 0 ? INDEX_MIN : INDEX_MAX);
63     } else {
64       stop_ = std::move(stop_index).value();
65     }
66   }
67 
startfinal68   inline c10::SymInt start() const {
69     return start_;
70   }
71 
stopfinal72   inline c10::SymInt stop() const {
73     return stop_;
74   }
75 
stepfinal76   inline c10::SymInt step() const {
77     return step_;
78   }
79 
80  private:
81   c10::SymInt start_;
82   c10::SymInt stop_;
83   c10::SymInt step_;
84 };
85 
86 TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice);
87 
88 // `at::indexing::TensorIndex` is used for converting C++ tensor indices such as
89 // `{None, "...", Ellipsis, 0, true, Slice(1, None, 2), torch::tensor({1, 2})}`
90 // into its equivalent `std::vector<TensorIndex>`, so that further tensor
91 // indexing operations can be performed using the supplied indices.
92 //
93 // There is one-to-one correspondence between Python and C++ tensor index types:
94 // Python                  | C++
95 // -----------------------------------------------------
96 // `None`                  | `at::indexing::None`
97 // `Ellipsis`              | `at::indexing::Ellipsis`
98 // `...`                   | `"..."`
99 // `123`                   | `123`
100 // `True` / `False`        | `true` / `false`
101 // `:`                     | `Slice()` / `Slice(None, None)`
102 // `::`                    | `Slice()` / `Slice(None, None, None)`
103 // `1:`                    | `Slice(1, None)`
104 // `1::`                   | `Slice(1, None, None)`
105 // `:3`                    | `Slice(None, 3)`
106 // `:3:`                   | `Slice(None, 3, None)`
107 // `::2`                   | `Slice(None, None, 2)`
108 // `1:3`                   | `Slice(1, 3)`
109 // `1::2`                  | `Slice(1, None, 2)`
110 // `:3:2`                  | `Slice(None, 3, 2)`
111 // `1:3:2`                 | `Slice(1, 3, 2)`
112 // `torch.tensor([1, 2])`) | `torch::tensor({1, 2})`
113 struct TORCH_API TensorIndex final {
114   // Case 1: `at::indexing::None`
TensorIndexfinal115   TensorIndex(std::nullopt_t) : type_(TensorIndexType::None) {}
116 
117   // Case 2: "..." / `at::indexing::Ellipsis`
TensorIndexfinal118   TensorIndex(at::indexing::EllipsisIndexType)
119       : type_(TensorIndexType::Ellipsis) {}
TensorIndexfinal120   TensorIndex(const char* str) : TensorIndex(at::indexing::Ellipsis) {
121     TORCH_CHECK_VALUE(
122         strcmp(str, "...") == 0,
123         "Expected \"...\" to represent an ellipsis index, but got \"",
124         str,
125         "\"");
126   }
127 
128   // Case 3: (Sym) Integer value
TensorIndexfinal129   TensorIndex(SymInt integer)
130       : integer_(std::move(integer)), type_(TensorIndexType::SymInt) {}
TensorIndexfinal131   TensorIndex(int64_t integer) : TensorIndex(SymInt(integer)) {}
TensorIndexfinal132   TensorIndex(int integer) : TensorIndex(SymInt(integer)) {}
133 
134   // Case 4: Boolean value
135   template <class T, class = std::enable_if_t<std::is_same_v<bool, T>>>
TensorIndexfinal136   TensorIndex(T boolean) : boolean_(boolean), type_(TensorIndexType::Boolean) {}
137 
138   // Case 5: Slice represented in `at::indexing::Slice` form
TensorIndexfinal139   TensorIndex(Slice slice)
140       : slice_(std::move(slice)), type_(TensorIndexType::Slice) {}
141 
142   // Case 6: Tensor value
TensorIndexfinal143   TensorIndex(Tensor tensor)
144       : tensor_(std::move(tensor)), type_(TensorIndexType::Tensor) {}
145 
is_nonefinal146   inline bool is_none() const {
147     return type_ == TensorIndexType::None;
148   }
149 
is_ellipsisfinal150   inline bool is_ellipsis() const {
151     return type_ == TensorIndexType::Ellipsis;
152   }
153 
is_integerfinal154   inline bool is_integer() const {
155     return type_ == TensorIndexType::SymInt;
156   }
157 
integerfinal158   inline SymInt integer() const {
159     return integer_;
160   }
161 
is_booleanfinal162   inline bool is_boolean() const {
163     return type_ == TensorIndexType::Boolean;
164   }
165 
booleanfinal166   inline bool boolean() const {
167     return boolean_;
168   }
169 
is_slicefinal170   inline bool is_slice() const {
171     return type_ == TensorIndexType::Slice;
172   }
173 
slicefinal174   inline const Slice& slice() const {
175     return slice_;
176   }
177 
is_tensorfinal178   inline bool is_tensor() const {
179     return type_ == TensorIndexType::Tensor;
180   }
181 
tensorfinal182   inline const Tensor& tensor() const {
183     return tensor_;
184   }
185 
186  private:
187   SymInt integer_ = 0;
188   bool boolean_ = false;
189   Slice slice_;
190   Tensor tensor_;
191   TensorIndexType type_;
192 };
193 
194 TORCH_API std::ostream& operator<<(
195     std::ostream& stream,
196     const TensorIndex& tensor_index);
197 TORCH_API std::ostream& operator<<(
198     std::ostream& stream,
199     const std::vector<TensorIndex>& tensor_indices);
200 
201 namespace impl {
applySlice(const Tensor & self,int64_t dim,c10::SymInt start,c10::SymInt stop,c10::SymInt step,bool disable_slice_optimization,const at::Device & self_device,const std::optional<SymIntArrayRef> & self_sizes)202 inline Tensor applySlice(
203     const Tensor& self,
204     int64_t dim,
205     c10::SymInt start,
206     c10::SymInt stop,
207     c10::SymInt step,
208     bool disable_slice_optimization,
209     const at::Device& self_device,
210     const std::optional<SymIntArrayRef>& self_sizes) {
211   // TODO: implement negative step
212   TORCH_CHECK_VALUE(
213       step.sym_gt(0).expect_true(__FILE__, __LINE__),
214       "step must be greater than zero");
215 
216   // See NOTE [nested tensor size for indexing]
217   if (self_sizes.has_value()) {
218     // Skip this optimization if we are tracing, as the trace may be polymorphic
219     // over the shape of the `self` tensor, and we still want to record
220     // the slice.
221     SymInt length = (self_device == at::kCPU || self_device == at::kCUDA)
222         ? (*self_sizes)[dim]
223         : self.sym_size(dim);
224     if (!disable_slice_optimization &&
225         TORCH_GUARD_SIZE_OBLIVIOUS(start.sym_eq(0)) &&
226         TORCH_GUARD_SIZE_OBLIVIOUS(length.sym_eq(stop)) && step == 1) {
227       return self;
228     }
229   }
230   return self.slice_symint(
231       dim, std::move(start), std::move(stop), std::move(step));
232 }
233 
applySelect(const Tensor & self,int64_t dim,SymInt index,int64_t real_dim,const at::Device &,const std::optional<SymIntArrayRef> & self_sizes)234 inline Tensor applySelect(
235     const Tensor& self,
236     int64_t dim,
237     SymInt index,
238     int64_t real_dim,
239     const at::Device& /*self_device*/,
240     const std::optional<SymIntArrayRef>& self_sizes) {
241   // See NOTE [nested tensor size for indexing]
242   if (self_sizes.has_value()) {
243     auto maybe_index = index.maybe_as_int();
244     if (maybe_index.has_value()) {
245       TORCH_CHECK_INDEX(
246           !(maybe_index.value() == 0 && dim == 0 && self_sizes->empty()),
247           "invalid index of a 0-dim tensor. ",
248           "Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number");
249     }
250 
251     auto size = (*self_sizes)[dim];
252     // Note: `size >= -index` is not equivalent to `size > -1 - index` if index
253     // is INT64_MIN For std::numeric_limits<int64_t>::min() result of unary
254     // minus is undefined by the standard but in practice is equal to self. On
255     // the other hand, indexing wraping is valid for all negative int64_t
256     // values, as x[INT64_MIN] is the same as x[INT64_MAX]
257     TORCH_CHECK_INDEX(
258         size > -1 - index && size > index,
259         "index ",
260         index,
261         " is out of bounds for dimension ",
262         real_dim,
263         " with size ",
264         size);
265   }
266 
267   // if the index is negative, do not normalize it because that would fix the
268   // index on the current tensor size in the tracer. aten::select also works on
269   // negative indices
270   return self.select_symint(dim, std::move(index));
271 }
272 
boolToIndexingTensorCPUOrCUDA(const Tensor & self,bool value)273 inline Tensor boolToIndexingTensorCPUOrCUDA(const Tensor& self, bool value) {
274   // booleans add a dimension of size 1. true indexes this dimension as if 0:,
275   // false as empty.
276   if (value) {
277     return at::empty({1}, self.options().dtype(kLong)).fill_(0.);
278   } else {
279     return at::empty({0}, self.options().dtype(kLong));
280   }
281 }
282 
boolToIndexingTensorNonNativeDeviceType(const Tensor & self,bool value)283 inline Tensor boolToIndexingTensorNonNativeDeviceType(
284     const Tensor& self,
285     bool value) {
286   // booleans add a dimension of size 1. true indexes this dimension as if 0:,
287   // false as empty.
288   if (value) {
289     return at::zeros({1}, self.options().dtype(kLong));
290   } else {
291     return at::empty({0}, self.options().dtype(kLong));
292   }
293 }
294 
boolToIndexingTensor(const Tensor & self,bool value,const at::Device & self_device)295 inline Tensor boolToIndexingTensor(
296     const Tensor& self,
297     bool value,
298     const at::Device& self_device) {
299   if (self_device == at::kCPU || self_device == at::kCUDA) {
300     return boolToIndexingTensorCPUOrCUDA(self, value);
301   } else {
302     return boolToIndexingTensorNonNativeDeviceType(self, value);
303   }
304 }
305 
scalarToTensorNonNativeDeviceType(const Scalar & v,const TensorOptions & options)306 inline Tensor scalarToTensorNonNativeDeviceType(
307     const Scalar& v,
308     const TensorOptions& options) {
309   return at::scalar_tensor(v, options);
310 }
311 
recordTensorIndex(const Tensor & tensor,std::vector<Tensor> & outIndices,int64_t * dim_ptr)312 inline void recordTensorIndex(
313     const Tensor& tensor,
314     std::vector<Tensor>& outIndices,
315     int64_t* dim_ptr) {
316   // TODO: check scalarType
317   outIndices.resize(*dim_ptr + 1);
318   outIndices[*dim_ptr] = tensor;
319   (*dim_ptr)++;
320 };
321 
typeConvertIndices(const Tensor &,std::vector<Tensor> && indices)322 inline c10::List<::std::optional<Tensor>> typeConvertIndices(
323     const Tensor& /*self*/,
324     std::vector<Tensor>&& indices) {
325   c10::List<::std::optional<Tensor>> converted_inds;
326   converted_inds.reserve(indices.size());
327   for (auto&& i : std::move(indices)) {
328     converted_inds.push_back(std::move(i));
329   }
330   return converted_inds;
331 }
332 
333 // NOTE: Why do we mirror instead of replace the `count_specified_dimensions`
334 // function in torch/csrc/autograd/python_variable_indexing.cpp? It's because
335 // `count_specified_dimensions` is on the hot path of Python tensor multi-dim
336 // indexing (i.e. it's called by `applySlicing` which is called by
337 // `THPVariable_getitem` / `THPVariable_setitem` when handling indexing of more
338 // than one dimension). If we were to merge the Python/C++
339 // `count_specified_dimensions` function, on the Python side we would have to
340 // construct a `std::vector` container to be consumed by the C++
341 // `count_specified_dimensions` function, which adds 100s of nanoseconds
342 // overhead and is undesirable.
count_specified_dimensions(const ArrayRef<TensorIndex> & indices)343 inline int64_t count_specified_dimensions(
344     const ArrayRef<TensorIndex>& indices) {
345   // Count the number of indexed dimensions (everything but ellipsis and None)
346   int64_t count = 0;
347   for (auto& obj : indices) {
348     if (obj.is_tensor()) {
349       auto& tensor = obj.tensor();
350       if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) {
351         count += tensor.dim();
352       } else {
353         count++;
354       }
355     } else if (!obj.is_none() && !obj.is_ellipsis() && !obj.is_boolean()) {
356       count++;
357     }
358   }
359   return count;
360 }
361 } // namespace impl
362 
363 // NOTE: Many functions below are only for consumption from Python indexing
364 // implementation, they include:
365 //
366 // - `Tensor scalarToTensor(...)`
367 // - `IntArrayRef slicePrefix1sSize(...)`
368 // - `void copy_to(...)`
369 // - `Tensor handleDimInMultiDimIndexing(...)`
370 // - `Tensor dispatch_index(...)`
371 // - `Tensor dispatch_index_put_(...)`
372 // - `Tensor get_item(...)`
373 // - `void set_item(...)`
374 //
375 // The rest of the functions are in `at::indexing::impl` namespace, signifying
376 // that they shouldn't be used from Python indexing implementation.
scalarToTensor(const Scalar & v,const TensorOptions & options,const at::Device & self_device)377 inline Tensor scalarToTensor(
378     const Scalar& v,
379     const TensorOptions& options,
380     const at::Device& self_device) {
381   if (self_device == at::kCPU && !v.isSymbolic()) {
382     return at::detail::scalar_tensor_static(
383         v, options.dtype_opt()->toScalarType(), self_device);
384   } else {
385     return impl::scalarToTensorNonNativeDeviceType(v, options);
386   }
387 }
388 
389 // To match numpy semantics:
390 // As a special case for backwards compatibility,
391 // strip away unit dimensions from the left of 'src'
slicePrefix1sSize(const SymIntArrayRef & sizes)392 inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) {
393   size_t first_non1_src = sizes.size();
394   for (const auto i : c10::irange(sizes.size())) {
395     // Unbacked SymInt has different behavior, but this is sound because
396     // failing to slice will only ever cause an error, not divergent
397     // behavior
398     if (!sizes[i].has_hint() || sizes[i] != 1) {
399       first_non1_src = i;
400       break;
401     }
402   }
403 
404   return sizes.slice(first_non1_src);
405 }
406 
copy_to(const Tensor & dst,const Tensor & src)407 inline void copy_to(const Tensor& dst, const Tensor& src) {
408   if (dst.sym_sizes().equals(src.sym_sizes())) {
409     // A shortcut to avoid generating hard-coded constant sizes during tracing.
410     // This is not a perfect solution: when src & dst have different shapes,
411     // constants will still appear. Users can workaround that case by
412     // dst[index..] = src.reshape(..)
413     dst.copy_(src);
414     return;
415   } else if (src.dim() == 0 && src.device().type() == at::kCPU) {
416     dst.fill_(src);
417     return;
418   }
419   auto src_view = src.view_symint(slicePrefix1sSize(src.sym_sizes()));
420   c10::MaybeOwned<Tensor> b_src = expand_inplace(dst, src_view, "setitem");
421   dst.copy_(*b_src);
422 }
423 
424 // See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor
425 // indexing functions from Python ]
handleDimInMultiDimIndexing(const Tensor & prev_dim_result,const Tensor & original_tensor,const TensorIndex & index,int64_t * dim_ptr,int64_t * specified_dims_ptr,int64_t real_dim,std::vector<Tensor> & outIndices,bool disable_slice_optimization,const at::Device & original_tensor_device,const std::optional<SymIntArrayRef> & prev_dim_result_sizes)426 inline Tensor handleDimInMultiDimIndexing(
427     const Tensor& prev_dim_result,
428     const Tensor& original_tensor,
429     const TensorIndex& index,
430     int64_t* dim_ptr,
431     int64_t* specified_dims_ptr,
432     int64_t real_dim,
433     std::vector<Tensor>& outIndices,
434     bool disable_slice_optimization,
435     const at::Device& original_tensor_device,
436     const std::optional<SymIntArrayRef>& prev_dim_result_sizes) {
437   if (index.is_integer()) {
438     return impl::applySelect(
439         prev_dim_result,
440         *dim_ptr,
441         index.integer(),
442         real_dim,
443         original_tensor_device,
444         prev_dim_result_sizes);
445   } else if (index.is_slice()) {
446     Tensor result = impl::applySlice(
447         prev_dim_result,
448         *dim_ptr,
449         index.slice().start(),
450         index.slice().stop(),
451         index.slice().step(),
452         /*disable_slice_optimization=*/disable_slice_optimization,
453         original_tensor_device,
454         prev_dim_result_sizes);
455     (*dim_ptr)++;
456     return result;
457   } else if (index.is_ellipsis()) {
458     (*dim_ptr) += original_tensor.dim() - (*specified_dims_ptr);
459     return prev_dim_result;
460   } else if (index.is_none()) {
461     Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
462     (*dim_ptr)++;
463     return result;
464   } else if (index.is_boolean()) {
465     Tensor result = prev_dim_result.unsqueeze(*dim_ptr);
466     impl::recordTensorIndex(
467         impl::boolToIndexingTensor(
468             result, index.boolean(), original_tensor_device),
469         outIndices,
470         dim_ptr);
471     return result;
472   } else if (index.is_tensor()) {
473     Tensor result = prev_dim_result;
474     const Tensor& tensor = index.tensor();
475     auto scalar_type = tensor.scalar_type();
476     if (tensor.dim() == 0 &&
477         at::isIntegralType(scalar_type, /*includeBool=*/true)) {
478       if (scalar_type != at::kByte && scalar_type != at::kBool) {
479         result = impl::applySelect(
480             result,
481             *dim_ptr,
482             tensor.item<int64_t>(),
483             real_dim,
484             original_tensor_device,
485             prev_dim_result_sizes);
486       } else {
487         result = result.unsqueeze(*dim_ptr);
488         if (scalar_type == at::kBool) {
489           impl::recordTensorIndex(
490               impl::boolToIndexingTensor(
491                   result, tensor.item<bool>() != 0, original_tensor_device),
492               outIndices,
493               dim_ptr);
494         } else {
495           impl::recordTensorIndex(
496               impl::boolToIndexingTensor(
497                   result, tensor.item<uint8_t>() != 0, original_tensor_device),
498               outIndices,
499               dim_ptr);
500         }
501       }
502     } else {
503       impl::recordTensorIndex(tensor, outIndices, dim_ptr);
504     }
505     return result;
506   } else {
507     TORCH_INTERNAL_ASSERT(false, "Invalid TensorIndex type");
508   }
509 }
510 
511 namespace impl {
512 // This mirrors `applySlicing` in
513 // torch/csrc/autograd/python_variable_indexing.cpp
applySlicing(const Tensor & self,const ArrayRef<TensorIndex> & indices,std::vector<Tensor> & outIndices,bool disable_slice_optimization,const at::Device & self_device,const std::optional<SymIntArrayRef> & self_sizes)514 inline Tensor applySlicing(
515     const Tensor& self,
516     const ArrayRef<TensorIndex>& indices,
517     std::vector<Tensor>& outIndices,
518     bool disable_slice_optimization,
519     const at::Device& self_device,
520     const std::optional<SymIntArrayRef>& self_sizes) {
521   int64_t dim = 0;
522   int64_t specified_dims = impl::count_specified_dimensions(indices);
523 
524   // See NOTE [nested tensor size for indexing]
525   if (self_sizes.has_value()) {
526     TORCH_CHECK_INDEX(
527         specified_dims <= (int64_t)self_sizes->size(),
528         "too many indices for tensor of dimension ",
529         (int)self_sizes->size());
530   }
531 
532   Tensor result = self;
533   for (const auto i : c10::irange(indices.size())) {
534     auto& obj = indices[i];
535     // See NOTE [nested tensor size for indexing]
536     std::optional<SymIntArrayRef> result_sizes = result.is_nested()
537         ? std::optional<SymIntArrayRef>(std::nullopt)
538         : std::optional<SymIntArrayRef>(result.sym_sizes());
539     result = handleDimInMultiDimIndexing(
540         /*prev_dim_result=*/result,
541         /*original_tensor=*/self,
542         /*index=*/obj,
543         /*dim_ptr=*/&dim,
544         /*specified_dims_ptr=*/&specified_dims,
545         /*real_dim=*/static_cast<int64_t>(i),
546         /*outIndices=*/outIndices,
547         /*disable_slice_optimization=*/disable_slice_optimization,
548         /*original_tensor_device=*/self_device,
549         /*prev_dim_result_sizes=*/result_sizes);
550   }
551   return result;
552 }
553 } // namespace impl
554 
dispatch_index(const Tensor & self,std::vector<Tensor> && indices)555 inline Tensor dispatch_index(
556     const Tensor& self,
557     std::vector<Tensor>&& indices) {
558   return self.index(impl::typeConvertIndices(self, std::move(indices)));
559 }
560 
dispatch_index_put_(Tensor & self,std::vector<Tensor> && indices,const Tensor & value)561 inline Tensor dispatch_index_put_(
562     Tensor& self,
563     std::vector<Tensor>&& indices,
564     const Tensor& value) {
565   return self.index_put_(
566       impl::typeConvertIndices(self, std::move(indices)), value);
567 }
568 
569 // NOTE [ Setting `disable_slice_optimization` when calling C++ tensor indexing
570 // functions from Python ]
571 //
572 // Question: When should we set `disable_slice_optimization` to `true` when
573 // calling C++ tensor indexing functions from Python indexing code?
574 //
575 // Answer: What "slice optimization" means: when we have a slicing expression
576 // like `x[0:5, 0]`, where the sliced tensor was of size 5 in dimension 0, we
577 // would skip dispatching the actual slice call as an optimization. However,
578 // here are the cases where we DON'T want this optimization:
579 //
580 // 1. When we are doing 1-D slicing (e.g. `tensor[:]`).
581 //    Reason: we always return a shallow copy for expressions such as
582 //    `tensor[:]` / `tensor[...]` / `tensor[:, :]`. (Note that for `tensor[:,
583 //    :]`, we return an alias of `tensor` by doing the following:
584 //    ```
585 //    Tensor sliced = impl::applySlicing(self, indices, tensorIndices,
586 //    disable_slice_optimization, self_device, self_sizes); if
587 //    (tensorIndices.empty()) {
588 //      if (sliced.is_same(self)) {
589 //        // ensure we return a shallow copy for things like x[...]
590 //        sliced = at::alias(sliced);
591 //      }
592 //      return sliced;
593 //    }
594 //    ```)
595 // 2. When we are doing JIT tracing.
596 //    Reason: JIT tracing needs the `self.slice(...)` call to properly trace the
597 //    slice operation.
598 
599 // This mirrors `THPVariable_getitem` in
600 // torch/csrc/autograd/python_variable_indexing.cpp See NOTE [ Setting
601 // `disable_slice_optimization` when calling C++ tensor indexing functions from
602 // Python ]
603 inline Tensor get_item(
604     const Tensor& self,
605     const ArrayRef<TensorIndex>& indices,
606     bool disable_slice_optimization = false) {
607   at::Device self_device = self.device();
608   // NOTE [nested tensor size for indexing]
609   // nested tensor does not have a size (yet) so for now we represent its size
610   // as null may need to be changed after we reach a better solution for nested
611   // tensor size
612   std::optional<SymIntArrayRef> self_sizes = self.is_nested()
613       ? std::optional<SymIntArrayRef>(std::nullopt)
614       : std::optional<SymIntArrayRef>(self.sym_sizes());
615 
616   // handle simple types: integers, slices, none, ellipsis, bool
617   if (indices.size() == 1) {
618     const TensorIndex& index = indices[0];
619     if (index.is_integer()) {
620       return impl::applySelect(
621           self, 0, index.integer(), 0, self_device, self_sizes);
622     } else if (index.is_slice()) {
623       return impl::applySlice(
624           self,
625           0,
626           index.slice().start(),
627           index.slice().stop(),
628           index.slice().step(),
629           /*disable_slice_optimization=*/true,
630           self_device,
631           self_sizes);
632     } else if (index.is_none()) {
633       return self.unsqueeze(0);
634     } else if (index.is_ellipsis()) {
635       return at::alias(self);
636     } else if (index.is_boolean()) {
637       Tensor result = self.unsqueeze(0);
638       return dispatch_index(
639           result,
640           std::vector<Tensor>{impl::boolToIndexingTensor(
641               result, index.boolean(), self_device)});
642     }
643   }
644 
645   std::vector<Tensor> tensorIndices;
646   Tensor sliced = impl::applySlicing(
647       self,
648       indices,
649       tensorIndices,
650       disable_slice_optimization,
651       self_device,
652       self_sizes);
653   if (tensorIndices.empty()) {
654     if (sliced.is_same(self)) {
655       // ensure we return a shallow copy for things like x[...]
656       sliced = at::alias(sliced);
657     }
658     return sliced;
659   }
660 
661   // indexing by tensors ("advanced" indexing)
662   return dispatch_index(sliced, std::move(tensorIndices));
663 }
664 
665 // This mirrors `THPVariable_setitem` in
666 // torch/csrc/autograd/python_variable_indexing.cpp for "the assigned value is a
667 // Tensor" case See NOTE [ Setting `disable_slice_optimization` when calling C++
668 // tensor indexing functions from Python ]
669 inline void set_item(
670     const Tensor& self,
671     const ArrayRef<TensorIndex>& indices,
672     const Tensor& value,
673     bool disable_slice_optimization = false) {
674   at::Device self_device = self.device();
675   SymIntArrayRef self_sizes = self.sym_sizes();
676 
677   // handle simple types: integers, slices, ellipsis, bool
678   if (indices.size() == 1) {
679     const TensorIndex& index = indices[0];
680     if (index.is_boolean() && !index.boolean()) {
681       // do nothing for false (technically we should check the size, but we
682       // don't have real 0-sized shapes.
683       return;
684     } else if (index.is_ellipsis()) {
685       copy_to(self, value);
686       return;
687     } else if (index.is_none() || (index.is_boolean() && index.boolean())) {
688       copy_to(self.unsqueeze(0), value);
689       return;
690     } else if (index.is_integer()) {
691       copy_to(
692           impl::applySelect(
693               self, 0, index.integer(), 0, self_device, self_sizes),
694           value);
695       return;
696     } else if (index.is_slice()) {
697       copy_to(
698           impl::applySlice(
699               self,
700               0,
701               index.slice().start(),
702               index.slice().stop(),
703               index.slice().step(),
704               /*disable_slice_optimization=*/disable_slice_optimization,
705               self_device,
706               self_sizes),
707           value);
708       return;
709     }
710   }
711 
712   std::vector<Tensor> tensorIndices;
713   Tensor sliced = impl::applySlicing(
714       self,
715       indices,
716       tensorIndices,
717       disable_slice_optimization,
718       self_device,
719       self_sizes);
720   if (tensorIndices.empty()) {
721     copy_to(sliced, value);
722     return;
723   }
724 
725   SymIntArrayRef valueSizes = value.sym_sizes();
726   SymIntArrayRef slicedValueSizes = slicePrefix1sSize(valueSizes);
727   Tensor valuesSliced;
728   if (!valueSizes.equals(slicedValueSizes)) {
729     valuesSliced = value.view_symint(slicedValueSizes);
730   } else {
731     valuesSliced = value;
732   }
733   dispatch_index_put_(sliced, std::move(tensorIndices), valuesSliced);
734   return;
735 }
736 
737 } // namespace at::indexing
738