1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <tuple>
18
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
25 #include "mlir/Pass/Pass.h" // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
30 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
32
33 namespace mlir {
34 namespace TFTPU {
35 namespace {
36
37 constexpr char kReplicateSharding[] = "";
38
39 struct TPUResourceReadsWritesPartitioningPass
40 : public TF::TPUResourceReadsWritesPartitioningPassBase<
41 TPUResourceReadsWritesPartitioningPass> {
42 void runOnOperation() override;
43 };
44
AllResourceTypesHaveSubtypes(TypeRange resources)45 bool AllResourceTypesHaveSubtypes(TypeRange resources) {
46 for (Type resource : resources)
47 if (!llvm::hasSingleElement(resource.cast<TensorType>()
48 .getElementType()
49 .cast<TF::ResourceType>()
50 .getSubtypes()))
51 return false;
52
53 return true;
54 }
55
GetResourceSubtype(Type type)56 Type GetResourceSubtype(Type type) {
57 return type.cast<TensorType>()
58 .getElementType()
59 .cast<TF::ResourceType>()
60 .getSubtypes()
61 .front();
62 }
63
GetResourceSubtype(Value resource)64 Type GetResourceSubtype(Value resource) {
65 return GetResourceSubtype(resource.getType());
66 }
67
68 // Rewrites unpartitioned resource reads and writes to partitioned resource
69 // reads and writes. The TPU computation from the frontend is generated in such
70 // a way that resource operations operate on the unpartitioned resource handle
71 // (from a `tf.TPUReplicatedInput`). This results in resource reads and writes
72 // on the unpartitioned resource handle post resource op decomposition/lifting.
73 // Here the unpartitioned resource read and write is expanded to individual
74 // resource reads and writes per associated partitioned resource handle.
PartitionResourceReadsWrites(tf_device::ClusterFuncOp cluster_func)75 void PartitionResourceReadsWrites(tf_device::ClusterFuncOp cluster_func) {
76 bool use_spmd = false;
77 if (auto use_spmd_attr = cluster_func->getAttrOfType<BoolAttr>(
78 "use_spmd_for_xla_partitioning"))
79 use_spmd = use_spmd_attr.getValue();
80
81 if (!use_spmd) return;
82
83 OpBuilder builder(cluster_func);
84 // Rewrite results before rewriting operands as `tf.TPUPartitionedInput`
85 // resource handle results is an indicator for a partitioned resource
86 // variable. These `tf.TPUPartitionedInput` will be removed when rewriting
87 // the operands.
88 for (Value result : cluster_func.results()) {
89 if (!result.hasOneUse()) continue;
90 auto assign_var =
91 llvm::dyn_cast<TF::AssignVariableOp>(*result.getUsers().begin());
92 if (!assign_var || assign_var.value() != result) continue;
93 auto partitioned_input = llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(
94 assign_var.resource().getDefiningOp());
95 if (!partitioned_input ||
96 !AllResourceTypesHaveSubtypes(partitioned_input.inputs().getTypes()))
97 continue;
98
99 builder.setInsertionPoint(assign_var);
100 llvm::SmallVector<Type, 4> partitioned_output_types;
101 partitioned_output_types.reserve(partitioned_input.N());
102 for (Type input_type : partitioned_input.inputs().getTypes())
103 partitioned_output_types.push_back(GetResourceSubtype(input_type));
104 auto partitioned_output = builder.create<TF::TPUPartitionedOutputOp>(
105 cluster_func->getLoc(), partitioned_output_types, result,
106 partitioned_input.partition_dimAttr(),
107 partitioned_input._XlaShardingAttr());
108 for (auto resource_write :
109 llvm::zip(partitioned_input.inputs(), partitioned_output.output()))
110 builder.create<TF::AssignVariableOp>(
111 assign_var->getLoc(), /*resource=*/std::get<0>(resource_write),
112 /*value=*/std::get<1>(resource_write));
113 assign_var.erase();
114 }
115
116 for (OpOperand& operand : cluster_func->getOpOperands()) {
117 auto read_var = llvm::dyn_cast_or_null<TF::ReadVariableOp>(
118 operand.get().getDefiningOp());
119 if (!read_var || !read_var.value().hasOneUse()) continue;
120 auto partitioned_input = llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(
121 read_var.resource().getDefiningOp());
122 if (!partitioned_input ||
123 !AllResourceTypesHaveSubtypes(partitioned_input.inputs().getTypes()))
124 continue;
125
126 builder.setInsertionPoint(partitioned_input);
127 llvm::SmallVector<Value, 4> partitioned_reads;
128 for (Value input : partitioned_input.inputs()) {
129 auto partitioned_read = builder.create<TF::ReadVariableOp>(
130 read_var->getLoc(), GetResourceSubtype(input), input);
131 partitioned_reads.push_back(partitioned_read.value());
132 }
133 auto partitioned_read = builder.create<TF::TPUPartitionedInputOp>(
134 partitioned_input->getLoc(), read_var.value().getType(),
135 partitioned_reads, partitioned_input.partition_dimAttr(),
136 partitioned_input._XlaShardingAttr());
137 operand.set(partitioned_read);
138 read_var->erase();
139 if (partitioned_input->use_empty()) partitioned_input->erase();
140 }
141 }
142
runOnOperation()143 void TPUResourceReadsWritesPartitioningPass::runOnOperation() {
144 llvm::SmallVector<tf_device::ClusterFuncOp, 4> cluster_funcs;
145 getOperation()->walk([&cluster_funcs](tf_device::ClusterFuncOp cluster_func) {
146 cluster_funcs.push_back(cluster_func);
147 });
148 for (tf_device::ClusterFuncOp cluster_func : cluster_funcs)
149 PartitionResourceReadsWrites(cluster_func);
150 }
151
152 } // namespace
153
154 std::unique_ptr<OperationPass<func::FuncOp>>
CreateTPUResourceReadsWritesPartitioningPass()155 CreateTPUResourceReadsWritesPartitioningPass() {
156 return std::make_unique<TPUResourceReadsWritesPartitioningPass>();
157 }
158
159 } // namespace TFTPU
160 } // namespace mlir
161