1 #include <torch/csrc/autograd/python_variable_indexing.h>
2
3 #include <torch/csrc/DynamicTypes.h>
4 #include <torch/csrc/Exceptions.h>
5 #include <torch/csrc/Export.h>
6 #include <torch/csrc/autograd/function.h>
7 #include <torch/csrc/autograd/utils/wrap_outputs.h>
8 #include <torch/csrc/autograd/variable.h>
9 #include <torch/csrc/jit/frontend/tracer.h>
10 #include <torch/csrc/jit/ir/ir.h>
11 #include <torch/csrc/utils/numpy_stub.h>
12 #include <torch/csrc/utils/pybind.h>
13 #include <torch/csrc/utils/python_arg_parser.h>
14 #include <torch/csrc/utils/python_compat.h>
15 #include <torch/csrc/utils/python_numbers.h>
16 #include <torch/csrc/utils/python_symnode.h>
17 #include <torch/csrc/utils/tensor_new.h>
18 #include <torch/csrc/utils/tensor_numpy.h>
19 #include <torch/csrc/utils/tensor_types.h>
20
21 #include <ATen/DeviceGuard.h>
22 #include <ATen/ExpandUtils.h>
23 #include <ATen/Functions.h>
24 #include <ATen/TensorIndexing.h>
25 #include <ATen/TracerMode.h>
26 #include <ATen/core/LegacyTypeDispatch.h>
27 #include <c10/core/TensorOptions.h>
28 #include <c10/util/irange.h>
29
30 #include <c10/core/Layout.h>
31
32 using namespace at;
33 using namespace torch::autograd::utils;
34
35 namespace torch::autograd {
36
THPVariable_length(PyObject * self)37 Py_ssize_t THPVariable_length(PyObject* self) {
38 HANDLE_TH_ERRORS
39 if (check_has_torch_function(self)) {
40 py::object ret = py::reinterpret_steal<py::object>(
41 handle_torch_function(self, "__len__"));
42 Py_ssize_t length = PyLong_AsSsize_t(ret.ptr());
43 if (PyErr_Occurred()) {
44 throw python_error();
45 }
46 return length;
47 }
48 const auto& self_ = THPVariable_Unpack(self);
49 if (self_.dim() == 0) {
50 return 0;
51 }
52 // TODO: Maybe this should return a SymInt directly?
53 // Add the guard to get a nice error message if/when we will hit this.
54 return (Py_ssize_t)self_.sym_size(0).guard_int(__FILE__, __LINE__);
55 END_HANDLE_TH_ERRORS_RET(-1)
56 }
57
58 // We allow indexing by integers, slices, ellipsis, None, Variables,
59 // and tuples of those types. We also handle bools as if they were a
60 // Variable[ByteTensor].
61
count_specified_dimensions(PyObject * index)62 static inline int64_t count_specified_dimensions(PyObject* index) {
63 // Count the number of indexed dimensions (everything but ellipsis and None)
64 // -1 is a sentinel for __torch_function__
65 int64_t count = 0;
66 auto size = PyTuple_GET_SIZE(index);
67 for (Py_ssize_t i = 0; i < size; i++) {
68 PyObject* obj = PyTuple_GET_ITEM(index, i);
69 if (check_has_torch_function(obj))
70 return -1;
71 if (THPVariable_Check(obj)) {
72 const auto& var = THPVariable_Unpack(obj);
73 const auto& var_scalar_type = var.scalar_type();
74 if (var_scalar_type == kByte || var_scalar_type == kBool) {
75 count += var.dim();
76 } else {
77 count++;
78 }
79 } else if (
80 obj != Py_None && obj != Py_Ellipsis && obj != Py_True &&
81 obj != Py_False) {
82 count++;
83 }
84 }
85 return count;
86 }
87
invalid_index(PyObject * obj)88 [[noreturn]] static inline void invalid_index(PyObject* obj) {
89 TORCH_CHECK_INDEX(
90 false,
91 "only integers, slices (`:`), ellipsis (`...`), None and long or byte "
92 "Variables are valid indices (got ",
93 Py_TYPE(obj)->tp_name,
94 ")");
95 }
96
sequenceToVariable(c10::TensorOptions options,PyObject * seq)97 static inline Variable sequenceToVariable(
98 c10::TensorOptions options,
99 PyObject* seq) {
100 return torch::utils::indexing_tensor_from_data(
101 options, kLong, std::nullopt, seq);
102 }
103
valueToTensor(c10::TensorOptions options,PyObject * value,const at::Device & device)104 inline Variable valueToTensor(
105 c10::TensorOptions options,
106 PyObject* value,
107 const at::Device& device) {
108 if (THPVariable_Check(value)) {
109 return THPVariable_Unpack(value);
110 }
111 at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove
112 at::tracer::impl::NoTracerDispatchMode tracer_guard;
113 Scalar scalar;
114 if (THPUtils_checkLong(value) || PyBool_Check(value)) {
115 scalar = Scalar(THPUtils_unpackLong(value));
116 } else if (PyFloat_Check(value)) {
117 scalar = Scalar(THPUtils_unpackDouble(value));
118 } else if (PyComplex_Check(value)) {
119 scalar = Scalar(THPUtils_unpackComplexDouble(value));
120 } else if (torch::is_symint(value)) {
121 scalar = Scalar(py::cast<c10::SymInt>(py::handle(value)));
122 } else if (torch::is_symfloat(value)) {
123 scalar = Scalar(py::cast<c10::SymFloat>(py::handle(value)));
124 } else if (torch::is_symbool(value)) {
125 scalar = Scalar(py::cast<c10::SymBool>(py::handle(value)));
126 } else {
127 throw TypeError(
128 "can't assign a %s to a %s",
129 Py_TYPE(value)->tp_name,
130 torch::utils::options_to_string(options).c_str());
131 }
132 // lift_fresh is supposed to be used in situations where you are guaranteed to
133 // get a plain Tensor which is not true for cpu device but not for non cpu
134 // device
135 if (device == at::kCPU && !scalar.isSymbolic()) {
136 return at::lift_fresh(
137 at::indexing::scalarToTensor(scalar, options, device));
138 } else {
139 return at::indexing::scalarToTensor(scalar, options, device);
140 }
141 }
142
recordSliceTrace(PyObject * obj)143 static inline void recordSliceTrace(PyObject* obj) {
144 PySliceObject* sliceobj = (PySliceObject*)obj;
145 if (THPVariable_Check(sliceobj->start)) {
146 torch::jit::tracer::ArgumentStash::stashValue(
147 std::string("start"),
148 1,
149 THPVariable_Unpack(sliceobj->start),
150 torch::jit::IntType::get());
151 }
152 if (THPVariable_Check(sliceobj->stop)) {
153 torch::jit::tracer::ArgumentStash::stashValue(
154 std::string("end"),
155 1,
156 THPVariable_Unpack(sliceobj->stop),
157 torch::jit::IntType::get());
158 }
159 if (THPVariable_Check(sliceobj->step)) {
160 torch::jit::tracer::ArgumentStash::stashValue(
161 std::string("step"),
162 1,
163 THPVariable_Unpack(sliceobj->step),
164 torch::jit::IntType::get());
165 }
166 }
167
recordSelectTrace(const Tensor & index_tensor)168 static inline void recordSelectTrace(const Tensor& index_tensor) {
169 torch::jit::tracer::ArgumentStash::stashValue(
170 std::string("index"), 1, index_tensor, torch::jit::IntType::get());
171 }
172
applySlicing(const Variable & self,PyObject * index,variable_list & outIndices,bool is_tracing,const at::Device & self_device,const std::optional<int64_t> & self_ndim,int64_t specified_dims)173 static inline Variable applySlicing(
174 const Variable& self,
175 PyObject* index,
176 variable_list& outIndices,
177 bool is_tracing,
178 const at::Device& self_device,
179 const std::optional<int64_t>& self_ndim,
180 int64_t specified_dims) {
181 int64_t size = PyTuple_GET_SIZE(index);
182 int64_t dim = 0;
183
184 // See NOTE [nested tensor size for indexing]
185 if (self_ndim.has_value()) {
186 TORCH_CHECK_INDEX(
187 specified_dims <= self_ndim.value(),
188 "too many indices for tensor of dimension ",
189 self_ndim.value());
190 }
191
192 Variable result = self;
193 for (const auto i : c10::irange(size)) {
194 PyObject* obj = PyTuple_GET_ITEM(index, i);
195 // NOTE [nested tensor size for indexing]
196 // nested tensor does not have a size (yet) so for now we represent its size
197 // as null may need to be changed after we reach a better solution for
198 // nested tensor size
199 std::optional<SymIntArrayRef> result_sizes = result.is_nested()
200 ? std::optional<SymIntArrayRef>(std::nullopt)
201 : std::optional<SymIntArrayRef>(result.sym_sizes());
202 result = at::indexing::handleDimInMultiDimIndexing(
203 /*prev_dim_result=*/result,
204 /*original_tensor=*/self,
205 /*index=*/([&]() {
206 if (THPUtils_checkLong(obj)) {
207 if (is_tracing && THPVariable_Check(obj)) {
208 recordSelectTrace(THPVariable_Unpack(obj));
209 }
210 return at::indexing::TensorIndex(THPUtils_unpackLong(obj));
211 } else if (PySlice_Check(obj)) {
212 auto val = __PySlice_Unpack(obj);
213 if (is_tracing) {
214 recordSliceTrace(obj);
215 }
216 return at::indexing::TensorIndex(
217 at::indexing::Slice(val.start, val.stop, val.step));
218 } else if (obj == Py_Ellipsis) {
219 return at::indexing::TensorIndex(at::indexing::Ellipsis);
220 } else if (obj == Py_None) {
221 return at::indexing::TensorIndex(at::indexing::None);
222 } else if (PyBool_Check(obj)) {
223 return at::indexing::TensorIndex(obj == Py_True);
224 } else if (THPVariable_Check(obj)) {
225 Tensor tensor = THPVariable_Unpack(obj);
226 if (is_tracing) {
227 auto scalar_type = tensor.scalar_type();
228 if (tensor.dim() == 0 &&
229 at::isIntegralType(scalar_type, /*includeBool=*/false) &&
230 scalar_type != at::kByte) {
231 recordSelectTrace(tensor);
232 }
233 }
234 return at::indexing::TensorIndex(std::move(tensor));
235 } else if (PySequence_Check(obj)) {
236 return at::indexing::TensorIndex(
237 sequenceToVariable(self.options(), obj));
238 } else {
239 auto idx = THPObjectPtr(PyNumber_Index(obj));
240 if (!idx) {
241 PyErr_Clear();
242 invalid_index(obj);
243 }
244 if (is_tracing && THPVariable_Check(idx)) {
245 recordSelectTrace(THPVariable_Unpack(idx));
246 }
247 return at::indexing::TensorIndex(THPUtils_unpackLong(idx));
248 }
249 })(),
250 /*dim_ptr=*/&dim,
251 /*specified_dims_ptr=*/&specified_dims,
252 /*real_dim=*/i,
253 /*outIndices=*/outIndices,
254 // See NOTE [ Setting `disable_slice_optimization` when calling C++
255 // tensor indexing functions from Python ]
256 /*disable_slice_optimization=*/is_tracing,
257 /*original_tensor_device=*/self_device,
258 /*prev_dim_result_sizes=*/result_sizes);
259 }
260 return result;
261 }
262
treatSequenceAsTuple(PyObject * index)263 static inline bool treatSequenceAsTuple(PyObject* index) {
264 if (PyTuple_Check(index)) {
265 return true;
266 }
267 if (THPVariable_Check(index)) {
268 return false;
269 }
270 // Allow indexing with ndarray if numpy compilation is enabled. An ndarray
271 // index should not be treated as a tuple since the indexing has a different
272 // syntax.
273 #ifdef USE_NUMPY
274 if (::torch::utils::is_numpy_available() && PyArray_CheckExact(index)) {
275 return false;
276 }
277 #endif
278 if (!PySequence_Check(index)) {
279 return false;
280 }
281 // This uses a heuristics from NumPy for determining whether to treat
282 // non-tuple sequences as if they were a tuple. From the NumPy code comments:
283 //
284 // "At this point, we're left with a non-tuple, non-array, sequence:
285 // typically, a list. We use some somewhat-arbitrary heuristics from here
286 // onwards to decided whether to treat that list as a single index, or a
287 // list of indices. Backwards compatibility only takes effect for short
288 // sequences - otherwise we treat it like any other scalar."
289 auto n = PySequence_Size(index);
290 if (n < 0) {
291 // Negative size indicates a Python error in the PySequence_Size call.
292 PyErr_Clear();
293 return false;
294 }
295 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
296 if (n >= 32) {
297 return false;
298 }
299 for (Py_ssize_t i = 0; i < n; i++) {
300 auto obj = THPObjectPtr{PySequence_GetItem(index, i)};
301 if (!obj.get()) {
302 PyErr_Clear();
303 return false;
304 }
305 if (THPVariable_Check(obj.get()) || PySequence_Check(obj.get()) ||
306 PySlice_Check(obj.get())) {
307 return true;
308 }
309 if (obj.get() == Py_Ellipsis || obj.get() == Py_None) {
310 return true;
311 }
312 }
313 return false;
314 }
315
wrapTuple(PyObject * index)316 static inline THPObjectPtr wrapTuple(PyObject* index) {
317 THPObjectPtr res;
318 if (treatSequenceAsTuple(index)) {
319 res = PySequence_Tuple(index);
320 } else {
321 res = PyTuple_Pack(1, index);
322 }
323 if (!res)
324 throw python_error();
325 return res;
326 }
327
328 // NOTE: Here is the dispatch structure for `THPVariable_getitem`:
329 //
330 // 1. Python 1-D getter calls C++ `at::indexing::get_item` after
331 // converting Python index to C++ TensorIndex.
332 //
333 // 2. Python N-D getter calls C++ `at::indexing::handleDimInMultiDimIndexing`
334 // for each dim, after converting Python index to C++ TensorIndex. If advanced
335 // indexing is needed, it calls C++ `at::indexing::dispatch_index`.
THPVariable_getitem(PyObject * self,PyObject * index)336 PyObject* THPVariable_getitem(PyObject* self, PyObject* index) {
337 HANDLE_TH_ERRORS
338 if (check_has_torch_function(self)) {
339 return handle_torch_function_indexing(self, index);
340 }
341 const auto& self_ = THPVariable_Unpack(self);
342 OptionalDeviceGuard device_guard(device_of(self_));
343
344 // handle simple types: none, ellipsis
345 if (index == Py_None) {
346 return THPVariable_Wrap(at::indexing::get_item(
347 self_, {at::indexing::TensorIndex(at::indexing::None)}));
348 } else if (index == Py_Ellipsis) {
349 return THPVariable_Wrap(at::indexing::get_item(
350 self_, {at::indexing::TensorIndex(at::indexing::Ellipsis)}));
351 }
352
353 bool is_tracing = torch::jit::tracer::isTracing();
354
355 // handle simple types: integers, slices, bool
356 if (THPUtils_checkLong(index)) {
357 if (is_tracing && THPVariable_Check(index)) {
358 recordSelectTrace(THPVariable_Unpack(index));
359 }
360 return THPVariable_Wrap(at::indexing::get_item(
361 self_, {at::indexing::TensorIndex(THPUtils_unpackLong(index))}));
362 } else if (PySlice_Check(index)) {
363 auto val = __PySlice_Unpack(index);
364 if (is_tracing) {
365 recordSliceTrace(index);
366 }
367 return THPVariable_Wrap(at::indexing::get_item(
368 self_,
369 {at::indexing::TensorIndex(
370 at::indexing::Slice(val.start, val.stop, val.step))}));
371 } else if (index == Py_False || index == Py_True) {
372 return THPVariable_Wrap(([&]() {
373 pybind11::gil_scoped_release no_gil;
374 return at::indexing::get_item(
375 self_, {at::indexing::TensorIndex(index == Py_True)});
376 })());
377 }
378
379 // wrap index in a tuple if it's not already one
380 THPObjectPtr holder = wrapTuple(index);
381
382 variable_list variableIndices;
383 int64_t specified_dims = count_specified_dimensions(holder.get());
384 if (specified_dims == -1) {
385 return handle_torch_function_indexing(self, holder.get());
386 }
387 Variable sliced = applySlicing(
388 self_,
389 holder.get(),
390 variableIndices,
391 /*is_tracing=*/is_tracing,
392 self_.device(),
393 self_.ndimension(),
394 specified_dims);
395 if (variableIndices.empty()) {
396 if (sliced.is_same(self_)) {
397 // ensure we return a shallow copy for things like x[...]
398 sliced = at::alias(sliced);
399 }
400 return THPVariable_Wrap(std::move(sliced));
401 }
402
403 // indexing by tensors ("advanced" indexing)
404 return THPVariable_Wrap(([&]() {
405 pybind11::gil_scoped_release no_gil;
406 return at::indexing::dispatch_index(sliced, std::move(variableIndices));
407 })());
408
409 Py_RETURN_NONE;
410 END_HANDLE_TH_ERRORS
411 }
412
dispatch_set_item(const Tensor & self,ArrayRef<at::indexing::TensorIndex> indices,const Tensor & value,bool disable_slice_optimization=false)413 void dispatch_set_item(
414 const Tensor& self,
415 ArrayRef<at::indexing::TensorIndex> indices,
416 const Tensor& value,
417 bool disable_slice_optimization = false) {
418 pybind11::gil_scoped_release no_gil;
419 at::indexing::set_item(self, indices, value, disable_slice_optimization);
420 }
421
422 // NOTE: Here is the dispatch structure for `THPVariable_setitem`:
423 //
424 // 1. Python 1-D setter calls C++ `at::indexing::set_item` after
425 // converting Python index to C++ TensorIndex.
426 //
427 // 2. Python N-D setter calls C++ `at::indexing::handleDimInMultiDimIndexing`
428 // for each dim, after converting Python index to C++ TensorIndex. If advanced
429 // indexing is needed, it calls C++ `at::indexing::dispatch_index_put_`.
THPVariable_setitem(PyObject * self,PyObject * index,PyObject * py_value)430 int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) {
431 HANDLE_TH_ERRORS
432 if (py_value == nullptr) {
433 throw TypeError("Tensor does not support deleting items");
434 }
435 if ((check_has_torch_function(self)) ||
436 (check_has_torch_function(py_value))) {
437 py::object ret = py::reinterpret_steal<py::object>(
438 handle_torch_function_indexing(self, index, py_value));
439 return 0;
440 }
441
442 const auto& self_ = THPVariable_Unpack(self);
443 if (self_.layout() == kSparse || self_.layout() == kSparseCsr ||
444 self_.layout() == kSparseCsc || self_.layout() == kSparseBsr ||
445 self_.layout() == kSparseBsc) {
446 throw TypeError("Cannot assign to a sparse tensor");
447 }
448 OptionalDeviceGuard device_guard(device_of(self_));
449 at::Device self_device = self_.device();
450 Variable value;
451 // TODO: This qint special case looks very suspicious...
452 if (isQIntType(self_.scalar_type())) {
453 value =
454 valueToTensor(device(kCPU).dtype(kFloat), py_value, at::Device(kCPU));
455 } else if (self_device.is_cuda()) {
456 value = valueToTensor(self_.options(), py_value, at::Device(kCPU));
457 } else {
458 value = valueToTensor(self_.options(), py_value, self_device);
459 }
460
461 // handle simple types: ellipsis, none, bool
462 if (index == Py_False) {
463 // do nothing for false (technically we should check the size, but we don't
464 // have real 0-sized shapes.
465 return 0;
466 } else if (index == Py_Ellipsis) {
467 dispatch_set_item(
468 self_, {at::indexing::TensorIndex(at::indexing::Ellipsis)}, value);
469 return 0;
470 } else if (index == Py_None) {
471 dispatch_set_item(
472 self_, {at::indexing::TensorIndex(at::indexing::None)}, value);
473 return 0;
474 } else if (index == Py_True) {
475 dispatch_set_item(self_, {at::indexing::TensorIndex(true)}, value);
476 return 0;
477 }
478
479 bool is_tracing = torch::jit::tracer::isTracing();
480
481 // handle simple types: integers, slices
482 if (THPUtils_checkLong(index) || torch::is_symint(index)) {
483 if (is_tracing && THPVariable_Check(index)) {
484 recordSelectTrace(THPVariable_Unpack(index));
485 }
486 auto symint = torch::is_symint(index) ? py::cast<SymInt>(index)
487 : SymInt(THPUtils_unpackLong(index));
488 dispatch_set_item(self_, {at::indexing::TensorIndex(symint)}, value);
489 return 0;
490 } else if (PySlice_Check(index)) {
491 auto val = __PySlice_Unpack(index);
492 if (is_tracing) {
493 recordSliceTrace(index);
494 }
495 // See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor
496 // indexing functions from Python ]
497 dispatch_set_item(
498 self_,
499 {at::indexing::TensorIndex(
500 at::indexing::Slice(val.start, val.stop, val.step))},
501 value,
502 /*disable_slice_optimization=*/is_tracing);
503 return 0;
504 }
505
506 // wrap index in a tuple if it's not already one
507 THPObjectPtr holder = wrapTuple(index);
508
509 variable_list variableIndices;
510 int64_t specified_dims = count_specified_dimensions(holder.get());
511 if (specified_dims == -1) {
512 py::object val = py::reinterpret_steal<py::object>(
513 handle_torch_function_indexing(self, index, py_value));
514 return 0;
515 }
516 Variable sliced = applySlicing(
517 self_,
518 holder.get(),
519 variableIndices,
520 /*is_tracing=*/is_tracing,
521 self_device,
522 self_.ndimension(),
523 specified_dims);
524 if (variableIndices.empty()) {
525 pybind11::gil_scoped_release no_gil;
526 at::indexing::copy_to(sliced, value);
527 return 0;
528 }
529
530 {
531 pybind11::gil_scoped_release no_gil;
532 SymIntArrayRef valueSizes = value.sym_sizes();
533 SymIntArrayRef slicedValueSizes =
534 at::indexing::slicePrefix1sSize(valueSizes);
535 torch::autograd::Variable valuesSliced;
536 if (!valueSizes.equals(slicedValueSizes)) {
537 valuesSliced = value.view_symint(slicedValueSizes);
538 } else {
539 valuesSliced = value;
540 }
541 at::indexing::dispatch_index_put_(
542 sliced, std::move(variableIndices), valuesSliced);
543 return 0;
544 }
545 END_HANDLE_TH_ERRORS_RET(-1)
546 }
547
548 } // namespace torch::autograd
549