xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/utils/update_tpu_metadata.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/IR/Attributes.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/Operation.h"
22 #include "mlir/IR/Value.h"
23 #include "mlir/IR/Visitors.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Pass/PassManager.h"
26 #include "mlir/Support/LogicalResult.h"
27 #include "mlir/Transforms/Passes.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/platform/str_util.h"
32 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
33 #include "tensorflow/dtensor/cc/constants.h"
34 #include "tensorflow/dtensor/cc/dtensor_utils.h"
35 #include "tensorflow/dtensor/cc/tensor_layout.h"
36 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
37 #include "tensorflow/dtensor/mlir/layout_parsing.h"
38 #include "tensorflow/dtensor/mlir/op_utils.h"
39 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
40 
41 namespace tensorflow {
42 namespace dtensor {
43 namespace {
44 
45 constexpr char kDeviceAttr[] = "device";
46 constexpr char kFuncDeviceAttr[] = "tf.device";
47 
48 // Removes explicit device assignment on TPUExecute and _TPUCompileMlir ops.
49 // As TPU execution replication logic is delegated to DTensorDevice,
50 // DTensorDevice should handle replication and Placer would assign devices.
UpdateTPUDeviceAssignment(mlir::func::FuncOp function,mlir::OpBuilder * builder)51 void UpdateTPUDeviceAssignment(mlir::func::FuncOp function,
52                                mlir::OpBuilder* builder) {
53   function.walk([&](mlir::Operation* op) {
54     if (!llvm::isa<
55             mlir::TF::TPUExecuteOp, mlir::TF::TPUExecuteAndUpdateVariablesOp,
56             mlir::TF::_TPUCompileMlirOp, mlir::TF::TPUCompileSucceededAssertOp>(
57             op))
58       return;
59 
60     assert(!op->getAttrOfType<mlir::StringAttr>(kDeviceAttr));
61 
62     auto enclosing_launch = op->getParentOfType<mlir::tf_device::LaunchOp>();
63     if (!enclosing_launch) return;
64 
65     enclosing_launch.deviceAttr(builder->getStringAttr(""));
66 
67     // Remove placeholder device attributes of resource arguments to TPU
68     // computation.
69     for (int i = 0; i < function.getNumArguments(); ++i)
70       function.removeArgAttr(i, builder->getStringAttr(kFuncDeviceAttr));
71   });
72 }
73 
74 // Updates `num_replicas` section of TPUCompileMetadataProto to number of
75 // devices set by DTensor device.
UpdateTPUCompileMetadata(const Mesh & mesh_config,mlir::func::FuncOp function,mlir::OpBuilder * builder)76 mlir::LogicalResult UpdateTPUCompileMetadata(const Mesh& mesh_config,
77                                              mlir::func::FuncOp function,
78                                              mlir::OpBuilder* builder) {
79   auto result = function.walk([&](mlir::TF::_TPUCompileMlirOp compile) {
80     auto original_metadata = compile.metadata();
81     tpu::TPUCompileMetadataProto metadata_proto;
82     if (!metadata_proto.ParseFromString(original_metadata.str())) {
83       compile.emitOpError("unable to parse TPUCompileMetadata");
84       return mlir::WalkResult::interrupt();
85     }
86 
87     int num_replicas = mesh_config.num_devices();
88     metadata_proto.set_num_replicas(num_replicas);
89 
90     // We keep DTensor mesh global device IDs equal to XLA replica IDs, both
91     // sequentially increasing over mesh dimensions. Collective lowering has
92     // generated `replica_groups` using these IDs.
93     //
94     // We need to set appropriate XLA replica ID-to-core ID mappings here to get
95     // correct results, by being consistent with what the user Python program
96     // gets and assumes. There are three kinds of mesh:
97     //
98     // 1. The first mesh getting here is a one-of-a-kind mesh for merging core
99     //    IDs across hosts during TPU initialization. This mesh doesn't need any
100     //    mapping to be set. Mesh::tpu_core_ids() is empty when this happens.
101     // 2. Users can manually create meshes, with empty or non-empty names. These
102     //    meshes have global device IDs equal to TF task-device ordinals, and
103     //    they do not place any entry in Mesh::tpu_core_ids(). The default entry
104     //    in Mesh::tpu_core_ids(), stored under an empty name key by the mesh
105     //    computation in 1, works on these meshes.
106     // 3. Users can create ring reduction-optimized meshes using provided
107     //    helpers. These meshes must have non-empty names and store an entry in
108     //    Mesh::tpu_core_ids() when they are created, using their name as key.
109     //
110     // For any user-defined mesh, if users have manually specified device
111     // assignment, always respect that.
112     if (!Mesh::tpu_core_ids().empty() &&
113         !metadata_proto.has_device_assignment()) {
114       std::string mesh_name = mesh_config.name();
115       if (Mesh::tpu_core_ids().count(mesh_name) == 0) {
116         // This can happen only for manually created meshes (2 above) with
117         // non-empty names. This mesh should use the default mapping.
118         VLOG(1) << "mesh_name " << mesh_name << " not found, using empty name";
119         mesh_name = "";
120       }
121       const std::vector<int>& tpu_core_ids = Mesh::tpu_core_ids()[mesh_name];
122       VLOG(1) << "tpu_core_ids: " << str_util::Join(tpu_core_ids, ", ");
123 
124       xla::DeviceAssignmentProto device_assignment;
125       device_assignment.set_replica_count(num_replicas);
126       device_assignment.set_computation_count(1);
127       auto* computation_device = device_assignment.add_computation_devices();
128       // TODO(b/188076080): Clean up device id.
129       const int64_t start_device_id = mesh_config.min_global_device_id();
130       for (int i = 0; i < num_replicas; ++i) {
131         int tpu_core_id_index = i + start_device_id;
132         computation_device->add_replica_device_ids(
133             tpu_core_ids[tpu_core_id_index]);
134       }
135       *metadata_proto.mutable_device_assignment() = device_assignment;
136     }
137 
138     compile.metadataAttr(
139         builder->getStringAttr(metadata_proto.SerializeAsString()));
140     return mlir::WalkResult::advance();
141   });
142   return mlir::failure(result.wasInterrupted());
143 }
144 
145 // Pass that updates TPU specific metadata including `num_replicas` and device
146 // assignment of TPUCompileMlirOp and TPUExecute ops.
147 struct DTensorUpdateTPUMetadata
148     : public DTensorUpdateTPUMetadataBase<DTensorUpdateTPUMetadata> {
runOnOperationtensorflow::dtensor::__anon05d002cb0111::DTensorUpdateTPUMetadata149   void runOnOperation() override {
150     mlir::MLIRContext& context = getContext();
151     mlir::OpBuilder builder(&context);
152     auto module = getOperation();
153     mlir::func::FuncOp main_func = module.lookupSymbol<mlir::func::FuncOp>("main");
154     if (!main_func) return;
155 
156     auto result = main_func.walk([&](mlir::TF::StatefulPartitionedCallOp op) {
157       auto call_config = op.config();
158       auto mesh_or_status = Mesh::FromString(call_config.str());
159       if (!mesh_or_status.ok()) return mlir::WalkResult::advance();
160 
161       const auto mesh = mesh_or_status.ValueOrDie();
162       if (!mesh.is_tpu_mesh()) return mlir::WalkResult::advance();
163 
164       auto function = MaybeFindFunction(op);
165       if (!function) {
166         op.emitOpError(
167             "Could not find function definition for "
168             "StatefulPartitionedCall op running on TPU.");
169         return mlir::WalkResult::interrupt();
170       }
171 
172       if (mlir::failed(UpdateTPUCompileMetadata(mesh, *function, &builder)))
173         return mlir::WalkResult::interrupt();
174 
175       UpdateTPUDeviceAssignment(*function, &builder);
176       return mlir::WalkResult::advance();
177     });
178 
179     if (result.wasInterrupted()) return signalPassFailure();
180   };
181 };
182 
183 }  // namespace
184 
185 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorUpdateTPUMetadata()186 CreateDTensorUpdateTPUMetadata() {
187   return std::make_unique<DTensorUpdateTPUMetadata>();
188 }
189 
190 }  // namespace dtensor
191 }  // namespace tensorflow
192