xref: /aosp_15_r20/external/pytorch/aten/src/ATen/TensorIndexing.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/TensorIndexing.h>
2 
3 #include <c10/util/Exception.h>
4 #include <c10/util/irange.h>
5 
6 namespace at {
7 namespace indexing {
8 
9 const EllipsisIndexType Ellipsis = EllipsisIndexType();
10 
operator <<(std::ostream & stream,const Slice & slice)11 std::ostream& operator<<(std::ostream& stream, const Slice& slice) {
12   stream << slice.start() << ":" << slice.stop() << ":" << slice.step();
13   return stream;
14 }
15 
operator <<(std::ostream & stream,const TensorIndex & tensor_index)16 std::ostream& operator<<(std::ostream& stream, const TensorIndex& tensor_index) {
17   if (tensor_index.is_none()) {
18     stream << "None";
19   } else if (tensor_index.is_ellipsis()) {
20     stream << "...";
21   } else if (tensor_index.is_integer()) {
22     stream << tensor_index.integer();
23   } else if (tensor_index.is_boolean()) {
24     stream << std::boolalpha << tensor_index.boolean();
25   } else if (tensor_index.is_slice()) {
26     stream << tensor_index.slice();
27   } else if (tensor_index.is_tensor()) {
28     stream << tensor_index.tensor();
29   }
30   return stream;
31 }
32 
operator <<(std::ostream & stream,const std::vector<TensorIndex> & tensor_indices)33 std::ostream& operator<<(std::ostream& stream, const std::vector<TensorIndex>& tensor_indices) {
34   stream << "(";
35   for (const auto i : c10::irange(tensor_indices.size())) {
36     stream << tensor_indices[i];
37     if (i < tensor_indices.size() - 1) stream << ", ";
38   }
39   stream << ")";
40   return stream;
41 }
42 
43 // This mirrors `THPVariable_setitem` in torch/csrc/autograd/python_variable_indexing.cpp
44 // for "the assigned value is a Scalar" case
set_item(const Tensor & self,ArrayRef<TensorIndex> indices,const Scalar & v)45 static inline void set_item(const Tensor& self, ArrayRef<TensorIndex> indices, const Scalar& v) {
46   Tensor value;
47 
48   {
49     at::AutoDispatchBelowADInplaceOrView guard;
50     at::Device self_device = self.device();
51 
52     // TODO: This qint special case looks very suspicious...
53     if (isQIntType(self.scalar_type())) {
54       value = at::indexing::scalarToTensor(v, device(kCPU).dtype(kFloat), at::Device(kCPU));
55     } else if (self_device.is_cuda()) {
56       value = at::indexing::scalarToTensor(v, self.options(), at::Device(kCPU));
57     } else {
58       value = at::indexing::scalarToTensor(v, self.options(), self_device);
59     }
60   }
61 
62   return set_item(self, indices, value);
63 }
64 
65 } // namespace indexing
66 
index(ArrayRef<at::indexing::TensorIndex> indices) const67 Tensor Tensor::index(ArrayRef<at::indexing::TensorIndex> indices) const {
68   TORCH_CHECK(!indices.empty(), "Passing an empty index list to Tensor::index() is not valid syntax");
69   OptionalDeviceGuard device_guard(device_of(*this));
70   return at::indexing::get_item(*this, indices);
71 }
index(std::initializer_list<at::indexing::TensorIndex> indices) const72 Tensor Tensor::index(std::initializer_list<at::indexing::TensorIndex> indices) const {
73   return index(ArrayRef<at::indexing::TensorIndex>(indices));
74 }
75 
index_put_(ArrayRef<at::indexing::TensorIndex> indices,Tensor const & rhs)76 Tensor & Tensor::index_put_(ArrayRef<at::indexing::TensorIndex> indices, Tensor const & rhs) {
77   TORCH_CHECK(!indices.empty(), "Passing an empty index list to Tensor::index_put_() is not valid syntax");
78   OptionalDeviceGuard device_guard(device_of(*this));
79   at::indexing::set_item(*this, indices, rhs);
80   return *this;
81 }
index_put_(ArrayRef<at::indexing::TensorIndex> indices,const Scalar & v)82 Tensor & Tensor::index_put_(ArrayRef<at::indexing::TensorIndex> indices, const Scalar& v) {
83   TORCH_CHECK(!indices.empty(), "Passing an empty index list to Tensor::index_put_() is not valid syntax");
84   OptionalDeviceGuard device_guard(device_of(*this));
85   at::indexing::set_item(*this, indices, v);
86   return *this;
87 }
index_put_(std::initializer_list<at::indexing::TensorIndex> indices,Tensor const & rhs)88 Tensor & Tensor::index_put_(std::initializer_list<at::indexing::TensorIndex> indices, Tensor const & rhs) {
89   return index_put_(ArrayRef<at::indexing::TensorIndex>(indices), rhs);
90 }
index_put_(std::initializer_list<at::indexing::TensorIndex> indices,const Scalar & v)91 Tensor & Tensor::index_put_(std::initializer_list<at::indexing::TensorIndex> indices, const Scalar& v) {
92   return index_put_(ArrayRef<at::indexing::TensorIndex>(indices), v);
93 }
94 
95 } // namespace at
96