1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstdint>
17 #include <iostream>
18
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/Debug.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
26 #include "mlir/IR/Attributes.h" // from @llvm-project
27 #include "mlir/IR/Builders.h" // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
30 #include "mlir/IR/Location.h" // from @llvm-project
31 #include "mlir/IR/MLIRContext.h" // from @llvm-project
32 #include "mlir/IR/Operation.h" // from @llvm-project
33 #include "mlir/IR/OperationSupport.h" // from @llvm-project
34 #include "mlir/IR/PatternMatch.h" // from @llvm-project
35 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
36 #include "mlir/IR/Types.h" // from @llvm-project
37 #include "mlir/IR/Value.h" // from @llvm-project
38 #include "mlir/Pass/Pass.h" // from @llvm-project
39 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
40 #include "mlir/Support/LLVM.h" // from @llvm-project
41 #include "mlir/Support/LogicalResult.h" // from @llvm-project
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
45 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
48 #include "tensorflow/core/framework/tensor_shape.pb.h"
49 #include "tensorflow/core/util/device_name_utils.h"
50
51 namespace mlir {
52 namespace TFTPU {
53
54 namespace {
55
56 constexpr char kDeviceAttr[] = "device";
57 typedef std::pair<TF::Conv2DOp, int64_t> Conv2DWithBlockSize;
58
59 struct BlockArgumentInfo {
60 unsigned arg_num;
61 unsigned num_users;
62 };
63
64 // TODO(wangtao): add a pass to check if it is profitable to space to depth
65 // transform and invoke the transform if it is needed.
66 struct TPUSpaceToDepthPass
67 : public TF::TPUSpaceToDepthPassBase<TPUSpaceToDepthPass> {
68 void runOnOperation() override;
69 };
70
71 // Updates func argument type to have the updated input shape.
UpdateFuncType(func::FuncOp func)72 void UpdateFuncType(func::FuncOp func) {
73 auto arg_types = func.front().getArgumentTypes();
74 auto result_types = func.front().getTerminator()->getOperandTypes();
75 func.setType(FunctionType::get(func.getContext(), arg_types, result_types));
76 }
77
HandleFuncOp(Operation * op)78 void HandleFuncOp(Operation* op) {
79 auto func = llvm::cast<func::FuncOp>(op);
80 UpdateFuncType(func);
81 }
82
83 // Handles cast op between the first convolution and the block argument.
HandleCast(TF::CastOp cast_op,ArrayRef<int64_t> new_shape)84 LogicalResult HandleCast(TF::CastOp cast_op, ArrayRef<int64_t> new_shape) {
85 auto cast_input = cast_op.x();
86 // Update input type.
87 auto transform_result_type =
88 RankedTensorType::get(new_shape, getElementTypeOrSelf(cast_input));
89 cast_input.setType(transform_result_type);
90 auto block_arg = cast_input.dyn_cast<mlir::BlockArgument>();
91 auto cast_op_input = dyn_cast_or_null<TF::CastOp>(cast_input.getDefiningOp());
92 while (block_arg || cast_op_input) {
93 if (block_arg) {
94 // Change on device function type/shape.
95 HandleFuncOp(block_arg.getOwner()->getParentOp());
96 block_arg = nullptr;
97 cast_op_input = nullptr;
98 } else {
99 auto cast_input = cast_op_input.x();
100 // Update input type.
101 auto transform_result_type =
102 RankedTensorType::get(new_shape, getElementTypeOrSelf(cast_input));
103 cast_input.setType(transform_result_type);
104 // Update block arg and cast_op_input.
105 block_arg = cast_input.dyn_cast<mlir::BlockArgument>();
106 cast_op_input = dyn_cast_or_null<TF::CastOp>(cast_input.getDefiningOp());
107 }
108 }
109 return success();
110 }
111
112 // Handles padding before convolution for space to depth transform.
HandlePad(TF::PadOp op,int32_t kernel_size,int32_t block_size)113 LogicalResult HandlePad(TF::PadOp op, int32_t kernel_size, int32_t block_size) {
114 auto ranked_type = op.input().getType().dyn_cast<RankedTensorType>();
115 if (!ranked_type) return failure();
116 auto pad_input_shape = ranked_type.getShape();
117 Location loc = op.getLoc();
118 OpBuilder builder(op);
119 builder.setInsertionPoint(op);
120 auto padding_type = RankedTensorType::get({4, 2}, builder.getIntegerType(32));
121
122 // Calculate paddings.
123 int32_t pad_total = kernel_size - 1;
124 int32_t pad_beg = (pad_total / 2 + 1) / block_size;
125 int32_t pad_end = (pad_total / 2) / block_size;
126 SmallVector<int32_t, 8> values = {0, 0, pad_beg, pad_end,
127 pad_beg, pad_end, 0, 0};
128 auto paddings = DenseIntElementsAttr::get(padding_type, values);
129 // Update pad_op paddings.
130 op.setOperand(1, builder.create<TF::ConstOp>(loc, paddings));
131
132 // Set input type.
133 auto input = op.getOperand(0);
134 SmallVector<int64_t, 4> transform_shape = {
135 pad_input_shape[0], pad_input_shape[1] / block_size,
136 pad_input_shape[2] / block_size,
137 pad_input_shape[3] * block_size * block_size};
138 // Input of the pad op could be a cast op.
139 if (auto cast_op = dyn_cast_or_null<TF::CastOp>(input.getDefiningOp()))
140 if (failed(HandleCast(cast_op, transform_shape))) return failure();
141
142 auto transform_result_type =
143 RankedTensorType::get(transform_shape, getElementTypeOrSelf(input));
144 input.setType(transform_result_type);
145 op.setOperand(0, input);
146 return success();
147 }
148
149 // Handles stride for the first convolution for the transform.
HandleConv2DStride(TF::Conv2DOp conv2d)150 void HandleConv2DStride(TF::Conv2DOp conv2d) {
151 MLIRContext* context = conv2d.getContext();
152 SmallVector<int64_t, 4> values = {1, 1, 1, 1};
153 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
154 return IntegerAttr::get(IntegerType::get(context, 64), v);
155 });
156 // TODO(b/157276506): change type of strides to DenseElementsAttr
157 auto strides = ArrayAttr::get(context, llvm::to_vector<4>(attrs));
158 conv2d->setAttr("strides", strides);
159 }
160
161 // Transforms input shape for the first convolution.
HandleConv2DInput(TF::Conv2DOp conv2d,int64_t block_size)162 void HandleConv2DInput(TF::Conv2DOp conv2d, int64_t block_size) {
163 auto input = conv2d.input();
164 auto input_shape = input.getType().cast<RankedTensorType>().getShape();
165 SmallVector<int64_t, 4> transform_shape = {
166 input_shape[0], input_shape[1] / block_size, input_shape[2] / block_size,
167 input_shape[3] * block_size * block_size};
168 auto transform_result_type =
169 RankedTensorType::get(transform_shape, getElementTypeOrSelf(input));
170 input.setType(transform_result_type);
171 }
172
173 // Adds padding for convolution filter for space to depth transform.
GetPadOpForConv2DFilter(ArrayRef<int64_t> filter_shape,Value filter,OpBuilder * builder,int32_t pad_h,int32_t pad_w)174 TF::PadOp GetPadOpForConv2DFilter(ArrayRef<int64_t> filter_shape, Value filter,
175 OpBuilder* builder, int32_t pad_h,
176 int32_t pad_w) {
177 SmallVector<int32_t, 8> values = {pad_h, 0, pad_w, 0, 0, 0, 0, 0};
178 auto padding_type =
179 RankedTensorType::get({4, 2}, builder->getIntegerType(32));
180 auto paddings = DenseIntElementsAttr::get(padding_type, values);
181 auto paddings_value = builder->create<TF::ConstOp>(filter.getLoc(), paddings);
182 std::vector<int64_t> pad_shape = {filter_shape[0] + pad_h,
183 filter_shape[1] + pad_w, filter_shape[2],
184 filter_shape[3]};
185 SmallVector<int64_t, 4> expand_shape(pad_shape.begin(), pad_shape.end());
186
187 auto expand_result_type =
188 RankedTensorType::get(expand_shape, getElementTypeOrSelf(filter));
189 return builder->create<TF::PadOp>(filter.getLoc(), expand_result_type, filter,
190 paddings_value);
191 }
192
193 // Creates reshape op for space to depth transform.
GetReshapeOpForConv2DFilter(ArrayRef<int64_t> new_shape,Value input,OpBuilder * builder)194 TF::ReshapeOp GetReshapeOpForConv2DFilter(ArrayRef<int64_t> new_shape,
195 Value input, OpBuilder* builder) {
196 auto reshape_result_type =
197 RankedTensorType::get(new_shape, getElementTypeOrSelf(input));
198 auto reshape_type = RankedTensorType::get(
199 {static_cast<int64_t>(new_shape.size())}, builder->getIntegerType(64));
200 auto reshape_sizes = DenseIntElementsAttr::get(reshape_type, new_shape);
201 auto reshape_value =
202 builder->create<TF::ConstOp>(input.getLoc(), reshape_sizes);
203 return builder->create<TF::ReshapeOp>(input.getLoc(), reshape_result_type,
204 input, reshape_value);
205 }
206
207 // Creates transpose op for shape to depth transform.
GetTransposeOpForConv2DFilter(OpBuilder * builder,Value input)208 TF::TransposeOp GetTransposeOpForConv2DFilter(OpBuilder* builder, Value input) {
209 SmallVector<int32_t, 6> permutation = {0, 2, 1, 3, 4, 5};
210 auto permute_type = RankedTensorType::get({6}, builder->getIntegerType(32));
211 auto permute_attr = DenseIntElementsAttr::get(permute_type, permutation);
212 auto permute_value =
213 builder->create<TF::ConstOp>(input.getLoc(), permute_attr);
214 return builder->create<TF::TransposeOp>(input.getLoc(), input, permute_value);
215 }
216
HandleConv2DFilter(TF::Conv2DOp conv2d,int64_t block_size)217 void HandleConv2DFilter(TF::Conv2DOp conv2d, int64_t block_size) {
218 // For example, if filter shape is [7, 7, 3, 64] with block_size 2,
219 // will apply below transforms to the filter:
220 // 1. Pad the filter to [8, 8, 3, 64]
221 // 2. Reshape to [4, 2, 4, 2, 3, 64]
222 // 3. Transpose to [4, 4, 2, 2, 3, 64]
223 // 4. Reshape to [4, 4, 12, 64]
224 auto filter = conv2d.filter();
225 OpBuilder builder(conv2d);
226 builder.setInsertionPoint(conv2d);
227 // Book keeping filter information.
228 auto filter_shape = filter.getType().cast<RankedTensorType>().getShape();
229 int64_t height = filter_shape[0];
230 int64_t width = filter_shape[1];
231 int64_t channel = filter_shape[2];
232 int64_t out_channel = filter_shape[3];
233 // Value/Op before reshape op.
234 Value before_reshape_value = filter;
235 if (height % block_size != 0 || width % block_size != 0) {
236 // Calculate paddings for height and width.
237 int32_t pad_h = block_size - height % block_size;
238 int32_t pad_w = block_size - width % block_size;
239 auto pad_op =
240 GetPadOpForConv2DFilter(filter_shape, filter, &builder, pad_h, pad_w);
241 // Update op, height and width before reshape.
242 before_reshape_value = pad_op;
243 height = height + pad_h;
244 width = width + pad_w;
245 }
246
247 // Reshape.
248 SmallVector<int64_t, 6> new_shape = {
249 height / block_size, block_size, width / block_size,
250 block_size, channel, out_channel};
251 auto reshape_op =
252 GetReshapeOpForConv2DFilter(new_shape, before_reshape_value, &builder);
253
254 // Transpose.
255 auto transpose_op = GetTransposeOpForConv2DFilter(&builder, reshape_op);
256
257 // Reshape Back.
258 SmallVector<int64_t, 4> final_shape = {
259 height / block_size, width / block_size,
260 channel * block_size * block_size, out_channel};
261 auto final_reshape_op =
262 GetReshapeOpForConv2DFilter(final_shape, transpose_op, &builder);
263 // Update filter of Conv2D.
264 conv2d.setOperand(1, final_reshape_op);
265 }
266
267 // Creates slice op for filter in back prop pass.
GetSliceOpForConv2DBackPropFilter(ArrayRef<int32_t> old_filter_shape,Value input,OpBuilder * builder)268 TF::SliceOp GetSliceOpForConv2DBackPropFilter(
269 ArrayRef<int32_t> old_filter_shape, Value input, OpBuilder* builder) {
270 SmallVector<int64_t, 4> slice_size(old_filter_shape.begin(),
271 old_filter_shape.end());
272 auto slice_result_type =
273 RankedTensorType::get(slice_size, getElementTypeOrSelf(input));
274 auto slice_size_op = builder->create<TF::ConstOp>(
275 input.getLoc(),
276 DenseIntElementsAttr::get(
277 RankedTensorType::get({4}, builder->getIntegerType(32)),
278 old_filter_shape));
279 SmallVector<int64_t, 4> slice_start_position = {0, 0, 0, 0};
280 auto start_position_type =
281 RankedTensorType::get({4}, builder->getIntegerType(64));
282 auto start_position = builder->create<TF::ConstOp>(
283 input.getLoc(),
284 DenseIntElementsAttr::get(start_position_type, slice_start_position));
285 return builder->create<TF::SliceOp>(input.getLoc(), slice_result_type, input,
286 start_position, slice_size_op);
287 }
288
289 // Transforms Conv2DBackPropFilter for space to depth.
HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop,ArrayRef<int32_t> old_filter_shape,ArrayRef<int32_t> new_filter_shape,int64_t block_size)290 void HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop,
291 ArrayRef<int32_t> old_filter_shape,
292 ArrayRef<int32_t> new_filter_shape,
293 int64_t block_size) {
294 OpBuilder builder(backprop);
295 builder.setInsertionPoint(backprop);
296
297 auto input = backprop.input();
298 // Get new filter size from new_filter_shape.
299 auto new_filter_sizes = builder.create<TF::ConstOp>(
300 backprop.getLoc(),
301 DenseIntElementsAttr::get(
302 RankedTensorType::get({4}, builder.getIntegerType(32)),
303 new_filter_shape));
304
305 // Set stride to [1, 1, 1, 1].
306 MLIRContext* context = backprop.getContext();
307 SmallVector<int64_t, 4> values = {1, 1, 1, 1};
308 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
309 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
310 });
311 auto strides = ArrayAttr::get(context, llvm::to_vector<4>(attrs));
312
313 // new result type.
314 SmallVector<int64_t, 4> new_shape(new_filter_shape.begin(),
315 new_filter_shape.end());
316 auto new_result_type =
317 RankedTensorType::get(new_shape, getElementTypeOrSelf(input));
318
319 // Build new BackPropFilterOp.
320 auto loc = backprop.getLoc();
321 auto new_backprop = builder.create<TF::Conv2DBackpropFilterOp>(
322 loc, new_result_type, input, new_filter_sizes, backprop.out_backprop(),
323 strides, backprop.use_cudnn_on_gpu(), backprop.padding(),
324 backprop.explicit_paddings(), backprop.data_format(),
325 backprop.dilations());
326
327 // For example, if new filter shape is [4, 4, 12, 64], old filter shape
328 // is [7, 7, 3, 64] with block_size 2.
329 // Below transforms will be applied to the filter:
330 // 1. Reshape to [4, 4, 2, 2, 3, 64];
331 // 2. Transpose to [4, 2, 4, 2, 3, 64];
332 // 3. Reshape to [8, 8, 3, 64];
333 // 4. Slice to [7, 7, 3, 64].
334 SmallVector<int64_t, 6> first_reshape_shape = {
335 new_filter_shape[0],
336 new_filter_shape[1],
337 block_size,
338 block_size,
339 new_filter_shape[2] / (block_size * block_size),
340 new_filter_shape[3]};
341 auto first_reshape_op =
342 GetReshapeOpForConv2DFilter(first_reshape_shape, new_backprop, &builder);
343
344 // Transpose.
345 auto transpose_op = GetTransposeOpForConv2DFilter(&builder, first_reshape_op);
346
347 // Last Reshape op.
348 SmallVector<int64_t, 4> last_reshape_shape = {
349 new_filter_shape[0] * block_size, new_filter_shape[1] * block_size,
350 new_filter_shape[2] / (block_size * block_size), new_filter_shape[3]};
351 auto final_reshape_op =
352 GetReshapeOpForConv2DFilter(last_reshape_shape, transpose_op, &builder);
353
354 // create slice op.
355 auto slice_op = GetSliceOpForConv2DBackPropFilter(old_filter_shape,
356 final_reshape_op, &builder);
357
358 // Update backprop's user with the slice op.
359 backprop.replaceAllUsesWith(slice_op.getResult());
360 }
361
362 // Checks if the input producer op is supported in this transform. Right now, we
363 // only check if it is a host tf.IteratorGetNext.
IsSupportedHostInputOp(Operation * op)364 bool IsSupportedHostInputOp(Operation* op) {
365 TF::IteratorGetNextOp iter = llvm::dyn_cast<TF::IteratorGetNextOp>(op);
366 if (!iter) return false;
367 auto device = op->getAttrOfType<StringAttr>(kDeviceAttr);
368 if (!device) return false;
369 tensorflow::DeviceNameUtils::ParsedName parsed_device;
370 if (!tensorflow::DeviceNameUtils::ParseFullName(device.getValue().str(),
371 &parsed_device)) {
372 return false;
373 }
374 return parsed_device.type == "CPU";
375 }
376
377 // Builds a SpaceToDepthOp with the given get_layout op and input.
BuildSpaceToDepth(tf_device::ClusterFuncOp cluster_func,Value input,int32_t block_size,ArrayRef<int64_t> input_shape)378 TF::SpaceToDepthOp BuildSpaceToDepth(tf_device::ClusterFuncOp cluster_func,
379 Value input, int32_t block_size,
380 ArrayRef<int64_t> input_shape) {
381 auto input_op = input.getDefiningOp();
382 OpBuilder builder(input_op);
383 builder.setInsertionPointAfter(input_op);
384 SmallVector<int64_t, 4> transform_shape = {
385 input_shape[0], input_shape[1] / block_size, input_shape[2] / block_size,
386 input_shape[3] * block_size * block_size};
387 auto transform_result_type =
388 RankedTensorType::get(transform_shape, getElementTypeOrSelf(input));
389 return builder.create<TF::SpaceToDepthOp>(
390 cluster_func.getLoc(), transform_result_type, input, block_size);
391 }
392
393 // Performs transformation for a non-replicated input.
HandleHostInput(Value input,int64_t index,tf_device::ClusterFuncOp cluster_func,int32_t block_size,ArrayRef<int64_t> input_shape)394 TF::SpaceToDepthOp HandleHostInput(Value input, int64_t index,
395 tf_device::ClusterFuncOp cluster_func,
396 int32_t block_size,
397 ArrayRef<int64_t> input_shape) {
398 auto space_to_depth =
399 BuildSpaceToDepth(cluster_func, input, block_size, input_shape);
400 cluster_func.setOperand(index, space_to_depth);
401 return space_to_depth;
402 }
403
404 // Performs transformation for replicated inputs. Returns true if this is a
405 // supported case (thus transform happened).
HandleHostReplicatedInputs(int64_t index,tf_device::ClusterFuncOp cluster_func,BlockArgument block_arg,tf_device::ReplicateOp replicate,int32_t block_size)406 bool HandleHostReplicatedInputs(int64_t index,
407 tf_device::ClusterFuncOp cluster_func,
408 BlockArgument block_arg,
409 tf_device::ReplicateOp replicate,
410 int32_t block_size) {
411 // We need to know the devices to copy to.
412 if (!replicate.devices()) return false;
413
414 MutableArrayRef<OpOperand> inputs =
415 replicate.GetOperandsForBlockArgument(block_arg);
416 for (auto& input : inputs) {
417 auto input_op = input.get().getDefiningOp();
418 if (!input_op || !IsSupportedHostInputOp(input_op)) return false;
419 }
420 for (auto entry : llvm::enumerate(inputs)) {
421 Value input = entry.value().get();
422 auto ranked_type = input.getType().dyn_cast<RankedTensorType>();
423 if (!ranked_type) return false;
424 auto input_shape = ranked_type.getShape();
425 auto space_to_depth =
426 BuildSpaceToDepth(cluster_func, input, block_size, input_shape);
427 entry.value().set(space_to_depth);
428 block_arg.setType(space_to_depth.getType());
429 }
430 return true;
431 }
432
433 // Performs transformation on a pair of execute and compile ops. The compile
434 // should not have other uses.
HandleCluster(tf_device::ClusterFuncOp cluster_func,int32_t block_size,unsigned arg_num)435 void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size,
436 unsigned arg_num) {
437 auto maybe_replicate =
438 llvm::dyn_cast<tf_device::ReplicateOp>(cluster_func->getParentOp());
439
440 llvm::SmallVector<int64_t, 8> transform_input_indices;
441 for (auto input : llvm::enumerate(cluster_func.operands())) {
442 if (auto block_arg = input.value().dyn_cast<BlockArgument>()) {
443 if (block_arg.getArgNumber() != arg_num) continue;
444 // For a block argument, consider transforms only when it is a replicated
445 // input (defining ops will be outside the replicate node).
446 if (maybe_replicate == block_arg.getParentRegion()->getParentOp()) {
447 HandleHostReplicatedInputs(input.index(), cluster_func, block_arg,
448 maybe_replicate, block_size);
449 }
450 } else {
451 // For an op output, consider transforms only when 1) there is no
452 // replicateion or 2) it is outside the replicate node that encloses the
453 // execute node. (Because if the op is inside replicate, it is probably
454 // not on the host.)
455 if (input.index() != arg_num) continue;
456 auto input_op = input.value().getDefiningOp();
457 if (maybe_replicate &&
458 maybe_replicate.body().isAncestor(input_op->getParentRegion())) {
459 continue;
460 }
461 if (!IsSupportedHostInputOp(input_op)) continue;
462 auto ranked_type = input.value().getType().dyn_cast<RankedTensorType>();
463 if (!ranked_type) continue;
464 auto input_shape = ranked_type.getShape();
465 HandleHostInput(input.value(), input.index(), cluster_func, block_size,
466 input_shape);
467 }
468 }
469 }
470
471 // Checks if input shape of convolution is good for space to depth transform.
Conv2DInputShapeCanTransform(Value input)472 bool Conv2DInputShapeCanTransform(Value input) {
473 auto ranked_type = input.getType().dyn_cast<RankedTensorType>();
474 if (!ranked_type) return false;
475 auto input_shape = ranked_type.getShape();
476 int32_t batch_size = input_shape[0];
477 int32_t channel = input_shape[3];
478 if (batch_size > 8 || channel > 8) {
479 return false;
480 }
481 return true;
482 }
483
484 // Get block argument id and number of users for the input arg.
GetBlockArgNum(Value arg)485 Optional<BlockArgumentInfo> GetBlockArgNum(Value arg) {
486 if (auto block_arg = arg.dyn_cast<mlir::BlockArgument>()) {
487 if (!Conv2DInputShapeCanTransform(arg)) return None;
488 unsigned num_users =
489 std::distance(block_arg.getUsers().begin(), block_arg.getUsers().end());
490 BlockArgumentInfo block_arg_info = {block_arg.getArgNumber(), num_users};
491 return block_arg_info;
492 }
493 return None;
494 }
495
496 // Gets input block argument id and number of users for the input recursively.
497 // Current supported ops between convolution input and the block arguments are
498 // PadOp and CastOp.
GetInputBlockArgNum(Value input)499 Optional<BlockArgumentInfo> GetInputBlockArgNum(Value input) {
500 auto block_arg_num = GetBlockArgNum(input);
501 if (block_arg_num.has_value()) return block_arg_num;
502
503 Value next_input = input;
504 auto pad_op = dyn_cast_or_null<TF::PadOp>(next_input.getDefiningOp());
505 auto cast_op = dyn_cast_or_null<TF::CastOp>(next_input.getDefiningOp());
506
507 while (pad_op || cast_op) {
508 if (pad_op) {
509 auto block_arg_num = GetBlockArgNum(pad_op.input());
510 if (block_arg_num.has_value()) return block_arg_num;
511 next_input = pad_op.input();
512 } else {
513 auto block_arg_num = GetBlockArgNum(cast_op.x());
514 if (block_arg_num.has_value()) return block_arg_num;
515 next_input = cast_op.x();
516 }
517 pad_op = dyn_cast_or_null<TF::PadOp>(next_input.getDefiningOp());
518 cast_op = dyn_cast_or_null<TF::CastOp>(next_input.getDefiningOp());
519 }
520
521 return None;
522 }
523
524 // Checks if a convoluton can apply SpaceToDepth transform.
525 // Only the first convolution in the graph whose batch size smaller than 8
526 // and its input feature size smaller than 8 can be transformed.
GetConv2DInputArgNum(TF::Conv2DOp conv2d)527 Optional<BlockArgumentInfo> GetConv2DInputArgNum(TF::Conv2DOp conv2d) {
528 if (conv2d.data_format() != "NHWC" || conv2d.strides().size() != 4) {
529 return None;
530 }
531 // Current supported ops between convolution input and the block arguments are
532 // PadOp and CastOp.
533 return GetInputBlockArgNum(conv2d.input());
534 }
535
536 // Applies space to depth transform for the first convolution on TPU device.
HandleFirstConvolution(TF::Conv2DOp conv2d,int64_t block_size)537 void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) {
538 // Check if input and filter type are RankedTensorType.
539 auto input_tensor_type =
540 conv2d.input().getType().dyn_cast<RankedTensorType>();
541 auto filter_tensor_type =
542 conv2d.filter().getType().dyn_cast<RankedTensorType>();
543 if (!input_tensor_type || !filter_tensor_type) return;
544 // Book keeping filter shape for padding and backprop filter rewrite.
545 auto filter_shape = filter_tensor_type.getShape();
546 SmallVector<int32_t, 4> old_filter_shape(filter_shape.begin(),
547 filter_shape.end());
548 // Handles input.
549 auto conv2d_input = conv2d.input();
550 if (auto block_arg = conv2d_input.dyn_cast<mlir::BlockArgument>()) {
551 // Change on device function type/shape.
552 HandleFuncOp(block_arg.getOwner()->getParentOp());
553 }
554
555 if (auto pad_op = dyn_cast_or_null<TF::PadOp>(conv2d_input.getDefiningOp())) {
556 // Rewrite pad_op before Convolutioin.
557 if (failed(HandlePad(pad_op, filter_shape[0], block_size))) return;
558 auto pad_input = pad_op.input();
559 if (auto block_arg = pad_input.dyn_cast<mlir::BlockArgument>()) {
560 // Change on device function type/shape.
561 HandleFuncOp(block_arg.getOwner()->getParentOp());
562 }
563 }
564
565 // Handle Conv2D input, stride and filter.
566 HandleConv2DInput(conv2d, block_size);
567 HandleConv2DStride(conv2d);
568 HandleConv2DFilter(conv2d, block_size);
569
570 // Book keeping new filter shape for backprop filter rewrite.
571 // Filter shape is defined in HandleConv2DFilter, thus it is RankedTensorType.
572 filter_shape = conv2d.filter().getType().cast<RankedTensorType>().getShape();
573 SmallVector<int32_t, 4> new_filter_shape(filter_shape.begin(),
574 filter_shape.end());
575
576 // Rewrite Conv2DBackPropFilter that is the user of first convolution's input.
577 if (!conv2d_input.getDefiningOp()) return;
578 for (Operation* user : conv2d_input.getDefiningOp()->getUsers()) {
579 if (auto backprop = dyn_cast<TF::Conv2DBackpropFilterOp>(user)) {
580 HandleConv2DBackPropFilter(backprop, old_filter_shape, new_filter_shape,
581 block_size);
582 }
583 }
584 }
585
586 // Gets block size that is equal to stride from spatial dimension
587 // from convolution.
588 // Space to depth transform won't be triggered if block size <= 1.
GetConv2DBlockSize(TF::Conv2DOp conv2d)589 int32_t GetConv2DBlockSize(TF::Conv2DOp conv2d) {
590 SmallVector<int32_t, 4> strides(4, 1);
591 for (int i = 0; i < 3; ++i) {
592 strides[i] = conv2d.strides()[i].cast<mlir::IntegerAttr>().getInt();
593 }
594
595 // Space to depth only supports striding at spatial dimension.
596 if (strides[0] != 1 || strides[3] != 1) return 1;
597
598 // Space to depth only supports height_stride == width_stride case.
599 if (strides[1] != strides[2]) return 1;
600
601 return strides[1];
602 }
603
runOnOperation()604 void TPUSpaceToDepthPass::runOnOperation() {
605 Optional<tf_device::ClusterFuncOp> cluster_func;
606 // Space to depth only supports training loop.
607 auto func_result = getOperation().walk([&](tf_device::ClusterFuncOp cluster) {
608 cluster_func = cluster;
609 return WalkResult::interrupt();
610 });
611
612 // Return if there is no tf_device::ClusterFuncOp in training loop.
613 if (!func_result.wasInterrupted() || !cluster_func.has_value()) {
614 return;
615 }
616
617 // Get the function on device.
618 auto device_func = cluster_func->getFunc();
619 if (!device_func) return;
620
621 TF::Conv2DOp first_conv;
622 // A map maps block argument id to the convolutions consumes them.
623 llvm::SmallDenseMap<unsigned, std::vector<Conv2DWithBlockSize>>
624 argnum_and_convolutions;
625 // A map maps block argument id to the number of users.
626 llvm::SmallDenseMap<unsigned, int> argnum_num_users;
627
628 // Find out the qualified convolutions and its block argument ids.
629 auto conv2d_result = device_func.walk([&](TF::Conv2DOp conv2d) {
630 Optional<BlockArgumentInfo> arg_num_and_num_users =
631 GetConv2DInputArgNum(conv2d);
632 if (arg_num_and_num_users.has_value()) {
633 // Get block size for the first convolution.
634 int64_t block_size = GetConv2DBlockSize(conv2d);
635 auto arg_num = arg_num_and_num_users.getValue().arg_num;
636 auto num_users = arg_num_and_num_users.getValue().num_users;
637 argnum_and_convolutions[arg_num].emplace_back(conv2d, block_size);
638 argnum_num_users[arg_num] = num_users;
639 return WalkResult::interrupt();
640 }
641 return WalkResult::advance();
642 });
643 if (!conv2d_result.wasInterrupted()) {
644 return;
645 }
646
647 // Iterate through block argument and its convolution users. Space to depth
648 // transform will be applied only if all the below conditions are satisfied:
649 // 1. All the users of the block argument will lead to convolutions;
650 // 2. block_size of for the space to depth transform for these convolutions
651 // are the same;
652 // 3. block_size of for the space to depth transform for these convolutions
653 // are larger than 1.
654 for (auto argnum_and_convolution : argnum_and_convolutions) {
655 auto arg_num = argnum_and_convolution.getFirst();
656 auto conv2d_and_block_sizes = argnum_and_convolution.getSecond();
657 // Continue if number of users of the block argment doesn't equal to number
658 // of transformable convolutions and there is no qualified convolution
659 // for transform or block size is smaller than 2.
660 if (argnum_num_users[arg_num] != conv2d_and_block_sizes.size() ||
661 conv2d_and_block_sizes.empty()) {
662 argnum_and_convolutions.erase(arg_num);
663 continue;
664 }
665 int64_t block_size = conv2d_and_block_sizes[0].second;
666 if (block_size < 2) {
667 argnum_and_convolutions.erase(arg_num);
668 continue;
669 }
670 // Continue if not all the block sizes for space to depth transform are the
671 // same.
672 for (auto conv2d_and_block_size : conv2d_and_block_sizes) {
673 if (conv2d_and_block_size.second != block_size) {
674 argnum_and_convolutions.erase(arg_num);
675 break;
676 }
677 }
678 }
679
680 // If there is no qualified space to depth transform.
681 if (argnum_and_convolutions.empty()) {
682 return;
683 }
684
685 // Apply space to depth transform.
686 for (auto argnum_and_convolution : argnum_and_convolutions) {
687 auto conv2d_and_block_sizes = argnum_and_convolution.getSecond();
688 int64_t block_size = conv2d_and_block_sizes[0].second;
689 // Apply space to depth transform to the input on the host.
690 HandleCluster(cluster_func.getValue(), block_size,
691 argnum_and_convolution.getFirst());
692 // Transform the convolution.
693 for (auto conv2d_and_block_size : conv2d_and_block_sizes) {
694 HandleFirstConvolution(conv2d_and_block_size.first,
695 conv2d_and_block_size.second);
696 }
697 }
698 }
699
700 } // namespace
701
CreateTPUSpaceToDepthPass()702 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUSpaceToDepthPass() {
703 return std::make_unique<TPUSpaceToDepthPass>();
704 }
705
706 } // namespace TFTPU
707 } // namespace mlir
708