xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/nested/NestedTensorMath.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/ATen_fwd.h>
4 #include <ATen/NestedTensorImpl.h>
5 #include <c10/macros/Macros.h>
6 
7 namespace at::native {
8 
9 TORCH_API Tensor NestedTensor_to_padded_tensor_generic(
10     const Tensor& t,
11     double padding,
12     OptionalIntArrayRef output_size);
13 
14 template <typename Func>
map_nt(const Tensor & nt,Func f)15 Tensor map_nt(const Tensor& nt, Func f) {
16   auto* nt_impl = get_nested_tensor_impl(nt);
17   const auto& sizes = nt_impl->get_nested_sizes();
18   return at::detail::make_tensor<NestedTensorImpl>(f(nt_impl->get_buffer()), sizes);
19 }
20 template <typename Func>
map_nt_binary(const Tensor & nt_1,const Tensor & nt_2,Func f)21 Tensor map_nt_binary(const Tensor& nt_1, const Tensor& nt_2, Func f){
22   auto* nt_impl_1 = get_nested_tensor_impl(nt_1);
23   auto* nt_impl_2 = get_nested_tensor_impl(nt_2);
24   const auto& sizes = nt_impl_1->get_nested_sizes();
25   return at::detail::make_tensor<NestedTensorImpl>(f(nt_impl_1->get_buffer(), nt_impl_2->get_buffer()), sizes);
26 }
27 
_check_nested_layer_norm_inputs(const NestedTensorImpl & input,IntArrayRef normalized_shape,const Tensor & weight,const Tensor & bias)28 C10_ALWAYS_INLINE std::pair<int64_t, int64_t> _check_nested_layer_norm_inputs(
29     const NestedTensorImpl& input,
30     IntArrayRef normalized_shape,
31     const Tensor& weight /* optional */,
32     const Tensor& bias /* optional */) {
33 
34   const size_t normalized_ndim = normalized_shape.size();
35   TORCH_CHECK(
36       normalized_ndim >= 1,
37       "Expected normalized_shape to be at least 1-dimensional, i.e., ",
38       "containing at least one element, but got normalized_shape = ",
39       normalized_shape);
40   TORCH_CHECK(
41       !weight.defined() || weight.sizes().equals(normalized_shape),
42       "Expected weight to be of same shape as normalized_shape, but got ",
43       "weight of shape ",
44       weight.sizes(),
45       " and normalized_shape = ",
46       normalized_shape);
47   TORCH_CHECK(
48       !bias.defined() || bias.sizes().equals(normalized_shape),
49       "Expected bias to be of same shape as normalized_shape, but got ",
50       "bias of shape ",
51       bias.sizes(),
52       " and normalized_shape = ",
53       normalized_shape);
54 
55   // Check that the normalized_shape has the exact same sizes as the last dimensions from the NestedTensor input
56   // Also, compute M and N considering the idiosyncracies of NestedTensors
57   int64_t N = 1;
58   for (const auto i: c10::irange(normalized_ndim)) {
59     TORCH_CHECK(
60       input.opt_size(-normalized_ndim + i) != std::nullopt,
61       "normalized_shape extends into irregular dimensions for the nested tensor"
62     );
63     TORCH_CHECK(
64       normalized_shape[i] == *input.opt_size(-normalized_ndim + i),
65       "The shape at dimension ",
66       i,
67       "of normalized_shape doesn't match the input"
68     );
69     N *= normalized_shape[i];
70   }
71 
72   const int64_t M = input.numel() / N;
73 
74   return std::make_pair(M, N);
75 }
76 
77 Tensor reshape_nested(const Tensor& self, IntArrayRef proposed_shape);
78 
79 } // namespace at::native
80