1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/IListRef.h>
4*da0073e9SAndroid Build Coastguard Worker #include <ATen/core/Tensor.h>
5*da0073e9SAndroid Build Coastguard Worker #include <c10/core/TensorImpl.h>
6*da0073e9SAndroid Build Coastguard Worker #include <c10/core/WrapDimMinimal.h>
7*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker namespace at {
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker // if dim_post_expr is 0 and wrap_scalar is true, then dim must be in the
12*da0073e9SAndroid Build Coastguard Worker // range [-1, 0]. This is a special case for scalar tensors and manifests in
13*da0073e9SAndroid Build Coastguard Worker // e.g. torch.sum(scalar_tensor, 0) Otherwise, dim should be in the range
14*da0073e9SAndroid Build Coastguard Worker // [-dim_post_expr, dim_post_expr-1].
15*da0073e9SAndroid Build Coastguard Worker using c10::maybe_wrap_dim;
16*da0073e9SAndroid Build Coastguard Worker
maybe_wrap_dim(int64_t dim,TensorImpl * tensor)17*da0073e9SAndroid Build Coastguard Worker inline int64_t maybe_wrap_dim(int64_t dim, TensorImpl* tensor) {
18*da0073e9SAndroid Build Coastguard Worker return maybe_wrap_dim(dim, tensor->dim());
19*da0073e9SAndroid Build Coastguard Worker }
20*da0073e9SAndroid Build Coastguard Worker
maybe_wrap_dim(int64_t dim,TensorList tensors)21*da0073e9SAndroid Build Coastguard Worker inline int64_t maybe_wrap_dim(int64_t dim, TensorList tensors) {
22*da0073e9SAndroid Build Coastguard Worker if (tensors.empty()) {
23*da0073e9SAndroid Build Coastguard Worker // can't wrap empty TensorList; rely on underlying implementation to throw
24*da0073e9SAndroid Build Coastguard Worker // error if necessary.
25*da0073e9SAndroid Build Coastguard Worker return dim;
26*da0073e9SAndroid Build Coastguard Worker }
27*da0073e9SAndroid Build Coastguard Worker return maybe_wrap_dim(dim, tensors[0].dim());
28*da0073e9SAndroid Build Coastguard Worker }
29*da0073e9SAndroid Build Coastguard Worker
maybe_wrap_dim(int64_t dim,const std::vector<std::vector<int64_t>> & tensor_sizes)30*da0073e9SAndroid Build Coastguard Worker inline int64_t maybe_wrap_dim(
31*da0073e9SAndroid Build Coastguard Worker int64_t dim,
32*da0073e9SAndroid Build Coastguard Worker const std::vector<std::vector<int64_t>>& tensor_sizes) {
33*da0073e9SAndroid Build Coastguard Worker if (tensor_sizes.empty()) {
34*da0073e9SAndroid Build Coastguard Worker // can't wrap empty list; rely on underlying implementation to throw error
35*da0073e9SAndroid Build Coastguard Worker // if necessary
36*da0073e9SAndroid Build Coastguard Worker return dim;
37*da0073e9SAndroid Build Coastguard Worker }
38*da0073e9SAndroid Build Coastguard Worker return maybe_wrap_dim(dim, tensor_sizes[0].size());
39*da0073e9SAndroid Build Coastguard Worker }
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker // Given an array of dimensions `dims` of length `ndims`, this function "Wraps"
42*da0073e9SAndroid Build Coastguard Worker // each dim in-place for a tensor of rank `dim_post_expr`, allowing dims to be
43*da0073e9SAndroid Build Coastguard Worker // specified using negative indices.
44*da0073e9SAndroid Build Coastguard Worker //
45*da0073e9SAndroid Build Coastguard Worker // Additionally, if `wrap_scalar` is true then scalar tensors with rank 0, will
46*da0073e9SAndroid Build Coastguard Worker // allow dimensions in the range [-1, 0]. Otherwise, an IndexError is raised for
47*da0073e9SAndroid Build Coastguard Worker // dimensions not in the range [-dim_post_expr, dim_post_expr).
48*da0073e9SAndroid Build Coastguard Worker inline void maybe_wrap_dims_n(
49*da0073e9SAndroid Build Coastguard Worker int64_t* dims,
50*da0073e9SAndroid Build Coastguard Worker int64_t ndims,
51*da0073e9SAndroid Build Coastguard Worker int64_t dim_post_expr,
52*da0073e9SAndroid Build Coastguard Worker bool wrap_scalars = true) {
53*da0073e9SAndroid Build Coastguard Worker if (dim_post_expr <= 0) {
54*da0073e9SAndroid Build Coastguard Worker if (wrap_scalars) {
55*da0073e9SAndroid Build Coastguard Worker dim_post_expr = 1; // this will make range [-1, 0]
56*da0073e9SAndroid Build Coastguard Worker } else {
57*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK_INDEX(
58*da0073e9SAndroid Build Coastguard Worker ndims == 0,
59*da0073e9SAndroid Build Coastguard Worker "Dimension specified as ",
60*da0073e9SAndroid Build Coastguard Worker dims[0],
61*da0073e9SAndroid Build Coastguard Worker " but tensor has no dimensions");
62*da0073e9SAndroid Build Coastguard Worker return;
63*da0073e9SAndroid Build Coastguard Worker }
64*da0073e9SAndroid Build Coastguard Worker }
65*da0073e9SAndroid Build Coastguard Worker int64_t min = -dim_post_expr;
66*da0073e9SAndroid Build Coastguard Worker int64_t max = dim_post_expr - 1;
67*da0073e9SAndroid Build Coastguard Worker for (const auto i : c10::irange(ndims)) {
68*da0073e9SAndroid Build Coastguard Worker auto& dim = dims[i];
69*da0073e9SAndroid Build Coastguard Worker if (dim < min || dim > max) {
70*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK_INDEX(
71*da0073e9SAndroid Build Coastguard Worker false,
72*da0073e9SAndroid Build Coastguard Worker "Dimension out of range (expected to be in range of [",
73*da0073e9SAndroid Build Coastguard Worker min,
74*da0073e9SAndroid Build Coastguard Worker ", ",
75*da0073e9SAndroid Build Coastguard Worker max,
76*da0073e9SAndroid Build Coastguard Worker "], but got ",
77*da0073e9SAndroid Build Coastguard Worker dim,
78*da0073e9SAndroid Build Coastguard Worker ")");
79*da0073e9SAndroid Build Coastguard Worker }
80*da0073e9SAndroid Build Coastguard Worker if (dim < 0)
81*da0073e9SAndroid Build Coastguard Worker dim += dim_post_expr;
82*da0073e9SAndroid Build Coastguard Worker }
83*da0073e9SAndroid Build Coastguard Worker }
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker // Given a contiguous container of dimensions `dims`, this function "Wraps"
86*da0073e9SAndroid Build Coastguard Worker // each dim in-place for a tensor of rank `dim_post_expr`, allowing dims to be
87*da0073e9SAndroid Build Coastguard Worker // specified using negative indices.
88*da0073e9SAndroid Build Coastguard Worker //
89*da0073e9SAndroid Build Coastguard Worker // Additionally, if `wrap_scalar` is true then scalar tensors with rank 0, will
90*da0073e9SAndroid Build Coastguard Worker // allow dimensions in the range [-1, 0]. Otherwise, an IndexError is raised for
91*da0073e9SAndroid Build Coastguard Worker // dimensions not in the range [-dim_post_expr, dim_post_expr).
92*da0073e9SAndroid Build Coastguard Worker template <typename Container>
93*da0073e9SAndroid Build Coastguard Worker inline void maybe_wrap_dims(
94*da0073e9SAndroid Build Coastguard Worker Container& dims,
95*da0073e9SAndroid Build Coastguard Worker int64_t dim_post_expr,
96*da0073e9SAndroid Build Coastguard Worker bool wrap_scalars = true) {
97*da0073e9SAndroid Build Coastguard Worker return maybe_wrap_dims_n(
98*da0073e9SAndroid Build Coastguard Worker dims.data(), dims.size(), dim_post_expr, wrap_scalars);
99*da0073e9SAndroid Build Coastguard Worker }
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker // previously, size [0] tensors were the only possible empty tensors; thus, it
102*da0073e9SAndroid Build Coastguard Worker // wasn't possible to cat empty tensors unless all the other tensors were
103*da0073e9SAndroid Build Coastguard Worker // 1-dimensional, so we allowed these tensors to be "skipped" (both for wrap
104*da0073e9SAndroid Build Coastguard Worker // dimension behavior and dimension size checking). We maintain this behavior
105*da0073e9SAndroid Build Coastguard Worker // for backwards compatibility, but only for this specific size (i.e. other
106*da0073e9SAndroid Build Coastguard Worker // empty sizes are not skipped).
legacy_cat_wrap_dim(int64_t dim,const std::vector<std::vector<int64_t>> & tensor_sizes)107*da0073e9SAndroid Build Coastguard Worker inline int64_t legacy_cat_wrap_dim(
108*da0073e9SAndroid Build Coastguard Worker int64_t dim,
109*da0073e9SAndroid Build Coastguard Worker const std::vector<std::vector<int64_t>>& tensor_sizes) {
110*da0073e9SAndroid Build Coastguard Worker for (auto& sizes : tensor_sizes) {
111*da0073e9SAndroid Build Coastguard Worker if (sizes.size() == 1 && sizes[0] == 0) {
112*da0073e9SAndroid Build Coastguard Worker continue;
113*da0073e9SAndroid Build Coastguard Worker }
114*da0073e9SAndroid Build Coastguard Worker return maybe_wrap_dim(dim, static_cast<int64_t>(sizes.size()));
115*da0073e9SAndroid Build Coastguard Worker }
116*da0073e9SAndroid Build Coastguard Worker return dim;
117*da0073e9SAndroid Build Coastguard Worker }
118*da0073e9SAndroid Build Coastguard Worker
legacy_cat_wrap_dim_symint(int64_t dim,const std::vector<std::vector<c10::SymInt>> & tensor_sizes)119*da0073e9SAndroid Build Coastguard Worker inline int64_t legacy_cat_wrap_dim_symint(
120*da0073e9SAndroid Build Coastguard Worker int64_t dim,
121*da0073e9SAndroid Build Coastguard Worker const std::vector<std::vector<c10::SymInt>>& tensor_sizes) {
122*da0073e9SAndroid Build Coastguard Worker for (auto& sizes : tensor_sizes) {
123*da0073e9SAndroid Build Coastguard Worker if (sizes.size() == 1) {
124*da0073e9SAndroid Build Coastguard Worker if (TORCH_GUARD_SIZE_OBLIVIOUS(sizes[0].sym_eq(0))) {
125*da0073e9SAndroid Build Coastguard Worker continue;
126*da0073e9SAndroid Build Coastguard Worker }
127*da0073e9SAndroid Build Coastguard Worker }
128*da0073e9SAndroid Build Coastguard Worker return maybe_wrap_dim(dim, static_cast<int64_t>(sizes.size()));
129*da0073e9SAndroid Build Coastguard Worker }
130*da0073e9SAndroid Build Coastguard Worker return dim;
131*da0073e9SAndroid Build Coastguard Worker }
132*da0073e9SAndroid Build Coastguard Worker
legacy_cat_wrap_dim(int64_t dim,const MaterializedITensorListRef & tensors)133*da0073e9SAndroid Build Coastguard Worker inline int64_t legacy_cat_wrap_dim(
134*da0073e9SAndroid Build Coastguard Worker int64_t dim,
135*da0073e9SAndroid Build Coastguard Worker const MaterializedITensorListRef& tensors) {
136*da0073e9SAndroid Build Coastguard Worker for (const Tensor& tensor : tensors) {
137*da0073e9SAndroid Build Coastguard Worker if (tensor.dim() == 1) {
138*da0073e9SAndroid Build Coastguard Worker if (TORCH_GUARD_SIZE_OBLIVIOUS(tensor.sym_sizes()[0].sym_eq(0))) {
139*da0073e9SAndroid Build Coastguard Worker continue;
140*da0073e9SAndroid Build Coastguard Worker }
141*da0073e9SAndroid Build Coastguard Worker }
142*da0073e9SAndroid Build Coastguard Worker return maybe_wrap_dim(dim, tensor.dim());
143*da0073e9SAndroid Build Coastguard Worker }
144*da0073e9SAndroid Build Coastguard Worker return dim;
145*da0073e9SAndroid Build Coastguard Worker }
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker // wrap negative dims in a vector
wrap_all_dims(std::vector<int64_t> & dims_to_wrap,int64_t tensor_total_dims)148*da0073e9SAndroid Build Coastguard Worker inline void wrap_all_dims(
149*da0073e9SAndroid Build Coastguard Worker std::vector<int64_t>& dims_to_wrap,
150*da0073e9SAndroid Build Coastguard Worker int64_t tensor_total_dims) {
151*da0073e9SAndroid Build Coastguard Worker for (const auto i : c10::irange(dims_to_wrap.size())) {
152*da0073e9SAndroid Build Coastguard Worker dims_to_wrap[i] = maybe_wrap_dim(dims_to_wrap[i], tensor_total_dims);
153*da0073e9SAndroid Build Coastguard Worker }
154*da0073e9SAndroid Build Coastguard Worker }
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker } // namespace at
157