xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/python_variable_indexing.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/autograd/python_variable_indexing.h>
2 
3 #include <torch/csrc/DynamicTypes.h>
4 #include <torch/csrc/Exceptions.h>
5 #include <torch/csrc/Export.h>
6 #include <torch/csrc/autograd/function.h>
7 #include <torch/csrc/autograd/utils/wrap_outputs.h>
8 #include <torch/csrc/autograd/variable.h>
9 #include <torch/csrc/jit/frontend/tracer.h>
10 #include <torch/csrc/jit/ir/ir.h>
11 #include <torch/csrc/utils/numpy_stub.h>
12 #include <torch/csrc/utils/pybind.h>
13 #include <torch/csrc/utils/python_arg_parser.h>
14 #include <torch/csrc/utils/python_compat.h>
15 #include <torch/csrc/utils/python_numbers.h>
16 #include <torch/csrc/utils/python_symnode.h>
17 #include <torch/csrc/utils/tensor_new.h>
18 #include <torch/csrc/utils/tensor_numpy.h>
19 #include <torch/csrc/utils/tensor_types.h>
20 
21 #include <ATen/DeviceGuard.h>
22 #include <ATen/ExpandUtils.h>
23 #include <ATen/Functions.h>
24 #include <ATen/TensorIndexing.h>
25 #include <ATen/TracerMode.h>
26 #include <ATen/core/LegacyTypeDispatch.h>
27 #include <c10/core/TensorOptions.h>
28 #include <c10/util/irange.h>
29 
30 #include <c10/core/Layout.h>
31 
32 using namespace at;
33 using namespace torch::autograd::utils;
34 
35 namespace torch::autograd {
36 
THPVariable_length(PyObject * self)37 Py_ssize_t THPVariable_length(PyObject* self) {
38   HANDLE_TH_ERRORS
39   if (check_has_torch_function(self)) {
40     py::object ret = py::reinterpret_steal<py::object>(
41         handle_torch_function(self, "__len__"));
42     Py_ssize_t length = PyLong_AsSsize_t(ret.ptr());
43     if (PyErr_Occurred()) {
44       throw python_error();
45     }
46     return length;
47   }
48   const auto& self_ = THPVariable_Unpack(self);
49   if (self_.dim() == 0) {
50     return 0;
51   }
52   // TODO: Maybe this should return a SymInt directly?
53   // Add the guard to get a nice error message if/when we will hit this.
54   return (Py_ssize_t)self_.sym_size(0).guard_int(__FILE__, __LINE__);
55   END_HANDLE_TH_ERRORS_RET(-1)
56 }
57 
58 // We allow indexing by integers, slices, ellipsis, None, Variables,
59 // and tuples of those types. We also handle bools as if they were a
60 // Variable[ByteTensor].
61 
count_specified_dimensions(PyObject * index)62 static inline int64_t count_specified_dimensions(PyObject* index) {
63   // Count the number of indexed dimensions (everything but ellipsis and None)
64   // -1 is a sentinel for __torch_function__
65   int64_t count = 0;
66   auto size = PyTuple_GET_SIZE(index);
67   for (Py_ssize_t i = 0; i < size; i++) {
68     PyObject* obj = PyTuple_GET_ITEM(index, i);
69     if (check_has_torch_function(obj))
70       return -1;
71     if (THPVariable_Check(obj)) {
72       const auto& var = THPVariable_Unpack(obj);
73       const auto& var_scalar_type = var.scalar_type();
74       if (var_scalar_type == kByte || var_scalar_type == kBool) {
75         count += var.dim();
76       } else {
77         count++;
78       }
79     } else if (
80         obj != Py_None && obj != Py_Ellipsis && obj != Py_True &&
81         obj != Py_False) {
82       count++;
83     }
84   }
85   return count;
86 }
87 
invalid_index(PyObject * obj)88 [[noreturn]] static inline void invalid_index(PyObject* obj) {
89   TORCH_CHECK_INDEX(
90       false,
91       "only integers, slices (`:`), ellipsis (`...`), None and long or byte "
92       "Variables are valid indices (got ",
93       Py_TYPE(obj)->tp_name,
94       ")");
95 }
96 
sequenceToVariable(c10::TensorOptions options,PyObject * seq)97 static inline Variable sequenceToVariable(
98     c10::TensorOptions options,
99     PyObject* seq) {
100   return torch::utils::indexing_tensor_from_data(
101       options, kLong, std::nullopt, seq);
102 }
103 
valueToTensor(c10::TensorOptions options,PyObject * value,const at::Device & device)104 inline Variable valueToTensor(
105     c10::TensorOptions options,
106     PyObject* value,
107     const at::Device& device) {
108   if (THPVariable_Check(value)) {
109     return THPVariable_Unpack(value);
110   }
111   at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove
112   at::tracer::impl::NoTracerDispatchMode tracer_guard;
113   Scalar scalar;
114   if (THPUtils_checkLong(value) || PyBool_Check(value)) {
115     scalar = Scalar(THPUtils_unpackLong(value));
116   } else if (PyFloat_Check(value)) {
117     scalar = Scalar(THPUtils_unpackDouble(value));
118   } else if (PyComplex_Check(value)) {
119     scalar = Scalar(THPUtils_unpackComplexDouble(value));
120   } else if (torch::is_symint(value)) {
121     scalar = Scalar(py::cast<c10::SymInt>(py::handle(value)));
122   } else if (torch::is_symfloat(value)) {
123     scalar = Scalar(py::cast<c10::SymFloat>(py::handle(value)));
124   } else if (torch::is_symbool(value)) {
125     scalar = Scalar(py::cast<c10::SymBool>(py::handle(value)));
126   } else {
127     throw TypeError(
128         "can't assign a %s to a %s",
129         Py_TYPE(value)->tp_name,
130         torch::utils::options_to_string(options).c_str());
131   }
132   // lift_fresh is supposed to be used in situations where you are guaranteed to
133   // get a plain Tensor which is not true for cpu device but not for non cpu
134   // device
135   if (device == at::kCPU && !scalar.isSymbolic()) {
136     return at::lift_fresh(
137         at::indexing::scalarToTensor(scalar, options, device));
138   } else {
139     return at::indexing::scalarToTensor(scalar, options, device);
140   }
141 }
142 
recordSliceTrace(PyObject * obj)143 static inline void recordSliceTrace(PyObject* obj) {
144   PySliceObject* sliceobj = (PySliceObject*)obj;
145   if (THPVariable_Check(sliceobj->start)) {
146     torch::jit::tracer::ArgumentStash::stashValue(
147         std::string("start"),
148         1,
149         THPVariable_Unpack(sliceobj->start),
150         torch::jit::IntType::get());
151   }
152   if (THPVariable_Check(sliceobj->stop)) {
153     torch::jit::tracer::ArgumentStash::stashValue(
154         std::string("end"),
155         1,
156         THPVariable_Unpack(sliceobj->stop),
157         torch::jit::IntType::get());
158   }
159   if (THPVariable_Check(sliceobj->step)) {
160     torch::jit::tracer::ArgumentStash::stashValue(
161         std::string("step"),
162         1,
163         THPVariable_Unpack(sliceobj->step),
164         torch::jit::IntType::get());
165   }
166 }
167 
recordSelectTrace(const Tensor & index_tensor)168 static inline void recordSelectTrace(const Tensor& index_tensor) {
169   torch::jit::tracer::ArgumentStash::stashValue(
170       std::string("index"), 1, index_tensor, torch::jit::IntType::get());
171 }
172 
applySlicing(const Variable & self,PyObject * index,variable_list & outIndices,bool is_tracing,const at::Device & self_device,const std::optional<int64_t> & self_ndim,int64_t specified_dims)173 static inline Variable applySlicing(
174     const Variable& self,
175     PyObject* index,
176     variable_list& outIndices,
177     bool is_tracing,
178     const at::Device& self_device,
179     const std::optional<int64_t>& self_ndim,
180     int64_t specified_dims) {
181   int64_t size = PyTuple_GET_SIZE(index);
182   int64_t dim = 0;
183 
184   // See NOTE [nested tensor size for indexing]
185   if (self_ndim.has_value()) {
186     TORCH_CHECK_INDEX(
187         specified_dims <= self_ndim.value(),
188         "too many indices for tensor of dimension ",
189         self_ndim.value());
190   }
191 
192   Variable result = self;
193   for (const auto i : c10::irange(size)) {
194     PyObject* obj = PyTuple_GET_ITEM(index, i);
195     // NOTE [nested tensor size for indexing]
196     // nested tensor does not have a size (yet) so for now we represent its size
197     // as null may need to be changed after we reach a better solution for
198     // nested tensor size
199     std::optional<SymIntArrayRef> result_sizes = result.is_nested()
200         ? std::optional<SymIntArrayRef>(std::nullopt)
201         : std::optional<SymIntArrayRef>(result.sym_sizes());
202     result = at::indexing::handleDimInMultiDimIndexing(
203         /*prev_dim_result=*/result,
204         /*original_tensor=*/self,
205         /*index=*/([&]() {
206           if (THPUtils_checkLong(obj)) {
207             if (is_tracing && THPVariable_Check(obj)) {
208               recordSelectTrace(THPVariable_Unpack(obj));
209             }
210             return at::indexing::TensorIndex(THPUtils_unpackLong(obj));
211           } else if (PySlice_Check(obj)) {
212             auto val = __PySlice_Unpack(obj);
213             if (is_tracing) {
214               recordSliceTrace(obj);
215             }
216             return at::indexing::TensorIndex(
217                 at::indexing::Slice(val.start, val.stop, val.step));
218           } else if (obj == Py_Ellipsis) {
219             return at::indexing::TensorIndex(at::indexing::Ellipsis);
220           } else if (obj == Py_None) {
221             return at::indexing::TensorIndex(at::indexing::None);
222           } else if (PyBool_Check(obj)) {
223             return at::indexing::TensorIndex(obj == Py_True);
224           } else if (THPVariable_Check(obj)) {
225             Tensor tensor = THPVariable_Unpack(obj);
226             if (is_tracing) {
227               auto scalar_type = tensor.scalar_type();
228               if (tensor.dim() == 0 &&
229                   at::isIntegralType(scalar_type, /*includeBool=*/false) &&
230                   scalar_type != at::kByte) {
231                 recordSelectTrace(tensor);
232               }
233             }
234             return at::indexing::TensorIndex(std::move(tensor));
235           } else if (PySequence_Check(obj)) {
236             return at::indexing::TensorIndex(
237                 sequenceToVariable(self.options(), obj));
238           } else {
239             auto idx = THPObjectPtr(PyNumber_Index(obj));
240             if (!idx) {
241               PyErr_Clear();
242               invalid_index(obj);
243             }
244             if (is_tracing && THPVariable_Check(idx)) {
245               recordSelectTrace(THPVariable_Unpack(idx));
246             }
247             return at::indexing::TensorIndex(THPUtils_unpackLong(idx));
248           }
249         })(),
250         /*dim_ptr=*/&dim,
251         /*specified_dims_ptr=*/&specified_dims,
252         /*real_dim=*/i,
253         /*outIndices=*/outIndices,
254         // See NOTE [ Setting `disable_slice_optimization` when calling C++
255         // tensor indexing functions from Python ]
256         /*disable_slice_optimization=*/is_tracing,
257         /*original_tensor_device=*/self_device,
258         /*prev_dim_result_sizes=*/result_sizes);
259   }
260   return result;
261 }
262 
treatSequenceAsTuple(PyObject * index)263 static inline bool treatSequenceAsTuple(PyObject* index) {
264   if (PyTuple_Check(index)) {
265     return true;
266   }
267   if (THPVariable_Check(index)) {
268     return false;
269   }
270   //  Allow indexing with ndarray if numpy compilation is enabled. An ndarray
271   //  index should not be treated as a tuple since the indexing has a different
272   //  syntax.
273 #ifdef USE_NUMPY
274   if (::torch::utils::is_numpy_available() && PyArray_CheckExact(index)) {
275     return false;
276   }
277 #endif
278   if (!PySequence_Check(index)) {
279     return false;
280   }
281   // This uses a heuristics from NumPy for determining whether to treat
282   // non-tuple sequences as if they were a tuple. From the NumPy code comments:
283   //
284   // "At this point, we're left with a non-tuple, non-array, sequence:
285   //  typically, a list. We use some somewhat-arbitrary heuristics from here
286   //  onwards to decided whether to treat that list as a single index, or a
287   //  list of indices. Backwards compatibility only takes effect for short
288   //  sequences - otherwise we treat it like any other scalar."
289   auto n = PySequence_Size(index);
290   if (n < 0) {
291     // Negative size indicates a Python error in the PySequence_Size call.
292     PyErr_Clear();
293     return false;
294   }
295   // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
296   if (n >= 32) {
297     return false;
298   }
299   for (Py_ssize_t i = 0; i < n; i++) {
300     auto obj = THPObjectPtr{PySequence_GetItem(index, i)};
301     if (!obj.get()) {
302       PyErr_Clear();
303       return false;
304     }
305     if (THPVariable_Check(obj.get()) || PySequence_Check(obj.get()) ||
306         PySlice_Check(obj.get())) {
307       return true;
308     }
309     if (obj.get() == Py_Ellipsis || obj.get() == Py_None) {
310       return true;
311     }
312   }
313   return false;
314 }
315 
wrapTuple(PyObject * index)316 static inline THPObjectPtr wrapTuple(PyObject* index) {
317   THPObjectPtr res;
318   if (treatSequenceAsTuple(index)) {
319     res = PySequence_Tuple(index);
320   } else {
321     res = PyTuple_Pack(1, index);
322   }
323   if (!res)
324     throw python_error();
325   return res;
326 }
327 
328 // NOTE: Here is the dispatch structure for `THPVariable_getitem`:
329 //
330 // 1. Python 1-D getter calls C++ `at::indexing::get_item` after
331 // converting Python index to C++ TensorIndex.
332 //
333 // 2. Python N-D getter calls C++ `at::indexing::handleDimInMultiDimIndexing`
334 // for each dim, after converting Python index to C++ TensorIndex. If advanced
335 // indexing is needed, it calls C++ `at::indexing::dispatch_index`.
THPVariable_getitem(PyObject * self,PyObject * index)336 PyObject* THPVariable_getitem(PyObject* self, PyObject* index) {
337   HANDLE_TH_ERRORS
338   if (check_has_torch_function(self)) {
339     return handle_torch_function_indexing(self, index);
340   }
341   const auto& self_ = THPVariable_Unpack(self);
342   OptionalDeviceGuard device_guard(device_of(self_));
343 
344   // handle simple types: none, ellipsis
345   if (index == Py_None) {
346     return THPVariable_Wrap(at::indexing::get_item(
347         self_, {at::indexing::TensorIndex(at::indexing::None)}));
348   } else if (index == Py_Ellipsis) {
349     return THPVariable_Wrap(at::indexing::get_item(
350         self_, {at::indexing::TensorIndex(at::indexing::Ellipsis)}));
351   }
352 
353   bool is_tracing = torch::jit::tracer::isTracing();
354 
355   // handle simple types: integers, slices, bool
356   if (THPUtils_checkLong(index)) {
357     if (is_tracing && THPVariable_Check(index)) {
358       recordSelectTrace(THPVariable_Unpack(index));
359     }
360     return THPVariable_Wrap(at::indexing::get_item(
361         self_, {at::indexing::TensorIndex(THPUtils_unpackLong(index))}));
362   } else if (PySlice_Check(index)) {
363     auto val = __PySlice_Unpack(index);
364     if (is_tracing) {
365       recordSliceTrace(index);
366     }
367     return THPVariable_Wrap(at::indexing::get_item(
368         self_,
369         {at::indexing::TensorIndex(
370             at::indexing::Slice(val.start, val.stop, val.step))}));
371   } else if (index == Py_False || index == Py_True) {
372     return THPVariable_Wrap(([&]() {
373       pybind11::gil_scoped_release no_gil;
374       return at::indexing::get_item(
375           self_, {at::indexing::TensorIndex(index == Py_True)});
376     })());
377   }
378 
379   // wrap index in a tuple if it's not already one
380   THPObjectPtr holder = wrapTuple(index);
381 
382   variable_list variableIndices;
383   int64_t specified_dims = count_specified_dimensions(holder.get());
384   if (specified_dims == -1) {
385     return handle_torch_function_indexing(self, holder.get());
386   }
387   Variable sliced = applySlicing(
388       self_,
389       holder.get(),
390       variableIndices,
391       /*is_tracing=*/is_tracing,
392       self_.device(),
393       self_.ndimension(),
394       specified_dims);
395   if (variableIndices.empty()) {
396     if (sliced.is_same(self_)) {
397       // ensure we return a shallow copy for things like x[...]
398       sliced = at::alias(sliced);
399     }
400     return THPVariable_Wrap(std::move(sliced));
401   }
402 
403   // indexing by tensors ("advanced" indexing)
404   return THPVariable_Wrap(([&]() {
405     pybind11::gil_scoped_release no_gil;
406     return at::indexing::dispatch_index(sliced, std::move(variableIndices));
407   })());
408 
409   Py_RETURN_NONE;
410   END_HANDLE_TH_ERRORS
411 }
412 
dispatch_set_item(const Tensor & self,ArrayRef<at::indexing::TensorIndex> indices,const Tensor & value,bool disable_slice_optimization=false)413 void dispatch_set_item(
414     const Tensor& self,
415     ArrayRef<at::indexing::TensorIndex> indices,
416     const Tensor& value,
417     bool disable_slice_optimization = false) {
418   pybind11::gil_scoped_release no_gil;
419   at::indexing::set_item(self, indices, value, disable_slice_optimization);
420 }
421 
422 // NOTE: Here is the dispatch structure for `THPVariable_setitem`:
423 //
424 // 1. Python 1-D setter calls C++ `at::indexing::set_item` after
425 // converting Python index to C++ TensorIndex.
426 //
427 // 2. Python N-D setter calls C++ `at::indexing::handleDimInMultiDimIndexing`
428 // for each dim, after converting Python index to C++ TensorIndex. If advanced
429 // indexing is needed, it calls C++ `at::indexing::dispatch_index_put_`.
THPVariable_setitem(PyObject * self,PyObject * index,PyObject * py_value)430 int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) {
431   HANDLE_TH_ERRORS
432   if (py_value == nullptr) {
433     throw TypeError("Tensor does not support deleting items");
434   }
435   if ((check_has_torch_function(self)) ||
436       (check_has_torch_function(py_value))) {
437     py::object ret = py::reinterpret_steal<py::object>(
438         handle_torch_function_indexing(self, index, py_value));
439     return 0;
440   }
441 
442   const auto& self_ = THPVariable_Unpack(self);
443   if (self_.layout() == kSparse || self_.layout() == kSparseCsr ||
444       self_.layout() == kSparseCsc || self_.layout() == kSparseBsr ||
445       self_.layout() == kSparseBsc) {
446     throw TypeError("Cannot assign to a sparse tensor");
447   }
448   OptionalDeviceGuard device_guard(device_of(self_));
449   at::Device self_device = self_.device();
450   Variable value;
451   // TODO: This qint special case looks very suspicious...
452   if (isQIntType(self_.scalar_type())) {
453     value =
454         valueToTensor(device(kCPU).dtype(kFloat), py_value, at::Device(kCPU));
455   } else if (self_device.is_cuda()) {
456     value = valueToTensor(self_.options(), py_value, at::Device(kCPU));
457   } else {
458     value = valueToTensor(self_.options(), py_value, self_device);
459   }
460 
461   // handle simple types: ellipsis, none, bool
462   if (index == Py_False) {
463     // do nothing for false (technically we should check the size, but we don't
464     // have real 0-sized shapes.
465     return 0;
466   } else if (index == Py_Ellipsis) {
467     dispatch_set_item(
468         self_, {at::indexing::TensorIndex(at::indexing::Ellipsis)}, value);
469     return 0;
470   } else if (index == Py_None) {
471     dispatch_set_item(
472         self_, {at::indexing::TensorIndex(at::indexing::None)}, value);
473     return 0;
474   } else if (index == Py_True) {
475     dispatch_set_item(self_, {at::indexing::TensorIndex(true)}, value);
476     return 0;
477   }
478 
479   bool is_tracing = torch::jit::tracer::isTracing();
480 
481   // handle simple types: integers, slices
482   if (THPUtils_checkLong(index) || torch::is_symint(index)) {
483     if (is_tracing && THPVariable_Check(index)) {
484       recordSelectTrace(THPVariable_Unpack(index));
485     }
486     auto symint = torch::is_symint(index) ? py::cast<SymInt>(index)
487                                           : SymInt(THPUtils_unpackLong(index));
488     dispatch_set_item(self_, {at::indexing::TensorIndex(symint)}, value);
489     return 0;
490   } else if (PySlice_Check(index)) {
491     auto val = __PySlice_Unpack(index);
492     if (is_tracing) {
493       recordSliceTrace(index);
494     }
495     // See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor
496     // indexing functions from Python ]
497     dispatch_set_item(
498         self_,
499         {at::indexing::TensorIndex(
500             at::indexing::Slice(val.start, val.stop, val.step))},
501         value,
502         /*disable_slice_optimization=*/is_tracing);
503     return 0;
504   }
505 
506   // wrap index in a tuple if it's not already one
507   THPObjectPtr holder = wrapTuple(index);
508 
509   variable_list variableIndices;
510   int64_t specified_dims = count_specified_dimensions(holder.get());
511   if (specified_dims == -1) {
512     py::object val = py::reinterpret_steal<py::object>(
513         handle_torch_function_indexing(self, index, py_value));
514     return 0;
515   }
516   Variable sliced = applySlicing(
517       self_,
518       holder.get(),
519       variableIndices,
520       /*is_tracing=*/is_tracing,
521       self_device,
522       self_.ndimension(),
523       specified_dims);
524   if (variableIndices.empty()) {
525     pybind11::gil_scoped_release no_gil;
526     at::indexing::copy_to(sliced, value);
527     return 0;
528   }
529 
530   {
531     pybind11::gil_scoped_release no_gil;
532     SymIntArrayRef valueSizes = value.sym_sizes();
533     SymIntArrayRef slicedValueSizes =
534         at::indexing::slicePrefix1sSize(valueSizes);
535     torch::autograd::Variable valuesSliced;
536     if (!valueSizes.equals(slicedValueSizes)) {
537       valuesSliced = value.view_symint(slicedValueSizes);
538     } else {
539       valuesSliced = value;
540     }
541     at::indexing::dispatch_index_put_(
542         sliced, std::move(variableIndices), valuesSliced);
543     return 0;
544   }
545   END_HANDLE_TH_ERRORS_RET(-1)
546 }
547 
548 } // namespace torch::autograd
549