xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/tpu_integration.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 #include <utility>
18 
19 #include "llvm/ADT/APInt.h"
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SetVector.h"
24 #include "llvm/ADT/SmallVector.h"
25 #include "llvm/ADT/StringRef.h"
26 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
27 #include "mlir/IR/Attributes.h"  // from @llvm-project
28 #include "mlir/IR/Builders.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
30 #include "mlir/IR/Operation.h"  // from @llvm-project
31 #include "mlir/IR/Types.h"  // from @llvm-project
32 #include "mlir/IR/Visitors.h"  // from @llvm-project
33 #include "mlir/Pass/Pass.h"  // from @llvm-project
34 #include "mlir/Pass/PassManager.h"  // from @llvm-project
35 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
36 #include "mlir/Transforms/Passes.h"  // from @llvm-project
37 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
41 #include "tensorflow/compiler/xla/client/sharding_builder.h"
42 #include "tensorflow/dtensor/cc/constants.h"
43 #include "tensorflow/dtensor/cc/tensor_layout.h"
44 #include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h"
45 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
46 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
47 #include "tensorflow/dtensor/mlir/layout_parsing.h"
48 #include "tensorflow/dtensor/mlir/op_utils.h"
49 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
50 
51 namespace tensorflow {
52 namespace dtensor {
53 namespace {
54 
55 // Adds metadata used in TPU Compilation to `cluster` as attributes.
AddMetadataToTPUCluster(const Mesh & mesh_config,mlir::tf_device::ClusterOp cluster,mlir::OpBuilder * builder)56 void AddMetadataToTPUCluster(const Mesh& mesh_config,
57                              mlir::tf_device::ClusterOp cluster,
58                              mlir::OpBuilder* builder) {
59   cluster->setAttr("_tpu_replicate",
60                    builder->getStringAttr(mesh_config.ToString()));
61   cluster->setAttr("step_marker_location", builder->getStringAttr(""));
62   cluster->setAttr("padding_map", builder->getArrayAttr({}));
63   cluster->setAttr("use_spmd_for_xla_partitioning",
64                    builder->getBoolAttr(false));
65   cluster->setAttr(tensorflow::kTopologyAttr, builder->getStringAttr(""));
66   cluster->setAttr(tensorflow::kDeviceAssignmentAttr,
67                    builder->getArrayAttr({}));
68   cluster->setAttr(tensorflow::kNumCoresPerReplicaAttr,
69                    builder->getI64IntegerAttr(1));
70 }
71 
72 // TODO(hongjunchoi): Implement cluster inlining pass so that there are no
73 // nested tf_device.cluster ops with same mesh.
IdentifyTPUFunctions(mlir::ModuleOp module,llvm::SmallVectorImpl<Mesh> * tpu_meshs,llvm::SmallVectorImpl<mlir::TF::StatefulPartitionedCallOp> * tpu_functions)74 void IdentifyTPUFunctions(
75     mlir::ModuleOp module, llvm::SmallVectorImpl<Mesh>* tpu_meshs,
76     llvm::SmallVectorImpl<mlir::TF::StatefulPartitionedCallOp>* tpu_functions) {
77   auto main_func = module.lookupSymbol<mlir::func::FuncOp>("main");
78   if (!main_func) return;
79 
80   for (auto call : main_func.getOps<mlir::TF::StatefulPartitionedCallOp>()) {
81     auto mesh_or_status = Mesh::FromString(string(call.config()));
82     // Function calls created by end users instead of being converted from
83     // tf_device.cluster do not have a serialized mesh as a config attribute. We
84     // ignore the error returned from parsing in this case.
85     if (!mesh_or_status.ok()) return;
86     bool skip_xla_compilation = false;
87     if (call->hasAttr(kSkipXlaCompilation)) {
88       skip_xla_compilation =
89           call->getAttrOfType<mlir::BoolAttr>(kSkipXlaCompilation).getValue();
90     }
91     if (mesh_or_status->is_tpu_mesh() && !skip_xla_compilation) {
92       tpu_functions->emplace_back(call);
93       tpu_meshs->emplace_back(std::move(mesh_or_status.ValueOrDie()));
94     }
95   }
96 }
97 
CreateTPUCluster(mlir::TF::StatefulPartitionedCallOp tpu_call,mlir::OpBuilder * builder,mlir::tf_device::ClusterOp * newly_created_cluster)98 mlir::LogicalResult CreateTPUCluster(
99     mlir::TF::StatefulPartitionedCallOp tpu_call, mlir::OpBuilder* builder,
100     mlir::tf_device::ClusterOp* newly_created_cluster) {
101   auto function = MaybeFindFunction(tpu_call);
102   if (!function)
103     return tpu_call.emitOpError(
104         "failed during TPU Integration as Func op TPU mesh was not found");
105 
106   auto& function_block = function->getCallableRegion()->front();
107   builder->setInsertionPointToStart(&function_block);
108 
109   auto cluster = builder->create<mlir::tf_device::ClusterOp>(
110       tpu_call.getLoc(), function->getCallableResults());
111   cluster.body().push_back(new mlir::Block);
112 
113   auto& function_body = function_block.getOperations();
114   cluster.GetBody().getOperations().splice(
115       cluster.GetBody().getOperations().begin(), function_body,
116       std::next(function_body.begin()), std::prev(function_body.end()));
117 
118   builder->setInsertionPointToEnd(&cluster.GetBody());
119   mlir::Operation* function_block_terminator = function_block.getTerminator();
120   builder->create<mlir::tf_device::ReturnOp>(
121       tpu_call.getLoc(), function_block_terminator->getOperands());
122 
123   function_block_terminator->setOperands(cluster.getResults());
124 
125   *newly_created_cluster = cluster;
126   return mlir::success();
127 }
128 
129 struct DTensorTPUIntegration
130     : public DTensorTPUIntegrationBase<DTensorTPUIntegration> {
getDependentDialectstensorflow::dtensor::__anon913eb44c0111::DTensorTPUIntegration131   void getDependentDialects(mlir::DialectRegistry& registry) const override {
132     registry.insert<mlir::dtensor::DTensorDialect>();
133     registry.insert<mlir::tf_device::TensorFlowDeviceDialect>();
134   }
135 
runOnOperationtensorflow::dtensor::__anon913eb44c0111::DTensorTPUIntegration136   void runOnOperation() override {
137     mlir::MLIRContext& context = getContext();
138     mlir::OpBuilder op_builder(&context);
139     auto module = getOperation();
140     llvm::SmallVector<mlir::TF::StatefulPartitionedCallOp, 4> tpu_functions;
141     llvm::SmallVector<Mesh, 4> tpu_meshes;
142     IdentifyTPUFunctions(module, &tpu_meshes, &tpu_functions);
143 
144     for (auto tpu_function_and_mesh : llvm::zip(tpu_meshes, tpu_functions)) {
145       mlir::tf_device::ClusterOp cluster;
146 
147       if (mlir::failed(CreateTPUCluster(std::get<1>(tpu_function_and_mesh),
148                                         &op_builder, &cluster)))
149         return signalPassFailure();
150 
151       AddMetadataToTPUCluster(std::get<0>(tpu_function_and_mesh), cluster,
152                               &op_builder);
153     }
154   };
155 };
156 
157 }  // namespace
158 
159 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorTPUIntegration()160 CreateDTensorTPUIntegration() {
161   return std::make_unique<DTensorTPUIntegration>();
162 }
163 
164 }  // namespace dtensor
165 }  // namespace tensorflow
166