1 #pragma once
2
3 #include <ATen/core/ATen_fwd.h>
4 #include <ATen/NestedTensorImpl.h>
5 #include <c10/macros/Macros.h>
6
7 namespace at::native {
8
9 TORCH_API Tensor NestedTensor_to_padded_tensor_generic(
10 const Tensor& t,
11 double padding,
12 OptionalIntArrayRef output_size);
13
14 template <typename Func>
map_nt(const Tensor & nt,Func f)15 Tensor map_nt(const Tensor& nt, Func f) {
16 auto* nt_impl = get_nested_tensor_impl(nt);
17 const auto& sizes = nt_impl->get_nested_sizes();
18 return at::detail::make_tensor<NestedTensorImpl>(f(nt_impl->get_buffer()), sizes);
19 }
20 template <typename Func>
map_nt_binary(const Tensor & nt_1,const Tensor & nt_2,Func f)21 Tensor map_nt_binary(const Tensor& nt_1, const Tensor& nt_2, Func f){
22 auto* nt_impl_1 = get_nested_tensor_impl(nt_1);
23 auto* nt_impl_2 = get_nested_tensor_impl(nt_2);
24 const auto& sizes = nt_impl_1->get_nested_sizes();
25 return at::detail::make_tensor<NestedTensorImpl>(f(nt_impl_1->get_buffer(), nt_impl_2->get_buffer()), sizes);
26 }
27
_check_nested_layer_norm_inputs(const NestedTensorImpl & input,IntArrayRef normalized_shape,const Tensor & weight,const Tensor & bias)28 C10_ALWAYS_INLINE std::pair<int64_t, int64_t> _check_nested_layer_norm_inputs(
29 const NestedTensorImpl& input,
30 IntArrayRef normalized_shape,
31 const Tensor& weight /* optional */,
32 const Tensor& bias /* optional */) {
33
34 const size_t normalized_ndim = normalized_shape.size();
35 TORCH_CHECK(
36 normalized_ndim >= 1,
37 "Expected normalized_shape to be at least 1-dimensional, i.e., ",
38 "containing at least one element, but got normalized_shape = ",
39 normalized_shape);
40 TORCH_CHECK(
41 !weight.defined() || weight.sizes().equals(normalized_shape),
42 "Expected weight to be of same shape as normalized_shape, but got ",
43 "weight of shape ",
44 weight.sizes(),
45 " and normalized_shape = ",
46 normalized_shape);
47 TORCH_CHECK(
48 !bias.defined() || bias.sizes().equals(normalized_shape),
49 "Expected bias to be of same shape as normalized_shape, but got ",
50 "bias of shape ",
51 bias.sizes(),
52 " and normalized_shape = ",
53 normalized_shape);
54
55 // Check that the normalized_shape has the exact same sizes as the last dimensions from the NestedTensor input
56 // Also, compute M and N considering the idiosyncracies of NestedTensors
57 int64_t N = 1;
58 for (const auto i: c10::irange(normalized_ndim)) {
59 TORCH_CHECK(
60 input.opt_size(-normalized_ndim + i) != std::nullopt,
61 "normalized_shape extends into irregular dimensions for the nested tensor"
62 );
63 TORCH_CHECK(
64 normalized_shape[i] == *input.opt_size(-normalized_ndim + i),
65 "The shape at dimension ",
66 i,
67 "of normalized_shape doesn't match the input"
68 );
69 N *= normalized_shape[i];
70 }
71
72 const int64_t M = input.numel() / N;
73
74 return std::make_pair(M, N);
75 }
76
77 Tensor reshape_nested(const Tensor& self, IntArrayRef proposed_shape);
78
79 } // namespace at::native
80