#include #include #include namespace at { namespace indexing { const EllipsisIndexType Ellipsis = EllipsisIndexType(); std::ostream& operator<<(std::ostream& stream, const Slice& slice) { stream << slice.start() << ":" << slice.stop() << ":" << slice.step(); return stream; } std::ostream& operator<<(std::ostream& stream, const TensorIndex& tensor_index) { if (tensor_index.is_none()) { stream << "None"; } else if (tensor_index.is_ellipsis()) { stream << "..."; } else if (tensor_index.is_integer()) { stream << tensor_index.integer(); } else if (tensor_index.is_boolean()) { stream << std::boolalpha << tensor_index.boolean(); } else if (tensor_index.is_slice()) { stream << tensor_index.slice(); } else if (tensor_index.is_tensor()) { stream << tensor_index.tensor(); } return stream; } std::ostream& operator<<(std::ostream& stream, const std::vector& tensor_indices) { stream << "("; for (const auto i : c10::irange(tensor_indices.size())) { stream << tensor_indices[i]; if (i < tensor_indices.size() - 1) stream << ", "; } stream << ")"; return stream; } // This mirrors `THPVariable_setitem` in torch/csrc/autograd/python_variable_indexing.cpp // for "the assigned value is a Scalar" case static inline void set_item(const Tensor& self, ArrayRef indices, const Scalar& v) { Tensor value; { at::AutoDispatchBelowADInplaceOrView guard; at::Device self_device = self.device(); // TODO: This qint special case looks very suspicious... if (isQIntType(self.scalar_type())) { value = at::indexing::scalarToTensor(v, device(kCPU).dtype(kFloat), at::Device(kCPU)); } else if (self_device.is_cuda()) { value = at::indexing::scalarToTensor(v, self.options(), at::Device(kCPU)); } else { value = at::indexing::scalarToTensor(v, self.options(), self_device); } } return set_item(self, indices, value); } } // namespace indexing Tensor Tensor::index(ArrayRef indices) const { TORCH_CHECK(!indices.empty(), "Passing an empty index list to Tensor::index() is not valid syntax"); OptionalDeviceGuard device_guard(device_of(*this)); return at::indexing::get_item(*this, indices); } Tensor Tensor::index(std::initializer_list indices) const { return index(ArrayRef(indices)); } Tensor & Tensor::index_put_(ArrayRef indices, Tensor const & rhs) { TORCH_CHECK(!indices.empty(), "Passing an empty index list to Tensor::index_put_() is not valid syntax"); OptionalDeviceGuard device_guard(device_of(*this)); at::indexing::set_item(*this, indices, rhs); return *this; } Tensor & Tensor::index_put_(ArrayRef indices, const Scalar& v) { TORCH_CHECK(!indices.empty(), "Passing an empty index list to Tensor::index_put_() is not valid syntax"); OptionalDeviceGuard device_guard(device_of(*this)); at::indexing::set_item(*this, indices, v); return *this; } Tensor & Tensor::index_put_(std::initializer_list indices, Tensor const & rhs) { return index_put_(ArrayRef(indices), rhs); } Tensor & Tensor::index_put_(std::initializer_list indices, const Scalar& v) { return index_put_(ArrayRef(indices), v); } } // namespace at