1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18
19 #include "llvm/ADT/ArrayRef.h"
20 #include "llvm/ADT/Optional.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Casting.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
26 #include "mlir/IR/Attributes.h" // from @llvm-project
27 #include "mlir/IR/Block.h" // from @llvm-project
28 #include "mlir/IR/Builders.h" // from @llvm-project
29 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
31 #include "mlir/IR/Operation.h" // from @llvm-project
32 #include "mlir/IR/Value.h" // from @llvm-project
33 #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
34 #include "mlir/Pass/Pass.h" // from @llvm-project
35 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
36 #include "mlir/Support/LogicalResult.h" // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
39 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
40 #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
41 #include "tensorflow/compiler/xla/client/sharding_builder.h"
42 #include "tensorflow/compiler/xla/xla_data.pb.h"
43
44 namespace mlir {
45 namespace TFTPU {
46 namespace {
47
48 constexpr char kReplicateSharding[] = "";
49 constexpr char kShardingAttr[] = "mhlo.sharding";
50 constexpr char kUseSpmdAttr[] = "use_spmd_for_xla_partitioning";
51 constexpr char kAliasingAttr[] = "tf.aliasing_output";
52
53 struct TPUShardingIdentificationPass
54 : public TF::TPUShardingIdentificationPassBase<
55 TPUShardingIdentificationPass> {
56 void runOnOperation() final;
57 };
58
59 // Returns XLA sharding from TPUPartitionedInput op connected to a
60 // `tf_device.cluster_func` operand value. If value is a resource type then
61 // TPUPartitionedInput op will be connected to a ReadVariable op that feeds into
62 // a `tf_device.cluster_func`.
GetXlaShardingFromOperand(Value value)63 llvm::Optional<llvm::StringRef> GetXlaShardingFromOperand(Value value) {
64 Value value_to_visit = value;
65 if (auto read_var = value_to_visit.getDefiningOp<TF::ReadVariableOp>())
66 value_to_visit = read_var.resource();
67
68 if (auto partitioned_input =
69 value_to_visit.getDefiningOp<TF::TPUPartitionedInputOp>())
70 return partitioned_input._XlaSharding();
71
72 return llvm::None;
73 }
74
75 // Given a `tf_device.cluster_func` operand value return true iff it a device
76 // variable that should default to MAXIMAL sharding. Device variables that are
77 // per-replica or distributed default to MAXIMAL sharding, which corresponds to
78 // arguments of the `tf_device.replicate`. Otherwise the variable is broadcast,
79 // which corresponds to edges that are implicitly captured by the `replicate`.
IsMaximalVariable(Value value)80 bool IsMaximalVariable(Value value) {
81 auto read_var = value.getDefiningOp<TF::ReadVariableOp>();
82 return read_var && read_var->getParentOfType<tf_device::ReplicateOp>();
83 }
84
85 // Verify whether the given sharding can be applied to the given (tensor) type.
86 // (A bad sharding might mean failing tf.Split ops if the graph later executes
87 // on CPU)
88 // If the sharding is incorrect, return failure. If it's good, or if we can't
89 // verify it, return success.
VerifySharding(Type type,StringRef sharding_string)90 LogicalResult VerifySharding(Type type, StringRef sharding_string) {
91 xla::OpSharding sharding;
92 if (!sharding.ParseFromString(sharding_string.str())) {
93 // Some test cases use \01\02\03 as sharding, to test propagation. Treat
94 // a non-proto sharding as valid, and don't verify further.
95 return success();
96 }
97 if (sharding.type() != xla::OpSharding::OTHER) {
98 // We currently only verify shardings that actually break a tensor apart.
99 return success();
100 }
101 if (RankedTensorType ranked_type = type.dyn_cast<RankedTensorType>()) {
102 if (ranked_type.getRank() < sharding.tile_assignment_dimensions_size()) {
103 return failure();
104 }
105 }
106 return success();
107 }
108
109 // Verify sharding for all arguments and return values.
VerifyShardings(mlir::func::FuncOp func,const llvm::SmallVectorImpl<llvm::StringRef> & sharding_for_args,const llvm::SmallVectorImpl<llvm::StringRef> & sharding_for_rets)110 LogicalResult VerifyShardings(
111 mlir::func::FuncOp func,
112 const llvm::SmallVectorImpl<llvm::StringRef>& sharding_for_args,
113 const llvm::SmallVectorImpl<llvm::StringRef>& sharding_for_rets) {
114 Block& function_block = func.front();
115 for (auto sharding_and_arg :
116 llvm::zip(sharding_for_args, function_block.getArguments())) {
117 StringRef sharding = std::get<0>(sharding_and_arg);
118 BlockArgument arg = std::get<1>(sharding_and_arg);
119 if (failed(VerifySharding(arg.getType(), sharding))) return failure();
120 }
121 Operation* terminator = function_block.getTerminator();
122 for (auto sharding_and_retval :
123 llvm::zip(sharding_for_rets, terminator->getOpOperands())) {
124 StringRef sharding = std::get<0>(sharding_and_retval);
125 OpOperand& retval = std::get<1>(sharding_and_retval);
126 if (failed(VerifySharding(retval.get().getType(), sharding)))
127 return failure();
128 }
129 return success();
130 }
131
132 // Returns XLA sharding from a XlaSharding op connected to an argument value. If
133 // value is a resource type then XlaSharding op will be connected to a
134 // ReadVariable op. XlaSharding op may be direct user of inputs but it may also
135 // be followed by an Identity op and, in the case where bfloat16 type is used,
136 // Cast op may be added right after the input.
137 //
138 // TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If,
139 // Case, While) ops and Caller return values.
140 // TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded
141 // inputs.
GetXlaShardingFromArg(Value value)142 llvm::Optional<llvm::StringRef> GetXlaShardingFromArg(Value value) {
143 llvm::SmallPtrSet<Value, 4> visited_values;
144 llvm::SmallVector<Value, 4> values_to_visit{value};
145 while (!values_to_visit.empty()) {
146 llvm::SmallVector<Value, 4> next_values_to_visit;
147 for (Value value_to_visit : values_to_visit) {
148 if (!visited_values.insert(value_to_visit).second) continue;
149
150 for (auto& use : value_to_visit.getUses()) {
151 Operation* owner = use.getOwner();
152 if (auto sharding = llvm::dyn_cast<TF::XlaShardingOp>(owner))
153 return sharding._XlaSharding();
154
155 if (llvm::isa<TF::IdentityOp, TF::CastOp, TF::ReadVariableOp>(owner)) {
156 next_values_to_visit.push_back(use.getOwner()->getResult(0));
157 continue;
158 }
159
160 if (auto call_op = llvm::dyn_cast<CallOpInterface>(owner)) {
161 func::FuncOp func =
162 llvm::dyn_cast<func::FuncOp>(call_op.resolveCallable());
163 if (!func) continue;
164 next_values_to_visit.push_back(
165 func.getArgument(use.getOperandNumber()));
166 }
167 }
168 }
169
170 values_to_visit.swap(next_values_to_visit);
171 }
172
173 return llvm::None;
174 }
175
176 // Extracts sharding configurations for all inputs by parsing XlaSharding/
177 // TPUPartitionedInput op connected to the operands/arguments. If argument to
178 // the `cluster_func` directly feeds into another function call op, then
179 // recursively walk the function definition to find the connected XlaSharding
180 // op.
IdentifyXlaShardingForComputationInputs(StringRef logical_core_0_sharding,bool use_spmd,bool infer_from_computation,tf_device::ClusterFuncOp cluster_func,func::FuncOp func,Builder * builder,llvm::SmallVectorImpl<llvm::StringRef> & sharding_for_args)181 void IdentifyXlaShardingForComputationInputs(
182 StringRef logical_core_0_sharding, bool use_spmd,
183 bool infer_from_computation, tf_device::ClusterFuncOp cluster_func,
184 func::FuncOp func, Builder* builder,
185 llvm::SmallVectorImpl<llvm::StringRef>& sharding_for_args) {
186 // Look up function definition from module.
187 Block& function_block = func.front();
188
189 sharding_for_args.reserve(function_block.getNumArguments());
190
191 // Iterate through operands of `cluster_func`.
192 // The computation operand can either be:
193 // 1) a TPUPartitionedInput Op if the input has a non-resource type;
194 // 2) a ReadVariableOp else.
195 //
196 // Replicate sharding is used if `use_spmd` is set.
197 //
198 // Iterate through input arguments to the entry block of
199 // tf_device.ClusterFunc. For input ops, look for XlaSharding ops.
200 // XlaSharding ops can:
201 // 1) Directly follow the input argument if input argument has non-resource
202 // types.
203 // 2) Follow ReadVariableOp if the input type is of resource type.
204 // 3) Follow IdentityOp or CastOp after above cases (1), (2).
205 //
206 // Sharding configurations are added to the tf_device.ClusterFunc as an
207 // attribute and the function as an argument attribute.
208 for (auto operand_and_arg :
209 llvm::zip(cluster_func.operands(), function_block.getArguments())) {
210 Value operand = std::get<0>(operand_and_arg);
211 BlockArgument arg = std::get<1>(operand_and_arg);
212
213 if (auto operand_sharding = GetXlaShardingFromOperand(operand)) {
214 sharding_for_args.push_back(operand_sharding.getValue());
215 continue;
216 }
217
218 if (infer_from_computation) {
219 auto arg_sharding = GetXlaShardingFromArg(arg);
220 if (arg_sharding) {
221 sharding_for_args.push_back(arg_sharding.getValue());
222 continue;
223 }
224 }
225
226 if (use_spmd && !IsMaximalVariable(operand)) {
227 // If XLA SPMD is enabled, host variables or non-variable per-replica
228 // inputs should take on replicate sharding, so that every device gets the
229 // whole tensor(s) (and can slice them up later). Exclude device
230 // variables, which always should take maximal sharding.
231 sharding_for_args.push_back(kReplicateSharding);
232 continue;
233 }
234
235 // Otherwise, default to maximal sharding core 0.
236 sharding_for_args.push_back(logical_core_0_sharding);
237 }
238 }
239
240 // Returns XLA sharding from TPUPartitionedOutput or TPUPartitionedInput (via
241 // AssignVariableOp/resource write) op connected to a `tf_device.cluster_func`
242 // result value.
GetXlaShardingFromResult(Value value)243 llvm::Optional<llvm::StringRef> GetXlaShardingFromResult(Value value) {
244 if (!value.hasOneUse()) return llvm::None;
245
246 Operation* user = *value.getUsers().begin();
247 if (auto partitioned_output =
248 llvm::dyn_cast<TF::TPUPartitionedOutputOp>(user))
249 return partitioned_output._XlaSharding();
250
251 if (auto assign_var = llvm::dyn_cast<TF::AssignVariableOp>(user))
252 if (auto partitioned_input =
253 assign_var.resource().getDefiningOp<TF::TPUPartitionedInputOp>())
254 return partitioned_input._XlaSharding();
255
256 return llvm::None;
257 }
258
259 // Looks up arg->retval aliases for every argument, and builds a reverse map.
ExtractAliases(func::FuncOp func,llvm::SmallVectorImpl<int> & aliases)260 void ExtractAliases(func::FuncOp func, llvm::SmallVectorImpl<int>& aliases) {
261 aliases.resize(func.getNumResults(), -1);
262 for (int i = 0; i < func.getNumArguments(); i++) {
263 if (auto v = func.getArgAttrOfType<mlir::IntegerAttr>(i, kAliasingAttr)) {
264 int retval_index = v.getInt();
265 if (retval_index >= 0 && retval_index < aliases.size()) {
266 aliases[retval_index] = i;
267 }
268 }
269 }
270 }
271
272 // Returns XLA sharding from argument connected via tf.aliasing_output.
GetXlaShardingFromAlias(Value value,llvm::SmallVectorImpl<int> & aliases,const llvm::SmallVectorImpl<llvm::StringRef> & sharding_for_args)273 llvm::Optional<StringRef> GetXlaShardingFromAlias(
274 Value value, llvm::SmallVectorImpl<int>& aliases,
275 const llvm::SmallVectorImpl<llvm::StringRef>& sharding_for_args) {
276 int retval_index = value.cast<OpResult>().getResultNumber();
277 if (retval_index >= 0 && retval_index < aliases.size()) {
278 int arg_index = aliases[retval_index];
279 if (arg_index >= 0 && arg_index < sharding_for_args.size()) {
280 return sharding_for_args[arg_index];
281 }
282 }
283 return llvm::None;
284 }
285
286 // Returns XLA sharding from XlaSharding op connected to a result value.
287 // XlaSharding op may be directly connected to output but it may also be
288 // followed by Identity or simple arithmetic ops. In case where bfloat16 type is
289 // used, we might see a Cast op.
290 //
291 // TODO(hongjunchoi): Add logic to parse XlaSharding op inside control flow (If,
292 // Case, While) ops and Caller argument values.
293 // TODO(hongjunchoi): Consider explicitly checking op patterns to detect sharded
294 // inputs.
GetXlaShardingFromRetval(Value value)295 llvm::Optional<StringRef> GetXlaShardingFromRetval(Value value) {
296 llvm::SmallPtrSet<Value, 4> visited_values;
297 llvm::SmallVector<Value, 4> values_to_visit;
298 values_to_visit.push_back(value);
299
300 while (!values_to_visit.empty()) {
301 Value value_to_visit = values_to_visit.pop_back_val();
302
303 if (!visited_values.insert(value_to_visit).second) {
304 continue;
305 }
306
307 Operation* def = value_to_visit.getDefiningOp();
308 if (!def) {
309 continue;
310 }
311
312 if (auto sharding = llvm::dyn_cast_or_null<TF::XlaShardingOp>(def))
313 return sharding._XlaSharding();
314
315 if (auto sharding = def->getAttrOfType<StringAttr>("_XlaSharding")) {
316 return sharding.strref();
317 }
318
319 if ( // Cast, real/imag, etc.
320 def->hasTrait<mlir::OpTrait::SameOperandsAndResultShape>() ||
321 // Exp, ceil, etc.
322 def->hasTrait<mlir::OpTrait::SameOperandsAndResultType>() ||
323 // Identity
324 def->hasTrait<mlir::OpTrait::TF::OperandsSameAsResultsTypeOrRef>() ||
325 // AddV2, Sub, etc.
326 (def->hasTrait<
327 mlir::OpTrait::TF::SameOperandsAndResultElementTypeResolveRef>() &&
328 def->hasTrait<mlir::OpTrait::TF::CwiseBinary>())) {
329 for (auto operand : def->getOperands()) {
330 values_to_visit.push_back(operand);
331 }
332 continue;
333 }
334
335 if (auto call_op = llvm::dyn_cast_or_null<CallOpInterface>(def)) {
336 func::FuncOp func =
337 llvm::dyn_cast<func::FuncOp>(call_op.resolveCallable());
338 if (!func) continue;
339 value_to_visit = func.front().getTerminator()->getOperand(
340 value_to_visit.cast<OpResult>().getResultNumber());
341 values_to_visit.push_back(value_to_visit);
342 continue;
343 }
344 }
345
346 return llvm::None;
347 }
348
349 // Extracts sharding configurations for all outputs by parsing XlaSharding/
350 // TPUPartitionedOutput op connected to the retvals/results.
IdentifyXlaShardingForComputationOutputs(StringRef logical_core_0_sharding,bool use_spmd,bool infer_from_computation,tf_device::ClusterFuncOp cluster_func,func::FuncOp func,Builder * builder,const llvm::SmallVectorImpl<llvm::StringRef> & sharding_for_args,llvm::SmallVectorImpl<llvm::StringRef> & sharding_for_rets)351 void IdentifyXlaShardingForComputationOutputs(
352 StringRef logical_core_0_sharding, bool use_spmd,
353 bool infer_from_computation, tf_device::ClusterFuncOp cluster_func,
354 func::FuncOp func, Builder* builder,
355 const llvm::SmallVectorImpl<llvm::StringRef>& sharding_for_args,
356 llvm::SmallVectorImpl<llvm::StringRef>& sharding_for_rets) {
357 Block& function_block = func.front();
358 Operation* terminator = function_block.getTerminator();
359 sharding_for_rets.reserve(terminator->getNumOperands());
360
361 llvm::SmallVector<int, 8> aliases; // maps return value index to arg index
362 ExtractAliases(func, aliases);
363
364 // Iterate through results of `cluster_func`. For output ops, look for
365 // TPUPartitionedOutput ops.
366 //
367 // Replicate sharding is used if `use_spmd` is set.
368 //
369 // Iterate through operands of the terminator. If the preceding op is
370 // XlaShardingOp, then the provided sharding configuration is added to the
371 // tf_device.ClusterFunc as an attribute and the function as a result
372 // attribute.
373 for (auto result_and_retval :
374 llvm::zip(cluster_func.results(), terminator->getOpOperands())) {
375 Value result = std::get<0>(result_and_retval);
376 OpOperand& retval = std::get<1>(result_and_retval);
377
378 if (auto result_sharding = GetXlaShardingFromResult(result)) {
379 sharding_for_rets.push_back(result_sharding.getValue());
380 continue;
381 }
382
383 if (auto from_alias =
384 GetXlaShardingFromAlias(result, aliases, sharding_for_args)) {
385 sharding_for_rets.push_back(from_alias.getValue());
386 continue;
387 }
388
389 if (infer_from_computation) {
390 if (auto retval_sharding = GetXlaShardingFromRetval(retval.get())) {
391 sharding_for_rets.push_back(retval_sharding.getValue());
392 continue;
393 }
394 }
395
396 if (use_spmd) {
397 // If XLA SPMD is enabled, we default to replicate sharding. This way,
398 // all devices get the whole tensor(s), but if there's an XlaSharding op
399 // deeper in the function, they can use dynamic-slice to slice off their
400 // part of the computation.
401 sharding_for_rets.push_back(kReplicateSharding);
402 continue;
403 }
404
405 // Otherwise, default to maximal sharding core 0.
406 sharding_for_rets.push_back(logical_core_0_sharding);
407 }
408 }
409
410 // Extracts input/output sharding configuration of `cluster_func` by parsing
411 // XlaSharding ops inside the `cluster_func`.
IdentifyXlaShardingForTPUComputation(Builder * builder,tf_device::ClusterFuncOp cluster_func)412 void IdentifyXlaShardingForTPUComputation(
413 Builder* builder, tf_device::ClusterFuncOp cluster_func) {
414 // Look up function definition from module.
415 func::FuncOp func =
416 cluster_func->getParentOfType<ModuleOp>().lookupSymbol<func::FuncOp>(
417 cluster_func.func());
418
419 // By default inputs/outputs have maximal sharding and are assigned to logical
420 // core 0 if no sharding is defined.
421 const std::string logical_core_0_sharding =
422 xla::sharding_builder::AssignDevice(0).SerializeAsString();
423
424 bool use_spmd = false;
425 if (auto use_spmd_attr = cluster_func->getAttrOfType<BoolAttr>(kUseSpmdAttr))
426 use_spmd = use_spmd_attr.getValue();
427
428 llvm::SmallVector<llvm::StringRef, 8> sharding_for_args;
429 IdentifyXlaShardingForComputationInputs(logical_core_0_sharding, use_spmd,
430 /*infer_from_computation=*/true,
431 cluster_func, func, builder,
432 sharding_for_args);
433
434 llvm::SmallVector<llvm::StringRef, 8> sharding_for_rets;
435 IdentifyXlaShardingForComputationOutputs(
436 logical_core_0_sharding, use_spmd, /*infer_from_computation=*/true,
437 cluster_func, func, builder, sharding_for_args, sharding_for_rets);
438
439 auto has_maximal_sharding = [](llvm::StringRef sharding_string) -> bool {
440 xla::OpSharding sharding;
441 sharding.ParseFromString(sharding_string.str());
442 return sharding.type() == xla::OpSharding::MAXIMAL;
443 };
444
445 // XLA SPMD only supports cases where all inputs/outputs exist on every
446 // partition (sharded or replicated). If any of the inputs/outputs have
447 // maximal sharding, then fallback to MPMD. Also fall back if any of the
448 // shardings aren't compatible with the rank of their tensor.
449 if ((use_spmd && (absl::c_any_of(sharding_for_args, has_maximal_sharding) ||
450 absl::c_any_of(sharding_for_rets, has_maximal_sharding))) ||
451 failed(VerifyShardings(func, sharding_for_args, sharding_for_rets))) {
452 LOG(WARNING) << "XLA SPMD only supports cases where all inputs/outputs "
453 "exist on every partition (sharded or replicated). If any "
454 "of the inputs/outputs have maximal sharding, then "
455 "fallback to MPMD.";
456 sharding_for_args.clear();
457 sharding_for_rets.clear();
458 cluster_func->setAttr(kUseSpmdAttr, builder->getBoolAttr(false));
459
460 IdentifyXlaShardingForComputationInputs(
461 logical_core_0_sharding,
462 /*use_spmd=*/false, /*infer_from_computation=*/false, cluster_func,
463 func, builder, sharding_for_args);
464 IdentifyXlaShardingForComputationOutputs(
465 logical_core_0_sharding,
466 /*use_spmd=*/false, /*infer_from_computation=*/false, cluster_func,
467 func, builder, sharding_for_args, sharding_for_rets);
468 }
469
470 // Update sharding on function arguments and returns.
471 Block& function_block = func.front();
472 for (auto sharding_and_arg :
473 llvm::zip(sharding_for_args, function_block.getArguments())) {
474 StringRef sharding = std::get<0>(sharding_and_arg);
475 BlockArgument arg = std::get<1>(sharding_and_arg);
476 func.setArgAttr(arg.getArgNumber(), kShardingAttr,
477 builder->getStringAttr(sharding));
478 }
479
480 Operation* terminator = function_block.getTerminator();
481 for (auto sharding_and_retval :
482 llvm::zip(sharding_for_rets, terminator->getOpOperands())) {
483 StringRef sharding = std::get<0>(sharding_and_retval);
484 OpOperand& retval = std::get<1>(sharding_and_retval);
485 func.setResultAttr(retval.getOperandNumber(), kShardingAttr,
486 builder->getStringAttr(sharding));
487 }
488
489 // Update input/output sharding attributes on tf_device.cluster_func op.
490 cluster_func->setAttr(tensorflow::kInputShardingAttr,
491 builder->getStrArrayAttr(sharding_for_args));
492 cluster_func->setAttr(tensorflow::kOutputShardingAttr,
493 builder->getStrArrayAttr(sharding_for_rets));
494 }
495
runOnOperation()496 void TPUShardingIdentificationPass::runOnOperation() {
497 Builder builder(getOperation().getContext());
498
499 getOperation().walk([&](tf_device::ClusterFuncOp cluster_func) {
500 IdentifyXlaShardingForTPUComputation(&builder, cluster_func);
501 });
502 }
503
504 } // anonymous namespace
505
CreateTPUShardingIdentificationPass()506 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUShardingIdentificationPass() {
507 return std::make_unique<TPUShardingIdentificationPass>();
508 }
509
510 } // namespace TFTPU
511 } // namespace mlir
512