1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.h"
17
18 #include <functional>
19 #include <utility>
20
21 #include "mlir/IR/BuiltinAttributes.h"
22 #include "mlir/IR/Operation.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/StringSet.h"
26 #include "tensorflow/compiler/jit/flags.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
30 #include "tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.h"
31 #include "tensorflow/compiler/xla/mlir/utils/runtime/constraints.h"
32
33 namespace tensorflow {
34
35 using mlir::failure;
36 using mlir::LogicalResult;
37 using mlir::Operation;
38 using mlir::success;
39 using mlir::TensorType;
40 using mlir::Type;
41 using mlir::Value;
42
43 using mlir::TFDevice::Cluster;
44 using mlir::TFDevice::ClusteringPolicy;
45 using mlir::TFDevice::ClusteringPolicySet;
46 using mlir::TFDevice::ValueConstraint;
47 using mlir::TFDevice::ValuesConstraintSet;
48
49 using mlir::TF::_FusedMatMulOp;
50 using mlir::TF::BatchMatMulV2Op;
51 using mlir::TF::BroadcastToOp;
52 using mlir::TF::ConcatV2Op;
53 using mlir::TF::ConstOp;
54 using mlir::TF::ExpandDimsOp;
55 using mlir::TF::FillOp;
56 using mlir::TF::MatMulOp;
57 using mlir::TF::OneHotOp;
58 using mlir::TF::PackOp;
59 using mlir::TF::RangeOp;
60 using mlir::TF::ReshapeOp;
61 using mlir::TF::ShapeOp;
62 using mlir::TF::SliceOp;
63 using mlir::TF::SqueezeOp;
64 using mlir::TF::StopGradientOp;
65 using mlir::TF::StridedSliceOp;
66 using mlir::TF::TransposeOp;
67
68 namespace {
69
70 // A set of clustering constraints that allow TF -> JitRt compilation pipeline
71 // to lower Tensorflow operations to MHLO and then to Linalg. Tensorflow
72 // dynamism is not fully representable at Linalg level, so by providing a
73 // clustering policy we ensure that we can successfully compile all clustered
74 // operations (we have enough static information to lower to MHLO, or build
75 // static Linalg indexing maps).
76 //
77 // Some of these constraints gets resolved at constant folding time, and
78 // operations are completely removed from the IR, and some constraints just
79 // enable TF->MHLO or MHLO->Linalg lowering.
80
81 // Returns true if all types are supported by the Tensorflow -> JitRt
82 // compilation pipeline and TFRT JIT runtime integration (see jitrt.h).
83 template <typename TypeRange>
IsSupportedDataTypes(TypeRange && types)84 static bool IsSupportedDataTypes(TypeRange&& types) {
85 return llvm::all_of(types, [](Type type) -> bool {
86 if (auto tensor = type.dyn_cast<TensorType>()) {
87 auto elt_type = tensor.getElementType();
88 return elt_type.isF32() || elt_type.isInteger(1) ||
89 elt_type.isInteger(32) || elt_type.isInteger(64);
90 }
91 return false;
92 });
93 }
94
IsSupportedOperandTypes(Operation * op)95 static bool IsSupportedOperandTypes(Operation* op) {
96 return IsSupportedDataTypes(op->getOperandTypes());
97 }
98
IsSupportedResultTypes(Operation * op)99 static bool IsSupportedResultTypes(Operation* op) {
100 return IsSupportedDataTypes(op->getResultTypes());
101 }
102
IsSupportedOperandAndResultTypes(Operation * op)103 static bool IsSupportedOperandAndResultTypes(Operation* op) {
104 return IsSupportedOperandTypes(op) && IsSupportedResultTypes(op);
105 }
106
107 // Clustering policy for a specific Tensorflow operation type that verifies
108 // that operation operands and results data types are supported.
109 template <typename OpTy>
110 class TensorflowOpClusteringPolicy : public ClusteringPolicy {
111 public:
MatchAndUpdateConstraints(Operation * operation,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const112 LogicalResult MatchAndUpdateConstraints(
113 Operation* operation, const ValuesConstraintSet& results,
114 ValuesConstraintSet& operands) const final {
115 auto op = mlir::dyn_cast<OpTy>(operation);
116 if (op && IsSupportedOperandAndResultTypes(op))
117 return MatchAndUpdateConstraints(op, results, operands);
118 return failure();
119 }
120
121 virtual LogicalResult MatchAndUpdateConstraints(
122 OpTy op, const ValuesConstraintSet& results,
123 ValuesConstraintSet& operands) const = 0;
124 };
125
126 // -------------------------------------------------------------------------- //
127 // Default clustering policy for TF -> JitRt compilation.
128 // -------------------------------------------------------------------------- //
129
130 // Default clustering policy for Tensorflow -> TFRT JIT compilation propagates
131 // the most restrictive constraint from the results to all operands. If results
132 // do not have any constraints it adds default constraint to all operands if it
133 // is provided, otherwise just returns `success` without adding any constraints.
134 class DefaultClusteringPolicy : public ClusteringPolicy {
135 public:
DefaultClusteringPolicy(std::function<bool (Operation *)> filter,llvm::Optional<ValueConstraint> default_constraint=llvm::None)136 explicit DefaultClusteringPolicy(
137 std::function<bool(Operation*)> filter,
138 llvm::Optional<ValueConstraint> default_constraint = llvm::None)
139 : filter_(std::move(filter)), default_constraint_(default_constraint) {}
140
141 LogicalResult MatchAndUpdateConstraints(
142 Operation* op, const ValuesConstraintSet& results,
143 ValuesConstraintSet& operands) const final;
144
145 private:
146 // A filter for operations that are supported.
147 std::function<bool(Operation*)> filter_;
148 // Default constraint for all operands.
149 llvm::Optional<ValueConstraint> default_constraint_;
150 };
151
152 template <typename OpTy>
153 class OpDefaultClusteringPolicy : public DefaultClusteringPolicy {
154 public:
OpDefaultClusteringPolicy(llvm::Optional<ValueConstraint> default_constraint=llvm::None)155 explicit OpDefaultClusteringPolicy(
156 llvm::Optional<ValueConstraint> default_constraint = llvm::None)
157 : DefaultClusteringPolicy(
158 [](Operation* op) -> bool { return mlir::isa<OpTy>(op); },
159 default_constraint) {}
160 };
161
MatchAndUpdateConstraints(Operation * op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const162 LogicalResult DefaultClusteringPolicy::MatchAndUpdateConstraints(
163 Operation* op, const ValuesConstraintSet& results,
164 ValuesConstraintSet& operands) const {
165 if (!filter_(op)) return failure();
166
167 if (!IsSupportedOperandAndResultTypes(op)) return failure();
168
169 // Find the most restrictive constraint from the operation results.
170 llvm::Optional<ValueConstraint> default_constraint = default_constraint_;
171
172 for (mlir::Value result : op->getResults()) {
173 if (auto result_constraint = results.GetConstraint(result)) {
174 // TODO(ezhulenev): We can safely propagate value constraints if we know
175 // that the value is an integer scalar or a small vector, however in
176 // practice all values that we are interested in are defined by constant
177 // operations directly. Revisit if this becomes a problem.
178 if (*result_constraint == ValueConstraint::kValue) return failure();
179
180 default_constraint = default_constraint.has_value()
181 ? Merge(*default_constraint, *result_constraint)
182 : *result_constraint;
183 }
184 }
185
186 // No constraints to propagate.
187 if (!default_constraint.has_value()) return success();
188
189 // Propage constraint to all operands.
190 for (unsigned i = 0; i < op->getNumOperands(); ++i)
191 operands.Insert(op->getOperand(i), *default_constraint);
192 return success();
193 }
194
195 // -------------------------------------------------------------------------- //
196 // tf.BatchMatMulV2
197 // -------------------------------------------------------------------------- //
198
199 class BatchMatMulV2OpClusteringPolicy
200 : public OpDefaultClusteringPolicy<BatchMatMulV2Op> {};
201
202 // -------------------------------------------------------------------------- //
203 // tf.BroadcastTo
204 // -------------------------------------------------------------------------- //
205
206 class BroadcastToOpClusteringPolicy
207 : public TensorflowOpClusteringPolicy<BroadcastToOp> {
MatchAndUpdateConstraints(BroadcastToOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const208 LogicalResult MatchAndUpdateConstraints(
209 BroadcastToOp op, const ValuesConstraintSet& results,
210 ValuesConstraintSet& operands) const final {
211 // Only ranked inputs are supported.
212 operands.Insert(op.input(), ValueConstraint::kRank);
213
214 if (auto result_constraint = results.GetConstraint(op.getResult())) {
215 if (*result_constraint == ValueConstraint::kValue) return failure();
216 // For a static output shape we need a constant shape operand.
217 if (*result_constraint == ValueConstraint::kShape) {
218 operands.Insert(op.shape(), ValueConstraint::kValue);
219 return success();
220 }
221 }
222
223 // Producing a ranked output requires a known shape for the shape operand.
224 operands.Insert(op.shape(), ValueConstraint::kShape);
225
226 return success();
227 }
228 };
229
230 // -------------------------------------------------------------------------- //
231 // Cwise Binary Operations.
232 // -------------------------------------------------------------------------- //
233
234 class CwiseBinaryOpClusteringPolicy : public DefaultClusteringPolicy {
235 public:
CwiseBinaryOpClusteringPolicy()236 CwiseBinaryOpClusteringPolicy()
237 : DefaultClusteringPolicy(IsBinaryOp(), ValueConstraint::kRank) {}
238
239 private:
240 // TODO(ezhulenev): Use mlir::isa<>() to filter operations.
IsBinaryOp()241 std::function<bool(Operation* op)> IsBinaryOp() {
242 llvm::StringSet<> binary_ops = {
243 "tf.Add",
244 "tf.AddV2",
245 "tf.ApproximateEqual",
246 "tf.Atan2",
247 "tf.BiasAdd",
248 "tf.BitwiseAnd",
249 "tf.BitwiseOr",
250 "tf.BitwiseXor",
251 "tf.Div",
252 "tf.DivNoNan",
253 "tf.Equal",
254 "tf.FloorDiv",
255 "tf.FloorMod",
256 "tf.Greater",
257 "tf.GreaterEqual",
258 "tf.Less",
259 "tf.LessEqual",
260 "tf.LogicalAnd",
261 "tf.LogicalOr",
262 "tf.Maximum",
263 "tf.Minimum",
264 "tf.Mod",
265 "tf.Mul",
266 "tf.MulNoNan",
267 "tf.NotEqual",
268 "tf.Pow",
269 "tf.RealDiv",
270 "tf.SquaredDifference",
271 "tf.Sub",
272 "tf.TruncateDiv",
273 "tf.Xdivy",
274 "tf.Xlogy",
275 };
276 return [binary_ops = std::move(binary_ops)](Operation* op) {
277 return binary_ops.contains(op->getName().getStringRef());
278 };
279 }
280 };
281
282 // -------------------------------------------------------------------------- //
283 // Cwise Unary Operations.
284 // -------------------------------------------------------------------------- //
285
286 class CwiseUnaryOpClusteringPolicy : public DefaultClusteringPolicy {
287 public:
CwiseUnaryOpClusteringPolicy()288 CwiseUnaryOpClusteringPolicy()
289 : DefaultClusteringPolicy(IsUnaryOp(), ValueConstraint::kRank) {}
290
291 private:
IsUnaryOp()292 std::function<bool(Operation* op)> IsUnaryOp() {
293 // TODO(ezhulenev): Use mlir::isa<>() to filter operations.
294 llvm::StringSet<> unary_ops = {
295 "tf.Abs", "tf.Acos", "tf.Acosh", "tf.Asin",
296 "tf.Asinh", "tf.Atan", "tf.Atanh", "tf.Cast",
297 "tf.Ceil", "tf.ClipByValue", "tf.ComplexAbs", "tf.Conj",
298 "tf.Cos", "tf.Cosh", "tf.Elu", "tf.Erf",
299 "tf.Exp", "tf.Floor", "tf.Inv", "tf.Invert",
300 "tf.IsFinite", "tf.IsInf", "tf.IsNan", "tf.LeakyRelu",
301 "tf.Log", "tf.Log1p", "tf.LogicalNot", "tf.Neg",
302 "tf.Real", "tf.Reciprocal", "tf.Relu", "tf.Relu6",
303 "tf.Rint", "tf.Round", "tf.Rsqrt", "tf.Selu",
304 "tf.Sigmoid", "tf.Sign", "tf.Sin", "tf.Sinh",
305 "tf.Softplus", "tf.Softsign", "tf.Sqrt", "tf.Square",
306 "tf.Tan", "tf.Tanh", "tf.ZerosLike",
307 };
308 return [unary_ops = std::move(unary_ops)](Operation* op) {
309 return unary_ops.contains(op->getName().getStringRef());
310 };
311 }
312 };
313
314 // -------------------------------------------------------------------------- //
315 // Cwise Ternary Operations.
316 // -------------------------------------------------------------------------- //
317
318 class CwiseTernaryOpClusteringPolicy : public DefaultClusteringPolicy {
319 public:
CwiseTernaryOpClusteringPolicy()320 CwiseTernaryOpClusteringPolicy()
321 : DefaultClusteringPolicy(IsTernaryOp(), ValueConstraint::kRank) {}
322
323 private:
IsTernaryOp()324 std::function<bool(Operation* op)> IsTernaryOp() {
325 return [](Operation* op) {
326 return mlir::isa<mlir::TF::SelectOp, mlir::TF::SelectV2Op>(op);
327 };
328 }
329 };
330
331 // -------------------------------------------------------------------------- //
332 // Reduction Operations.
333 // -------------------------------------------------------------------------- //
334
335 // Clustering policy for Tensorflow reduction operations:
336 // - shape constraint can be propagated from the result to the input
337 // - reduction indices value must be known at compile time
338 //
339 // All operations that use this policy must have two operands (input and
340 // reduction indices) and a single result.
341 class ReductionOpClusteringPolicy : public ClusteringPolicy {
342 public:
343 LogicalResult MatchAndUpdateConstraints(
344 Operation* op, const ValuesConstraintSet& results,
345 ValuesConstraintSet& operands) const final;
346
347 private:
348 bool IsSupported(Operation* op) const;
349 };
350
MatchAndUpdateConstraints(Operation * op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const351 LogicalResult ReductionOpClusteringPolicy::MatchAndUpdateConstraints(
352 Operation* op, const ValuesConstraintSet& results,
353 ValuesConstraintSet& operands) const {
354 // Verify that the operation is a reduction with supported operands
355 // and results data types.
356 if (!IsSupported(op) || !IsSupportedOperandAndResultTypes(op))
357 return failure();
358
359 assert(op->getNumOperands() == 2 && "expected two operands");
360 assert(op->getNumResults() == 1 && "expected one result");
361
362 // Propagate constraint from the result to the input.
363 if (auto result_constraint = results.GetConstraint(op->getResult(0))) {
364 if (*result_constraint == ValueConstraint::kValue) return failure();
365 operands.Insert(op->getOperand(0), *result_constraint);
366 } else {
367 operands.Insert(op->getOperand(0), ValueConstraint::kRank);
368 }
369
370 // Reduction indices must be known at compile time.
371 operands.Insert(op->getOperand(1), ValueConstraint::kValue);
372
373 return success();
374 }
375
IsSupported(Operation * op) const376 bool ReductionOpClusteringPolicy::IsSupported(Operation* op) const {
377 return mlir::isa<mlir::TF::AllOp, //
378 mlir::TF::AnyOp, //
379 mlir::TF::MaxOp, //
380 mlir::TF::MeanOp, //
381 mlir::TF::MinOp, //
382 mlir::TF::ProdOp, //
383 mlir::TF::SumOp>(op);
384 }
385
386 // -------------------------------------------------------------------------- //
387 // tf.ConcatV2
388 // -------------------------------------------------------------------------- //
389
390 class ConcatV2OpClusteringPolicy
391 : public TensorflowOpClusteringPolicy<ConcatV2Op> {
MatchAndUpdateConstraints(ConcatV2Op op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const392 LogicalResult MatchAndUpdateConstraints(
393 ConcatV2Op op, const ValuesConstraintSet& results,
394 ValuesConstraintSet& operands) const final {
395 auto result_constraint = results.GetConstraint(op->getResult(0));
396 if (result_constraint && *result_constraint == ValueConstraint::kValue)
397 return failure();
398
399 // Propagate constraint from the result to the input. All inputs always need
400 // a known rank.
401 for (auto value : op.values()) {
402 operands.Insert(value,
403 result_constraint.getValueOr(ValueConstraint::kRank));
404 }
405
406 // Force axis to be a constant.
407 operands.Insert(op.axis(), ValueConstraint::kValue);
408
409 return success();
410 }
411 };
412
413 // -------------------------------------------------------------------------- //
414 // tf.Const
415 // -------------------------------------------------------------------------- //
416
417 class ConstOpClusteringPolicy : public TensorflowOpClusteringPolicy<ConstOp> {
MatchAndUpdateConstraints(ConstOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const418 LogicalResult MatchAndUpdateConstraints(
419 ConstOp op, const ValuesConstraintSet& results,
420 ValuesConstraintSet& operands) const final {
421 // We cluster constant operation only if it is required to resolve some of
422 // the constraints.
423 auto result_constraint = results.GetConstraint(op.getResult());
424 if (!result_constraint.has_value()) return failure();
425
426 return IsCompilableConstant(op.value());
427 }
428 };
429
430 // -------------------------------------------------------------------------- //
431 // tf.ExpandDims
432 // -------------------------------------------------------------------------- //
433
434 class ExpandDimsOpClusteringPolicy
435 : public TensorflowOpClusteringPolicy<ExpandDimsOp> {
MatchAndUpdateConstraints(ExpandDimsOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const436 LogicalResult MatchAndUpdateConstraints(
437 ExpandDimsOp op, const ValuesConstraintSet& results,
438 ValuesConstraintSet& operands) const final {
439 // Propagate constraint from the result to the input.
440 if (auto result_constraint = results.GetConstraint(op->getResult(0))) {
441 if (*result_constraint == ValueConstraint::kValue) return failure();
442 operands.Insert(op.input(), *result_constraint);
443 } else {
444 operands.Insert(op.input(), ValueConstraint::kRank);
445 }
446
447 // The inserted dimension must be always known at compile time.
448 operands.Insert(op.dim(), ValueConstraint::kValue);
449
450 return success();
451 }
452 };
453
454 // -------------------------------------------------------------------------- //
455 // tf._FusedMatMul
456 // -------------------------------------------------------------------------- //
457
458 class FusedMatMulOpClusteringPolicy
459 : public TensorflowOpClusteringPolicy<_FusedMatMulOp> {
MatchAndUpdateConstraints(_FusedMatMulOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const460 LogicalResult MatchAndUpdateConstraints(
461 _FusedMatMulOp op, const ValuesConstraintSet& results,
462 ValuesConstraintSet& operands) const final {
463 // Check if the default policy accepts the operation.
464 OpDefaultClusteringPolicy<_FusedMatMulOp> default_policy;
465 if (failed(default_policy.MatchAndUpdateConstraints(op, results, operands)))
466 return failure();
467
468 // Check if we do support a set of fused operations.
469 size_t n = op.fused_ops().size();
470
471 auto fusion =
472 n > 0 ? op.fused_ops()[0].dyn_cast<mlir::StringAttr>() : nullptr;
473 auto activation =
474 n > 1 ? op.fused_ops()[1].dyn_cast<mlir::StringAttr>() : nullptr;
475
476 if ((n > 0 && !fusion) || (n > 1 && !activation)) return failure();
477
478 // TODO(ezhulenev): Update fission pass to support more fusions and
479 // activations.
480
481 // We only support BiasAdd fusion ...
482 if (fusion && fusion.getValue() != "BiasAdd") return failure();
483
484 // ... with Relu activation.
485 if (activation && activation.getValue() != "Relu") return failure();
486
487 return success();
488 }
489 };
490
491 // -------------------------------------------------------------------------- //
492 // tf.Fill
493 // -------------------------------------------------------------------------- //
494
495 class FillOpClusteringPolicy : public TensorflowOpClusteringPolicy<FillOp> {
MatchAndUpdateConstraints(FillOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const496 LogicalResult MatchAndUpdateConstraints(
497 FillOp op, const ValuesConstraintSet& results,
498 ValuesConstraintSet& operands) const final {
499 // Fill operation does not have any default constraints.
500 auto result_constraint = results.GetConstraint(op->getResult(0));
501 if (!result_constraint.has_value()) return success();
502
503 // To know the result shape we need to know the shape operand value.
504 if (*result_constraint == ValueConstraint::kShape)
505 operands.Insert(op.dims(), ValueConstraint::kValue);
506
507 // To know the result rank we need to know the shape operand shape.
508 if (*result_constraint == ValueConstraint::kRank)
509 operands.Insert(op.dims(), ValueConstraint::kShape);
510
511 // Value constraint propagation is not supported.
512 if (*result_constraint == ValueConstraint::kValue) return failure();
513
514 return success();
515 }
516 };
517
518 // -------------------------------------------------------------------------- //
519 // tf.MatMul
520 // -------------------------------------------------------------------------- //
521
522 class MatMulOpClusteringPolicy : public OpDefaultClusteringPolicy<MatMulOp> {};
523
524 // -------------------------------------------------------------------------- //
525 // tf.OneHot
526 // -------------------------------------------------------------------------- //
527
528 class OneHotOpClusteringPolicy : public TensorflowOpClusteringPolicy<OneHotOp> {
MatchAndUpdateConstraints(OneHotOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const529 LogicalResult MatchAndUpdateConstraints(
530 OneHotOp op, const ValuesConstraintSet& results,
531 ValuesConstraintSet& operands) const final {
532 // Value constraint propagation is not supported.
533 if (auto constraint = results.GetConstraint(op.getResult()))
534 if (*constraint == ValueConstraint::kValue) return failure();
535
536 // MHLO lowering needs a static shape for the indices and a constant depth.
537 operands.Insert(op.indices(), ValueConstraint::kShape);
538 operands.Insert(op.depth(), ValueConstraint::kValue);
539
540 return success();
541 }
542 };
543
544 // -------------------------------------------------------------------------- //
545 // tf.Pack
546 // -------------------------------------------------------------------------- //
547
548 class PackOpClusteringPolicy : public OpDefaultClusteringPolicy<PackOp> {};
549
550 // -------------------------------------------------------------------------- //
551 // tf.Range
552 // -------------------------------------------------------------------------- //
553
554 class RangeOpClusteringPolicy : public TensorflowOpClusteringPolicy<RangeOp> {
MatchAndUpdateConstraints(RangeOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const555 LogicalResult MatchAndUpdateConstraints(
556 RangeOp op, const ValuesConstraintSet& results,
557 ValuesConstraintSet& operands) const final {
558 // Range operation does not have any default constraints.
559 auto result_constraint = results.GetConstraint(op.getResult());
560 if (!result_constraint.has_value()) return success();
561
562 // To know the result shape we need the input values.
563 if (*result_constraint == ValueConstraint::kShape) {
564 operands.Insert({op.start(), op.limit(), op.delta()},
565 ValueConstraint::kValue);
566 }
567
568 // Value constraint propagation is not supported.
569 if (*result_constraint == ValueConstraint::kValue) return failure();
570
571 return success();
572 }
573 };
574
575 // -------------------------------------------------------------------------- //
576 // tf.Reshape
577 // -------------------------------------------------------------------------- //
578
579 class ReshapeOpClusteringPolicy
580 : public TensorflowOpClusteringPolicy<ReshapeOp> {
MatchAndUpdateConstraints(ReshapeOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const581 LogicalResult MatchAndUpdateConstraints(
582 ReshapeOp op, const ValuesConstraintSet& results,
583 ValuesConstraintSet& operands) const final {
584 // The runtime only supports ranked tensors.
585 operands.Insert(op.tensor(), ValueConstraint::kRank);
586
587 // Reshape operation does not have any default constraints.
588 auto result_constraint = results.GetConstraint(op.getResult());
589 if (!result_constraint.has_value()) return success();
590
591 // To know the result shape we need to know the shape operand value. We also
592 // require a static shape on the input in case there's a -1 in the shape.
593 if (*result_constraint == ValueConstraint::kShape) {
594 operands.Insert(op.shape(), ValueConstraint::kValue);
595 operands.Insert(op.tensor(), ValueConstraint::kShape);
596 }
597
598 // To know the result rank we need to know the shape operand shape.
599 if (*result_constraint == ValueConstraint::kRank)
600 operands.Insert(op.shape(), ValueConstraint::kShape);
601
602 // Value constraint propagation is not supported.
603 if (*result_constraint == ValueConstraint::kValue) return failure();
604
605 return success();
606 }
607 };
608
609 // -------------------------------------------------------------------------- //
610 // tf.Shape
611 // -------------------------------------------------------------------------- //
612
613 class ShapeOpClusteringPolicy : public TensorflowOpClusteringPolicy<ShapeOp> {
MatchAndUpdateConstraints(ShapeOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const614 LogicalResult MatchAndUpdateConstraints(
615 ShapeOp op, const ValuesConstraintSet& results,
616 ValuesConstraintSet& operands) const final {
617 // Unranked inputs aren't supported by JitRt.
618 operands.Insert(op.input(), ValueConstraint::kRank);
619
620 // Check constraint on the result value.
621 auto result_constraint = results.GetConstraint(op.getResult());
622 if (!result_constraint.has_value()) return success();
623
624 // To know the result shape we need only the rank of the input.
625 if (*result_constraint == ValueConstraint::kShape)
626 operands.Insert(op.input(), ValueConstraint::kRank);
627
628 // To know the result value we need to know the shape of the input.
629 if (*result_constraint == ValueConstraint::kValue)
630 operands.Insert(op.input(), ValueConstraint::kShape);
631
632 return success();
633 }
634 };
635
636 // -------------------------------------------------------------------------- //
637 // tf.Softmax
638 // -------------------------------------------------------------------------- //
639
640 class SoftmaxOpClusteringPolicy : public DefaultClusteringPolicy {
641 public:
SoftmaxOpClusteringPolicy()642 SoftmaxOpClusteringPolicy()
643 : DefaultClusteringPolicy(IsSoftmaxOp(), ValueConstraint::kRank) {}
644
645 private:
IsSoftmaxOp()646 std::function<bool(Operation* op)> IsSoftmaxOp() {
647 return [](Operation* op) {
648 return mlir::isa<mlir::TF::SoftmaxOp, mlir::TF::LogSoftmaxOp>(op);
649 };
650 }
651 };
652
653 // -------------------------------------------------------------------------- //
654 // tf.Squeeze
655 // -------------------------------------------------------------------------- //
656
657 class SqueezeOpClusteringPolicy
658 : public TensorflowOpClusteringPolicy<SqueezeOp> {
MatchAndUpdateConstraints(SqueezeOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const659 LogicalResult MatchAndUpdateConstraints(
660 SqueezeOp op, const ValuesConstraintSet& results,
661 ValuesConstraintSet& operands) const final {
662 // Propagate static shape constraints.
663 auto input_constraint = ValueConstraint::kRank;
664 if (auto result_constraint = results.GetConstraint(op.getResult())) {
665 if (*result_constraint == ValueConstraint::kValue) return failure();
666 input_constraint = *result_constraint;
667 }
668
669 // If squeeze_dims is not present we need a static shape.
670 if (op.squeeze_dims().empty()) input_constraint = ValueConstraint::kShape;
671
672 operands.Insert(op.input(), input_constraint);
673 return success();
674 }
675 };
676
677 // -------------------------------------------------------------------------- //
678 // tf.StopGradient
679 // -------------------------------------------------------------------------- //
680
681 class StopGradientOpClusteringPolicy
682 : public OpDefaultClusteringPolicy<StopGradientOp> {};
683
684 // -------------------------------------------------------------------------- //
685 // tf.Transpose
686 // -------------------------------------------------------------------------- //
687
688 class TransposeOpClusteringPolicy
689 : public TensorflowOpClusteringPolicy<TransposeOp> {
MatchAndUpdateConstraints(TransposeOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const690 LogicalResult MatchAndUpdateConstraints(
691 TransposeOp op, const ValuesConstraintSet& results,
692 ValuesConstraintSet& operands) const final {
693 // Propagate result constraints to the input, at minimum require known rank.
694 if (auto constraint = results.GetConstraint(op.getResult())) {
695 operands.Insert(op.x(), *constraint);
696 } else {
697 operands.Insert(op.x(), ValueConstraint::kRank);
698 }
699
700 // Permutation must be always known at compile time.
701 operands.Insert(op.perm(), ValueConstraint::kValue);
702
703 return success();
704 }
705 };
706
707 // -------------------------------------------------------------------------- //
708 // tf.Slice
709 // -------------------------------------------------------------------------- //
710
711 class SliceOpClusteringPolicy : public TensorflowOpClusteringPolicy<SliceOp> {
MatchAndUpdateConstraints(SliceOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const712 LogicalResult MatchAndUpdateConstraints(
713 SliceOp op, const ValuesConstraintSet& results,
714 ValuesConstraintSet& operands) const final {
715 // Value constraint propagation is not supported.
716 if (auto constraint = results.GetConstraint(op.getResult()))
717 if (*constraint == ValueConstraint::kValue) return failure();
718
719 // We must know the shape of the input.
720 operands.Insert(op.input(), ValueConstraint::kShape);
721
722 // Force begin and size to be constants. The restriction on begin could be
723 // lifted if we know that there are no `-1` sizes.
724 // TODO(kramerb): Revisit this when mhlo.real_dynamic_slice stabilizes.
725 operands.Insert({op.begin(), op.size()}, ValueConstraint::kValue);
726
727 return success();
728 }
729 };
730
731 // -------------------------------------------------------------------------- //
732 // tf.StridedSlice
733 // -------------------------------------------------------------------------- //
734
735 class StridedSliceOpClusteringPolicy
736 : public TensorflowOpClusteringPolicy<StridedSliceOp> {
MatchAndUpdateConstraints(StridedSliceOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const737 LogicalResult MatchAndUpdateConstraints(
738 StridedSliceOp op, const ValuesConstraintSet& results,
739 ValuesConstraintSet& operands) const final {
740 // We must know the shape of the input.
741 operands.Insert(op.input(), ValueConstraint::kShape);
742
743 // And values of operands that control the slice size.
744 operands.Insert({op.begin(), op.end(), op.strides()},
745 ValueConstraint::kValue);
746
747 return success();
748 }
749 };
750
751 // -------------------------------------------------------------------------- //
752 // Gather Operations.
753 // -------------------------------------------------------------------------- //
754
755 class GatherOpClusteringPolicy : public DefaultClusteringPolicy {
756 public:
GatherOpClusteringPolicy()757 GatherOpClusteringPolicy()
758 : DefaultClusteringPolicy(IsGatherOp(), ValueConstraint::kRank) {}
759
760 private:
IsGatherOp()761 std::function<bool(Operation* op)> IsGatherOp() {
762 return [](Operation* op) {
763 return mlir::isa<mlir::TF::GatherNdOp, mlir::TF::GatherV2Op,
764 mlir::TF::GatherOp>(op);
765 };
766 }
767 };
768
769 // -------------------------------------------------------------------------- //
770 // Scatter Operations.
771 // -------------------------------------------------------------------------- //
772
773 class ScatterOpClusteringPolicy : public DefaultClusteringPolicy {
774 public:
ScatterOpClusteringPolicy()775 ScatterOpClusteringPolicy()
776 : DefaultClusteringPolicy(IsScatterOp(), ValueConstraint::kRank) {}
777
778 private:
IsScatterOp()779 std::function<bool(Operation* op)> IsScatterOp() {
780 return [](Operation* op) {
781 return mlir::isa<
782 mlir::TF::ScatterNdOp, mlir::TF::TensorScatterAddOp,
783 mlir::TF::TensorScatterMaxOp, mlir::TF::TensorScatterMinOp,
784 mlir::TF::TensorScatterSubOp, mlir::TF::TensorScatterUpdateOp>(op);
785 };
786 }
787 };
788
789 } // namespace
790
populateTfJitRtClusteringPolicies(ClusteringPolicySet & policies,JitRtClusteringTier tier)791 void populateTfJitRtClusteringPolicies(ClusteringPolicySet& policies,
792 JitRtClusteringTier tier) {
793 // Returns true if the given jitrt compilation tier is enabled.
794 auto is_enabled = [&](JitRtClusteringTier requested) -> bool {
795 return (static_cast<uint8_t>(tier) & static_cast<uint8_t>(requested)) ==
796 static_cast<uint8_t>(requested);
797 };
798
799 if (is_enabled(JitRtClusteringTier::kCwise)) {
800 policies.Add<CwiseBinaryOpClusteringPolicy, //
801 CwiseUnaryOpClusteringPolicy, //
802 CwiseTernaryOpClusteringPolicy, //
803 StopGradientOpClusteringPolicy>();
804 }
805
806 if (is_enabled(JitRtClusteringTier::kTranspose)) {
807 policies.Add<TransposeOpClusteringPolicy>();
808 }
809
810 if (is_enabled(JitRtClusteringTier::kReductions)) {
811 policies.Add<ReductionOpClusteringPolicy>();
812 }
813
814 if (is_enabled(JitRtClusteringTier::kMetadata)) {
815 policies.Add<ExpandDimsOpClusteringPolicy, //
816 ReshapeOpClusteringPolicy, //
817 ShapeOpClusteringPolicy, //
818 SqueezeOpClusteringPolicy>();
819 }
820
821 if (is_enabled(JitRtClusteringTier::kGatherScatter)) {
822 policies.Add<GatherOpClusteringPolicy, //
823 ScatterOpClusteringPolicy>();
824 }
825
826 if (is_enabled(JitRtClusteringTier::kAll)) {
827 policies.Add<BatchMatMulV2OpClusteringPolicy, //
828 BroadcastToOpClusteringPolicy, //
829 ConcatV2OpClusteringPolicy, //
830 FillOpClusteringPolicy, //
831 FusedMatMulOpClusteringPolicy, //
832 MatMulOpClusteringPolicy, //
833 OneHotOpClusteringPolicy, //
834 PackOpClusteringPolicy, //
835 RangeOpClusteringPolicy, //
836 SliceOpClusteringPolicy, //
837 SoftmaxOpClusteringPolicy, //
838 StridedSliceOpClusteringPolicy>();
839 }
840 }
841
populateTfJitRtConstraintsPolicies(ClusteringPolicySet & policies,JitRtClusteringTier tier)842 void populateTfJitRtConstraintsPolicies(ClusteringPolicySet& policies,
843 JitRtClusteringTier tier) {
844 populateTfJitRtClusteringPolicies(policies, tier);
845 policies.Add<ConstOpClusteringPolicy>();
846 }
847
848 // -------------------------------------------------------------------------- //
849 // Helper functions.
850 // -------------------------------------------------------------------------- //
851
IsCompilableConstant(mlir::ElementsAttr value)852 mlir::LogicalResult IsCompilableConstant(mlir::ElementsAttr value) {
853 return success(value.getNumElements() <= 16 &&
854 value.getType().getElementType().isIntOrIndexOrFloat());
855 }
856
IsI1Integer(Type type)857 static bool IsI1Integer(Type type) {
858 return mlir::getElementTypeOrSelf(type).isInteger(1);
859 }
860
IsUnsignedInteger(Type type)861 static bool IsUnsignedInteger(Type type) {
862 return mlir::getElementTypeOrSelf(type).isUnsignedInteger();
863 }
864
VerifyCluster(const Cluster & cluster)865 mlir::LogicalResult VerifyCluster(const Cluster& cluster) {
866 llvm::SmallDenseSet<Operation*> ops;
867 for (Operation* op : cluster.operations) {
868 auto inserted = ops.insert(op);
869 assert(inserted.second && "clustered operations must be unique");
870 (void)inserted;
871 }
872
873 // TODO(ezhulenev): Too large clusters with dynamic shapes can take a very
874 // long time to compile. Skip them for now.
875 if (ops.size() > 20) return failure();
876
877 // TODO(ezhulenev): This is a temporary workaround to disable forming clusters
878 // with known compilation problems.
879 for (Operation* op : ops) {
880 // TODO(b/205714705): Memory layout of `i1` data type is not defined, and
881 // when vectorization is enabled it can lead to crashes.
882 bool has_i1_integers = llvm::any_of(op->getOperandTypes(), IsI1Integer) ||
883 llvm::any_of(op->getResultTypes(), IsI1Integer);
884 if (has_i1_integers && tensorflow::GetJitRtFlags().vectorize)
885 return failure();
886
887 // TODO(b/205905286): Unsigned integers support has a lot of gaps, and
888 // similar to handling `i1` we need a type conversion to signless integers.
889 bool has_unsigned_integers =
890 llvm::any_of(op->getOperandTypes(), IsUnsignedInteger) ||
891 llvm::any_of(op->getResultTypes(), IsUnsignedInteger);
892 if (has_unsigned_integers) return failure();
893 }
894
895 for (auto& pair : cluster.constraints) {
896 Value value = pair.getFirst();
897 ValueConstraint constraint = pair.getSecond();
898
899 // We can satisfy shape and rank constraints on the compiled function
900 // operands.
901 if (constraint == ValueConstraint::kRank ||
902 constraint == ValueConstraint::kShape)
903 continue;
904
905 if (constraint == ValueConstraint::kValue &&
906 xla::runtime::SupportsValueSpecialization(value.getType()))
907 continue;
908
909 Operation* op = value.getDefiningOp();
910 if (!op) return failure(); // we do not support block arguments
911
912 // Operations defined inside the cluster will be constant folded before the
913 // compilation. This property is guaranteed by the clustering policy.
914 if (ops.contains(op)) continue;
915
916 // Small constants will be sunk into the compiled function body.
917 auto const_op = mlir::dyn_cast<mlir::TF::ConstOp>(op);
918 if (!const_op || failed(IsCompilableConstant(const_op.value())))
919 return failure();
920 }
921
922 return success();
923 }
924
925 } // namespace tensorflow
926