xref: /aosp_15_r20/external/pytorch/aten/src/ATen/NestedTensorImpl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/ATen.h>
2 #include <ATen/NamedTensorUtils.h>
3 #include <ATen/WrapDimUtils.h>
4 #include <ATen/NestedTensorImpl.h>
5 #include <c10/core/DispatchKey.h>
6 #include <c10/core/DispatchKeySet.h>
7 #include <c10/util/Exception.h>
8 #include <c10/core/TensorImpl.h>
9 #include <c10/util/Logging.h>
10 
11 #include <numeric>
12 #include <functional>
13 #include <utility>
14 
15 namespace {
validate_nested_tensor_metadata(const at::Tensor & nested_sizes,const at::Tensor & nested_strides,const at::Tensor & offsets)16 inline void validate_nested_tensor_metadata(
17     const at::Tensor& nested_sizes,
18     const at::Tensor& nested_strides,
19     const at::Tensor& offsets) {
20   TORCH_INTERNAL_ASSERT(nested_sizes.is_contiguous());
21   int64_t size_dim = nested_sizes.dim();
22   TORCH_INTERNAL_ASSERT(size_dim == 0 || size_dim == 2);
23   TORCH_INTERNAL_ASSERT(nested_strides.is_contiguous());
24   TORCH_INTERNAL_ASSERT(nested_strides.dim() == size_dim);
25   TORCH_INTERNAL_ASSERT(nested_sizes.sizes() == nested_strides.sizes());
26   TORCH_INTERNAL_ASSERT(
27       (size_dim == 0 && offsets.size(0) == 0) ||
28       (size_dim == 2 && nested_sizes.size(0) == offsets.size(0)));
29 }
30 
31 /**
32  * Generates a nested key_set from a non-nested tensor.
33  *
34  * When creating a nested tensor from a non-nested tensor
35  * We want to maintain the same keyset as the buffer but
36  * swap non nested keys for nested ones
37  *
38  * @return Appropriate key set for nested tensor
39  */
generate_nested_key_set_from_buffer(const at::Tensor & buffer)40 inline c10::DispatchKeySet generate_nested_key_set_from_buffer(
41     const at::Tensor& buffer) {
42   auto nested_key_set = buffer.key_set();
43   const bool has_autograd = nested_key_set.has_any(c10::autograd_dispatch_keyset);
44   // Remove non_nested tensor specific keys
45   nested_key_set = nested_key_set -
46       c10::DispatchKeySet{c10::DispatchKey::Dense, c10::DispatchKey::Autograd};
47 
48   // Add nested tensor specific keys
49   nested_key_set =
50       nested_key_set | c10::DispatchKeySet{c10::DispatchKey::NestedTensor};
51   nested_key_set =
52       has_autograd ? nested_key_set | c10::autograd_nested : nested_key_set;
53   return nested_key_set;
54 }
55 
56 /**
57  * Generates a the correct view keyset.
58  *
59  * When creating a nested tensor view of base
60  * The appropriate keyset will be dependent on the nested
61  * status of the base
62  *
63  * @return Appropriate key set for nested tensor
64  */
get_view_key_set(const at::Tensor & base)65 c10::DispatchKeySet get_view_key_set(const at::Tensor& base) {
66   return base.is_nested() ? base.key_set()
67                           : generate_nested_key_set_from_buffer(base);
68 }
69 
70 } // namespace
71 
72 namespace at::native {
73 
construct_opt_sizes(const at::Tensor & sizes)74 inline std::vector<int64_t> construct_opt_sizes(const at::Tensor& sizes) {
75   // torch.tensor([]) is considered to have `dim() = 1` and `size(0) = 0`
76   // torch.nested_tensor([]) should also has `dim() = 1` and `size(0) = 0`
77   if (sizes.dim() == 0) {
78     return std::vector<int64_t>({0});
79   }
80   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sizes.dim() == 2);
81   std::vector<int64_t> result(1, sizes.sizes()[0]);
82   if (sizes.dim() > 0) {
83     size_t nested_dim = result.size();
84     const int64_t* sizes_ptr = sizes.const_data_ptr<int64_t>();
85     result.resize(nested_dim + sizes.sizes()[1]);
86     int64_t sizes_size_0 = sizes.sizes()[0];
87     int64_t sizes_size_1 = sizes.sizes()[1];
88     for (const auto i : c10::irange(sizes_size_1)) {
89       result[nested_dim + i] = sizes_ptr[i];
90     }
91     for (const auto j : c10::irange(sizes_size_1)) {
92       for (const auto i : c10::irange(sizes_size_0)) {
93         if (result[nested_dim + j] &&
94             (result[nested_dim + j] != sizes_ptr[i * sizes.size(1) + j])) {
95           result[nested_dim + j] = -1;
96         }
97       }
98     }
99   }
100   return result;
101 }
102 
103 // assume contiguous, we can construct stride from size
construct_nested_strides(const at::Tensor & sizes)104 at::Tensor construct_nested_strides(const at::Tensor& sizes) {
105   // empty `sizes` means empty nested tensor, so return empty strides
106   if (sizes.dim() == 0) {
107     return sizes;
108   }
109   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(sizes.dim() == 2);
110   int64_t orig_dim = sizes.size(1);
111   // `sizes`.sizes() = ntensors x 0 means empty but shaped `sizes`
112   // in this case strides is also empty but shaped
113   if (orig_dim == 0) {
114     return sizes;
115   }
116   at::Tensor strides = sizes.new_empty(sizes.sizes());
117   const int64_t* sizes_ptr = sizes.const_data_ptr<int64_t>();
118   int64_t* strides_ptr = strides.data_ptr<int64_t>();
119   for (int64_t i = 0; i < sizes.size(0); i++) {
120     strides_ptr[orig_dim - 1] = 1;
121     int64_t product = sizes_ptr[orig_dim - 1];
122     for (int64_t j = orig_dim - 2; j >= 0; j--) {
123       strides_ptr[j] = product;
124       product *= sizes_ptr[j];
125     }
126     sizes_ptr += orig_dim;
127     strides_ptr += orig_dim;
128   }
129   return strides;
130 }
131 
132 /**
133    * Create a tensor of offsets assuming the nested tensor is contiguous
134    *
135    * This function iterates over the implicit ntensor outer dimension
136    * populating a tensor with the num_elements in each implicit tensor.
137    * The first element is always 0 and the length of the returned tensor
138    * is n_tensor.
139    *
140    * @return A tensor of offsets
141   */
construct_offsets(const at::Tensor & sizes)142 at::Tensor construct_offsets(const at::Tensor& sizes) {
143   // empty `sizes` means empty nested tensor, so return empty strides
144   if (sizes.dim() == 0) {
145     return at::empty({0}, sizes.options().dtype(kLong));
146   }
147   int64_t ntensors = sizes.size(0), orig_dim = sizes.size(1);
148   auto offsets = at::empty({ntensors}, sizes.options());
149   int64_t *offsets_ptr = offsets.mutable_data_ptr<int64_t>();
150   // nesting scalars has easy offsets
151   if (orig_dim == 0) {
152     std::iota(offsets_ptr, offsets_ptr + ntensors, 0);
153     return offsets;
154   }
155   const int64_t* sizes_ptr = sizes.const_data_ptr<int64_t>();
156   offsets_ptr[0] = 0;
157   for (const auto i : c10::irange(ntensors - 1)) {
158     const int64_t row_product = std::accumulate(sizes_ptr, sizes_ptr + orig_dim, 1, std::multiplies());
159     offsets_ptr[i + 1] = offsets_ptr[i] + row_product;
160     sizes_ptr += orig_dim;
161   }
162   return offsets;
163 }
164 
NestedTensorImpl(Storage storage,c10::DispatchKeySet key_set,const caffe2::TypeMeta data_type,at::Tensor nested_sizes,at::Tensor nested_strides,at::Tensor storage_offsets)165 NestedTensorImpl::NestedTensorImpl(
166     Storage storage,
167     c10::DispatchKeySet key_set,
168     const caffe2::TypeMeta data_type,
169     at::Tensor nested_sizes,
170     at::Tensor nested_strides,
171     at::Tensor storage_offsets)
172     : TensorImpl(std::move(storage), key_set, data_type),
173       nested_sizes_(std::move(nested_sizes)),
174       nested_strides_(std::move(nested_strides)),
175       storage_offsets_(std::move(storage_offsets)),
176       opt_sizes_(std::nullopt) {
177   C10_LOG_API_USAGE_ONCE("torch.NestedTensor");
178   TORCH_WARN_ONCE(
179       "The PyTorch API of nested tensors is in prototype stage and will change "
180       "in the near future.");
181   auto storage_device = storage_.device();
182   TORCH_INTERNAL_ASSERT(
183       storage_device.is_cpu() || storage_device.is_cuda() || storage_device.is_xpu() || storage_device.is_privateuseone(),
184       "NestedTensorImpl storage must be either CUDA, CPU, XPU or ", get_privateuse1_backend(), " but got ",
185       storage_device);
186   validate_nested_tensor_metadata(nested_sizes_, nested_strides_, storage_offsets_);
187   refresh_dim();
188   set_custom_sizes_strides(c10::TensorImpl::SizesStridesPolicy::CustomSizes);
189 }
190 
NestedTensorImpl(const at::Tensor & buffer,at::Tensor nested_sizes,at::Tensor nested_strides,at::Tensor storage_offsets)191 NestedTensorImpl::NestedTensorImpl(
192     const at::Tensor& buffer,
193     at::Tensor nested_sizes,
194     at::Tensor nested_strides,
195     at::Tensor storage_offsets)
196     : NestedTensorImpl(
197           buffer.storage(),
198           generate_nested_key_set_from_buffer(buffer),
199           buffer.dtype(),
200           std::move(nested_sizes),
201           std::move(nested_strides),
202           std::move(storage_offsets)) {
203 
204   TORCH_INTERNAL_ASSERT(
205       buffer.dim() == 1,
206       "NestedTensorImpl buffer is required to be 1 dimensional but got a buffer with ",
207       buffer.dim(),
208       " dimensions.");
209 }
210 
211 // assume contiguous, `nested_strides` and `offsets`
212 // can be infered from `nested_sizes`
NestedTensorImpl(const at::Tensor & buffer,const at::Tensor & nested_sizes)213 NestedTensorImpl::NestedTensorImpl(
214     const at::Tensor& buffer,
215     const at::Tensor& nested_sizes)
216     : NestedTensorImpl(
217           buffer,
218           nested_sizes,
219           construct_nested_strides(nested_sizes),
220           construct_offsets(nested_sizes))
221 {}
222 
NestedTensorImpl(c10::TensorImpl::ImplType impl_type,const at::Tensor & base_tensor,at::Tensor nested_sizes,at::Tensor nested_strides,at::Tensor storage_offsets)223 NestedTensorImpl::NestedTensorImpl(
224     c10::TensorImpl::ImplType impl_type,
225     const at::Tensor& base_tensor,
226     at::Tensor nested_sizes,
227     at::Tensor nested_strides,
228     at::Tensor storage_offsets)
229     : TensorImpl(impl_type, Storage(base_tensor.storage()), get_view_key_set(base_tensor), base_tensor.dtype()),
230       nested_sizes_(std::move(nested_sizes)),
231       nested_strides_(std::move(nested_strides)),
232       storage_offsets_(std::move(storage_offsets)),
233       opt_sizes_(std::nullopt) {
234   validate_nested_tensor_metadata(nested_sizes_, nested_strides_, storage_offsets_);
235   refresh_dim();
236   set_custom_sizes_strides(c10::TensorImpl::SizesStridesPolicy::CustomSizes);
237 }
238 
opt_size(int64_t d) const239 std::optional<int64_t> NestedTensorImpl::opt_size(int64_t d) const {
240   if (C10_UNLIKELY(!opt_sizes_.has_value())) {
241     // Cache the metadata to avoid recomputing it each time.
242     opt_sizes_ = std::make_optional(construct_opt_sizes(nested_sizes_));
243   }
244   d = at::maybe_wrap_dim(d, dim(), false);
245   if ((*opt_sizes_)[d] == -1) {
246     return std::nullopt;
247   }
248   return (*opt_sizes_)[d];
249 }
250 
refresh_dim()251 void NestedTensorImpl::refresh_dim() {
252   const auto my_dim = nested_sizes_.dim() ? nested_sizes_.sizes()[1] + 1 : 1;
253   sizes_and_strides_.resize(my_dim);
254   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(dim() == my_dim);
255 }
256 
dim_custom() const257 int64_t NestedTensorImpl::dim_custom() const {
258   return dim_default();
259 }
260 
261 // Currently sizes and strides assume contiguous
numel_custom() const262 int64_t NestedTensorImpl::numel_custom() const {
263   if (nested_sizes_.dim() == 0) {
264     return 0;
265   }
266   return get_numel_from_nested_size_tensor(nested_sizes_);
267 }
268 
269 
sym_numel_custom() const270 c10::SymInt NestedTensorImpl::sym_numel_custom() const {
271   return NestedTensorImpl::numel_custom();
272 }
273 
is_contiguous_custom(MemoryFormat) const274 bool NestedTensorImpl::is_contiguous_custom(MemoryFormat) const {
275   return nested_tensor_impl_is_contiguous(this);
276 }
sizes_custom() const277 IntArrayRef NestedTensorImpl::sizes_custom() const {
278   TORCH_CHECK(false, "Internal error: NestedTensorImpl doesn't support sizes. Please file an issue.");
279 }
sym_sizes_custom() const280 c10::SymIntArrayRef NestedTensorImpl::sym_sizes_custom() const {
281   TORCH_CHECK(false, "Internal error: NestedTensorImpl doesn't support sizes. Please file an issue.");
282 }
283 
sym_strides_custom() const284 c10::SymIntArrayRef NestedTensorImpl::sym_strides_custom() const {
285   TORCH_CHECK(false, "Internal error: NestedTensorImpl doesn't support strides. Please file an issue.");
286 }
287 
strides_custom() const288 IntArrayRef NestedTensorImpl::strides_custom() const {
289   TORCH_CHECK(false, "Internal error: NestedTensorImpl doesn't support strides. Please file an issue.");
290 }
291 
tensorimpl_type_name() const292 const char* NestedTensorImpl::tensorimpl_type_name() const {
293   return "NestedTensorImpl";
294 }
295 
296 
297 template <typename VariableVersion>
shallow_copy_and_detach_core(VariableVersion && version_counter,bool allow_tensor_metadata_change) const298 c10::intrusive_ptr<TensorImpl> NestedTensorImpl::shallow_copy_and_detach_core(
299     VariableVersion&& version_counter,
300     bool allow_tensor_metadata_change) const {
301   if (key_set_.has(DispatchKey::Python) &&
302       !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) {
303     auto r = pyobj_slot_.load_pyobj_interpreter()->detach(this);
304     if (r) {
305       r->set_version_counter(std::forward<VariableVersion>(version_counter));
306       r->set_allow_tensor_metadata_change(allow_tensor_metadata_change);
307       return r;
308     }
309     // otherwise just copy the TensorImpl and not the PyObject.  Since
310     // the interpreter is dead no one can call us out on it
311   }
312   auto impl = c10::make_intrusive<NestedTensorImpl>(
313       storage_,
314       key_set_,
315       data_type_,
316       nested_sizes_,
317       nested_strides_,
318       storage_offsets_);
319 
320       copy_tensor_metadata(
321           /*src_impl=*/this,
322           /*dest_impl=*/impl.get(),
323           /*version_counter=*/std::forward<VariableVersion>(version_counter),
324           /*allow_tensor_metadata_change=*/allow_tensor_metadata_change);
325   return impl;
326 }
327 
shallow_copy_and_detach(const c10::VariableVersion & version_counter,bool allow_tensor_metadata_change) const328 c10::intrusive_ptr<TensorImpl> NestedTensorImpl::shallow_copy_and_detach(
329     const c10::VariableVersion& version_counter,
330     bool allow_tensor_metadata_change) const {
331   return shallow_copy_and_detach_core(
332       version_counter, allow_tensor_metadata_change);
333 }
334 
shallow_copy_and_detach(c10::VariableVersion && version_counter,bool allow_tensor_metadata_change) const335 c10::intrusive_ptr<TensorImpl> NestedTensorImpl::shallow_copy_and_detach(
336     c10::VariableVersion&& version_counter,
337     bool allow_tensor_metadata_change) const {
338   return shallow_copy_and_detach_core(
339       std::move(version_counter), allow_tensor_metadata_change);
340 }
341 
get_numel_from_nested_size_tensor(const at::Tensor & tensor)342 int64_t get_numel_from_nested_size_tensor(const at::Tensor& tensor) {
343   constexpr auto numel_max = std::min(
344       static_cast<uint64_t>(std::numeric_limits<int64_t>::max()),
345       static_cast<uint64_t>(std::numeric_limits<size_t>::max()));
346 
347   const int64_t* sizes_ptr = tensor.const_data_ptr<int64_t>();
348   const auto nt_dim = tensor.size(1);
349   uint64_t num_elements{0};
350 
351   for (const auto i : c10::irange(tensor.size(0))) {
352     uint64_t n = 1;
353     const auto start{sizes_ptr + i * nt_dim};
354     const auto end{start + nt_dim};
355     bool overflows = c10::safe_multiplies_u64(start, end, &n);
356     num_elements += n;
357     overflows |= (num_elements > numel_max);
358     TORCH_CHECK(!overflows, "numel: integer multiplication overflow");
359   }
360   return static_cast<int64_t>(num_elements);
361 }
362 
363 } // namespace at::native
364