1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstdint>
17 #include <iostream>
18 
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/DenseMap.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/Debug.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
26 #include "mlir/IR/Attributes.h"  // from @llvm-project
27 #include "mlir/IR/Builders.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
30 #include "mlir/IR/Location.h"  // from @llvm-project
31 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
32 #include "mlir/IR/Operation.h"  // from @llvm-project
33 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
34 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
35 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
36 #include "mlir/IR/Types.h"  // from @llvm-project
37 #include "mlir/IR/Value.h"  // from @llvm-project
38 #include "mlir/Pass/Pass.h"  // from @llvm-project
39 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
40 #include "mlir/Support/LLVM.h"  // from @llvm-project
41 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
45 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
46 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
48 #include "tensorflow/core/framework/tensor_shape.pb.h"
49 #include "tensorflow/core/util/device_name_utils.h"
50 
51 namespace mlir {
52 namespace TFTPU {
53 
54 namespace {
55 
56 constexpr char kDeviceAttr[] = "device";
57 typedef std::pair<TF::Conv2DOp, int64_t> Conv2DWithBlockSize;
58 
59 struct BlockArgumentInfo {
60   unsigned arg_num;
61   unsigned num_users;
62 };
63 
64 // TODO(wangtao): add a pass to check if it is profitable to space to depth
65 // transform and invoke the transform if it is needed.
66 struct TPUSpaceToDepthPass
67     : public TF::TPUSpaceToDepthPassBase<TPUSpaceToDepthPass> {
68   void runOnOperation() override;
69 };
70 
71 // Updates func argument type to have the updated input shape.
UpdateFuncType(func::FuncOp func)72 void UpdateFuncType(func::FuncOp func) {
73   auto arg_types = func.front().getArgumentTypes();
74   auto result_types = func.front().getTerminator()->getOperandTypes();
75   func.setType(FunctionType::get(func.getContext(), arg_types, result_types));
76 }
77 
HandleFuncOp(Operation * op)78 void HandleFuncOp(Operation* op) {
79   auto func = llvm::cast<func::FuncOp>(op);
80   UpdateFuncType(func);
81 }
82 
83 // Handles cast op between the first convolution and the block argument.
HandleCast(TF::CastOp cast_op,ArrayRef<int64_t> new_shape)84 LogicalResult HandleCast(TF::CastOp cast_op, ArrayRef<int64_t> new_shape) {
85   auto cast_input = cast_op.x();
86   // Update input type.
87   auto transform_result_type =
88       RankedTensorType::get(new_shape, getElementTypeOrSelf(cast_input));
89   cast_input.setType(transform_result_type);
90   auto block_arg = cast_input.dyn_cast<mlir::BlockArgument>();
91   auto cast_op_input = dyn_cast_or_null<TF::CastOp>(cast_input.getDefiningOp());
92   while (block_arg || cast_op_input) {
93     if (block_arg) {
94       // Change on device function type/shape.
95       HandleFuncOp(block_arg.getOwner()->getParentOp());
96       block_arg = nullptr;
97       cast_op_input = nullptr;
98     } else {
99       auto cast_input = cast_op_input.x();
100       // Update input type.
101       auto transform_result_type =
102           RankedTensorType::get(new_shape, getElementTypeOrSelf(cast_input));
103       cast_input.setType(transform_result_type);
104       // Update block arg and cast_op_input.
105       block_arg = cast_input.dyn_cast<mlir::BlockArgument>();
106       cast_op_input = dyn_cast_or_null<TF::CastOp>(cast_input.getDefiningOp());
107     }
108   }
109   return success();
110 }
111 
112 // Handles padding before convolution for space to depth transform.
HandlePad(TF::PadOp op,int32_t kernel_size,int32_t block_size)113 LogicalResult HandlePad(TF::PadOp op, int32_t kernel_size, int32_t block_size) {
114   auto ranked_type = op.input().getType().dyn_cast<RankedTensorType>();
115   if (!ranked_type) return failure();
116   auto pad_input_shape = ranked_type.getShape();
117   Location loc = op.getLoc();
118   OpBuilder builder(op);
119   builder.setInsertionPoint(op);
120   auto padding_type = RankedTensorType::get({4, 2}, builder.getIntegerType(32));
121 
122   // Calculate paddings.
123   int32_t pad_total = kernel_size - 1;
124   int32_t pad_beg = (pad_total / 2 + 1) / block_size;
125   int32_t pad_end = (pad_total / 2) / block_size;
126   SmallVector<int32_t, 8> values = {0,       0,       pad_beg, pad_end,
127                                     pad_beg, pad_end, 0,       0};
128   auto paddings = DenseIntElementsAttr::get(padding_type, values);
129   // Update pad_op paddings.
130   op.setOperand(1, builder.create<TF::ConstOp>(loc, paddings));
131 
132   // Set input type.
133   auto input = op.getOperand(0);
134   SmallVector<int64_t, 4> transform_shape = {
135       pad_input_shape[0], pad_input_shape[1] / block_size,
136       pad_input_shape[2] / block_size,
137       pad_input_shape[3] * block_size * block_size};
138   // Input of the pad op could be a cast op.
139   if (auto cast_op = dyn_cast_or_null<TF::CastOp>(input.getDefiningOp()))
140     if (failed(HandleCast(cast_op, transform_shape))) return failure();
141 
142   auto transform_result_type =
143       RankedTensorType::get(transform_shape, getElementTypeOrSelf(input));
144   input.setType(transform_result_type);
145   op.setOperand(0, input);
146   return success();
147 }
148 
149 // Handles stride for the first convolution for the transform.
HandleConv2DStride(TF::Conv2DOp conv2d)150 void HandleConv2DStride(TF::Conv2DOp conv2d) {
151   MLIRContext* context = conv2d.getContext();
152   SmallVector<int64_t, 4> values = {1, 1, 1, 1};
153   auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
154     return IntegerAttr::get(IntegerType::get(context, 64), v);
155   });
156   // TODO(b/157276506): change type of strides to DenseElementsAttr
157   auto strides = ArrayAttr::get(context, llvm::to_vector<4>(attrs));
158   conv2d->setAttr("strides", strides);
159 }
160 
161 // Transforms input shape for the first convolution.
HandleConv2DInput(TF::Conv2DOp conv2d,int64_t block_size)162 void HandleConv2DInput(TF::Conv2DOp conv2d, int64_t block_size) {
163   auto input = conv2d.input();
164   auto input_shape = input.getType().cast<RankedTensorType>().getShape();
165   SmallVector<int64_t, 4> transform_shape = {
166       input_shape[0], input_shape[1] / block_size, input_shape[2] / block_size,
167       input_shape[3] * block_size * block_size};
168   auto transform_result_type =
169       RankedTensorType::get(transform_shape, getElementTypeOrSelf(input));
170   input.setType(transform_result_type);
171 }
172 
173 // Adds padding for convolution filter for space to depth transform.
GetPadOpForConv2DFilter(ArrayRef<int64_t> filter_shape,Value filter,OpBuilder * builder,int32_t pad_h,int32_t pad_w)174 TF::PadOp GetPadOpForConv2DFilter(ArrayRef<int64_t> filter_shape, Value filter,
175                                   OpBuilder* builder, int32_t pad_h,
176                                   int32_t pad_w) {
177   SmallVector<int32_t, 8> values = {pad_h, 0, pad_w, 0, 0, 0, 0, 0};
178   auto padding_type =
179       RankedTensorType::get({4, 2}, builder->getIntegerType(32));
180   auto paddings = DenseIntElementsAttr::get(padding_type, values);
181   auto paddings_value = builder->create<TF::ConstOp>(filter.getLoc(), paddings);
182   std::vector<int64_t> pad_shape = {filter_shape[0] + pad_h,
183                                     filter_shape[1] + pad_w, filter_shape[2],
184                                     filter_shape[3]};
185   SmallVector<int64_t, 4> expand_shape(pad_shape.begin(), pad_shape.end());
186 
187   auto expand_result_type =
188       RankedTensorType::get(expand_shape, getElementTypeOrSelf(filter));
189   return builder->create<TF::PadOp>(filter.getLoc(), expand_result_type, filter,
190                                     paddings_value);
191 }
192 
193 // Creates reshape op for space to depth transform.
GetReshapeOpForConv2DFilter(ArrayRef<int64_t> new_shape,Value input,OpBuilder * builder)194 TF::ReshapeOp GetReshapeOpForConv2DFilter(ArrayRef<int64_t> new_shape,
195                                           Value input, OpBuilder* builder) {
196   auto reshape_result_type =
197       RankedTensorType::get(new_shape, getElementTypeOrSelf(input));
198   auto reshape_type = RankedTensorType::get(
199       {static_cast<int64_t>(new_shape.size())}, builder->getIntegerType(64));
200   auto reshape_sizes = DenseIntElementsAttr::get(reshape_type, new_shape);
201   auto reshape_value =
202       builder->create<TF::ConstOp>(input.getLoc(), reshape_sizes);
203   return builder->create<TF::ReshapeOp>(input.getLoc(), reshape_result_type,
204                                         input, reshape_value);
205 }
206 
207 // Creates transpose op for shape to depth transform.
GetTransposeOpForConv2DFilter(OpBuilder * builder,Value input)208 TF::TransposeOp GetTransposeOpForConv2DFilter(OpBuilder* builder, Value input) {
209   SmallVector<int32_t, 6> permutation = {0, 2, 1, 3, 4, 5};
210   auto permute_type = RankedTensorType::get({6}, builder->getIntegerType(32));
211   auto permute_attr = DenseIntElementsAttr::get(permute_type, permutation);
212   auto permute_value =
213       builder->create<TF::ConstOp>(input.getLoc(), permute_attr);
214   return builder->create<TF::TransposeOp>(input.getLoc(), input, permute_value);
215 }
216 
HandleConv2DFilter(TF::Conv2DOp conv2d,int64_t block_size)217 void HandleConv2DFilter(TF::Conv2DOp conv2d, int64_t block_size) {
218   // For example, if filter shape is [7, 7, 3, 64] with block_size 2,
219   // will apply below transforms to the filter:
220   // 1. Pad the filter to [8, 8, 3, 64]
221   // 2. Reshape to [4, 2, 4, 2, 3, 64]
222   // 3. Transpose to [4, 4, 2, 2, 3, 64]
223   // 4. Reshape to [4, 4, 12, 64]
224   auto filter = conv2d.filter();
225   OpBuilder builder(conv2d);
226   builder.setInsertionPoint(conv2d);
227   // Book keeping filter information.
228   auto filter_shape = filter.getType().cast<RankedTensorType>().getShape();
229   int64_t height = filter_shape[0];
230   int64_t width = filter_shape[1];
231   int64_t channel = filter_shape[2];
232   int64_t out_channel = filter_shape[3];
233   // Value/Op before reshape op.
234   Value before_reshape_value = filter;
235   if (height % block_size != 0 || width % block_size != 0) {
236     // Calculate paddings for height and width.
237     int32_t pad_h = block_size - height % block_size;
238     int32_t pad_w = block_size - width % block_size;
239     auto pad_op =
240         GetPadOpForConv2DFilter(filter_shape, filter, &builder, pad_h, pad_w);
241     // Update op, height and width before reshape.
242     before_reshape_value = pad_op;
243     height = height + pad_h;
244     width = width + pad_w;
245   }
246 
247   // Reshape.
248   SmallVector<int64_t, 6> new_shape = {
249       height / block_size, block_size, width / block_size,
250       block_size,          channel,    out_channel};
251   auto reshape_op =
252       GetReshapeOpForConv2DFilter(new_shape, before_reshape_value, &builder);
253 
254   // Transpose.
255   auto transpose_op = GetTransposeOpForConv2DFilter(&builder, reshape_op);
256 
257   // Reshape Back.
258   SmallVector<int64_t, 4> final_shape = {
259       height / block_size, width / block_size,
260       channel * block_size * block_size, out_channel};
261   auto final_reshape_op =
262       GetReshapeOpForConv2DFilter(final_shape, transpose_op, &builder);
263   // Update filter of Conv2D.
264   conv2d.setOperand(1, final_reshape_op);
265 }
266 
267 // Creates slice op for filter in back prop pass.
GetSliceOpForConv2DBackPropFilter(ArrayRef<int32_t> old_filter_shape,Value input,OpBuilder * builder)268 TF::SliceOp GetSliceOpForConv2DBackPropFilter(
269     ArrayRef<int32_t> old_filter_shape, Value input, OpBuilder* builder) {
270   SmallVector<int64_t, 4> slice_size(old_filter_shape.begin(),
271                                      old_filter_shape.end());
272   auto slice_result_type =
273       RankedTensorType::get(slice_size, getElementTypeOrSelf(input));
274   auto slice_size_op = builder->create<TF::ConstOp>(
275       input.getLoc(),
276       DenseIntElementsAttr::get(
277           RankedTensorType::get({4}, builder->getIntegerType(32)),
278           old_filter_shape));
279   SmallVector<int64_t, 4> slice_start_position = {0, 0, 0, 0};
280   auto start_position_type =
281       RankedTensorType::get({4}, builder->getIntegerType(64));
282   auto start_position = builder->create<TF::ConstOp>(
283       input.getLoc(),
284       DenseIntElementsAttr::get(start_position_type, slice_start_position));
285   return builder->create<TF::SliceOp>(input.getLoc(), slice_result_type, input,
286                                       start_position, slice_size_op);
287 }
288 
289 // Transforms Conv2DBackPropFilter for space to depth.
HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop,ArrayRef<int32_t> old_filter_shape,ArrayRef<int32_t> new_filter_shape,int64_t block_size)290 void HandleConv2DBackPropFilter(TF::Conv2DBackpropFilterOp backprop,
291                                 ArrayRef<int32_t> old_filter_shape,
292                                 ArrayRef<int32_t> new_filter_shape,
293                                 int64_t block_size) {
294   OpBuilder builder(backprop);
295   builder.setInsertionPoint(backprop);
296 
297   auto input = backprop.input();
298   // Get new filter size from new_filter_shape.
299   auto new_filter_sizes = builder.create<TF::ConstOp>(
300       backprop.getLoc(),
301       DenseIntElementsAttr::get(
302           RankedTensorType::get({4}, builder.getIntegerType(32)),
303           new_filter_shape));
304 
305   // Set stride to [1, 1, 1, 1].
306   MLIRContext* context = backprop.getContext();
307   SmallVector<int64_t, 4> values = {1, 1, 1, 1};
308   auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
309     return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
310   });
311   auto strides = ArrayAttr::get(context, llvm::to_vector<4>(attrs));
312 
313   // new result type.
314   SmallVector<int64_t, 4> new_shape(new_filter_shape.begin(),
315                                     new_filter_shape.end());
316   auto new_result_type =
317       RankedTensorType::get(new_shape, getElementTypeOrSelf(input));
318 
319   // Build new BackPropFilterOp.
320   auto loc = backprop.getLoc();
321   auto new_backprop = builder.create<TF::Conv2DBackpropFilterOp>(
322       loc, new_result_type, input, new_filter_sizes, backprop.out_backprop(),
323       strides, backprop.use_cudnn_on_gpu(), backprop.padding(),
324       backprop.explicit_paddings(), backprop.data_format(),
325       backprop.dilations());
326 
327   // For example, if new filter shape is [4, 4, 12, 64], old filter shape
328   // is [7, 7, 3, 64] with block_size 2.
329   // Below transforms will be applied to the filter:
330   // 1. Reshape to [4, 4, 2, 2, 3, 64];
331   // 2. Transpose to [4, 2, 4, 2, 3, 64];
332   // 3. Reshape to [8, 8, 3, 64];
333   // 4. Slice to [7, 7, 3, 64].
334   SmallVector<int64_t, 6> first_reshape_shape = {
335       new_filter_shape[0],
336       new_filter_shape[1],
337       block_size,
338       block_size,
339       new_filter_shape[2] / (block_size * block_size),
340       new_filter_shape[3]};
341   auto first_reshape_op =
342       GetReshapeOpForConv2DFilter(first_reshape_shape, new_backprop, &builder);
343 
344   // Transpose.
345   auto transpose_op = GetTransposeOpForConv2DFilter(&builder, first_reshape_op);
346 
347   // Last Reshape op.
348   SmallVector<int64_t, 4> last_reshape_shape = {
349       new_filter_shape[0] * block_size, new_filter_shape[1] * block_size,
350       new_filter_shape[2] / (block_size * block_size), new_filter_shape[3]};
351   auto final_reshape_op =
352       GetReshapeOpForConv2DFilter(last_reshape_shape, transpose_op, &builder);
353 
354   // create slice op.
355   auto slice_op = GetSliceOpForConv2DBackPropFilter(old_filter_shape,
356                                                     final_reshape_op, &builder);
357 
358   // Update backprop's user with the slice op.
359   backprop.replaceAllUsesWith(slice_op.getResult());
360 }
361 
362 // Checks if the input producer op is supported in this transform. Right now, we
363 // only check if it is a host tf.IteratorGetNext.
IsSupportedHostInputOp(Operation * op)364 bool IsSupportedHostInputOp(Operation* op) {
365   TF::IteratorGetNextOp iter = llvm::dyn_cast<TF::IteratorGetNextOp>(op);
366   if (!iter) return false;
367   auto device = op->getAttrOfType<StringAttr>(kDeviceAttr);
368   if (!device) return false;
369   tensorflow::DeviceNameUtils::ParsedName parsed_device;
370   if (!tensorflow::DeviceNameUtils::ParseFullName(device.getValue().str(),
371                                                   &parsed_device)) {
372     return false;
373   }
374   return parsed_device.type == "CPU";
375 }
376 
377 // Builds a SpaceToDepthOp with the given get_layout op and input.
BuildSpaceToDepth(tf_device::ClusterFuncOp cluster_func,Value input,int32_t block_size,ArrayRef<int64_t> input_shape)378 TF::SpaceToDepthOp BuildSpaceToDepth(tf_device::ClusterFuncOp cluster_func,
379                                      Value input, int32_t block_size,
380                                      ArrayRef<int64_t> input_shape) {
381   auto input_op = input.getDefiningOp();
382   OpBuilder builder(input_op);
383   builder.setInsertionPointAfter(input_op);
384   SmallVector<int64_t, 4> transform_shape = {
385       input_shape[0], input_shape[1] / block_size, input_shape[2] / block_size,
386       input_shape[3] * block_size * block_size};
387   auto transform_result_type =
388       RankedTensorType::get(transform_shape, getElementTypeOrSelf(input));
389   return builder.create<TF::SpaceToDepthOp>(
390       cluster_func.getLoc(), transform_result_type, input, block_size);
391 }
392 
393 // Performs transformation for a non-replicated input.
HandleHostInput(Value input,int64_t index,tf_device::ClusterFuncOp cluster_func,int32_t block_size,ArrayRef<int64_t> input_shape)394 TF::SpaceToDepthOp HandleHostInput(Value input, int64_t index,
395                                    tf_device::ClusterFuncOp cluster_func,
396                                    int32_t block_size,
397                                    ArrayRef<int64_t> input_shape) {
398   auto space_to_depth =
399       BuildSpaceToDepth(cluster_func, input, block_size, input_shape);
400   cluster_func.setOperand(index, space_to_depth);
401   return space_to_depth;
402 }
403 
404 // Performs transformation for replicated inputs. Returns true if this is a
405 // supported case (thus transform happened).
HandleHostReplicatedInputs(int64_t index,tf_device::ClusterFuncOp cluster_func,BlockArgument block_arg,tf_device::ReplicateOp replicate,int32_t block_size)406 bool HandleHostReplicatedInputs(int64_t index,
407                                 tf_device::ClusterFuncOp cluster_func,
408                                 BlockArgument block_arg,
409                                 tf_device::ReplicateOp replicate,
410                                 int32_t block_size) {
411   // We need to know the devices to copy to.
412   if (!replicate.devices()) return false;
413 
414   MutableArrayRef<OpOperand> inputs =
415       replicate.GetOperandsForBlockArgument(block_arg);
416   for (auto& input : inputs) {
417     auto input_op = input.get().getDefiningOp();
418     if (!input_op || !IsSupportedHostInputOp(input_op)) return false;
419   }
420   for (auto entry : llvm::enumerate(inputs)) {
421     Value input = entry.value().get();
422     auto ranked_type = input.getType().dyn_cast<RankedTensorType>();
423     if (!ranked_type) return false;
424     auto input_shape = ranked_type.getShape();
425     auto space_to_depth =
426         BuildSpaceToDepth(cluster_func, input, block_size, input_shape);
427     entry.value().set(space_to_depth);
428     block_arg.setType(space_to_depth.getType());
429   }
430   return true;
431 }
432 
433 // Performs transformation on a pair of execute and compile ops. The compile
434 // should not have other uses.
HandleCluster(tf_device::ClusterFuncOp cluster_func,int32_t block_size,unsigned arg_num)435 void HandleCluster(tf_device::ClusterFuncOp cluster_func, int32_t block_size,
436                    unsigned arg_num) {
437   auto maybe_replicate =
438       llvm::dyn_cast<tf_device::ReplicateOp>(cluster_func->getParentOp());
439 
440   llvm::SmallVector<int64_t, 8> transform_input_indices;
441   for (auto input : llvm::enumerate(cluster_func.operands())) {
442     if (auto block_arg = input.value().dyn_cast<BlockArgument>()) {
443       if (block_arg.getArgNumber() != arg_num) continue;
444       // For a block argument, consider transforms only when it is a replicated
445       // input (defining ops will be outside the replicate node).
446       if (maybe_replicate == block_arg.getParentRegion()->getParentOp()) {
447         HandleHostReplicatedInputs(input.index(), cluster_func, block_arg,
448                                    maybe_replicate, block_size);
449       }
450     } else {
451       // For an op output, consider transforms only when 1) there is no
452       // replicateion or 2) it is outside the replicate node that encloses the
453       // execute node. (Because if the op is inside replicate, it is probably
454       // not on the host.)
455       if (input.index() != arg_num) continue;
456       auto input_op = input.value().getDefiningOp();
457       if (maybe_replicate &&
458           maybe_replicate.body().isAncestor(input_op->getParentRegion())) {
459         continue;
460       }
461       if (!IsSupportedHostInputOp(input_op)) continue;
462       auto ranked_type = input.value().getType().dyn_cast<RankedTensorType>();
463       if (!ranked_type) continue;
464       auto input_shape = ranked_type.getShape();
465       HandleHostInput(input.value(), input.index(), cluster_func, block_size,
466                       input_shape);
467     }
468   }
469 }
470 
471 // Checks if input shape of convolution is good for space to depth transform.
Conv2DInputShapeCanTransform(Value input)472 bool Conv2DInputShapeCanTransform(Value input) {
473   auto ranked_type = input.getType().dyn_cast<RankedTensorType>();
474   if (!ranked_type) return false;
475   auto input_shape = ranked_type.getShape();
476   int32_t batch_size = input_shape[0];
477   int32_t channel = input_shape[3];
478   if (batch_size > 8 || channel > 8) {
479     return false;
480   }
481   return true;
482 }
483 
484 // Get block argument id and number of users for the input arg.
GetBlockArgNum(Value arg)485 Optional<BlockArgumentInfo> GetBlockArgNum(Value arg) {
486   if (auto block_arg = arg.dyn_cast<mlir::BlockArgument>()) {
487     if (!Conv2DInputShapeCanTransform(arg)) return None;
488     unsigned num_users =
489         std::distance(block_arg.getUsers().begin(), block_arg.getUsers().end());
490     BlockArgumentInfo block_arg_info = {block_arg.getArgNumber(), num_users};
491     return block_arg_info;
492   }
493   return None;
494 }
495 
496 // Gets input block argument id and number of users for the input recursively.
497 // Current supported ops between convolution input and the block arguments are
498 // PadOp and CastOp.
GetInputBlockArgNum(Value input)499 Optional<BlockArgumentInfo> GetInputBlockArgNum(Value input) {
500   auto block_arg_num = GetBlockArgNum(input);
501   if (block_arg_num.has_value()) return block_arg_num;
502 
503   Value next_input = input;
504   auto pad_op = dyn_cast_or_null<TF::PadOp>(next_input.getDefiningOp());
505   auto cast_op = dyn_cast_or_null<TF::CastOp>(next_input.getDefiningOp());
506 
507   while (pad_op || cast_op) {
508     if (pad_op) {
509       auto block_arg_num = GetBlockArgNum(pad_op.input());
510       if (block_arg_num.has_value()) return block_arg_num;
511       next_input = pad_op.input();
512     } else {
513       auto block_arg_num = GetBlockArgNum(cast_op.x());
514       if (block_arg_num.has_value()) return block_arg_num;
515       next_input = cast_op.x();
516     }
517     pad_op = dyn_cast_or_null<TF::PadOp>(next_input.getDefiningOp());
518     cast_op = dyn_cast_or_null<TF::CastOp>(next_input.getDefiningOp());
519   }
520 
521   return None;
522 }
523 
524 // Checks if a convoluton can apply SpaceToDepth transform.
525 // Only the first convolution in the graph whose batch size smaller than 8
526 // and its input feature size smaller than 8 can be transformed.
GetConv2DInputArgNum(TF::Conv2DOp conv2d)527 Optional<BlockArgumentInfo> GetConv2DInputArgNum(TF::Conv2DOp conv2d) {
528   if (conv2d.data_format() != "NHWC" || conv2d.strides().size() != 4) {
529     return None;
530   }
531   // Current supported ops between convolution input and the block arguments are
532   // PadOp and CastOp.
533   return GetInputBlockArgNum(conv2d.input());
534 }
535 
536 // Applies space to depth transform for the first convolution on TPU device.
HandleFirstConvolution(TF::Conv2DOp conv2d,int64_t block_size)537 void HandleFirstConvolution(TF::Conv2DOp conv2d, int64_t block_size) {
538   // Check if input and filter type are RankedTensorType.
539   auto input_tensor_type =
540       conv2d.input().getType().dyn_cast<RankedTensorType>();
541   auto filter_tensor_type =
542       conv2d.filter().getType().dyn_cast<RankedTensorType>();
543   if (!input_tensor_type || !filter_tensor_type) return;
544   // Book keeping filter shape for padding and backprop filter rewrite.
545   auto filter_shape = filter_tensor_type.getShape();
546   SmallVector<int32_t, 4> old_filter_shape(filter_shape.begin(),
547                                            filter_shape.end());
548   // Handles input.
549   auto conv2d_input = conv2d.input();
550   if (auto block_arg = conv2d_input.dyn_cast<mlir::BlockArgument>()) {
551     // Change on device function type/shape.
552     HandleFuncOp(block_arg.getOwner()->getParentOp());
553   }
554 
555   if (auto pad_op = dyn_cast_or_null<TF::PadOp>(conv2d_input.getDefiningOp())) {
556     // Rewrite pad_op before Convolutioin.
557     if (failed(HandlePad(pad_op, filter_shape[0], block_size))) return;
558     auto pad_input = pad_op.input();
559     if (auto block_arg = pad_input.dyn_cast<mlir::BlockArgument>()) {
560       // Change on device function type/shape.
561       HandleFuncOp(block_arg.getOwner()->getParentOp());
562     }
563   }
564 
565   // Handle Conv2D input, stride and filter.
566   HandleConv2DInput(conv2d, block_size);
567   HandleConv2DStride(conv2d);
568   HandleConv2DFilter(conv2d, block_size);
569 
570   // Book keeping new filter shape for backprop filter rewrite.
571   // Filter shape is defined in HandleConv2DFilter, thus it is RankedTensorType.
572   filter_shape = conv2d.filter().getType().cast<RankedTensorType>().getShape();
573   SmallVector<int32_t, 4> new_filter_shape(filter_shape.begin(),
574                                            filter_shape.end());
575 
576   // Rewrite Conv2DBackPropFilter that is the user of first convolution's input.
577   if (!conv2d_input.getDefiningOp()) return;
578   for (Operation* user : conv2d_input.getDefiningOp()->getUsers()) {
579     if (auto backprop = dyn_cast<TF::Conv2DBackpropFilterOp>(user)) {
580       HandleConv2DBackPropFilter(backprop, old_filter_shape, new_filter_shape,
581                                  block_size);
582     }
583   }
584 }
585 
586 // Gets block size that is equal to stride from spatial dimension
587 // from convolution.
588 // Space to depth transform won't be triggered if block size <= 1.
GetConv2DBlockSize(TF::Conv2DOp conv2d)589 int32_t GetConv2DBlockSize(TF::Conv2DOp conv2d) {
590   SmallVector<int32_t, 4> strides(4, 1);
591   for (int i = 0; i < 3; ++i) {
592     strides[i] = conv2d.strides()[i].cast<mlir::IntegerAttr>().getInt();
593   }
594 
595   // Space to depth only supports striding at spatial dimension.
596   if (strides[0] != 1 || strides[3] != 1) return 1;
597 
598   // Space to depth only supports height_stride == width_stride case.
599   if (strides[1] != strides[2]) return 1;
600 
601   return strides[1];
602 }
603 
runOnOperation()604 void TPUSpaceToDepthPass::runOnOperation() {
605   Optional<tf_device::ClusterFuncOp> cluster_func;
606   // Space to depth only supports training loop.
607   auto func_result = getOperation().walk([&](tf_device::ClusterFuncOp cluster) {
608     cluster_func = cluster;
609     return WalkResult::interrupt();
610   });
611 
612   // Return if there is no tf_device::ClusterFuncOp in training loop.
613   if (!func_result.wasInterrupted() || !cluster_func.has_value()) {
614     return;
615   }
616 
617   // Get the function on device.
618   auto device_func = cluster_func->getFunc();
619   if (!device_func) return;
620 
621   TF::Conv2DOp first_conv;
622   // A map maps block argument id to the convolutions consumes them.
623   llvm::SmallDenseMap<unsigned, std::vector<Conv2DWithBlockSize>>
624       argnum_and_convolutions;
625   // A map maps block argument id to the number of users.
626   llvm::SmallDenseMap<unsigned, int> argnum_num_users;
627 
628   // Find out the qualified convolutions and its block argument ids.
629   auto conv2d_result = device_func.walk([&](TF::Conv2DOp conv2d) {
630     Optional<BlockArgumentInfo> arg_num_and_num_users =
631         GetConv2DInputArgNum(conv2d);
632     if (arg_num_and_num_users.has_value()) {
633       // Get block size for the first convolution.
634       int64_t block_size = GetConv2DBlockSize(conv2d);
635       auto arg_num = arg_num_and_num_users.getValue().arg_num;
636       auto num_users = arg_num_and_num_users.getValue().num_users;
637       argnum_and_convolutions[arg_num].emplace_back(conv2d, block_size);
638       argnum_num_users[arg_num] = num_users;
639       return WalkResult::interrupt();
640     }
641     return WalkResult::advance();
642   });
643   if (!conv2d_result.wasInterrupted()) {
644     return;
645   }
646 
647   // Iterate through block argument and its convolution users. Space to depth
648   // transform will be applied only if all the below conditions are satisfied:
649   //  1. All the users of the block argument will lead to convolutions;
650   //  2. block_size of for the space to depth transform for these convolutions
651   //     are the same;
652   //  3. block_size of for the space to depth transform for these convolutions
653   //     are larger than 1.
654   for (auto argnum_and_convolution : argnum_and_convolutions) {
655     auto arg_num = argnum_and_convolution.getFirst();
656     auto conv2d_and_block_sizes = argnum_and_convolution.getSecond();
657     // Continue if number of users of the block argment doesn't equal to number
658     // of transformable convolutions and there is no qualified convolution
659     // for transform or block size is smaller than 2.
660     if (argnum_num_users[arg_num] != conv2d_and_block_sizes.size() ||
661         conv2d_and_block_sizes.empty()) {
662       argnum_and_convolutions.erase(arg_num);
663       continue;
664     }
665     int64_t block_size = conv2d_and_block_sizes[0].second;
666     if (block_size < 2) {
667       argnum_and_convolutions.erase(arg_num);
668       continue;
669     }
670     // Continue if not all the block sizes for space to depth transform are the
671     // same.
672     for (auto conv2d_and_block_size : conv2d_and_block_sizes) {
673       if (conv2d_and_block_size.second != block_size) {
674         argnum_and_convolutions.erase(arg_num);
675         break;
676       }
677     }
678   }
679 
680   // If there is no qualified space to depth transform.
681   if (argnum_and_convolutions.empty()) {
682     return;
683   }
684 
685   // Apply space to depth transform.
686   for (auto argnum_and_convolution : argnum_and_convolutions) {
687     auto conv2d_and_block_sizes = argnum_and_convolution.getSecond();
688     int64_t block_size = conv2d_and_block_sizes[0].second;
689     // Apply space to depth transform to the input on the host.
690     HandleCluster(cluster_func.getValue(), block_size,
691                   argnum_and_convolution.getFirst());
692     // Transform the convolution.
693     for (auto conv2d_and_block_size : conv2d_and_block_sizes) {
694       HandleFirstConvolution(conv2d_and_block_size.first,
695                              conv2d_and_block_size.second);
696     }
697   }
698 }
699 
700 }  // namespace
701 
CreateTPUSpaceToDepthPass()702 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUSpaceToDepthPass() {
703   return std::make_unique<TPUSpaceToDepthPass>();
704 }
705 
706 }  // namespace TFTPU
707 }  // namespace mlir
708