1 #pragma once
2
3 #include <ATen/DimVector.h>
4 #include <c10/core/ScalarType.h>
5 #include <c10/core/SymIntArrayRef.h>
6 #include <c10/util/DimVector.h>
7 #include <optional>
8 #include <sstream>
9 #include <vector>
10
11 namespace at {
12
13 // Infers the size of a dim with size -1, if it exists. Also checks that new
14 // shape is compatible with the number of elements.
15 //
16 // templated to handle std::vector<int64_t> and DimVector use cases, see
17 // below
18 //
19 template <typename InputArrayRef, typename NumelType, typename ResultVec>
infer_size_impl(InputArrayRef shape,NumelType numel,ResultVec & res)20 inline void infer_size_impl(
21 InputArrayRef shape,
22 NumelType numel,
23 ResultVec& res) {
24 NumelType newsize = 1;
25 // N.B. this is an index, not a sym dim!
26 std::optional<int64_t> infer_dim;
27 for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
28 if (shape[dim] == -1) {
29 if (infer_dim) {
30 throw std::runtime_error("only one dimension can be inferred");
31 }
32 infer_dim = dim;
33 } else if (shape[dim] >= 0) {
34 newsize *= shape[dim];
35 } else {
36 AT_ERROR("invalid shape dimension ", shape[dim]);
37 }
38 }
39
40 if (TORCH_GUARD_SIZE_OBLIVIOUS(sym_eq(numel, newsize)) ||
41 (infer_dim && newsize > 0 && numel % newsize == 0)) {
42 if (infer_dim) {
43 // We have a degree of freedom here to select the dimension size; follow
44 // NumPy semantics and just bail. However, a nice error message is needed
45 // because users often use `view` as a way to flatten & unflatten
46 // dimensions and will otherwise be confused why
47 // empty_tensor.view( 0, 0)
48 // works yet
49 // empty_tensor.view(-1, 0)
50 // doesn't.
51 TORCH_CHECK(
52 newsize != 0,
53 "cannot reshape tensor of 0 elements into shape ",
54 shape,
55 " because the unspecified dimension size -1 can be any "
56 "value and is ambiguous");
57 res[*infer_dim] = numel / newsize;
58 }
59 return;
60 }
61
62 std::ostringstream ss;
63 ss << "shape '" << shape << "' is invalid for input of size " << numel;
64 throw std::runtime_error(ss.str());
65 }
66
infer_size(IntArrayRef shape,int64_t numel)67 inline std::vector<int64_t> infer_size(IntArrayRef shape, int64_t numel) {
68 auto res = shape.vec();
69 infer_size_impl(shape, numel, res);
70 return res;
71 }
72
infer_size_dv(IntArrayRef shape,int64_t numel)73 inline at::DimVector infer_size_dv(IntArrayRef shape, int64_t numel) {
74 auto res = at::DimVector(shape);
75 infer_size_impl(shape, numel, res);
76 return res;
77 }
78
infer_size_dv(c10::SymIntArrayRef shape,c10::SymInt numel)79 inline at::SymDimVector infer_size_dv(
80 c10::SymIntArrayRef shape,
81 c10::SymInt numel) {
82 auto res = at::SymDimVector(shape);
83 infer_size_impl<c10::SymIntArrayRef, c10::SymInt, at::SymDimVector>(
84 shape, std::move(numel), res);
85 return res;
86 }
87
88 } // namespace at
89