xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/dtensor_send_recv.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/dtensor/mlir/dtensor_send_recv.h"
17 
18 #include <string>
19 
20 #include "llvm/ADT/SmallVector.h"
21 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
22 #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
23 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
24 #include "tensorflow/dtensor/mlir/device_utils.h"
25 #include "tensorflow/dtensor/mlir/layout_parsing.h"
26 #include "tensorflow/dtensor/mlir/value_utils.h"
27 
28 namespace tensorflow {
29 namespace dtensor {
30 namespace {
31 
32 // Returns compilation key placeholder. This placeholder will be replaced with
33 // output of TPUCompile op during TPURewrite pass. Program key (output of
34 // TPUCompile op) is used to differentiate TPU computation from which to receive
35 // data.
GetOrCreateCompilationKey(mlir::Operation * op)36 mlir::Value GetOrCreateCompilationKey(mlir::Operation* op) {
37   mlir::Value key;
38   auto cluster = op->getParentOfType<mlir::tf_device::ClusterOp>();
39   assert(cluster);
40   cluster.walk(
41       [&](mlir::TF::_TPUCompileMlirPlaceholderProgramKeyOp compilation_key) {
42         key = compilation_key.program();
43       });
44   if (key) return key;
45 
46   mlir::OpBuilder builder(&cluster.GetBody().front());
47   auto result_type =
48       mlir::RankedTensorType::get({3}, builder.getType<mlir::TF::StringType>());
49   auto new_compilation_key =
50       builder.create<mlir::TF::_TPUCompileMlirPlaceholderProgramKeyOp>(
51           cluster.getLoc(), /*program=*/result_type,
52           llvm::ArrayRef<mlir::Value>{});
53   return new_compilation_key.program();
54 }
55 
56 }  // namespace
57 
GetDeviceOrdinal(const Mesh & mesh,const mlir::Location & loc,mlir::func::FuncOp function,mlir::OpBuilder * builder,bool return_int64_type)58 StatusOr<mlir::Value> GetDeviceOrdinal(const Mesh& mesh,
59                                        const mlir::Location& loc,
60                                        mlir::func::FuncOp function,
61                                        mlir::OpBuilder* builder,
62                                        bool return_int64_type) {
63   // Create as many entries as the number of devices in the entire mesh.
64   llvm::SmallVector<int32, 4> device_id_to_ordinal(mesh.num_devices(), 0);
65   // Only fill in entries with indices equal to local device IDs. For TPUs,
66   // there are usually 8 local devices.
67   for (int i = 0; i < mesh.local_device_ids().size(); ++i) {
68     device_id_to_ordinal[mesh.local_device_ids()[i]] = i;
69   }
70   // Slice out the device ordinal using the device ID as index.
71   TF_ASSIGN_OR_RETURN(mlir::Value device_id, DeviceId(function));
72   mlir::TF::SliceOp device_ordinal = builder->create<mlir::TF::SliceOp>(
73       loc,
74       /*output=*/EffectivelyScalarR1Type(builder->getIntegerType(32)),
75       /*input=*/IntConst(*builder, loc, device_id_to_ordinal),
76       /*begin=*/
77       mlir::TF::collection_ops_util::ReshapeScalarToSizeType(*builder,
78                                                              device_id, loc),
79       /*size=*/IntConst(*builder, loc, {1}));
80   mlir::Value device_ordinal_scalar =
81       ReshapeSizeTypeToScalar(*builder, loc, device_ordinal);
82   if (return_int64_type) {
83     device_ordinal_scalar = builder->create<mlir::TF::CastOp>(
84         loc, mlir::RankedTensorType::get({}, builder->getI64Type()),
85         device_ordinal_scalar);
86   }
87   return device_ordinal_scalar;
88 }
89 
90 // Lowers DTensorSend Op to either one of XlaSendFromHost op or XlaSendToHost,
91 // depending on the src mesh cluster.
LowerDTensorSendToXlaOp(const Layout & send_input_layout,mlir::Value send_input,mlir::TF::DTensorSend dtensor_send,bool send_from_device_zero)92 StatusOr<mlir::Operation*> LowerDTensorSendToXlaOp(
93     const Layout& send_input_layout, mlir::Value send_input,
94     mlir::TF::DTensorSend dtensor_send, bool send_from_device_zero) {
95   const bool send_from_cpu = !send_input_layout.mesh().is_tpu_mesh();
96   mlir::OpBuilder builder(dtensor_send);
97 
98   mlir::Location loc = dtensor_send.getLoc();
99   mlir::Operation* lowered_send_op;
100   if (send_from_cpu) {
101     llvm::SmallVector<mlir::Value, 4> value_to_send{send_input};
102     mlir::OpBuilder::InsertPoint insertion_point = builder.saveInsertionPoint();
103     mlir::Value program_key = GetOrCreateCompilationKey(dtensor_send);
104     builder.restoreInsertionPoint(insertion_point);
105 
106     mlir::Value device_ordinal;
107     if (send_from_device_zero) {
108       // For CopyToMesh, we currently only support sending from host device 0
109       // to target TPUs.
110       device_ordinal = CreateIntScalarConst(0, builder, loc);
111     } else {
112       // For special topologies, always send from CPU device i to TPU device i.
113       auto send_cluster =
114           dtensor_send->getParentOfType<mlir::tf_device::ClusterOp>();
115       if (!send_cluster) {
116         return errors::InvalidArgument("DTensorSend is not inside a ClusterOp");
117       }
118       auto send_func = send_cluster->getParentOfType<mlir::func::FuncOp>();
119       if (!send_func) {
120         return errors::InvalidArgument("DTensorSend is not inside a FuncOp");
121       }
122       TF_ASSIGN_OR_RETURN(
123           device_ordinal,
124           GetDeviceOrdinal(send_input_layout.mesh(), loc, send_func, &builder));
125     }
126     // Create XlaSendFromHostV2 op
127     lowered_send_op = builder.create<mlir::TF::_XlaSendFromHostV2Op>(
128         loc, value_to_send, program_key, device_ordinal, dtensor_send.key());
129   } else {
130     // Note that for ops running in XLA/TPU, device ordinal input is not needed.
131     lowered_send_op = builder.create<mlir::TF::XlaSendToHostOp>(
132         loc, send_input, dtensor_send.key());
133   }
134 
135   dtensor_send.erase();
136   return lowered_send_op;
137 }
138 
139 // Lowers DTensorRecv op to either one of XlaRecvAtHost or XlaRecvFromHost,
140 // depending on src mesh cluster configuration.
LowerDTensorRecvToXlaOp(mlir::TF::DTensorRecv dtensor_recv)141 StatusOr<mlir::Operation*> LowerDTensorRecvToXlaOp(
142     mlir::TF::DTensorRecv dtensor_recv) {
143   return LowerDTensorRecvToXlaOp(dtensor_recv, dtensor_recv.getType());
144 }
145 
146 // Lowers DTensorRecv op to either one of XlaRecvAtHost or XlaRecvFromHost,
147 // depending on src mesh cluster configuration. `output_type` can be set to the
148 // specific local tensor type needed, if different from the Recv op output type.
LowerDTensorRecvToXlaOp(mlir::TF::DTensorRecv dtensor_recv,mlir::Type output_type)149 StatusOr<mlir::Operation*> LowerDTensorRecvToXlaOp(
150     mlir::TF::DTensorRecv dtensor_recv, mlir::Type output_type) {
151   const bool recv_at_cpu = dtensor_recv.layout().mesh().is_cpu_mesh();
152   mlir::Operation* recv_xla_op = nullptr;
153   mlir::OpBuilder builder(dtensor_recv);
154 
155   if (recv_at_cpu) {
156     // Create XlaRecvAtHostV2 op.
157     llvm::SmallVector<mlir::Type, 4> output_types{output_type};
158     auto recv_cluster =
159         dtensor_recv->getParentOfType<mlir::tf_device::ClusterOp>();
160 
161     TF_ASSIGN_OR_RETURN(absl::optional<Mesh> mesh,
162                         ExtractDeviceMeshFromOp(recv_cluster));
163     if (!mesh.has_value())
164       return errors::InvalidArgument(
165           "failed to get device ordinal as mesh for operation is not "
166           "specified.");
167 
168     mlir::OpBuilder builder(&recv_cluster.GetBody().front());
169     TF_ASSIGN_OR_RETURN(
170         mlir::Value device_ordinal,
171         GetDeviceOrdinal(*mesh, recv_cluster.getLoc(),
172                          recv_cluster->getParentOfType<mlir::func::FuncOp>(),
173                          &builder));
174 
175     auto program_key = GetOrCreateCompilationKey(dtensor_recv);
176     builder.setInsertionPoint(dtensor_recv);
177     recv_xla_op = builder.create<mlir::TF::_XlaRecvAtHostV2Op>(
178         dtensor_recv.getLoc(), output_types,
179         /*dynamic_key=*/program_key, device_ordinal, dtensor_recv.keyAttr());
180   } else {
181     // Create XlaRecvFromHost op.
182     recv_xla_op = builder.create<mlir::TF::XlaRecvFromHostOp>(
183         dtensor_recv.getLoc(), output_type,
184         ConvertTypeToTensorShapeAttr(dtensor_recv.getType()),
185         dtensor_recv.keyAttr());
186   }
187 
188   assert(recv_xla_op);
189 
190   // TODO(hongjunchoi): After receiving tensor, convert tensor to requested
191   // layout with EmitRelayout.
192   return recv_xla_op;
193 }
194 
195 // Lowers a DTensorSend Op from a CPU to a TF Send op.
LowerDTensorSendFromCPUToTFOp(const Layout & send_input_layout,mlir::Value send_input,mlir::TF::DTensorSend dtensor_send)196 StatusOr<mlir::Operation*> LowerDTensorSendFromCPUToTFOp(
197     const Layout& send_input_layout, mlir::Value send_input,
198     mlir::TF::DTensorSend dtensor_send) {
199   mlir::OpBuilder builder(dtensor_send);
200   builder.setInsertionPointAfter(send_input.getDefiningOp());
201 
202   llvm::SmallVector<mlir::Value, 4> value_to_send{send_input};
203 
204   // Create multiple send from host. There should be #number of local
205   // devices(in target mesh) number of sends.
206   absl::Span<const std::string> sending_devices =
207       send_input_layout.mesh().local_devices();
208 
209   Layout target_layout = dtensor_send.target_layout();
210   absl::Span<const std::string> receiving_devices =
211       target_layout.mesh().local_devices();
212 
213   std::string tensor_name = dtensor_send.key().str();
214 
215   mlir::Operation* lowered_send_op;
216   for (size_t i = 0; i < receiving_devices.size(); ++i)
217     lowered_send_op = builder.create<mlir::TF::_HostSendOp>(
218         send_input.getLoc(), dtensor_send.input(), tensor_name,
219         sending_devices[0],
220         /*send_device_incarnation=*/0, receiving_devices[i]);
221 
222   dtensor_send.erase();
223   return lowered_send_op;
224 }
225 
226 // Lowers DTensorRecv op to TF Recv Op.
LowerDTensorRecvFromCPUToTFOp(const Mesh & send_mesh,mlir::TF::DTensorRecv dtensor_recv)227 StatusOr<mlir::Operation*> LowerDTensorRecvFromCPUToTFOp(
228     const Mesh& send_mesh, mlir::TF::DTensorRecv dtensor_recv) {
229   const Layout& recv_layout = dtensor_recv.layout();
230 
231   auto recv_cluster =
232       dtensor_recv->getParentOfType<mlir::tf_device::ClusterOp>();
233 
234   mlir::OpBuilder builder(&recv_cluster.GetBody().front());
235   llvm::SmallVector<mlir::Type, 4> output_types{dtensor_recv.getType()};
236   builder.setInsertionPoint(dtensor_recv);
237   std::string tensor_name = dtensor_recv.key().str();
238   absl::Span<const std::string> sending_devices = send_mesh.local_devices();
239   absl::Span<const std::string> receiving_devices =
240       recv_layout.mesh().local_devices();
241 
242   mlir::Operation* lowered_recv_op;
243   mlir::Location loc = dtensor_recv.getLoc();
244   for (size_t i = 0; i < receiving_devices.size(); ++i)
245     lowered_recv_op = builder.create<mlir::TF::_HostRecvOp>(
246         loc, dtensor_recv.getType(), tensor_name, sending_devices[0],
247         /*send_device_incarnation=*/0, receiving_devices[i]);
248 
249   // Replace dtensor_recv with newly created recv op and remove DTensorRecv op.
250   assert(lowered_recv_op);
251   dtensor_recv.replaceAllUsesWith(lowered_recv_op);
252   dtensor_recv.erase();
253   return lowered_recv_op;
254 }
255 
256 }  // namespace dtensor
257 }  // namespace tensorflow
258