1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <tuple>
18 
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
25 #include "mlir/Pass/Pass.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
30 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
32 
33 namespace mlir {
34 namespace TFTPU {
35 namespace {
36 
37 constexpr char kReplicateSharding[] = "";
38 
39 struct TPUResourceReadsWritesPartitioningPass
40     : public TF::TPUResourceReadsWritesPartitioningPassBase<
41           TPUResourceReadsWritesPartitioningPass> {
42   void runOnOperation() override;
43 };
44 
AllResourceTypesHaveSubtypes(TypeRange resources)45 bool AllResourceTypesHaveSubtypes(TypeRange resources) {
46   for (Type resource : resources)
47     if (!llvm::hasSingleElement(resource.cast<TensorType>()
48                                     .getElementType()
49                                     .cast<TF::ResourceType>()
50                                     .getSubtypes()))
51       return false;
52 
53   return true;
54 }
55 
GetResourceSubtype(Type type)56 Type GetResourceSubtype(Type type) {
57   return type.cast<TensorType>()
58       .getElementType()
59       .cast<TF::ResourceType>()
60       .getSubtypes()
61       .front();
62 }
63 
GetResourceSubtype(Value resource)64 Type GetResourceSubtype(Value resource) {
65   return GetResourceSubtype(resource.getType());
66 }
67 
68 // Rewrites unpartitioned resource reads and writes to partitioned resource
69 // reads and writes. The TPU computation from the frontend is generated in such
70 // a way that resource operations operate on the unpartitioned resource handle
71 // (from a `tf.TPUReplicatedInput`). This results in resource reads and writes
72 // on the unpartitioned resource handle post resource op decomposition/lifting.
73 // Here the unpartitioned resource read and write is expanded to individual
74 // resource reads and writes per associated partitioned resource handle.
PartitionResourceReadsWrites(tf_device::ClusterFuncOp cluster_func)75 void PartitionResourceReadsWrites(tf_device::ClusterFuncOp cluster_func) {
76   bool use_spmd = false;
77   if (auto use_spmd_attr = cluster_func->getAttrOfType<BoolAttr>(
78           "use_spmd_for_xla_partitioning"))
79     use_spmd = use_spmd_attr.getValue();
80 
81   if (!use_spmd) return;
82 
83   OpBuilder builder(cluster_func);
84   // Rewrite results before rewriting operands as `tf.TPUPartitionedInput`
85   // resource handle results is an indicator for a partitioned resource
86   // variable. These `tf.TPUPartitionedInput` will be removed when rewriting
87   // the operands.
88   for (Value result : cluster_func.results()) {
89     if (!result.hasOneUse()) continue;
90     auto assign_var =
91         llvm::dyn_cast<TF::AssignVariableOp>(*result.getUsers().begin());
92     if (!assign_var || assign_var.value() != result) continue;
93     auto partitioned_input = llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(
94         assign_var.resource().getDefiningOp());
95     if (!partitioned_input ||
96         !AllResourceTypesHaveSubtypes(partitioned_input.inputs().getTypes()))
97       continue;
98 
99     builder.setInsertionPoint(assign_var);
100     llvm::SmallVector<Type, 4> partitioned_output_types;
101     partitioned_output_types.reserve(partitioned_input.N());
102     for (Type input_type : partitioned_input.inputs().getTypes())
103       partitioned_output_types.push_back(GetResourceSubtype(input_type));
104     auto partitioned_output = builder.create<TF::TPUPartitionedOutputOp>(
105         cluster_func->getLoc(), partitioned_output_types, result,
106         partitioned_input.partition_dimAttr(),
107         partitioned_input._XlaShardingAttr());
108     for (auto resource_write :
109          llvm::zip(partitioned_input.inputs(), partitioned_output.output()))
110       builder.create<TF::AssignVariableOp>(
111           assign_var->getLoc(), /*resource=*/std::get<0>(resource_write),
112           /*value=*/std::get<1>(resource_write));
113     assign_var.erase();
114   }
115 
116   for (OpOperand& operand : cluster_func->getOpOperands()) {
117     auto read_var = llvm::dyn_cast_or_null<TF::ReadVariableOp>(
118         operand.get().getDefiningOp());
119     if (!read_var || !read_var.value().hasOneUse()) continue;
120     auto partitioned_input = llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(
121         read_var.resource().getDefiningOp());
122     if (!partitioned_input ||
123         !AllResourceTypesHaveSubtypes(partitioned_input.inputs().getTypes()))
124       continue;
125 
126     builder.setInsertionPoint(partitioned_input);
127     llvm::SmallVector<Value, 4> partitioned_reads;
128     for (Value input : partitioned_input.inputs()) {
129       auto partitioned_read = builder.create<TF::ReadVariableOp>(
130           read_var->getLoc(), GetResourceSubtype(input), input);
131       partitioned_reads.push_back(partitioned_read.value());
132     }
133     auto partitioned_read = builder.create<TF::TPUPartitionedInputOp>(
134         partitioned_input->getLoc(), read_var.value().getType(),
135         partitioned_reads, partitioned_input.partition_dimAttr(),
136         partitioned_input._XlaShardingAttr());
137     operand.set(partitioned_read);
138     read_var->erase();
139     if (partitioned_input->use_empty()) partitioned_input->erase();
140   }
141 }
142 
runOnOperation()143 void TPUResourceReadsWritesPartitioningPass::runOnOperation() {
144   llvm::SmallVector<tf_device::ClusterFuncOp, 4> cluster_funcs;
145   getOperation()->walk([&cluster_funcs](tf_device::ClusterFuncOp cluster_func) {
146     cluster_funcs.push_back(cluster_func);
147   });
148   for (tf_device::ClusterFuncOp cluster_func : cluster_funcs)
149     PartitionResourceReadsWrites(cluster_func);
150 }
151 
152 }  // namespace
153 
154 std::unique_ptr<OperationPass<func::FuncOp>>
CreateTPUResourceReadsWritesPartitioningPass()155 CreateTPUResourceReadsWritesPartitioningPass() {
156   return std::make_unique<TPUResourceReadsWritesPartitioningPass>();
157 }
158 
159 }  // namespace TFTPU
160 }  // namespace mlir
161