1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17
18 #include "mlir/Dialect/Func/IR/FuncOps.h"
19 #include "mlir/IR/Attributes.h"
20 #include "mlir/IR/Builders.h"
21 #include "mlir/IR/Operation.h"
22 #include "mlir/IR/Value.h"
23 #include "mlir/IR/Visitors.h"
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Pass/PassManager.h"
26 #include "mlir/Support/LogicalResult.h"
27 #include "mlir/Transforms/Passes.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/platform/str_util.h"
32 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
33 #include "tensorflow/dtensor/cc/constants.h"
34 #include "tensorflow/dtensor/cc/dtensor_utils.h"
35 #include "tensorflow/dtensor/cc/tensor_layout.h"
36 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
37 #include "tensorflow/dtensor/mlir/layout_parsing.h"
38 #include "tensorflow/dtensor/mlir/op_utils.h"
39 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
40
41 namespace tensorflow {
42 namespace dtensor {
43 namespace {
44
45 constexpr char kDeviceAttr[] = "device";
46 constexpr char kFuncDeviceAttr[] = "tf.device";
47
48 // Removes explicit device assignment on TPUExecute and _TPUCompileMlir ops.
49 // As TPU execution replication logic is delegated to DTensorDevice,
50 // DTensorDevice should handle replication and Placer would assign devices.
UpdateTPUDeviceAssignment(mlir::func::FuncOp function,mlir::OpBuilder * builder)51 void UpdateTPUDeviceAssignment(mlir::func::FuncOp function,
52 mlir::OpBuilder* builder) {
53 function.walk([&](mlir::Operation* op) {
54 if (!llvm::isa<
55 mlir::TF::TPUExecuteOp, mlir::TF::TPUExecuteAndUpdateVariablesOp,
56 mlir::TF::_TPUCompileMlirOp, mlir::TF::TPUCompileSucceededAssertOp>(
57 op))
58 return;
59
60 assert(!op->getAttrOfType<mlir::StringAttr>(kDeviceAttr));
61
62 auto enclosing_launch = op->getParentOfType<mlir::tf_device::LaunchOp>();
63 if (!enclosing_launch) return;
64
65 enclosing_launch.deviceAttr(builder->getStringAttr(""));
66
67 // Remove placeholder device attributes of resource arguments to TPU
68 // computation.
69 for (int i = 0; i < function.getNumArguments(); ++i)
70 function.removeArgAttr(i, builder->getStringAttr(kFuncDeviceAttr));
71 });
72 }
73
74 // Updates `num_replicas` section of TPUCompileMetadataProto to number of
75 // devices set by DTensor device.
UpdateTPUCompileMetadata(const Mesh & mesh_config,mlir::func::FuncOp function,mlir::OpBuilder * builder)76 mlir::LogicalResult UpdateTPUCompileMetadata(const Mesh& mesh_config,
77 mlir::func::FuncOp function,
78 mlir::OpBuilder* builder) {
79 auto result = function.walk([&](mlir::TF::_TPUCompileMlirOp compile) {
80 auto original_metadata = compile.metadata();
81 tpu::TPUCompileMetadataProto metadata_proto;
82 if (!metadata_proto.ParseFromString(original_metadata.str())) {
83 compile.emitOpError("unable to parse TPUCompileMetadata");
84 return mlir::WalkResult::interrupt();
85 }
86
87 int num_replicas = mesh_config.num_devices();
88 metadata_proto.set_num_replicas(num_replicas);
89
90 // We keep DTensor mesh global device IDs equal to XLA replica IDs, both
91 // sequentially increasing over mesh dimensions. Collective lowering has
92 // generated `replica_groups` using these IDs.
93 //
94 // We need to set appropriate XLA replica ID-to-core ID mappings here to get
95 // correct results, by being consistent with what the user Python program
96 // gets and assumes. There are three kinds of mesh:
97 //
98 // 1. The first mesh getting here is a one-of-a-kind mesh for merging core
99 // IDs across hosts during TPU initialization. This mesh doesn't need any
100 // mapping to be set. Mesh::tpu_core_ids() is empty when this happens.
101 // 2. Users can manually create meshes, with empty or non-empty names. These
102 // meshes have global device IDs equal to TF task-device ordinals, and
103 // they do not place any entry in Mesh::tpu_core_ids(). The default entry
104 // in Mesh::tpu_core_ids(), stored under an empty name key by the mesh
105 // computation in 1, works on these meshes.
106 // 3. Users can create ring reduction-optimized meshes using provided
107 // helpers. These meshes must have non-empty names and store an entry in
108 // Mesh::tpu_core_ids() when they are created, using their name as key.
109 //
110 // For any user-defined mesh, if users have manually specified device
111 // assignment, always respect that.
112 if (!Mesh::tpu_core_ids().empty() &&
113 !metadata_proto.has_device_assignment()) {
114 std::string mesh_name = mesh_config.name();
115 if (Mesh::tpu_core_ids().count(mesh_name) == 0) {
116 // This can happen only for manually created meshes (2 above) with
117 // non-empty names. This mesh should use the default mapping.
118 VLOG(1) << "mesh_name " << mesh_name << " not found, using empty name";
119 mesh_name = "";
120 }
121 const std::vector<int>& tpu_core_ids = Mesh::tpu_core_ids()[mesh_name];
122 VLOG(1) << "tpu_core_ids: " << str_util::Join(tpu_core_ids, ", ");
123
124 xla::DeviceAssignmentProto device_assignment;
125 device_assignment.set_replica_count(num_replicas);
126 device_assignment.set_computation_count(1);
127 auto* computation_device = device_assignment.add_computation_devices();
128 // TODO(b/188076080): Clean up device id.
129 const int64_t start_device_id = mesh_config.min_global_device_id();
130 for (int i = 0; i < num_replicas; ++i) {
131 int tpu_core_id_index = i + start_device_id;
132 computation_device->add_replica_device_ids(
133 tpu_core_ids[tpu_core_id_index]);
134 }
135 *metadata_proto.mutable_device_assignment() = device_assignment;
136 }
137
138 compile.metadataAttr(
139 builder->getStringAttr(metadata_proto.SerializeAsString()));
140 return mlir::WalkResult::advance();
141 });
142 return mlir::failure(result.wasInterrupted());
143 }
144
145 // Pass that updates TPU specific metadata including `num_replicas` and device
146 // assignment of TPUCompileMlirOp and TPUExecute ops.
147 struct DTensorUpdateTPUMetadata
148 : public DTensorUpdateTPUMetadataBase<DTensorUpdateTPUMetadata> {
runOnOperationtensorflow::dtensor::__anon05d002cb0111::DTensorUpdateTPUMetadata149 void runOnOperation() override {
150 mlir::MLIRContext& context = getContext();
151 mlir::OpBuilder builder(&context);
152 auto module = getOperation();
153 mlir::func::FuncOp main_func = module.lookupSymbol<mlir::func::FuncOp>("main");
154 if (!main_func) return;
155
156 auto result = main_func.walk([&](mlir::TF::StatefulPartitionedCallOp op) {
157 auto call_config = op.config();
158 auto mesh_or_status = Mesh::FromString(call_config.str());
159 if (!mesh_or_status.ok()) return mlir::WalkResult::advance();
160
161 const auto mesh = mesh_or_status.ValueOrDie();
162 if (!mesh.is_tpu_mesh()) return mlir::WalkResult::advance();
163
164 auto function = MaybeFindFunction(op);
165 if (!function) {
166 op.emitOpError(
167 "Could not find function definition for "
168 "StatefulPartitionedCall op running on TPU.");
169 return mlir::WalkResult::interrupt();
170 }
171
172 if (mlir::failed(UpdateTPUCompileMetadata(mesh, *function, &builder)))
173 return mlir::WalkResult::interrupt();
174
175 UpdateTPUDeviceAssignment(*function, &builder);
176 return mlir::WalkResult::advance();
177 });
178
179 if (result.wasInterrupted()) return signalPassFailure();
180 };
181 };
182
183 } // namespace
184
185 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorUpdateTPUMetadata()186 CreateDTensorUpdateTPUMetadata() {
187 return std::make_unique<DTensorUpdateTPUMetadata>();
188 }
189
190 } // namespace dtensor
191 } // namespace tensorflow
192