1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
17 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
18 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
19 #include "tensorflow/compiler/xla/client/xla_builder.h"
20 #include "tensorflow/core/util/overflow.h"
21
22 namespace tensorflow {
23 namespace {
24
SpaceToBatch(XlaOpKernelContext * ctx,const xla::XlaOp & input,DataType input_dtype,const TensorShape & input_tensor_shape,absl::Span<const int64_t> block_shape,const xla::Literal & paddings)25 void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp& input,
26 DataType input_dtype, const TensorShape& input_tensor_shape,
27 absl::Span<const int64_t> block_shape,
28 const xla::Literal& paddings) {
29 const int input_rank = input_tensor_shape.dims();
30 const absl::InlinedVector<int64_t, 4> input_shape =
31 input_tensor_shape.dim_sizes();
32 const int block_rank = block_shape.size();
33
34 OP_REQUIRES(
35 ctx, input_rank >= 1 + block_rank,
36 errors::InvalidArgument("input rank should be >= ", 1 + block_rank,
37 " instead of ", input_rank));
38 absl::Span<const int64_t> remainder_shape(input_shape);
39 remainder_shape.remove_prefix(1 + block_rank);
40
41 OP_REQUIRES(
42 ctx,
43 paddings.shape().rank() == 2 &&
44 block_rank == xla::ShapeUtil::GetDimension(paddings.shape(), 0) &&
45 2 == xla::ShapeUtil::GetDimension(paddings.shape(), 1),
46 errors::InvalidArgument("paddings should have shape [", block_rank,
47 ", 2] instead of ",
48 xla::ShapeUtil::HumanString(paddings.shape())));
49
50 xla::XlaBuilder* b = ctx->builder();
51
52 // 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the
53 // input according to `paddings` to produce `padded` of shape `padded_shape`.
54 xla::PaddingConfig padding_config;
55 std::vector<int64_t> padded_shape(input_shape.begin(), input_shape.end());
56 int64_t block_num_elems = 1LL;
57 padding_config.add_dimensions(); // Don't pad the batch dimension.
58 for (int i = 0; i < block_rank; ++i) {
59 auto* dim = padding_config.add_dimensions();
60 int64_t pad_start = paddings.Get<int64_t>({i, 0});
61 int64_t pad_end = paddings.Get<int64_t>({i, 1});
62 OP_REQUIRES(ctx, pad_start >= 0 && pad_end >= 0,
63 errors::InvalidArgument("Paddings must be non-negative"));
64 OP_REQUIRES(ctx, block_shape[i] >= 1,
65 errors::InvalidArgument(
66 "All values in block_shape must be positive, got value, ",
67 block_shape[i], " at index ", i, "."));
68 dim->set_edge_padding_low(pad_start);
69 dim->set_edge_padding_high(pad_end);
70 padded_shape[1 + i] += pad_start + pad_end;
71 block_num_elems = MultiplyWithoutOverflow(block_num_elems, block_shape[i]);
72 }
73 // Don't pad the remainder dimensions.
74 for (int i = 0; i < remainder_shape.size(); ++i) {
75 padding_config.add_dimensions();
76 }
77 OP_REQUIRES(ctx, block_num_elems > 0,
78 errors::InvalidArgument(
79 "The product of the block dimensions must be positive"));
80 const int64_t batch_size = input_shape[0];
81 const int64_t output_dim =
82 MultiplyWithoutOverflow(batch_size, block_num_elems);
83 if (output_dim < 0) {
84 OP_REQUIRES(
85 ctx, output_dim >= 0,
86 errors::InvalidArgument("Negative output dimension size caused by "
87 "overflow when multiplying ",
88 batch_size, " and ", block_num_elems));
89 }
90
91 xla::XlaOp padded =
92 xla::Pad(input, XlaHelpers::Zero(b, input_dtype), padding_config);
93
94 // 2. Reshape `padded` to `reshaped_padded` of shape:
95 //
96 // [batch] +
97 // [padded_shape[1] / block_shape[0],
98 // block_shape[0],
99 // ...,
100 // padded_shape[M] / block_shape[M-1],
101 // block_shape[M-1]] +
102 // remaining_shape
103 std::vector<int64_t> reshaped_padded_shape(input_rank + block_rank);
104 reshaped_padded_shape[0] = batch_size;
105 for (int i = 0; i < block_rank; ++i) {
106 OP_REQUIRES(ctx, padded_shape[1 + i] % block_shape[i] == 0,
107 errors::InvalidArgument("padded_shape[", 1 + i,
108 "]=", padded_shape[1 + i],
109 " is not divisible by block_shape[", i,
110 "]=", block_shape[i]));
111
112 reshaped_padded_shape[1 + i * 2] = padded_shape[1 + i] / block_shape[i];
113 reshaped_padded_shape[1 + i * 2 + 1] = block_shape[i];
114 }
115 std::copy(remainder_shape.begin(), remainder_shape.end(),
116 reshaped_padded_shape.begin() + 1 + 2 * block_rank);
117
118 xla::XlaOp reshaped_padded = xla::Reshape(padded, reshaped_padded_shape);
119
120 // 3. Permute dimensions of `reshaped_padded` to produce
121 // `permuted_reshaped_padded` of shape:
122 //
123 // block_shape +
124 // [batch] +
125 // [padded_shape[1] / block_shape[0],
126 // ...,
127 // padded_shape[M] / block_shape[M-1]] +
128 // remaining_shape
129 std::vector<int64_t> permutation(reshaped_padded_shape.size());
130 for (int i = 0; i < block_rank; ++i) {
131 permutation[i] = 1 + 2 * i + 1;
132 permutation[block_rank + 1 + i] = 1 + 2 * i;
133 }
134 permutation[block_rank] = 0;
135 std::iota(permutation.begin() + 1 + block_rank * 2, permutation.end(),
136 1 + block_rank * 2);
137 xla::XlaOp permuted_reshaped_padded =
138 xla::Transpose(reshaped_padded, permutation);
139
140 // 4. Reshape `permuted_reshaped_padded` to flatten `block_shape` into the
141 // batch dimension, producing an output tensor of shape:
142 //
143 // [batch * prod(block_shape)] +
144 // [padded_shape[1] / block_shape[0],
145 // ...,
146 // padded_shape[M] / block_shape[M-1]] +
147 // remaining_shape
148 // Determine the length of the prefix of block dims that can be combined
149 // into the batch dimension due to having no padding and block_shape=1.
150 std::vector<int64_t> output_shape(input_rank);
151 output_shape[0] = output_dim;
152 for (int i = 0; i < block_rank; ++i) {
153 output_shape[1 + i] = padded_shape[1 + i] / block_shape[i];
154 }
155 std::copy(remainder_shape.begin(), remainder_shape.end(),
156 output_shape.begin() + 1 + block_rank);
157
158 xla::XlaOp output = xla::Reshape(permuted_reshaped_padded, output_shape);
159 ctx->SetOutput(0, output);
160 }
161
162 class SpaceToBatchNDOp : public XlaOpKernel {
163 public:
SpaceToBatchNDOp(OpKernelConstruction * ctx)164 explicit SpaceToBatchNDOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
165
Compile(XlaOpKernelContext * ctx)166 void Compile(XlaOpKernelContext* ctx) override {
167 std::vector<int64_t> block_shape;
168 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &block_shape));
169
170 xla::Literal paddings;
171 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(2, &paddings));
172
173 SpaceToBatch(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0),
174 block_shape, paddings);
175 }
176 };
177 REGISTER_XLA_OP(Name("SpaceToBatchND")
178 .CompileTimeConstantInput("paddings")
179 .CompileTimeConstantInput("block_shape"),
180 SpaceToBatchNDOp);
181
182 class SpaceToBatchOp : public XlaOpKernel {
183 public:
SpaceToBatchOp(OpKernelConstruction * ctx)184 explicit SpaceToBatchOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
185 OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_));
186 OP_REQUIRES(
187 ctx, block_size_ > 1,
188 errors::InvalidArgument("Block size should be > 1: ", block_size_));
189 }
190
Compile(XlaOpKernelContext * ctx)191 void Compile(XlaOpKernelContext* ctx) override {
192 xla::Literal paddings;
193 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsInt64Literal(1, &paddings));
194
195 SpaceToBatch(ctx, ctx->Input(0), input_type(0), ctx->InputShape(0),
196 {block_size_, block_size_}, paddings);
197 }
198
199 private:
200 int block_size_;
201 };
202 REGISTER_XLA_OP(Name("SpaceToBatch").CompileTimeConstantInput("paddings"),
203 SpaceToBatchOp);
204
205 } // namespace
206 } // namespace tensorflow
207