1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <iterator>
18 #include <memory>
19 #include <tuple>
20 #include <utility>
21
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/DenseSet.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/StringRef.h"
29 #include "llvm/ADT/iterator_range.h"
30 #include "llvm/Support/Casting.h"
31 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
32 #include "mlir/IR/Attributes.h" // from @llvm-project
33 #include "mlir/IR/Builders.h" // from @llvm-project
34 #include "mlir/IR/MLIRContext.h" // from @llvm-project
35 #include "mlir/IR/Operation.h" // from @llvm-project
36 #include "mlir/IR/Types.h" // from @llvm-project
37 #include "mlir/IR/Value.h" // from @llvm-project
38 #include "mlir/Pass/Pass.h" // from @llvm-project
39 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
40 #include "mlir/Support/DebugStringHelper.h" // from @llvm-project
41 #include "mlir/Support/LogicalResult.h" // from @llvm-project
42 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
43 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
46 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/attribute_utils.h"
48 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
49
50 using mlir::func::FuncOp;
51
52 namespace mlir {
53 namespace TF {
54
55 namespace {
56
57 constexpr char kDeviceAttr[] = "device";
58
59 struct DecomposeReduceDatasetPass
60 : public DecomposeReduceDatasetPassBase<DecomposeReduceDatasetPass> {
getDependentDialectsmlir::TF::__anonc43256870111::DecomposeReduceDatasetPass61 void getDependentDialects(DialectRegistry& registry) const override {
62 registry.insert<tf_device::TensorFlowDeviceDialect>();
63 }
64
65 void runOnOperation() override;
66 };
67
68 // Create the AnonymousIterator for `reduce_dataset` with `dataset_types` using
69 // `builder`.
CreateIterator(OpBuilder builder,llvm::ArrayRef<Type> dataset_types,ReduceDatasetOp reduce_dataset)70 AnonymousIteratorV3Op CreateIterator(OpBuilder builder,
71 llvm::ArrayRef<Type> dataset_types,
72 ReduceDatasetOp reduce_dataset) {
73 llvm::SmallVector<Attribute, 2> shape_attrs;
74 llvm::SmallVector<Attribute, 2> type_attrs;
75 for (Type type : dataset_types) {
76 shape_attrs.push_back(
77 TF::ShapeAttr::get(builder.getContext(), type.cast<ShapedType>()));
78 type_attrs.push_back(TypeAttr::get(getElementTypeOrSelf(type)));
79 }
80
81 auto anonymous_iterator = builder.create<AnonymousIteratorV3Op>(
82 reduce_dataset.getLoc(),
83 RankedTensorType::get({}, builder.getType<ResourceType>()),
84 /*output_types=*/builder.getArrayAttr(type_attrs),
85 /*shape_types=*/builder.getArrayAttr(shape_attrs));
86 builder.create<MakeIteratorOp>(reduce_dataset.getLoc(),
87 reduce_dataset.input_dataset(),
88 anonymous_iterator.getResult());
89 return anonymous_iterator;
90 }
91
92 // Create a WhileRegionOp turning `reduce_dataset` into a dataset iteration with
93 // reduce_fn call.
CreateDatasetWhile(OpBuilder builder,ReduceDatasetOp reduce_dataset)94 WhileRegionOp CreateDatasetWhile(OpBuilder builder,
95 ReduceDatasetOp reduce_dataset) {
96 auto const_true = builder.create<TF::ConstOp>(
97 reduce_dataset.getLoc(),
98 DenseIntElementsAttr::get(
99 RankedTensorType::get(/*shape=*/{}, builder.getI1Type()), true));
100
101 SmallVector<Value, 4> while_input_values;
102 SmallVector<Type, 4> while_input_types;
103 while_input_values.push_back(const_true.getResult());
104 while_input_types.push_back(const_true.getResult().getType());
105 for (int i = 1; i < reduce_dataset.getNumOperands(); ++i) {
106 while_input_values.push_back(reduce_dataset.getOperand(i));
107 while_input_types.push_back(reduce_dataset.getOperand(i).getType());
108 }
109
110 auto dataset_while = builder.create<TF::WhileRegionOp>(
111 reduce_dataset.getLoc(), while_input_types, /*input=*/while_input_values,
112 /*parallel_iterations=*/10, false,
113 /*shape_invariant=*/false);
114
115 // `_lower_using_switch_merge` is the default for While ops created
116 // in TensorFlow and allows lowering to V1 control flow for loop
117 // parallelization.
118 dataset_while->setAttr("_lower_using_switch_merge",
119 builder.getBoolAttr(true));
120
121 return dataset_while;
122 }
123
124 // Populate the cond of `dataset_while`. The cond body just returns the
125 // condition of whether to continue to next iteration.
PopulateDatasetWhileCond(OpBuilder builder,WhileRegionOp dataset_while,Location loc)126 void PopulateDatasetWhileCond(OpBuilder builder, WhileRegionOp dataset_while,
127 Location loc) {
128 auto& cond_region = dataset_while.cond();
129 Block* cond_block = builder.createBlock(&cond_region);
130 auto while_input_types = dataset_while.getOperandTypes();
131 cond_block->addArguments(
132 while_input_types, SmallVector<Location>(while_input_types.size(), loc));
133 builder.create<YieldOp>(loc, cond_block->getArgument(0));
134 }
135
136 // Create an IfRegionOp with a predicate from `optional_has_value`. If true, it
137 // uses `get_next` to get the next value and calls `reduce_func`. `body_args`
138 // is used as pass through of state values for else branch. `dataset_types` is
139 // used for constructing the CallOp for `reduce_func`.
CreateOptionalDatasetIf(OpBuilder builder,ReduceDatasetOp reduce_dataset,FuncOp reduce_func,IteratorGetNextAsOptionalOp get_next,OptionalHasValueOp optional_has_value,ArrayRef<Value> body_args,ArrayRef<Type> dataset_types)140 IfRegionOp CreateOptionalDatasetIf(
141 OpBuilder builder, ReduceDatasetOp reduce_dataset, FuncOp reduce_func,
142 IteratorGetNextAsOptionalOp get_next, OptionalHasValueOp optional_has_value,
143 ArrayRef<Value> body_args, ArrayRef<Type> dataset_types) {
144 const Location loc = reduce_dataset.getLoc();
145 // If returns are the state variables.
146 SmallVector<Type, 4> if_return_types;
147 const int state_size =
148 reduce_dataset->getAttrOfType<ArrayAttr>("Tstate").size();
149 for (int i = 1; i < state_size + 1; i++) {
150 if_return_types.push_back(reduce_dataset.getOperand(i).getType());
151 }
152
153 auto dataset_if = builder.create<TF::IfRegionOp>(
154 loc, if_return_types, optional_has_value.getResult(), false,
155 /*_then_func_name=*/nullptr,
156 /*_else_func_name=*/nullptr);
157 // `_lower_using_switch_merge` allows lowering to V1 control flow for loop
158 // parallelization.
159 dataset_if->setAttr("_lower_using_switch_merge", builder.getBoolAttr(true));
160 // Empty else branch, if there is no more data, do nothing.
161 auto& else_branch = dataset_if.else_branch();
162 else_branch.push_back(new Block);
163 builder.setInsertionPointToEnd(&else_branch.front());
164 // Return only the state variables from the body arguments.
165 SmallVector<Value, 4> else_returns;
166 for (int i = 1; i < state_size + 1; i++) {
167 else_returns.push_back(body_args[i]);
168 }
169 builder.create<TF::YieldOp>(loc,
170 /*operands=*/else_returns);
171
172 // Then branch gets the data and calls the reduce_function.
173 auto& then_branch = dataset_if.then_branch();
174 then_branch.push_back(new Block);
175 builder.setInsertionPointToEnd(&then_branch.front());
176 // Add iterator operational data access inside if.
177 auto get_value = builder.create<TF::OptionalGetValueOp>(loc, dataset_types,
178 get_next.getResult());
179 SmallVector<Value, 4> reduce_fn_args;
180
181 // Function arguments are state values, dataset values, and then passthrough
182 // arguments.
183 // First argument to body is the while loop condition and state values start
184 // at index=1.
185 for (int i = 1; i < state_size + 1; ++i) {
186 reduce_fn_args.push_back(body_args[i]);
187 }
188 for (Value value : get_value.getResults()) {
189 reduce_fn_args.push_back(value);
190 }
191 for (int i = state_size + 1; i < body_args.size(); ++i) {
192 reduce_fn_args.push_back(body_args[i]);
193 }
194
195 auto reduce_call =
196 builder.create<mlir::func::CallOp>(loc, reduce_func, reduce_fn_args);
197
198 // Both the device attribute and compile_device_type attribute should be
199 // propagated to the reduce function call.
200 reduce_call->setAttr(kDeviceAttr,
201 reduce_dataset->getAttrOfType<StringAttr>(kDeviceAttr));
202 reduce_call->setAttr(
203 TF::kCompileDeviceTypeAttr,
204 reduce_dataset->getAttrOfType<StringAttr>(TF::kCompileDeviceTypeAttr));
205
206 SmallVector<Value, 4> if_returns;
207
208 builder.create<TF::YieldOp>(loc,
209 /*operands=*/reduce_call.getResults());
210 return dataset_if;
211 }
212
213 // Populates WhileRegionOp body which is replacing `reduce_dataset`. Iterates
214 // `anonymous_iterator` with `dataset_types` and optional calls `reduce_func`.
PopulateDatasetWhileBody(OpBuilder builder,ReduceDatasetOp reduce_dataset,FuncOp reduce_func,WhileRegionOp dataset_while,AnonymousIteratorV3Op anonymous_iterator,ArrayRef<Type> dataset_types)215 void PopulateDatasetWhileBody(OpBuilder builder, ReduceDatasetOp reduce_dataset,
216 FuncOp reduce_func, WhileRegionOp dataset_while,
217 AnonymousIteratorV3Op anonymous_iterator,
218 ArrayRef<Type> dataset_types) {
219 const Location loc = reduce_dataset.getLoc();
220 auto while_input_types = dataset_while.getOperandTypes();
221 auto& body_region = dataset_while.body();
222 Block* body_block = builder.createBlock(&body_region);
223 auto body_arguments = body_block->addArguments(
224 while_input_types, SmallVector<Location>(while_input_types.size(), loc));
225 auto get_next = builder.create<IteratorGetNextAsOptionalOp>(
226 loc, RankedTensorType::get({}, builder.getType<VariantType>()),
227 anonymous_iterator.getResult(), anonymous_iterator.output_types(),
228 anonymous_iterator.output_shapes());
229 auto optional_has_value = builder.create<OptionalHasValueOp>(
230 loc, RankedTensorType::get({}, builder.getI1Type()),
231 get_next.getResult());
232
233 SmallVector<Value, 4> body_args;
234 for (Value value : body_arguments) {
235 body_args.push_back(value);
236 }
237
238 IfRegionOp dataset_if =
239 CreateOptionalDatasetIf(builder, reduce_dataset, reduce_func, get_next,
240 optional_has_value, body_args, dataset_types);
241
242 builder.setInsertionPointToEnd(body_block);
243 // The body returns consist of the loop condition (whether the next iterator
244 // has a value), the state returned by the IfRegionOp, and the pass through
245 // values.
246 SmallVector<Value, 4> body_returns;
247 body_returns.push_back(optional_has_value.getResult());
248
249 const int state_size =
250 reduce_dataset->getAttrOfType<ArrayAttr>("Tstate").size();
251 for (int i = 0; i < state_size; ++i) {
252 body_returns.push_back(dataset_if.getResult(i));
253 }
254 // Copy the arguments but skip the states and the loop condition
255 // which are updated in while body.
256 for (int i = state_size + 1; i < body_args.size(); ++i) {
257 body_returns.push_back(body_args[i]);
258 }
259 builder.create<TF::YieldOp>(loc,
260 /*operands=*/body_returns);
261 }
262
263 // Decomposes any ReduceDatasetOps in `function` into a dataset iteration and a
264 // call to the reduce function in the ReduceDatasetOp.
DecomposeReduceDatasetInFunction(FuncOp function)265 LogicalResult DecomposeReduceDatasetInFunction(FuncOp function) {
266 if (!llvm::hasSingleElement(function))
267 return function.emitOpError("Expecting a single block function");
268
269 auto decompose_result = function.walk([&](ReduceDatasetOp reduce_dataset) {
270 if (!reduce_dataset->hasAttrOfType<StringAttr>(TF::kCompileDeviceTypeAttr))
271 return WalkResult::advance();
272 OpBuilder builder(reduce_dataset);
273 Location loc = reduce_dataset.getLoc();
274
275 // Get reduce function signature for dataset iteration types.
276 // Note: lookupSymbol is a linear lookup which means the overall
277 // complexity = # ReduceDataset ops x # of functions in module.
278 func::FuncOp reduce_func =
279 function->getParentOfType<ModuleOp>().lookupSymbol<FuncOp>(
280 reduce_dataset.f());
281
282 // The reduce function arguments consist of three part in this order:
283 // 1. Reduction state inputs.
284 // 2. Dataset inputs.
285 // 3. Captures inputs.
286 // The number of dataset inputs can be indirectly determined to be
287 // total_number_of_inputs - state_inputs - captured_inputs.
288 auto func_inputs = reduce_func.getFunctionType().getInputs();
289 const int func_input_size = func_inputs.size();
290 const int argument_size =
291 reduce_dataset->getAttrOfType<ArrayAttr>("Targuments").size();
292 const int state_size =
293 reduce_dataset->getAttrOfType<ArrayAttr>("Tstate").size();
294 const int dataset_input_size = func_input_size - state_size - argument_size;
295
296 SmallVector<Type, 2> dataset_types;
297 for (int i = 0; i < dataset_input_size; ++i) {
298 dataset_types.push_back(func_inputs[state_size + i]);
299 }
300
301 // Create dataset iterator and iterate dataset in while loop which calls
302 // reduce_fn.
303 AnonymousIteratorV3Op anonymous_iterator =
304 CreateIterator(builder, dataset_types, reduce_dataset);
305 WhileRegionOp dataset_while = CreateDatasetWhile(builder, reduce_dataset);
306 PopulateDatasetWhileCond(builder, dataset_while, loc);
307 PopulateDatasetWhileBody(builder, reduce_dataset, reduce_func,
308 dataset_while, anonymous_iterator, dataset_types);
309
310 // Updates usage and erases rewritten reduce_dataset op based on the number
311 // of state variables.
312 for (int i = 0; i < state_size; ++i) {
313 reduce_dataset.getResult(i).replaceAllUsesWith(
314 dataset_while.getResult(i + 1));
315 }
316 reduce_dataset.erase();
317
318 return WalkResult::advance();
319 });
320
321 return failure(decompose_result.wasInterrupted());
322 }
323
runOnOperation()324 void DecomposeReduceDatasetPass::runOnOperation() {
325 if (failed(DecomposeReduceDatasetInFunction(getOperation()))) {
326 return signalPassFailure();
327 }
328 }
329
330 } // anonymous namespace
331
332 std::unique_ptr<OperationPass<func::FuncOp>>
CreateDecomposeReduceDatasetPass()333 CreateDecomposeReduceDatasetPass() {
334 return std::make_unique<DecomposeReduceDatasetPass>();
335 }
336
337 } // namespace TF
338 } // namespace mlir
339