1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17 #include <utility>
18
19 #include "llvm/ADT/APInt.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
27 #include "mlir/IR/Attributes.h" // from @llvm-project
28 #include "mlir/IR/Builders.h" // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
30 #include "mlir/IR/Operation.h" // from @llvm-project
31 #include "mlir/IR/Types.h" // from @llvm-project
32 #include "mlir/IR/Visitors.h" // from @llvm-project
33 #include "mlir/Pass/Pass.h" // from @llvm-project
34 #include "mlir/Pass/PassManager.h" // from @llvm-project
35 #include "mlir/Support/LogicalResult.h" // from @llvm-project
36 #include "mlir/Transforms/Passes.h" // from @llvm-project
37 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
41 #include "tensorflow/compiler/xla/client/sharding_builder.h"
42 #include "tensorflow/dtensor/cc/constants.h"
43 #include "tensorflow/dtensor/cc/tensor_layout.h"
44 #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h"
45 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
46 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
47 #include "tensorflow/dtensor/mlir/layout_parsing.h"
48 #include "tensorflow/dtensor/mlir/op_utils.h"
49 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
50
51 namespace tensorflow {
52 namespace dtensor {
53 namespace {
54
55 // Adds metadata used in TPU Compilation to `cluster` as attributes.
AddMetadataToTPUCluster(const Mesh & mesh_config,mlir::tf_device::ClusterOp cluster,mlir::OpBuilder * builder)56 void AddMetadataToTPUCluster(const Mesh& mesh_config,
57 mlir::tf_device::ClusterOp cluster,
58 mlir::OpBuilder* builder) {
59 cluster->setAttr("_tpu_replicate",
60 builder->getStringAttr(mesh_config.ToString()));
61 cluster->setAttr("step_marker_location", builder->getStringAttr(""));
62 cluster->setAttr("padding_map", builder->getArrayAttr({}));
63 cluster->setAttr("use_spmd_for_xla_partitioning",
64 builder->getBoolAttr(false));
65 cluster->setAttr(tensorflow::kTopologyAttr, builder->getStringAttr(""));
66 cluster->setAttr(tensorflow::kDeviceAssignmentAttr,
67 builder->getArrayAttr({}));
68 cluster->setAttr(tensorflow::kNumCoresPerReplicaAttr,
69 builder->getI64IntegerAttr(1));
70 }
71
72 // TODO(hongjunchoi): Implement cluster inlining pass so that there are no
73 // nested tf_device.cluster ops with same mesh.
IdentifyTPUFunctions(mlir::ModuleOp module,llvm::SmallVectorImpl<Mesh> * tpu_meshs,llvm::SmallVectorImpl<mlir::TF::StatefulPartitionedCallOp> * tpu_functions)74 void IdentifyTPUFunctions(
75 mlir::ModuleOp module, llvm::SmallVectorImpl<Mesh>* tpu_meshs,
76 llvm::SmallVectorImpl<mlir::TF::StatefulPartitionedCallOp>* tpu_functions) {
77 auto main_func = module.lookupSymbol<mlir::func::FuncOp>("main");
78 if (!main_func) return;
79
80 for (auto call : main_func.getOps<mlir::TF::StatefulPartitionedCallOp>()) {
81 auto mesh_or_status = Mesh::FromString(string(call.config()));
82 // Function calls created by end users instead of being converted from
83 // tf_device.cluster do not have a serialized mesh as a config attribute. We
84 // ignore the error returned from parsing in this case.
85 if (!mesh_or_status.ok()) return;
86 bool skip_xla_compilation = false;
87 if (call->hasAttr(kSkipXlaCompilation)) {
88 skip_xla_compilation =
89 call->getAttrOfType<mlir::BoolAttr>(kSkipXlaCompilation).getValue();
90 }
91 if (mesh_or_status->is_tpu_mesh() && !skip_xla_compilation) {
92 tpu_functions->emplace_back(call);
93 tpu_meshs->emplace_back(std::move(mesh_or_status.ValueOrDie()));
94 }
95 }
96 }
97
CreateTPUCluster(mlir::TF::StatefulPartitionedCallOp tpu_call,mlir::OpBuilder * builder,mlir::tf_device::ClusterOp * newly_created_cluster)98 mlir::LogicalResult CreateTPUCluster(
99 mlir::TF::StatefulPartitionedCallOp tpu_call, mlir::OpBuilder* builder,
100 mlir::tf_device::ClusterOp* newly_created_cluster) {
101 auto function = MaybeFindFunction(tpu_call);
102 if (!function)
103 return tpu_call.emitOpError(
104 "failed during TPU Integration as Func op TPU mesh was not found");
105
106 auto& function_block = function->getCallableRegion()->front();
107 builder->setInsertionPointToStart(&function_block);
108
109 auto cluster = builder->create<mlir::tf_device::ClusterOp>(
110 tpu_call.getLoc(), function->getCallableResults());
111 cluster.body().push_back(new mlir::Block);
112
113 auto& function_body = function_block.getOperations();
114 cluster.GetBody().getOperations().splice(
115 cluster.GetBody().getOperations().begin(), function_body,
116 std::next(function_body.begin()), std::prev(function_body.end()));
117
118 builder->setInsertionPointToEnd(&cluster.GetBody());
119 mlir::Operation* function_block_terminator = function_block.getTerminator();
120 builder->create<mlir::tf_device::ReturnOp>(
121 tpu_call.getLoc(), function_block_terminator->getOperands());
122
123 function_block_terminator->setOperands(cluster.getResults());
124
125 *newly_created_cluster = cluster;
126 return mlir::success();
127 }
128
129 struct DTensorTPUIntegration
130 : public DTensorTPUIntegrationBase<DTensorTPUIntegration> {
getDependentDialectstensorflow::dtensor::__anon913eb44c0111::DTensorTPUIntegration131 void getDependentDialects(mlir::DialectRegistry& registry) const override {
132 registry.insert<mlir::dtensor::DTensorDialect>();
133 registry.insert<mlir::tf_device::TensorFlowDeviceDialect>();
134 }
135
runOnOperationtensorflow::dtensor::__anon913eb44c0111::DTensorTPUIntegration136 void runOnOperation() override {
137 mlir::MLIRContext& context = getContext();
138 mlir::OpBuilder op_builder(&context);
139 auto module = getOperation();
140 llvm::SmallVector<mlir::TF::StatefulPartitionedCallOp, 4> tpu_functions;
141 llvm::SmallVector<Mesh, 4> tpu_meshes;
142 IdentifyTPUFunctions(module, &tpu_meshes, &tpu_functions);
143
144 for (auto tpu_function_and_mesh : llvm::zip(tpu_meshes, tpu_functions)) {
145 mlir::tf_device::ClusterOp cluster;
146
147 if (mlir::failed(CreateTPUCluster(std::get<1>(tpu_function_and_mesh),
148 &op_builder, &cluster)))
149 return signalPassFailure();
150
151 AddMetadataToTPUCluster(std::get<0>(tpu_function_and_mesh), cluster,
152 &op_builder);
153 }
154 };
155 };
156
157 } // namespace
158
159 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorTPUIntegration()160 CreateDTensorTPUIntegration() {
161 return std::make_unique<DTensorTPUIntegration>();
162 }
163
164 } // namespace dtensor
165 } // namespace tensorflow
166