xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
19 #include "tensorflow/compiler/xla/client/xla_builder.h"
20 #include "tensorflow/core/util/overflow.h"
21 
22 namespace tensorflow {
23 namespace {
24 
SpaceToBatch(XlaOpKernelContext * ctx,const xla::XlaOp & input,DataType input_dtype,const TensorShape & input_tensor_shape,absl::Span<const int64_t> block_shape,const xla::Literal & paddings)25 void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input,
26                   DataType input_dtype, const TensorShape& input_tensor_shape,
27                   absl::Span<const int64_t> block_shape,
28                   const xla::Literal& paddings) {
29   const int input_rank = input_tensor_shape.dims();
30   const absl::InlinedVector<int64_t, 4> input_shape =
31       input_tensor_shape.dim_sizes();
32   const int block_rank = block_shape.size();
33 
34   OP_REQUIRES(
35       ctx, input_rank >= 1 + block_rank,
36       errors::InvalidArgument("input rank should be >= ", 1 + block_rank,
37                               " instead of ", input_rank));
38   absl::Span<const int64_t> remainder_shape(input_shape);
39   remainder_shape.remove_prefix(1 + block_rank);
40 
41   OP_REQUIRES(
42       ctx,
43       paddings.shape().rank() == 2 &&
44           block_rank == xla::ShapeUtil::GetDimension(paddings.shape(), 0) &&
45           2 == xla::ShapeUtil::GetDimension(paddings.shape(), 1),
46       errors::InvalidArgument("paddings should have shape [", block_rank,
47                               ", 2] instead of ",
48                               xla::ShapeUtil::HumanString(paddings.shape())));
49 
50   xla::XlaBuilder* b = ctx->builder();
51 
52   // 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the
53   //  input according to `paddings` to produce `padded` of shape `padded_shape`.
54   xla::PaddingConfig padding_config;
55   std::vector<int64_t> padded_shape(input_shape.begin(), input_shape.end());
56   int64_t block_num_elems = 1LL;
57   padding_config.add_dimensions();  // Don't pad the batch dimension.
58   for (int i = 0; i < block_rank; ++i) {
59     auto* dim = padding_config.add_dimensions();
60     int64_t pad_start = paddings.Get<int64_t>({i, 0});
61     int64_t pad_end = paddings.Get<int64_t>({i, 1});
62     OP_REQUIRES(ctx, pad_start >= 0 && pad_end >= 0,
63                 errors::InvalidArgument("Paddings must be non-negative"));
64     OP_REQUIRES(ctx, block_shape[i] >= 1,
65                 errors::InvalidArgument(
66                     "All values in block_shape must be positive, got value, ",
67                     block_shape[i], " at index ", i, "."));
68     dim->set_edge_padding_low(pad_start);
69     dim->set_edge_padding_high(pad_end);
70     padded_shape[1 + i] += pad_start + pad_end;
71     block_num_elems = MultiplyWithoutOverflow(block_num_elems, block_shape[i]);
72   }
73   // Don't pad the remainder dimensions.
74   for (int i = 0; i < remainder_shape.size(); ++i) {
75     padding_config.add_dimensions();
76   }
77   OP_REQUIRES(ctx, block_num_elems > 0,
78               errors::InvalidArgument(
79                   "The product of the block dimensions must be positive"));
80   const int64_t batch_size = input_shape[0];
81   const int64_t output_dim =
82       MultiplyWithoutOverflow(batch_size, block_num_elems);
83   if (output_dim < 0) {
84     OP_REQUIRES(
85         ctx, output_dim >= 0,
86         errors::InvalidArgument("Negative output dimension size caused by "
87                                 "overflow when multiplying ",
88                                 batch_size, " and ", block_num_elems));
89   }
90 
91   xla::XlaOp padded =
92       xla::Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config);
93 
94   // 2. Reshape `padded` to `reshaped_padded` of shape:
95   //
96   //      [batch] +
97   //      [padded_shape[1] / block_shape[0],
98   //        block_shape[0],
99   //       ...,
100   //       padded_shape[M] / block_shape[M-1],
101   //       block_shape[M-1]] +
102   //      remaining_shape
103   std::vector<int64_t> reshaped_padded_shape(input_rank + block_rank);
104   reshaped_padded_shape[0] = batch_size;
105   for (int i = 0; i < block_rank; ++i) {
106     OP_REQUIRES(ctx, padded_shape[1 + i] % block_shape[i] == 0,
107                 errors::InvalidArgument("padded_shape[", 1 + i,
108                                         "]=", padded_shape[1 + i],
109                                         " is not divisible by block_shape[", i,
110                                         "]=", block_shape[i]));
111 
112     reshaped_padded_shape[1 + i * 2] = padded_shape[1 + i] / block_shape[i];
113     reshaped_padded_shape[1 + i * 2 + 1] = block_shape[i];
114   }
115   std::copy(remainder_shape.begin(), remainder_shape.end(),
116             reshaped_padded_shape.begin() + 1 + 2 * block_rank);
117 
118   xla::XlaOp reshaped_padded = xla::Reshape(padded, reshaped_padded_shape);
119 
120   // 3. Permute dimensions of `reshaped_padded` to produce
121   //    `permuted_reshaped_padded` of shape:
122   //
123   //      block_shape +
124   //      [batch] +
125   //      [padded_shape[1] / block_shape[0],
126   //       ...,
127   //       padded_shape[M] / block_shape[M-1]] +
128   //      remaining_shape
129   std::vector<int64_t> permutation(reshaped_padded_shape.size());
130   for (int i = 0; i < block_rank; ++i) {
131     permutation[i] = 1 + 2 * i + 1;
132     permutation[block_rank + 1 + i] = 1 + 2 * i;
133   }
134   permutation[block_rank] = 0;
135   std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
136             1 + block_rank * 2);
137   xla::XlaOp permuted_reshaped_padded =
138       xla::Transpose(reshaped_padded, permutation);
139 
140   // 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the
141   //    batch dimension, producing an output tensor of shape:
142   //
143   //      [batch * prod(block_shape)] +
144   //      [padded_shape[1] / block_shape[0],
145   //       ...,
146   //       padded_shape[M] / block_shape[M-1]] +
147   //      remaining_shape
148   // Determine the length of the prefix of block dims that can be combined
149   // into the batch dimension due to having no padding and block_shape=1.
150   std::vector<int64_t> output_shape(input_rank);
151   output_shape[0] = output_dim;
152   for (int i = 0; i < block_rank; ++i) {
153     output_shape[1 + i] = padded_shape[1 + i] / block_shape[i];
154   }
155   std::copy(remainder_shape.begin(), remainder_shape.end(),
156             output_shape.begin() + 1 + block_rank);
157 
158   xla::XlaOp output = xla::Reshape(permuted_reshaped_padded, output_shape);
159   ctx->SetOutput(0, output);
160 }
161 
162 class SpaceToBatchNDOp : public XlaOpKernel {
163  public:
SpaceToBatchNDOp(OpKernelConstruction * ctx)164   explicit SpaceToBatchNDOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
165 
Compile(XlaOpKernelContext * ctx)166   void Compile(XlaOpKernelContext* ctx) override {
167     std::vector<int64_t> block_shape;
168     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &block_shape));
169 
170     xla::Literal paddings;
171     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(2, &paddings));
172 
173     SpaceToBatch(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0),
174                  block_shape, paddings);
175   }
176 };
177 REGISTER_XLA_OP(Name("SpaceToBatchND")
178                     .CompileTimeConstantInput("paddings")
179                     .CompileTimeConstantInput("block_shape"),
180                 SpaceToBatchNDOp);
181 
182 class SpaceToBatchOp : public XlaOpKernel {
183  public:
SpaceToBatchOp(OpKernelConstruction * ctx)184   explicit SpaceToBatchOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
185     OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_));
186     OP_REQUIRES(
187         ctx, block_size_ > 1,
188         errors::InvalidArgument("Block size should be > 1: ", block_size_));
189   }
190 
Compile(XlaOpKernelContext * ctx)191   void Compile(XlaOpKernelContext* ctx) override {
192     xla::Literal paddings;
193     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(1, &paddings));
194 
195     SpaceToBatch(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0),
196                  {block_size_, block_size_}, paddings);
197   }
198 
199  private:
200   int block_size_;
201 };
202 REGISTER_XLA_OP(Name("SpaceToBatch").CompileTimeConstantInput("paddings"),
203                 SpaceToBatchOp);
204 
205 }  // namespace
206 }  // namespace tensorflow
207