1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/DenseMap.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringMap.h"
20 #include "llvm/Support/Casting.h"
21 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
23 #include "mlir/IR/Attributes.h" // from @llvm-project
24 #include "mlir/IR/Builders.h" // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
27 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
28 #include "mlir/Pass/Pass.h" // from @llvm-project
29 #include "mlir/Support/LogicalResult.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
33 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
34 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
35 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
36 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
37 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
38 #include "tensorflow/core/framework/tensor.h"
39 #include "tensorflow/core/framework/types.pb.h"
40 #include "tensorflow/core/platform/types.h"
41
42 namespace mlir {
43
44 namespace {
45
46 namespace cutil = TF::collection_ops_util;
47
48 struct TensorListOpsDecompositionPass
49 : public TF::TensorListOpsDecompositionPassBase<
50 TensorListOpsDecompositionPass> {
51 void runOnOperation() override;
52 };
53
54 // Updates func's type according to its current arguments and return values.
UpdateFuncType(func::FuncOp func)55 void UpdateFuncType(func::FuncOp func) {
56 llvm::SmallVector<Type, 8> arg_types;
57 for (auto arg : func.getArguments()) arg_types.push_back(arg.getType());
58 func.setType(
59 FunctionType::get(func.getContext(), arg_types,
60 func.front().getTerminator()->getOperandTypes()));
61 }
62
63 // Holds the size value of a tensor list and whether the size is statically
64 // known (fixed).
65 struct SizeInfo {
66 Value size;
67 bool fixed;
68 };
69
70 // Modifies a function's signature to rewrite tensor list arguments to buffers
71 // and sizes.
ModifyFunctionSignature(func::FuncOp func,Type size_type,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::function_ref<llvm::Optional<Type> (int64_t)> arg_to_buffer_type,llvm::function_ref<bool (int64_t)> arg_buffer_size_is_fixed)72 void ModifyFunctionSignature(
73 func::FuncOp func, Type size_type,
74 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
75 llvm::function_ref<llvm::Optional<Type>(int64_t)> arg_to_buffer_type,
76 llvm::function_ref<bool(int64_t)> arg_buffer_size_is_fixed) {
77 auto new_input_types = llvm::to_vector<8>(func.getFunctionType().getInputs());
78 int64_t original_arg_count = new_input_types.size();
79 Location loc = func.getLoc();
80 for (int64_t i = 0; i < original_arg_count; ++i) {
81 auto buffer_type = arg_to_buffer_type(i);
82 if (!buffer_type.has_value()) continue;
83 func.getArgument(i).setType(*buffer_type);
84 new_input_types[i] = *buffer_type;
85 auto size_arg = func.front().addArgument(size_type, loc);
86 new_input_types.push_back(size_arg.getType());
87 if (buffer_to_size) {
88 (*buffer_to_size)[func.getArgument(i)] = {size_arg,
89 arg_buffer_size_is_fixed(i)};
90 }
91 }
92 UpdateFuncType(func);
93 }
94
95 // Holds information about a decomposed callee function for
96 // PartitionedCall/StatefulPartitionedCall.
97 struct PartitionedCallDecompositionInfo {
98 bool signature_change;
99 func::FuncOp decomposed_callee;
100 llvm::SmallDenseMap<int64_t, int64_t> buffer_arg_to_size_arg;
101 // Each element is a tuple of (buffer_return_index, size_return_index,
102 // fixed_size).
103 llvm::SmallVector<std::tuple<int64_t, int64_t, bool>, 8>
104 buffer_ret_to_size_ret;
105 };
106
107 LogicalResult DecomposeTensorListOpsInternal(
108 Block*, ModuleOp, llvm::SmallDenseMap<Value, SizeInfo>*,
109 llvm::StringMap<PartitionedCallDecompositionInfo>*);
110
111 // Adds the corresponding sizes of tensor list buffers in block's terminator
112 // to the list of return values. Returns the mapping from the buffer
113 // indices to the added size indices, which is a list of tuples
114 // (buffer_return_index, size_return_index, fixed_size).
115 template <class TerminatorOp>
116 llvm::SmallVector<std::tuple<int64_t, int64_t, bool>, 8>
AddTensorListSizesToTerminator(Block & block,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)117 AddTensorListSizesToTerminator(
118 Block& block, const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
119 auto old_terminator = block.getTerminator();
120 auto new_outputs = llvm::to_vector<8>(old_terminator->getOperands());
121 llvm::SmallVector<std::tuple<int64_t, int64_t, bool>, 8>
122 output_buffer_to_size;
123 for (auto retval : llvm::enumerate(old_terminator->getOperands())) {
124 auto it = buffer_to_size.find(retval.value());
125 if (it == buffer_to_size.end()) continue;
126 output_buffer_to_size.emplace_back(retval.index(), new_outputs.size(),
127 it->getSecond().fixed);
128 new_outputs.push_back(it->getSecond().size);
129 }
130 OpBuilder(old_terminator)
131 .create<TerminatorOp>(old_terminator->getLoc(), new_outputs);
132 old_terminator->erase();
133 return output_buffer_to_size;
134 }
135
136 // Adds the corresponding sizes of tensor list buffers in func's return values
137 // to the list of return values. Returns the mapping from the buffer indices to
138 // the added size indices, which is a list of tuples (buffer_return_index,
139 // size_return_index, fixed_size).
ModifyFunctionReturn(func::FuncOp func,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)140 llvm::SmallVector<std::tuple<int64_t, int64_t, bool>, 8> ModifyFunctionReturn(
141 func::FuncOp func,
142 const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
143 auto output_buffer_to_size = AddTensorListSizesToTerminator<func::ReturnOp>(
144 func.front(), buffer_to_size);
145 UpdateFuncType(func);
146 return output_buffer_to_size;
147 }
148
HandleWhileOp(TF::WhileOp while_op,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)149 LogicalResult HandleWhileOp(
150 TF::WhileOp while_op, ModuleOp module,
151 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
152 llvm::StringMap<PartitionedCallDecompositionInfo>*
153 decomposed_partitioned_call_callees) {
154 // Rewrite body.
155 auto body = while_op.body_function();
156 llvm::SmallDenseMap<Value, SizeInfo> body_map;
157 auto find_arg_tensor_list_type = [&](int64_t index) -> llvm::Optional<Type> {
158 auto it = buffer_to_size->find(while_op.getOperand(index));
159 if (it == buffer_to_size->end()) return llvm::None;
160 return it->getFirst().getType();
161 };
162 auto arg_buffer_size_is_fixed = [&](int64_t index) {
163 return (*buffer_to_size)[while_op.getOperand(index)].fixed;
164 };
165 OpBuilder builder(while_op);
166 ModifyFunctionSignature(body, cutil::GetSizeType(builder), &body_map,
167 find_arg_tensor_list_type, arg_buffer_size_is_fixed);
168 if (failed(DecomposeTensorListOpsInternal(
169 &body.front(), module, &body_map,
170 decomposed_partitioned_call_callees))) {
171 return failure();
172 }
173 auto output_buffer_to_size = ModifyFunctionReturn(body, body_map);
174
175 // Rewrite cond.
176 auto cond = while_op.cond_function();
177 llvm::SmallDenseMap<Value, SizeInfo> cond_map;
178 ModifyFunctionSignature(cond, cutil::GetSizeType(builder), &cond_map,
179 find_arg_tensor_list_type, arg_buffer_size_is_fixed);
180 if (failed(DecomposeTensorListOpsInternal(
181 &cond.front(), module, &cond_map,
182 decomposed_partitioned_call_callees))) {
183 return failure();
184 }
185 if (output_buffer_to_size.empty()) {
186 return success();
187 }
188 // Create the new while op.
189 auto new_while_operands = llvm::to_vector<8>(while_op.getOperands());
190 for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
191 auto it = buffer_to_size->find(while_op.getOperand(i));
192 if (it == buffer_to_size->end()) continue;
193 new_while_operands.push_back(it->getSecond().size);
194 }
195 auto new_while = builder.create<TF::WhileOp>(
196 while_op.getLoc(), body.getFunctionType().getInputs(), new_while_operands,
197 while_op->getAttrs());
198 for (const auto& entry : output_buffer_to_size) {
199 (*buffer_to_size)[new_while.getResult(std::get<0>(entry))] = {
200 new_while.getResult(std::get<1>(entry)), std::get<2>(entry)};
201 }
202 while_op.replaceAllUsesWith(
203 new_while.getResults().take_front(while_op.getNumResults()));
204 while_op.erase();
205 return success();
206 }
207
208 template <class CaseOrIfOp>
HandleCaseOrIfOp(CaseOrIfOp op,ArrayRef<func::FuncOp> branches,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)209 LogicalResult HandleCaseOrIfOp(
210 CaseOrIfOp op, ArrayRef<func::FuncOp> branches, ModuleOp module,
211 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
212 llvm::StringMap<PartitionedCallDecompositionInfo>*
213 decomposed_partitioned_call_callees) {
214 // Rewrite the branches.
215 SmallVector<llvm::SmallDenseMap<Value, SizeInfo>, 2> branch_maps;
216 branch_maps.resize(branches.size());
217
218 auto find_arg_buffer_type = [&](int64_t index) -> llvm::Optional<Type> {
219 auto it = buffer_to_size->find(op.getOperand(index + 1));
220 if (it == buffer_to_size->end()) return llvm::None;
221 return it->getFirst().getType();
222 };
223 auto arg_buffer_size_is_fixed = [&](int64_t index) {
224 return (*buffer_to_size)[op.getOperand(index + 1)].fixed;
225 };
226 OpBuilder builder(op);
227 for (const auto& pair : llvm::zip(branches, branch_maps)) {
228 func::FuncOp branch = std::get<0>(pair);
229 llvm::SmallDenseMap<Value, SizeInfo>& branch_map = std::get<1>(pair);
230 ModifyFunctionSignature(branch, cutil::GetSizeType(builder), &branch_map,
231 find_arg_buffer_type, arg_buffer_size_is_fixed);
232
233 if (failed(DecomposeTensorListOpsInternal(
234 &branch.front(), module, &branch_map,
235 decomposed_partitioned_call_callees)))
236 return failure();
237 }
238
239 const bool arg_no_changed = branch_maps.front().empty();
240 auto output_buffer_to_size =
241 ModifyFunctionReturn(branches.front(), branch_maps.front());
242 for (const auto& pair : llvm::drop_begin(llvm::zip(branches, branch_maps), 1))
243 ModifyFunctionReturn(std::get<0>(pair), std::get<1>(pair));
244
245 if (output_buffer_to_size.empty() && arg_no_changed) return success();
246
247 // Recreate the op.
248 auto new_operands = llvm::to_vector<8>(op.getOperands());
249 for (int64_t i = 1; i < op.getNumOperands(); ++i) {
250 auto it = buffer_to_size->find(op.getOperand(i));
251 if (it == buffer_to_size->end()) continue;
252 new_operands.push_back(it->getSecond().size);
253 }
254 func::FuncOp first_branch = branches.front();
255 auto new_op = OpBuilder(op).create<CaseOrIfOp>(
256 op.getLoc(), first_branch.getFunctionType().getResults(), new_operands,
257 op->getAttrs());
258 for (const auto& entry : output_buffer_to_size) {
259 (*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = {
260 new_op.getResult(std::get<1>(entry)), std::get<2>(entry)};
261 }
262 op.replaceAllUsesWith(new_op.getResults().take_front(op.getNumResults()));
263 op.erase();
264 return success();
265 }
266
HandleWhileRegionOp(TF::WhileRegionOp while_op,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)267 LogicalResult HandleWhileRegionOp(
268 TF::WhileRegionOp while_op, ModuleOp module,
269 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
270 llvm::StringMap<PartitionedCallDecompositionInfo>*
271 decomposed_partitioned_call_callees) {
272 OpBuilder builder(while_op);
273 auto modify_region_arguments = [&](Region& region) {
274 int64_t original_arg_count = region.getNumArguments();
275 for (int64_t i = 0; i < original_arg_count; ++i) {
276 auto operand = while_op.getOperand(i);
277 auto it = buffer_to_size->find(operand);
278 if (it == buffer_to_size->end()) continue;
279 auto buffer_type = it->getFirst().getType();
280 region.getArgument(i).setType(buffer_type);
281 auto size_arg =
282 region.addArgument(cutil::GetSizeType(builder), region.getLoc());
283 (*buffer_to_size)[region.getArgument(i)] = {size_arg,
284 it->getSecond().fixed};
285 }
286 };
287
288 // Rewrite body.
289 Region& body_region = while_op.body();
290 modify_region_arguments(body_region);
291 if (failed(DecomposeTensorListOpsInternal(
292 &body_region.front(), module, buffer_to_size,
293 decomposed_partitioned_call_callees))) {
294 return failure();
295 }
296 auto output_buffer_to_size = AddTensorListSizesToTerminator<TF::YieldOp>(
297 body_region.front(), *buffer_to_size);
298
299 // Rewrite cond.
300 Region& cond_region = while_op.cond();
301 modify_region_arguments(cond_region);
302 if (failed(DecomposeTensorListOpsInternal(
303 &cond_region.front(), module, buffer_to_size,
304 decomposed_partitioned_call_callees))) {
305 return failure();
306 }
307
308 if (output_buffer_to_size.empty()) return success();
309
310 // Create the new while op.
311 auto new_while_operands = llvm::to_vector<8>(while_op.getOperands());
312 for (int64_t i = 0; i < while_op.getNumResults(); ++i) {
313 auto it = buffer_to_size->find(while_op.getOperand(i));
314 if (it == buffer_to_size->end()) continue;
315 new_while_operands.push_back(it->getSecond().size);
316 }
317 auto new_while = builder.create<TF::WhileRegionOp>(
318 while_op.getLoc(), body_region.front().getTerminator()->getOperandTypes(),
319 new_while_operands, while_op->getAttrs());
320 new_while.body().takeBody(body_region);
321 new_while.cond().takeBody(cond_region);
322 for (const auto& entry : output_buffer_to_size) {
323 (*buffer_to_size)[new_while.getResult(std::get<0>(entry))] = {
324 new_while.getResult(std::get<1>(entry)), std::get<2>(entry)};
325 }
326 while_op.replaceAllUsesWith(
327 new_while.getResults().take_front(while_op.getNumResults()));
328 while_op.erase();
329 return success();
330 }
331
HandleIfRegionOp(TF::IfRegionOp if_op,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)332 LogicalResult HandleIfRegionOp(
333 TF::IfRegionOp if_op, ModuleOp module,
334 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
335 llvm::StringMap<PartitionedCallDecompositionInfo>*
336 decomposed_partitioned_call_callees) {
337 // Rewrite the branches.
338 Region& then_branch = if_op.then_branch();
339 Region& else_branch = if_op.else_branch();
340 if (failed(DecomposeTensorListOpsInternal(
341 &then_branch.front(), module, buffer_to_size,
342 decomposed_partitioned_call_callees)))
343 return failure();
344 if (failed(DecomposeTensorListOpsInternal(
345 &else_branch.front(), module, buffer_to_size,
346 decomposed_partitioned_call_callees)))
347 return failure();
348
349 auto output_buffer_to_size = AddTensorListSizesToTerminator<TF::YieldOp>(
350 then_branch.front(), *buffer_to_size);
351 AddTensorListSizesToTerminator<TF::YieldOp>(else_branch.front(),
352 *buffer_to_size);
353
354 if (output_buffer_to_size.empty()) return success();
355
356 // Recreate the op.
357 auto new_op = OpBuilder(if_op).create<TF::IfRegionOp>(
358 if_op.getLoc(), then_branch.front().getTerminator()->getOperandTypes(),
359 if_op.getOperand(), if_op->getAttrs());
360 for (const auto& entry : output_buffer_to_size) {
361 (*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = {
362 new_op.getResult(std::get<1>(entry)), std::get<2>(entry)};
363 }
364
365 new_op.then_branch().takeBody(if_op.then_branch());
366 new_op.else_branch().takeBody(if_op.else_branch());
367
368 if_op.replaceAllUsesWith(
369 new_op.getResults().take_front(if_op.getNumResults()));
370 if_op.erase();
371 return success();
372 }
373
HandleCaseRegionOp(TF::CaseRegionOp case_op,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)374 LogicalResult HandleCaseRegionOp(
375 TF::CaseRegionOp case_op, ModuleOp module,
376 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
377 llvm::StringMap<PartitionedCallDecompositionInfo>*
378 decomposed_partitioned_call_callees) {
379 // Rewrite the branches.
380 RegionRange branches = case_op.getRegions();
381
382 for (Region* branch : branches) {
383 if (failed(DecomposeTensorListOpsInternal(
384 &branch->front(), module, buffer_to_size,
385 decomposed_partitioned_call_callees)))
386 return failure();
387 }
388
389 // Get the output buffer index to size index mapping one of the branches. It
390 // should be same for all the branches so we only get it for the first branch.
391 Region* first_branch = branches.front();
392 auto output_buffer_to_size = AddTensorListSizesToTerminator<TF::YieldOp>(
393 first_branch->front(), *buffer_to_size);
394 for (Region* branch : branches.drop_front()) {
395 AddTensorListSizesToTerminator<TF::YieldOp>(branch->front(),
396 *buffer_to_size);
397 }
398
399 if (output_buffer_to_size.empty()) return success();
400
401 // Recreate the op.
402 auto new_op = OpBuilder(case_op).create<TF::CaseRegionOp>(
403 case_op.getLoc(),
404 first_branch->front().getTerminator()->getOperandTypes(),
405 case_op.getOperand(), case_op->getAttrs(), case_op.getNumRegions());
406 for (const auto& entry : output_buffer_to_size) {
407 (*buffer_to_size)[new_op.getResult(std::get<0>(entry))] = {
408 new_op.getResult(std::get<1>(entry)), std::get<2>(entry)};
409 }
410
411 for (auto pair : llvm::zip(new_op.getRegions(), case_op.getRegions())) {
412 std::get<0>(pair)->takeBody(*std::get<1>(pair));
413 }
414 case_op.replaceAllUsesWith(
415 new_op.getResults().take_front(case_op.getNumResults()));
416 case_op.erase();
417 return success();
418 }
419
420 template <typename CallOp>
HandlePartitionedCallOp(CallOp call,func::FuncOp callee,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)421 LogicalResult HandlePartitionedCallOp(
422 CallOp call, func::FuncOp callee, ModuleOp module,
423 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
424 llvm::StringMap<PartitionedCallDecompositionInfo>*
425 decomposed_partitioned_call_callees) {
426 auto emplace_res = decomposed_partitioned_call_callees->try_emplace(
427 callee.getName(), PartitionedCallDecompositionInfo());
428 auto& info = emplace_res.first->second;
429 // Recreates the call op with info.
430 auto recreate_caller = [&] {
431 auto new_operands = llvm::to_vector<8>(call.getOperands());
432 for (int64_t i = 0; i < call.getNumOperands(); ++i) {
433 auto arg_it = info.buffer_arg_to_size_arg.find(i);
434 if (arg_it == info.buffer_arg_to_size_arg.end()) continue;
435 auto it = buffer_to_size->find(call.getOperand(i));
436 if (it == buffer_to_size->end()) {
437 call.emitOpError("unknown tensor list.");
438 return failure();
439 }
440 assert(arg_it->second == new_operands.size());
441 new_operands.push_back(it->getSecond().size);
442 }
443 OpBuilder builder(call);
444 auto new_call = builder.create<CallOp>(
445 call.getLoc(), info.decomposed_callee.getFunctionType().getResults(),
446 new_operands, call->getAttrs());
447 new_call->setAttr(
448 "f", SymbolRefAttr::get(
449 builder.getContext(),
450 const_cast<func::FuncOp&>(info.decomposed_callee).getName()));
451 for (const auto& entry : info.buffer_ret_to_size_ret) {
452 (*buffer_to_size)[new_call.getResult(std::get<0>(entry))] = {
453 new_call.getResult(std::get<1>(entry)), std::get<2>(entry)};
454 }
455 call.replaceAllUsesWith(
456 new_call.getResults().take_front(call.getNumResults()));
457 call.erase();
458 return success();
459 };
460 if (!emplace_res.second) {
461 // This callee was handled before.
462 if (!info.signature_change) return success();
463 return recreate_caller();
464 }
465 // Rewrite the callee.
466 llvm::SmallDenseMap<Value, SizeInfo> callee_map;
467 func::FuncOp lowered_callee = callee;
468 if (!callee.isPrivate()) {
469 // Clone non-private callee in case of signature change.
470 lowered_callee = callee.clone();
471 lowered_callee.setPrivate();
472 }
473 auto find_arg_buffer_type = [&](int64_t index) -> llvm::Optional<Type> {
474 auto it = buffer_to_size->find(call.getOperand(index));
475 if (it == buffer_to_size->end()) return llvm::None;
476 return it->getFirst().getType();
477 };
478 auto arg_buffer_size_is_fixed = [&](int64_t index) {
479 return (*buffer_to_size)[call.getOperand(index)].fixed;
480 };
481 ModifyFunctionSignature(lowered_callee, cutil::GetSizeType(OpBuilder(call)),
482 &callee_map, find_arg_buffer_type,
483 arg_buffer_size_is_fixed);
484 const bool args_no_changed = callee_map.empty();
485 if (failed(DecomposeTensorListOpsInternal(
486 &lowered_callee.front(), module, &callee_map,
487 decomposed_partitioned_call_callees))) {
488 return failure();
489 }
490 info.buffer_ret_to_size_ret =
491 ModifyFunctionReturn(lowered_callee, callee_map);
492 info.decomposed_callee = lowered_callee;
493 if (args_no_changed && info.buffer_ret_to_size_ret.empty()) {
494 // Signature is not modified. We do not need to keep two copies.
495 info.signature_change = false;
496 if (lowered_callee != callee) {
497 lowered_callee.setName(
498 StringAttr::get(callee->getContext(), callee.getName()));
499 callee.erase();
500 SymbolTable(module).insert(lowered_callee);
501 }
502 } else {
503 info.signature_change = true;
504 for (auto& entry : callee_map) {
505 auto buffer_arg = entry.getFirst().dyn_cast<BlockArgument>();
506 if (!buffer_arg) continue;
507 info.buffer_arg_to_size_arg[buffer_arg.getArgNumber()] =
508 entry.getSecond().size.cast<BlockArgument>().getArgNumber();
509 }
510 if (lowered_callee != callee) {
511 // Add the clone with a new name.
512 lowered_callee.setName(StringAttr::get(
513 callee->getContext(),
514 llvm::formatv("{0}_tensorlist_decomposed", callee.getName()).str()));
515 SymbolTable(module).insert(lowered_callee);
516 callee = lowered_callee;
517 }
518 }
519 if (info.signature_change) return recreate_caller();
520 return success();
521 }
522
523 // Parses an R1 value to `shape` if it is a TF::ConstOp output. Otherwise,
524 // returns an error.
GetConstShapeValue(Value shape_value,llvm::SmallVector<int64_t,8> * shape)525 LogicalResult GetConstShapeValue(Value shape_value,
526 llvm::SmallVector<int64_t, 8>* shape) {
527 auto shape_op = shape_value.getDefiningOp();
528 if (!shape_op) return failure();
529 auto shape_const_op = llvm::dyn_cast<TF::ConstOp>(shape_op);
530 if (!shape_const_op) return failure();
531 for (const auto& v : shape_const_op.value().getValues<APInt>()) {
532 int64_t dim_size = v.getSExtValue();
533 if (dim_size == ShapedType::kDynamicSize) return failure();
534 shape->push_back(dim_size);
535 }
536 return success();
537 }
538
539 // Checks the result Variant type to infer the element shape if fully defined.
540 // If the Variant type has multiple subtypes or does not have static shape,
541 // return error.
GetElementShapeFromResultType(Type type,llvm::SmallVector<int64_t,8> * shape)542 LogicalResult GetElementShapeFromResultType(
543 Type type, llvm::SmallVector<int64_t, 8>* shape) {
544 auto variant_type = getElementTypeOrSelf(type).dyn_cast<TF::VariantType>();
545 if (!variant_type || variant_type.getSubtypes().size() != 1) return failure();
546 TensorType tensor_type = variant_type.getSubtypes().front();
547 if (!tensor_type.hasStaticShape()) return failure();
548 for (auto d : tensor_type.getShape()) shape->push_back(d);
549 return success();
550 }
551
HandleEmptyTensorListOp(TF::EmptyTensorListOp list,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)552 LogicalResult HandleEmptyTensorListOp(
553 TF::EmptyTensorListOp list,
554 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
555 Value buffer;
556 OpBuilder builder(list);
557 llvm::SmallVector<int64_t, 8> element_shape;
558 // Infer TensorList element shape from the return type first, and then from
559 // the const element shape operand. We first check the return type because
560 // shape inference might have successfully inferred the element shape from
561 // write operations on the TensorList.
562 if (failed(GetElementShapeFromResultType(list.getType(), &element_shape))) {
563 if (failed(GetConstShapeValue(list.element_shape(), &element_shape))) {
564 return list.emitOpError("unknown tensor list element shape");
565 }
566 }
567 if (failed(cutil::CreateInitBufferValue(
568 element_shape, list.max_num_elements(), list, list.element_dtype(),
569 builder, &buffer))) {
570 return failure();
571 }
572 Value size = cutil::GetR1Const({0LL}, builder, list.getLoc());
573 list.handle().replaceAllUsesWith(buffer);
574 (*buffer_to_size)[buffer] = {size, /*fixed=*/false};
575 list.erase();
576 return success();
577 }
578
HandleTensorListReserveOp(TF::TensorListReserveOp list,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)579 LogicalResult HandleTensorListReserveOp(
580 TF::TensorListReserveOp list,
581 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
582 Value buffer;
583 OpBuilder builder(list);
584 llvm::SmallVector<int64_t, 8> element_shape;
585 // Infer TensorList element shape from the return type first, and then from
586 // the const element shape operand. We first check the return type because
587 // shape inference might have successfully inferred the element shape from
588 // write operations on the TensorList.
589 if (failed(GetElementShapeFromResultType(list.getType(), &element_shape))) {
590 if (failed(GetConstShapeValue(list.element_shape(), &element_shape))) {
591 return list.emitOpError("unknown tensor list element shape");
592 }
593 }
594 if (failed(cutil::CreateInitBufferValue(element_shape, list.num_elements(),
595 list, list.element_dtype(), builder,
596 &buffer))) {
597 return failure();
598 }
599 Value size = cutil::ReshapeScalarToSizeType(builder, list.num_elements(),
600 list.getLoc());
601 (*buffer_to_size)[buffer] = {size, /*fixed=*/true};
602 list.handle().replaceAllUsesWith(buffer);
603 list.erase();
604 return success();
605 }
606
HandleTensorListFromTensorOp(TF::TensorListFromTensorOp list,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)607 LogicalResult HandleTensorListFromTensorOp(
608 TF::TensorListFromTensorOp list,
609 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
610 OpBuilder builder(list);
611 Value buffer = builder.create<TF::IdentityOp>(
612 list.getLoc(), ArrayRef<Type>{list.tensor().getType()},
613 ArrayRef<Value>{list.tensor()});
614 auto type = buffer.getType().cast<TensorType>();
615 if (!type.hasStaticShape()) {
616 return list.emitOpError("TensorListFromTensorOp input has unknown shape.");
617 }
618 Value size = cutil::GetR1Const({type.getShape()[0]}, builder, list.getLoc());
619 (*buffer_to_size)[buffer] = {size, /*fixed=*/true};
620 list.output_handle().replaceAllUsesWith(buffer);
621 list.erase();
622 return success();
623 }
624
HandleTensorListPushBackOp(TF::TensorListPushBackOp push,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)625 LogicalResult HandleTensorListPushBackOp(
626 TF::TensorListPushBackOp push,
627 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
628 auto buffer = push.input_handle();
629 auto it = buffer_to_size->find(buffer);
630 if (it == buffer_to_size->end()) {
631 return push.emitOpError(
632 "found tf.TensorListPushBack on unknown TensorList.");
633 }
634 if (it->getSecond().fixed) {
635 return push.emitError("cannot push on a fixed-size tensor list");
636 }
637 auto size = it->getSecond().size;
638 OpBuilder builder(push);
639 auto new_buffer =
640 cutil::SetElement(size, buffer, push.tensor(), builder, push.getLoc());
641 auto new_size = builder.create<TF::AddV2Op>(
642 push.getLoc(), ArrayRef<Type>{size.getType()},
643 ArrayRef<Value>{size, cutil::GetR1Const({1LL}, builder, push.getLoc())});
644 push.output_handle().replaceAllUsesWith(new_buffer);
645 (*buffer_to_size)[new_buffer] = {new_size, /*fixed=*/false};
646 push.erase();
647 return success();
648 }
649
HandleTensorListPopBackOp(TF::TensorListPopBackOp pop,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)650 LogicalResult HandleTensorListPopBackOp(
651 TF::TensorListPopBackOp pop,
652 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
653 auto buffer = pop.input_handle();
654 auto it = buffer_to_size->find(buffer);
655 if (it == buffer_to_size->end()) {
656 pop.emitOpError("found tf.TensorListPopBack on unknown TensorList.");
657 return failure();
658 }
659 if (it->getSecond().fixed) {
660 return pop.emitError("cannot pop on a fixed-size tensor list");
661 }
662 auto size = it->getSecond().size;
663 OpBuilder builder(pop);
664 auto new_buffer = builder.create<TF::IdentityOp>(
665 pop.getLoc(), ArrayRef<Type>{buffer.getType()}, ArrayRef<Value>{buffer});
666 auto new_size = builder.create<TF::SubOp>(
667 pop.getLoc(), ArrayRef<Type>{size.getType()},
668 ArrayRef<Value>{size, cutil::GetR1Const({1LL}, builder, pop.getLoc())});
669 auto element = cutil::GetElement(new_size, new_buffer, builder, pop.getLoc());
670 pop.output_handle().replaceAllUsesWith(new_buffer);
671 pop.tensor().replaceAllUsesWith(element);
672 pop.erase();
673 (*buffer_to_size)[new_buffer] = {new_size, /*fixed=*/false};
674 return success();
675 }
676
HandleTensorListGetItemOp(TF::TensorListGetItemOp get_item,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)677 LogicalResult HandleTensorListGetItemOp(
678 TF::TensorListGetItemOp get_item,
679 const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
680 auto buffer = get_item.input_handle();
681 auto it = buffer_to_size.find(buffer);
682 if (it == buffer_to_size.end()) {
683 get_item.emitOpError("found tf.TensorListGetItemOp on unknown TensorList.");
684 return failure();
685 }
686 OpBuilder builder(get_item);
687 auto index = cutil::ReshapeScalarToSizeType(builder, get_item.index(),
688 get_item.getLoc());
689 auto element =
690 cutil::GetElement(index, buffer, OpBuilder(get_item), get_item.getLoc());
691 get_item.item().replaceAllUsesWith(element);
692 get_item.erase();
693 return success();
694 }
695
HandleTensorListSetItemOp(TF::TensorListSetItemOp set_item,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)696 LogicalResult HandleTensorListSetItemOp(
697 TF::TensorListSetItemOp set_item,
698 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
699 auto buffer = set_item.input_handle();
700 auto it = buffer_to_size->find(buffer);
701 if (it == buffer_to_size->end()) {
702 set_item.emitOpError("found tf.TensorListSetItemOp on unknown TensorList.");
703 return failure();
704 }
705 OpBuilder builder(set_item);
706 auto index = cutil::ReshapeScalarToSizeType(builder, set_item.index(),
707 set_item.getLoc());
708 auto new_buffer = cutil::SetElement(index, buffer, set_item.item(), builder,
709 set_item.getLoc());
710 set_item.output_handle().replaceAllUsesWith(new_buffer);
711 auto size = it->getSecond();
712 (*buffer_to_size)[new_buffer] = size;
713 set_item.erase();
714 return success();
715 }
716
HandleTensorListLengthOp(TF::TensorListLengthOp length,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)717 LogicalResult HandleTensorListLengthOp(
718 TF::TensorListLengthOp length,
719 const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
720 auto it = buffer_to_size.find(length.input_handle());
721 if (it == buffer_to_size.end()) {
722 length.emitOpError("found tf.TensorListLength on unknown TensorList.");
723 return failure();
724 }
725 OpBuilder builder(length);
726 if (it->getSecond().fixed) {
727 auto dim = cutil::CreateScalarConst(
728 length.input_handle().getType().cast<RankedTensorType>().getDimSize(0),
729 builder, length.getLoc());
730 length.length().replaceAllUsesWith(dim);
731 } else {
732 auto current_size = it->getSecond().size;
733 // Reshapes the R1 length to a scalar.
734 auto reshape = builder.create<TF::ReshapeOp>(
735 length.getLoc(),
736 ArrayRef<Type>{RankedTensorType::get(
737 {}, getElementTypeOrSelf(current_size.getType()))},
738 ArrayRef<Value>{current_size,
739 cutil::GetR1Const({}, builder, length.getLoc())});
740 length.length().replaceAllUsesWith(reshape);
741 }
742 length.erase();
743 return success();
744 }
745
HandleTensorListElementShapeOp(TF::TensorListElementShapeOp elem_shape,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)746 LogicalResult HandleTensorListElementShapeOp(
747 TF::TensorListElementShapeOp elem_shape,
748 const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
749 if (buffer_to_size.count(elem_shape.input_handle()) == 0) {
750 return elem_shape.emitOpError("unknown tensor list");
751 }
752 auto buffer = elem_shape.input_handle();
753 auto result = cutil::GetR1Const(
754 buffer.getType().cast<RankedTensorType>().getShape().drop_front(),
755 OpBuilder(elem_shape), elem_shape.getLoc(),
756 elem_shape.shape_type().getIntOrFloatBitWidth());
757 elem_shape.element_shape().replaceAllUsesWith(result);
758 elem_shape.erase();
759 return success();
760 }
761
HandleTensorListGatherOp(TF::TensorListGatherOp gather,const llvm::SmallDenseMap<Value,SizeInfo> & buffer_to_size)762 LogicalResult HandleTensorListGatherOp(
763 TF::TensorListGatherOp gather,
764 const llvm::SmallDenseMap<Value, SizeInfo>& buffer_to_size) {
765 auto it = buffer_to_size.find(gather.input_handle());
766 if (it == buffer_to_size.end()) {
767 return gather.emitOpError("unknown tensor list");
768 }
769 auto buffer = gather.input_handle();
770 auto result = cutil::GatherElements(gather.indices(), buffer,
771 OpBuilder(gather), gather.getLoc());
772 gather.values().replaceAllUsesWith(result);
773 gather.erase();
774 return success();
775 }
776
HandleTensorListScatterIntoExistingListOp(TF::TensorListScatterIntoExistingListOp scatter,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size)777 LogicalResult HandleTensorListScatterIntoExistingListOp(
778 TF::TensorListScatterIntoExistingListOp scatter,
779 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size) {
780 auto it = buffer_to_size->find(scatter.input_handle());
781 if (it == buffer_to_size->end()) {
782 return scatter.emitOpError("unknown tensor list");
783 }
784 auto buffer = scatter.input_handle();
785 OpBuilder builder(scatter);
786 auto indices_type = scatter.indices().getType().cast<RankedTensorType>();
787 if (!indices_type) return scatter.emitOpError("unranked indices shape");
788 auto shape_type = RankedTensorType::get({2}, builder.getIntegerType(32));
789 auto shape = builder.create<TF::ConstOp>(
790 scatter.getLoc(),
791 DenseElementsAttr::get(
792 shape_type, {static_cast<int>(indices_type.getDimSize(0)), 1}));
793 auto indices =
794 builder.create<TF::ReshapeOp>(scatter.getLoc(), scatter.indices(), shape);
795 Value tensor_scatter_update = builder.create<TF::TensorScatterUpdateOp>(
796 scatter.getLoc(), buffer, indices, scatter.tensor());
797 scatter.output_handle().replaceAllUsesWith(tensor_scatter_update);
798 scatter.erase();
799 auto size = it->getSecond();
800 (*buffer_to_size)[tensor_scatter_update] = size;
801 return success();
802 }
803
DecomposeTensorListOpsInternal(Block * block,ModuleOp module,llvm::SmallDenseMap<Value,SizeInfo> * buffer_to_size,llvm::StringMap<PartitionedCallDecompositionInfo> * decomposed_partitioned_call_callees)804 LogicalResult DecomposeTensorListOpsInternal(
805 Block* block, ModuleOp module,
806 llvm::SmallDenseMap<Value, SizeInfo>* buffer_to_size,
807 llvm::StringMap<PartitionedCallDecompositionInfo>*
808 decomposed_partitioned_call_callees) {
809 for (auto& op : llvm::make_early_inc_range(block->getOperations())) {
810 // TODO(yuanzx): Add a pass to remove identities in device computation.
811 if (llvm::isa<TF::IdentityOp, TF::IdentityNOp, TF::StopGradientOp>(&op)) {
812 op.replaceAllUsesWith(op.getOperands());
813 op.erase();
814 } else if (auto list = llvm::dyn_cast<TF::EmptyTensorListOp>(&op)) {
815 if (failed(HandleEmptyTensorListOp(list, buffer_to_size))) {
816 return failure();
817 }
818 } else if (auto list = llvm::dyn_cast<TF::TensorListReserveOp>(&op)) {
819 if (failed(HandleTensorListReserveOp(list, buffer_to_size))) {
820 return failure();
821 }
822 } else if (auto list = llvm::dyn_cast<TF::TensorListFromTensorOp>(&op)) {
823 if (failed(HandleTensorListFromTensorOp(list, buffer_to_size))) {
824 return failure();
825 }
826 } else if (auto push = llvm::dyn_cast<TF::TensorListPushBackOp>(&op)) {
827 if (failed(HandleTensorListPushBackOp(push, buffer_to_size))) {
828 return failure();
829 }
830 } else if (auto pop = llvm::dyn_cast<TF::TensorListPopBackOp>(&op)) {
831 if (failed(HandleTensorListPopBackOp(pop, buffer_to_size))) {
832 return failure();
833 }
834 } else if (auto get_item = llvm::dyn_cast<TF::TensorListGetItemOp>(&op)) {
835 if (failed(HandleTensorListGetItemOp(get_item, *buffer_to_size))) {
836 return failure();
837 }
838 } else if (auto set_item = llvm::dyn_cast<TF::TensorListSetItemOp>(&op)) {
839 if (failed(HandleTensorListSetItemOp(set_item, buffer_to_size))) {
840 return failure();
841 }
842 } else if (auto length = llvm::dyn_cast<TF::TensorListLengthOp>(&op)) {
843 if (failed(HandleTensorListLengthOp(length, *buffer_to_size))) {
844 return failure();
845 }
846 } else if (auto stack = llvm::dyn_cast<TF::TensorListStackOp>(&op)) {
847 stack.tensor().replaceAllUsesWith(stack.input_handle());
848 stack.erase();
849 } else if (auto elem_shape =
850 llvm::dyn_cast<TF::TensorListElementShapeOp>(&op)) {
851 if (failed(HandleTensorListElementShapeOp(elem_shape, *buffer_to_size))) {
852 return failure();
853 }
854 } else if (auto gather = llvm::dyn_cast<TF::TensorListGatherOp>(&op)) {
855 if (failed(HandleTensorListGatherOp(gather, *buffer_to_size))) {
856 return failure();
857 }
858 } else if (auto scatter =
859 llvm::dyn_cast<TF::TensorListScatterIntoExistingListOp>(
860 &op)) {
861 if (failed(HandleTensorListScatterIntoExistingListOp(scatter,
862 buffer_to_size))) {
863 return failure();
864 }
865 } else if (auto addn = llvm::dyn_cast<TF::AddNOp>(&op)) {
866 auto it = buffer_to_size->find(addn.getOperand(0));
867 if (it != buffer_to_size->end()) {
868 addn.sum().setType(addn.getOperand(0).getType());
869 auto size = it->getSecond();
870 (*buffer_to_size)[addn.sum()] = size;
871 }
872 } else if (auto zeros = llvm::dyn_cast<TF::ZerosLikeOp>(&op)) {
873 if (buffer_to_size->count(zeros.x()) > 0) {
874 zeros.y().setType(zeros.x().getType());
875 auto size = (*buffer_to_size)[zeros.x()];
876 (*buffer_to_size)[zeros.y()] = size;
877 }
878 } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) {
879 if (failed(HandleWhileOp(while_op, module, buffer_to_size,
880 decomposed_partitioned_call_callees))) {
881 return failure();
882 }
883 } else if (auto if_op = llvm::dyn_cast<TF::IfOp>(&op)) {
884 if (failed(HandleCaseOrIfOp(
885 if_op, {if_op.then_function(), if_op.else_function()}, module,
886 buffer_to_size, decomposed_partitioned_call_callees))) {
887 return failure();
888 }
889 } else if (auto case_op = llvm::dyn_cast<TF::CaseOp>(&op)) {
890 SmallVector<func::FuncOp, 2> branches;
891 case_op.get_branch_functions(branches);
892 if (failed(HandleCaseOrIfOp(case_op, branches, module, buffer_to_size,
893 decomposed_partitioned_call_callees))) {
894 return failure();
895 }
896 } else if (auto pcall = llvm::dyn_cast<TF::PartitionedCallOp>(&op)) {
897 if (!pcall.func())
898 return pcall.emitOpError(
899 "TensorList decomposition does not support call with nested "
900 "references.");
901
902 if (failed(HandlePartitionedCallOp(
903 pcall, pcall.func(), module, buffer_to_size,
904 decomposed_partitioned_call_callees))) {
905 return failure();
906 }
907 } else if (auto spcall =
908 llvm::dyn_cast<TF::StatefulPartitionedCallOp>(&op)) {
909 if (failed(HandlePartitionedCallOp(
910 spcall, spcall.func(), module, buffer_to_size,
911 decomposed_partitioned_call_callees))) {
912 return failure();
913 }
914 } else if (auto while_op = llvm::dyn_cast<TF::WhileRegionOp>(&op)) {
915 if (failed(HandleWhileRegionOp(while_op, module, buffer_to_size,
916 decomposed_partitioned_call_callees))) {
917 return failure();
918 }
919 } else if (auto if_op = llvm::dyn_cast<TF::IfRegionOp>(&op)) {
920 if (failed(HandleIfRegionOp(if_op, module, buffer_to_size,
921 decomposed_partitioned_call_callees))) {
922 return failure();
923 }
924 } else if (auto case_op = llvm::dyn_cast<TF::CaseRegionOp>(&op)) {
925 if (failed(HandleCaseRegionOp(case_op, module, buffer_to_size,
926 decomposed_partitioned_call_callees))) {
927 return failure();
928 }
929 }
930 }
931 return success();
932 }
933
DecomposeTensorListOps(Block * block,ModuleOp module)934 LogicalResult DecomposeTensorListOps(Block* block, ModuleOp module) {
935 llvm::SmallDenseMap<Value, SizeInfo> buffer_to_size;
936 llvm::StringMap<PartitionedCallDecompositionInfo>
937 decomposed_partitioned_call_callees;
938 return DecomposeTensorListOpsInternal(block, module, &buffer_to_size,
939 &decomposed_partitioned_call_callees);
940 }
941
runOnOperation()942 void TensorListOpsDecompositionPass::runOnOperation() {
943 auto module = getOperation();
944 auto main = module.lookupSymbol<func::FuncOp>("main");
945 if (!main) return;
946 if (failed(DecomposeTensorListOps(&main.front(), module))) {
947 signalPassFailure();
948 }
949 }
950
951 } // namespace
952
953 namespace TF {
954 std::unique_ptr<OperationPass<ModuleOp>>
CreateTensorListOpsDecompositionPass()955 CreateTensorListOpsDecompositionPass() {
956 return std::make_unique<TensorListOpsDecompositionPass>();
957 }
958 } // namespace TF
959 } // namespace mlir
960