1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <iterator>
18 #include <memory>
19 #include <tuple>
20 #include <utility>
21 
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/ADT/iterator_range.h"
30 #include "llvm/Support/Casting.h"
31 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
32 #include "mlir/IR/Attributes.h"  // from @llvm-project
33 #include "mlir/IR/Builders.h"  // from @llvm-project
34 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
35 #include "mlir/IR/Operation.h"  // from @llvm-project
36 #include "mlir/IR/Types.h"  // from @llvm-project
37 #include "mlir/IR/Value.h"  // from @llvm-project
38 #include "mlir/Pass/Pass.h"  // from @llvm-project
39 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
40 #include "mlir/Support/DebugStringHelper.h"  // from @llvm-project
41 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
42 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
43 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
46 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
48 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
49 
50 using mlir::func::FuncOp;
51 
52 namespace mlir {
53 namespace TF {
54 
55 namespace {
56 
57 constexpr char kDeviceAttr[] = "device";
58 
59 struct DecomposeReduceDatasetPass
60     : public DecomposeReduceDatasetPassBase<DecomposeReduceDatasetPass> {
getDependentDialectsmlir::TF::__anonc43256870111::DecomposeReduceDatasetPass61   void getDependentDialects(DialectRegistry& registry) const override {
62     registry.insert<tf_device::TensorFlowDeviceDialect>();
63   }
64 
65   void runOnOperation() override;
66 };
67 
68 // Create the AnonymousIterator for `reduce_dataset` with `dataset_types` using
69 // `builder`.
CreateIterator(OpBuilder builder,llvm::ArrayRef<Type> dataset_types,ReduceDatasetOp reduce_dataset)70 AnonymousIteratorV3Op CreateIterator(OpBuilder builder,
71                                      llvm::ArrayRef<Type> dataset_types,
72                                      ReduceDatasetOp reduce_dataset) {
73   llvm::SmallVector<Attribute, 2> shape_attrs;
74   llvm::SmallVector<Attribute, 2> type_attrs;
75   for (Type type : dataset_types) {
76     shape_attrs.push_back(
77         TF::ShapeAttr::get(builder.getContext(), type.cast<ShapedType>()));
78     type_attrs.push_back(TypeAttr::get(getElementTypeOrSelf(type)));
79   }
80 
81   auto anonymous_iterator = builder.create<AnonymousIteratorV3Op>(
82       reduce_dataset.getLoc(),
83       RankedTensorType::get({}, builder.getType<ResourceType>()),
84       /*output_types=*/builder.getArrayAttr(type_attrs),
85       /*shape_types=*/builder.getArrayAttr(shape_attrs));
86   builder.create<MakeIteratorOp>(reduce_dataset.getLoc(),
87                                  reduce_dataset.input_dataset(),
88                                  anonymous_iterator.getResult());
89   return anonymous_iterator;
90 }
91 
92 // Create a WhileRegionOp turning `reduce_dataset` into a dataset iteration with
93 // reduce_fn call.
CreateDatasetWhile(OpBuilder builder,ReduceDatasetOp reduce_dataset)94 WhileRegionOp CreateDatasetWhile(OpBuilder builder,
95                                  ReduceDatasetOp reduce_dataset) {
96   auto const_true = builder.create<TF::ConstOp>(
97       reduce_dataset.getLoc(),
98       DenseIntElementsAttr::get(
99           RankedTensorType::get(/*shape=*/{}, builder.getI1Type()), true));
100 
101   SmallVector<Value, 4> while_input_values;
102   SmallVector<Type, 4> while_input_types;
103   while_input_values.push_back(const_true.getResult());
104   while_input_types.push_back(const_true.getResult().getType());
105   for (int i = 1; i < reduce_dataset.getNumOperands(); ++i) {
106     while_input_values.push_back(reduce_dataset.getOperand(i));
107     while_input_types.push_back(reduce_dataset.getOperand(i).getType());
108   }
109 
110   auto dataset_while = builder.create<TF::WhileRegionOp>(
111       reduce_dataset.getLoc(), while_input_types, /*input=*/while_input_values,
112       /*parallel_iterations=*/10, false,
113       /*shape_invariant=*/false);
114 
115   // `_lower_using_switch_merge` is the default for While ops created
116   // in TensorFlow and allows lowering to V1 control flow for loop
117   // parallelization.
118   dataset_while->setAttr("_lower_using_switch_merge",
119                          builder.getBoolAttr(true));
120 
121   return dataset_while;
122 }
123 
124 // Populate the cond of `dataset_while`.  The cond body just returns the
125 // condition of whether to continue to next iteration.
PopulateDatasetWhileCond(OpBuilder builder,WhileRegionOp dataset_while,Location loc)126 void PopulateDatasetWhileCond(OpBuilder builder, WhileRegionOp dataset_while,
127                               Location loc) {
128   auto& cond_region = dataset_while.cond();
129   Block* cond_block = builder.createBlock(&cond_region);
130   auto while_input_types = dataset_while.getOperandTypes();
131   cond_block->addArguments(
132       while_input_types, SmallVector<Location>(while_input_types.size(), loc));
133   builder.create<YieldOp>(loc, cond_block->getArgument(0));
134 }
135 
136 // Create an IfRegionOp with a predicate from `optional_has_value`.  If true, it
137 // uses `get_next` to get the next value and calls `reduce_func`.  `body_args`
138 // is used as pass through of state values for else branch.  `dataset_types` is
139 // used for constructing the CallOp for `reduce_func`.
CreateOptionalDatasetIf(OpBuilder builder,ReduceDatasetOp reduce_dataset,FuncOp reduce_func,IteratorGetNextAsOptionalOp get_next,OptionalHasValueOp optional_has_value,ArrayRef<Value> body_args,ArrayRef<Type> dataset_types)140 IfRegionOp CreateOptionalDatasetIf(
141     OpBuilder builder, ReduceDatasetOp reduce_dataset, FuncOp reduce_func,
142     IteratorGetNextAsOptionalOp get_next, OptionalHasValueOp optional_has_value,
143     ArrayRef<Value> body_args, ArrayRef<Type> dataset_types) {
144   const Location loc = reduce_dataset.getLoc();
145   // If returns are the state variables.
146   SmallVector<Type, 4> if_return_types;
147   const int state_size =
148       reduce_dataset->getAttrOfType<ArrayAttr>("Tstate").size();
149   for (int i = 1; i < state_size + 1; i++) {
150     if_return_types.push_back(reduce_dataset.getOperand(i).getType());
151   }
152 
153   auto dataset_if = builder.create<TF::IfRegionOp>(
154       loc, if_return_types, optional_has_value.getResult(), false,
155       /*_then_func_name=*/nullptr,
156       /*_else_func_name=*/nullptr);
157   // `_lower_using_switch_merge` allows lowering to V1 control flow for loop
158   // parallelization.
159   dataset_if->setAttr("_lower_using_switch_merge", builder.getBoolAttr(true));
160   // Empty else branch, if there is no more data, do nothing.
161   auto& else_branch = dataset_if.else_branch();
162   else_branch.push_back(new Block);
163   builder.setInsertionPointToEnd(&else_branch.front());
164   // Return only the state variables from the body arguments.
165   SmallVector<Value, 4> else_returns;
166   for (int i = 1; i < state_size + 1; i++) {
167     else_returns.push_back(body_args[i]);
168   }
169   builder.create<TF::YieldOp>(loc,
170                               /*operands=*/else_returns);
171 
172   // Then branch gets the data and calls the reduce_function.
173   auto& then_branch = dataset_if.then_branch();
174   then_branch.push_back(new Block);
175   builder.setInsertionPointToEnd(&then_branch.front());
176   // Add iterator operational data access inside if.
177   auto get_value = builder.create<TF::OptionalGetValueOp>(loc, dataset_types,
178                                                           get_next.getResult());
179   SmallVector<Value, 4> reduce_fn_args;
180 
181   // Function arguments are state values, dataset values, and then passthrough
182   // arguments.
183   // First argument to body is the while loop condition and state values start
184   // at index=1.
185   for (int i = 1; i < state_size + 1; ++i) {
186     reduce_fn_args.push_back(body_args[i]);
187   }
188   for (Value value : get_value.getResults()) {
189     reduce_fn_args.push_back(value);
190   }
191   for (int i = state_size + 1; i < body_args.size(); ++i) {
192     reduce_fn_args.push_back(body_args[i]);
193   }
194 
195   auto reduce_call =
196       builder.create<mlir::func::CallOp>(loc, reduce_func, reduce_fn_args);
197 
198   // Both the device attribute and compile_device_type attribute should be
199   // propagated to the reduce function call.
200   reduce_call->setAttr(kDeviceAttr,
201                        reduce_dataset->getAttrOfType<StringAttr>(kDeviceAttr));
202   reduce_call->setAttr(
203       TF::kCompileDeviceTypeAttr,
204       reduce_dataset->getAttrOfType<StringAttr>(TF::kCompileDeviceTypeAttr));
205 
206   SmallVector<Value, 4> if_returns;
207 
208   builder.create<TF::YieldOp>(loc,
209                               /*operands=*/reduce_call.getResults());
210   return dataset_if;
211 }
212 
213 // Populates WhileRegionOp body which is replacing `reduce_dataset`.  Iterates
214 // `anonymous_iterator` with `dataset_types` and optional calls `reduce_func`.
PopulateDatasetWhileBody(OpBuilder builder,ReduceDatasetOp reduce_dataset,FuncOp reduce_func,WhileRegionOp dataset_while,AnonymousIteratorV3Op anonymous_iterator,ArrayRef<Type> dataset_types)215 void PopulateDatasetWhileBody(OpBuilder builder, ReduceDatasetOp reduce_dataset,
216                               FuncOp reduce_func, WhileRegionOp dataset_while,
217                               AnonymousIteratorV3Op anonymous_iterator,
218                               ArrayRef<Type> dataset_types) {
219   const Location loc = reduce_dataset.getLoc();
220   auto while_input_types = dataset_while.getOperandTypes();
221   auto& body_region = dataset_while.body();
222   Block* body_block = builder.createBlock(&body_region);
223   auto body_arguments = body_block->addArguments(
224       while_input_types, SmallVector<Location>(while_input_types.size(), loc));
225   auto get_next = builder.create<IteratorGetNextAsOptionalOp>(
226       loc, RankedTensorType::get({}, builder.getType<VariantType>()),
227       anonymous_iterator.getResult(), anonymous_iterator.output_types(),
228       anonymous_iterator.output_shapes());
229   auto optional_has_value = builder.create<OptionalHasValueOp>(
230       loc, RankedTensorType::get({}, builder.getI1Type()),
231       get_next.getResult());
232 
233   SmallVector<Value, 4> body_args;
234   for (Value value : body_arguments) {
235     body_args.push_back(value);
236   }
237 
238   IfRegionOp dataset_if =
239       CreateOptionalDatasetIf(builder, reduce_dataset, reduce_func, get_next,
240                               optional_has_value, body_args, dataset_types);
241 
242   builder.setInsertionPointToEnd(body_block);
243   // The body returns consist of the loop condition (whether the next iterator
244   // has a value), the state returned by the IfRegionOp, and the pass through
245   // values.
246   SmallVector<Value, 4> body_returns;
247   body_returns.push_back(optional_has_value.getResult());
248 
249   const int state_size =
250       reduce_dataset->getAttrOfType<ArrayAttr>("Tstate").size();
251   for (int i = 0; i < state_size; ++i) {
252     body_returns.push_back(dataset_if.getResult(i));
253   }
254   // Copy the arguments but skip the states and the loop condition
255   // which are updated in while body.
256   for (int i = state_size + 1; i < body_args.size(); ++i) {
257     body_returns.push_back(body_args[i]);
258   }
259   builder.create<TF::YieldOp>(loc,
260                               /*operands=*/body_returns);
261 }
262 
263 // Decomposes any ReduceDatasetOps in `function` into a dataset iteration and a
264 // call to the reduce function in the ReduceDatasetOp.
DecomposeReduceDatasetInFunction(FuncOp function)265 LogicalResult DecomposeReduceDatasetInFunction(FuncOp function) {
266   if (!llvm::hasSingleElement(function))
267     return function.emitOpError("Expecting a single block function");
268 
269   auto decompose_result = function.walk([&](ReduceDatasetOp reduce_dataset) {
270     if (!reduce_dataset->hasAttrOfType<StringAttr>(TF::kCompileDeviceTypeAttr))
271       return WalkResult::advance();
272     OpBuilder builder(reduce_dataset);
273     Location loc = reduce_dataset.getLoc();
274 
275     // Get reduce function signature for dataset iteration types.
276     // Note: lookupSymbol is a linear lookup which means the overall
277     // complexity = # ReduceDataset ops x # of functions in module.
278     func::FuncOp reduce_func =
279         function->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
280             reduce_dataset.f());
281 
282     // The reduce function arguments consist of three part in this order:
283     // 1. Reduction state inputs.
284     // 2. Dataset inputs.
285     // 3. Captures inputs.
286     // The number of dataset inputs can be indirectly determined to be
287     // total_number_of_inputs - state_inputs - captured_inputs.
288     auto func_inputs = reduce_func.getFunctionType().getInputs();
289     const int func_input_size = func_inputs.size();
290     const int argument_size =
291         reduce_dataset->getAttrOfType<ArrayAttr>("Targuments").size();
292     const int state_size =
293         reduce_dataset->getAttrOfType<ArrayAttr>("Tstate").size();
294     const int dataset_input_size = func_input_size - state_size - argument_size;
295 
296     SmallVector<Type, 2> dataset_types;
297     for (int i = 0; i < dataset_input_size; ++i) {
298       dataset_types.push_back(func_inputs[state_size + i]);
299     }
300 
301     // Create dataset iterator and iterate dataset in while loop which calls
302     // reduce_fn.
303     AnonymousIteratorV3Op anonymous_iterator =
304         CreateIterator(builder, dataset_types, reduce_dataset);
305     WhileRegionOp dataset_while = CreateDatasetWhile(builder, reduce_dataset);
306     PopulateDatasetWhileCond(builder, dataset_while, loc);
307     PopulateDatasetWhileBody(builder, reduce_dataset, reduce_func,
308                              dataset_while, anonymous_iterator, dataset_types);
309 
310     // Updates usage and erases rewritten reduce_dataset op based on the number
311     // of state variables.
312     for (int i = 0; i < state_size; ++i) {
313       reduce_dataset.getResult(i).replaceAllUsesWith(
314           dataset_while.getResult(i + 1));
315     }
316     reduce_dataset.erase();
317 
318     return WalkResult::advance();
319   });
320 
321   return failure(decompose_result.wasInterrupted());
322 }
323 
runOnOperation()324 void DecomposeReduceDatasetPass::runOnOperation() {
325   if (failed(DecomposeReduceDatasetInFunction(getOperation()))) {
326     return signalPassFailure();
327   }
328 }
329 
330 }  // anonymous namespace
331 
332 std::unique_ptr<OperationPass<func::FuncOp>>
CreateDecomposeReduceDatasetPass()333 CreateDecomposeReduceDatasetPass() {
334   return std::make_unique<DecomposeReduceDatasetPass>();
335 }
336 
337 }  // namespace TF
338 }  // namespace mlir
339