xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tfrt/jit/transforms/tf_jitrt_clustering.h"
17 
18 #include <functional>
19 #include <utility>
20 
21 #include "mlir/IR/BuiltinAttributes.h"
22 #include "mlir/IR/Operation.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/StringSet.h"
26 #include "tensorflow/compiler/jit/flags.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
30 #include "tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.h"
31 #include "tensorflow/compiler/xla/mlir/utils/runtime/constraints.h"
32 
33 namespace tensorflow {
34 
35 using mlir::failure;
36 using mlir::LogicalResult;
37 using mlir::Operation;
38 using mlir::success;
39 using mlir::TensorType;
40 using mlir::Type;
41 using mlir::Value;
42 
43 using mlir::TFDevice::Cluster;
44 using mlir::TFDevice::ClusteringPolicy;
45 using mlir::TFDevice::ClusteringPolicySet;
46 using mlir::TFDevice::ValueConstraint;
47 using mlir::TFDevice::ValuesConstraintSet;
48 
49 using mlir::TF::_FusedMatMulOp;
50 using mlir::TF::BatchMatMulV2Op;
51 using mlir::TF::BroadcastToOp;
52 using mlir::TF::ConcatV2Op;
53 using mlir::TF::ConstOp;
54 using mlir::TF::ExpandDimsOp;
55 using mlir::TF::FillOp;
56 using mlir::TF::MatMulOp;
57 using mlir::TF::OneHotOp;
58 using mlir::TF::PackOp;
59 using mlir::TF::RangeOp;
60 using mlir::TF::ReshapeOp;
61 using mlir::TF::ShapeOp;
62 using mlir::TF::SliceOp;
63 using mlir::TF::SqueezeOp;
64 using mlir::TF::StopGradientOp;
65 using mlir::TF::StridedSliceOp;
66 using mlir::TF::TransposeOp;
67 
68 namespace {
69 
70 // A set of clustering constraints that allow TF -> JitRt compilation pipeline
71 // to lower Tensorflow operations to MHLO and then to Linalg. Tensorflow
72 // dynamism is not fully representable at Linalg level, so by providing a
73 // clustering policy we ensure that we can successfully compile all clustered
74 // operations (we have enough static information to lower to MHLO, or build
75 // static Linalg indexing maps).
76 //
77 // Some of these constraints gets resolved at constant folding time, and
78 // operations are completely removed from the IR, and some constraints just
79 // enable TF->MHLO or MHLO->Linalg lowering.
80 
81 // Returns true if all types are supported by the Tensorflow -> JitRt
82 // compilation pipeline and TFRT JIT runtime integration (see jitrt.h).
83 template <typename TypeRange>
IsSupportedDataTypes(TypeRange && types)84 static bool IsSupportedDataTypes(TypeRange&& types) {
85   return llvm::all_of(types, [](Type type) -> bool {
86     if (auto tensor = type.dyn_cast<TensorType>()) {
87       auto elt_type = tensor.getElementType();
88       return elt_type.isF32() || elt_type.isInteger(1) ||
89              elt_type.isInteger(32) || elt_type.isInteger(64);
90     }
91     return false;
92   });
93 }
94 
IsSupportedOperandTypes(Operation * op)95 static bool IsSupportedOperandTypes(Operation* op) {
96   return IsSupportedDataTypes(op->getOperandTypes());
97 }
98 
IsSupportedResultTypes(Operation * op)99 static bool IsSupportedResultTypes(Operation* op) {
100   return IsSupportedDataTypes(op->getResultTypes());
101 }
102 
IsSupportedOperandAndResultTypes(Operation * op)103 static bool IsSupportedOperandAndResultTypes(Operation* op) {
104   return IsSupportedOperandTypes(op) && IsSupportedResultTypes(op);
105 }
106 
107 // Clustering policy for a specific Tensorflow operation type that verifies
108 // that operation operands and results data types are supported.
109 template <typename OpTy>
110 class TensorflowOpClusteringPolicy : public ClusteringPolicy {
111  public:
MatchAndUpdateConstraints(Operation * operation,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const112   LogicalResult MatchAndUpdateConstraints(
113       Operation* operation, const ValuesConstraintSet& results,
114       ValuesConstraintSet& operands) const final {
115     auto op = mlir::dyn_cast<OpTy>(operation);
116     if (op && IsSupportedOperandAndResultTypes(op))
117       return MatchAndUpdateConstraints(op, results, operands);
118     return failure();
119   }
120 
121   virtual LogicalResult MatchAndUpdateConstraints(
122       OpTy op, const ValuesConstraintSet& results,
123       ValuesConstraintSet& operands) const = 0;
124 };
125 
126 // -------------------------------------------------------------------------- //
127 // Default clustering policy for TF -> JitRt compilation.
128 // -------------------------------------------------------------------------- //
129 
130 // Default clustering policy for Tensorflow -> TFRT JIT compilation propagates
131 // the most restrictive constraint from the results to all operands. If results
132 // do not have any constraints it adds default constraint to all operands if it
133 // is provided, otherwise just returns `success` without adding any constraints.
134 class DefaultClusteringPolicy : public ClusteringPolicy {
135  public:
DefaultClusteringPolicy(std::function<bool (Operation *)> filter,llvm::Optional<ValueConstraint> default_constraint=llvm::None)136   explicit DefaultClusteringPolicy(
137       std::function<bool(Operation*)> filter,
138       llvm::Optional<ValueConstraint> default_constraint = llvm::None)
139       : filter_(std::move(filter)), default_constraint_(default_constraint) {}
140 
141   LogicalResult MatchAndUpdateConstraints(
142       Operation* op, const ValuesConstraintSet& results,
143       ValuesConstraintSet& operands) const final;
144 
145  private:
146   // A filter for operations that are supported.
147   std::function<bool(Operation*)> filter_;
148   // Default constraint for all operands.
149   llvm::Optional<ValueConstraint> default_constraint_;
150 };
151 
152 template <typename OpTy>
153 class OpDefaultClusteringPolicy : public DefaultClusteringPolicy {
154  public:
OpDefaultClusteringPolicy(llvm::Optional<ValueConstraint> default_constraint=llvm::None)155   explicit OpDefaultClusteringPolicy(
156       llvm::Optional<ValueConstraint> default_constraint = llvm::None)
157       : DefaultClusteringPolicy(
158             [](Operation* op) -> bool { return mlir::isa<OpTy>(op); },
159             default_constraint) {}
160 };
161 
MatchAndUpdateConstraints(Operation * op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const162 LogicalResult DefaultClusteringPolicy::MatchAndUpdateConstraints(
163     Operation* op, const ValuesConstraintSet& results,
164     ValuesConstraintSet& operands) const {
165   if (!filter_(op)) return failure();
166 
167   if (!IsSupportedOperandAndResultTypes(op)) return failure();
168 
169   // Find the most restrictive constraint from the operation results.
170   llvm::Optional<ValueConstraint> default_constraint = default_constraint_;
171 
172   for (mlir::Value result : op->getResults()) {
173     if (auto result_constraint = results.GetConstraint(result)) {
174       // TODO(ezhulenev): We can safely propagate value constraints if we know
175       // that the value is an integer scalar or a small vector, however in
176       // practice all values that we are interested in are defined by constant
177       // operations directly. Revisit if this becomes a problem.
178       if (*result_constraint == ValueConstraint::kValue) return failure();
179 
180       default_constraint = default_constraint.has_value()
181                                ? Merge(*default_constraint, *result_constraint)
182                                : *result_constraint;
183     }
184   }
185 
186   // No constraints to propagate.
187   if (!default_constraint.has_value()) return success();
188 
189   // Propage constraint to all operands.
190   for (unsigned i = 0; i < op->getNumOperands(); ++i)
191     operands.Insert(op->getOperand(i), *default_constraint);
192   return success();
193 }
194 
195 // -------------------------------------------------------------------------- //
196 // tf.BatchMatMulV2
197 // -------------------------------------------------------------------------- //
198 
199 class BatchMatMulV2OpClusteringPolicy
200     : public OpDefaultClusteringPolicy<BatchMatMulV2Op> {};
201 
202 // -------------------------------------------------------------------------- //
203 // tf.BroadcastTo
204 // -------------------------------------------------------------------------- //
205 
206 class BroadcastToOpClusteringPolicy
207     : public TensorflowOpClusteringPolicy<BroadcastToOp> {
MatchAndUpdateConstraints(BroadcastToOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const208   LogicalResult MatchAndUpdateConstraints(
209       BroadcastToOp op, const ValuesConstraintSet& results,
210       ValuesConstraintSet& operands) const final {
211     // Only ranked inputs are supported.
212     operands.Insert(op.input(), ValueConstraint::kRank);
213 
214     if (auto result_constraint = results.GetConstraint(op.getResult())) {
215       if (*result_constraint == ValueConstraint::kValue) return failure();
216       // For a static output shape we need a constant shape operand.
217       if (*result_constraint == ValueConstraint::kShape) {
218         operands.Insert(op.shape(), ValueConstraint::kValue);
219         return success();
220       }
221     }
222 
223     // Producing a ranked output requires a known shape for the shape operand.
224     operands.Insert(op.shape(), ValueConstraint::kShape);
225 
226     return success();
227   }
228 };
229 
230 // -------------------------------------------------------------------------- //
231 // Cwise Binary Operations.
232 // -------------------------------------------------------------------------- //
233 
234 class CwiseBinaryOpClusteringPolicy : public DefaultClusteringPolicy {
235  public:
CwiseBinaryOpClusteringPolicy()236   CwiseBinaryOpClusteringPolicy()
237       : DefaultClusteringPolicy(IsBinaryOp(), ValueConstraint::kRank) {}
238 
239  private:
240   // TODO(ezhulenev): Use mlir::isa<>() to filter operations.
IsBinaryOp()241   std::function<bool(Operation* op)> IsBinaryOp() {
242     llvm::StringSet<> binary_ops = {
243         "tf.Add",
244         "tf.AddV2",
245         "tf.ApproximateEqual",
246         "tf.Atan2",
247         "tf.BiasAdd",
248         "tf.BitwiseAnd",
249         "tf.BitwiseOr",
250         "tf.BitwiseXor",
251         "tf.Div",
252         "tf.DivNoNan",
253         "tf.Equal",
254         "tf.FloorDiv",
255         "tf.FloorMod",
256         "tf.Greater",
257         "tf.GreaterEqual",
258         "tf.Less",
259         "tf.LessEqual",
260         "tf.LogicalAnd",
261         "tf.LogicalOr",
262         "tf.Maximum",
263         "tf.Minimum",
264         "tf.Mod",
265         "tf.Mul",
266         "tf.MulNoNan",
267         "tf.NotEqual",
268         "tf.Pow",
269         "tf.RealDiv",
270         "tf.SquaredDifference",
271         "tf.Sub",
272         "tf.TruncateDiv",
273         "tf.Xdivy",
274         "tf.Xlogy",
275     };
276     return [binary_ops = std::move(binary_ops)](Operation* op) {
277       return binary_ops.contains(op->getName().getStringRef());
278     };
279   }
280 };
281 
282 // -------------------------------------------------------------------------- //
283 // Cwise Unary Operations.
284 // -------------------------------------------------------------------------- //
285 
286 class CwiseUnaryOpClusteringPolicy : public DefaultClusteringPolicy {
287  public:
CwiseUnaryOpClusteringPolicy()288   CwiseUnaryOpClusteringPolicy()
289       : DefaultClusteringPolicy(IsUnaryOp(), ValueConstraint::kRank) {}
290 
291  private:
IsUnaryOp()292   std::function<bool(Operation* op)> IsUnaryOp() {
293     // TODO(ezhulenev): Use mlir::isa<>() to filter operations.
294     llvm::StringSet<> unary_ops = {
295         "tf.Abs",      "tf.Acos",        "tf.Acosh",      "tf.Asin",
296         "tf.Asinh",    "tf.Atan",        "tf.Atanh",      "tf.Cast",
297         "tf.Ceil",     "tf.ClipByValue", "tf.ComplexAbs", "tf.Conj",
298         "tf.Cos",      "tf.Cosh",        "tf.Elu",        "tf.Erf",
299         "tf.Exp",      "tf.Floor",       "tf.Inv",        "tf.Invert",
300         "tf.IsFinite", "tf.IsInf",       "tf.IsNan",      "tf.LeakyRelu",
301         "tf.Log",      "tf.Log1p",       "tf.LogicalNot", "tf.Neg",
302         "tf.Real",     "tf.Reciprocal",  "tf.Relu",       "tf.Relu6",
303         "tf.Rint",     "tf.Round",       "tf.Rsqrt",      "tf.Selu",
304         "tf.Sigmoid",  "tf.Sign",        "tf.Sin",        "tf.Sinh",
305         "tf.Softplus", "tf.Softsign",    "tf.Sqrt",       "tf.Square",
306         "tf.Tan",      "tf.Tanh",        "tf.ZerosLike",
307     };
308     return [unary_ops = std::move(unary_ops)](Operation* op) {
309       return unary_ops.contains(op->getName().getStringRef());
310     };
311   }
312 };
313 
314 // -------------------------------------------------------------------------- //
315 // Cwise Ternary Operations.
316 // -------------------------------------------------------------------------- //
317 
318 class CwiseTernaryOpClusteringPolicy : public DefaultClusteringPolicy {
319  public:
CwiseTernaryOpClusteringPolicy()320   CwiseTernaryOpClusteringPolicy()
321       : DefaultClusteringPolicy(IsTernaryOp(), ValueConstraint::kRank) {}
322 
323  private:
IsTernaryOp()324   std::function<bool(Operation* op)> IsTernaryOp() {
325     return [](Operation* op) {
326       return mlir::isa<mlir::TF::SelectOp, mlir::TF::SelectV2Op>(op);
327     };
328   }
329 };
330 
331 // -------------------------------------------------------------------------- //
332 // Reduction Operations.
333 // -------------------------------------------------------------------------- //
334 
335 // Clustering policy for Tensorflow reduction operations:
336 //   - shape constraint can be propagated from the result to the input
337 //   - reduction indices value must be known at compile time
338 //
339 // All operations that use this policy must have two operands (input and
340 // reduction indices) and a single result.
341 class ReductionOpClusteringPolicy : public ClusteringPolicy {
342  public:
343   LogicalResult MatchAndUpdateConstraints(
344       Operation* op, const ValuesConstraintSet& results,
345       ValuesConstraintSet& operands) const final;
346 
347  private:
348   bool IsSupported(Operation* op) const;
349 };
350 
MatchAndUpdateConstraints(Operation * op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const351 LogicalResult ReductionOpClusteringPolicy::MatchAndUpdateConstraints(
352     Operation* op, const ValuesConstraintSet& results,
353     ValuesConstraintSet& operands) const {
354   // Verify that the operation is a reduction with supported operands
355   // and results data types.
356   if (!IsSupported(op) || !IsSupportedOperandAndResultTypes(op))
357     return failure();
358 
359   assert(op->getNumOperands() == 2 && "expected two operands");
360   assert(op->getNumResults() == 1 && "expected one result");
361 
362   // Propagate constraint from the result to the input.
363   if (auto result_constraint = results.GetConstraint(op->getResult(0))) {
364     if (*result_constraint == ValueConstraint::kValue) return failure();
365     operands.Insert(op->getOperand(0), *result_constraint);
366   } else {
367     operands.Insert(op->getOperand(0), ValueConstraint::kRank);
368   }
369 
370   // Reduction indices must be known at compile time.
371   operands.Insert(op->getOperand(1), ValueConstraint::kValue);
372 
373   return success();
374 }
375 
IsSupported(Operation * op) const376 bool ReductionOpClusteringPolicy::IsSupported(Operation* op) const {
377   return mlir::isa<mlir::TF::AllOp,   //
378                    mlir::TF::AnyOp,   //
379                    mlir::TF::MaxOp,   //
380                    mlir::TF::MeanOp,  //
381                    mlir::TF::MinOp,   //
382                    mlir::TF::ProdOp,  //
383                    mlir::TF::SumOp>(op);
384 }
385 
386 // -------------------------------------------------------------------------- //
387 // tf.ConcatV2
388 // -------------------------------------------------------------------------- //
389 
390 class ConcatV2OpClusteringPolicy
391     : public TensorflowOpClusteringPolicy<ConcatV2Op> {
MatchAndUpdateConstraints(ConcatV2Op op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const392   LogicalResult MatchAndUpdateConstraints(
393       ConcatV2Op op, const ValuesConstraintSet& results,
394       ValuesConstraintSet& operands) const final {
395     auto result_constraint = results.GetConstraint(op->getResult(0));
396     if (result_constraint && *result_constraint == ValueConstraint::kValue)
397       return failure();
398 
399     // Propagate constraint from the result to the input. All inputs always need
400     // a known rank.
401     for (auto value : op.values()) {
402       operands.Insert(value,
403                       result_constraint.getValueOr(ValueConstraint::kRank));
404     }
405 
406     // Force axis to be a constant.
407     operands.Insert(op.axis(), ValueConstraint::kValue);
408 
409     return success();
410   }
411 };
412 
413 // -------------------------------------------------------------------------- //
414 // tf.Const
415 // -------------------------------------------------------------------------- //
416 
417 class ConstOpClusteringPolicy : public TensorflowOpClusteringPolicy<ConstOp> {
MatchAndUpdateConstraints(ConstOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const418   LogicalResult MatchAndUpdateConstraints(
419       ConstOp op, const ValuesConstraintSet& results,
420       ValuesConstraintSet& operands) const final {
421     // We cluster constant operation only if it is required to resolve some of
422     // the constraints.
423     auto result_constraint = results.GetConstraint(op.getResult());
424     if (!result_constraint.has_value()) return failure();
425 
426     return IsCompilableConstant(op.value());
427   }
428 };
429 
430 // -------------------------------------------------------------------------- //
431 // tf.ExpandDims
432 // -------------------------------------------------------------------------- //
433 
434 class ExpandDimsOpClusteringPolicy
435     : public TensorflowOpClusteringPolicy<ExpandDimsOp> {
MatchAndUpdateConstraints(ExpandDimsOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const436   LogicalResult MatchAndUpdateConstraints(
437       ExpandDimsOp op, const ValuesConstraintSet& results,
438       ValuesConstraintSet& operands) const final {
439     // Propagate constraint from the result to the input.
440     if (auto result_constraint = results.GetConstraint(op->getResult(0))) {
441       if (*result_constraint == ValueConstraint::kValue) return failure();
442       operands.Insert(op.input(), *result_constraint);
443     } else {
444       operands.Insert(op.input(), ValueConstraint::kRank);
445     }
446 
447     // The inserted dimension must be always known at compile time.
448     operands.Insert(op.dim(), ValueConstraint::kValue);
449 
450     return success();
451   }
452 };
453 
454 // -------------------------------------------------------------------------- //
455 // tf._FusedMatMul
456 // -------------------------------------------------------------------------- //
457 
458 class FusedMatMulOpClusteringPolicy
459     : public TensorflowOpClusteringPolicy<_FusedMatMulOp> {
MatchAndUpdateConstraints(_FusedMatMulOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const460   LogicalResult MatchAndUpdateConstraints(
461       _FusedMatMulOp op, const ValuesConstraintSet& results,
462       ValuesConstraintSet& operands) const final {
463     // Check if the default policy accepts the operation.
464     OpDefaultClusteringPolicy<_FusedMatMulOp> default_policy;
465     if (failed(default_policy.MatchAndUpdateConstraints(op, results, operands)))
466       return failure();
467 
468     // Check if we do support a set of fused operations.
469     size_t n = op.fused_ops().size();
470 
471     auto fusion =
472         n > 0 ? op.fused_ops()[0].dyn_cast<mlir::StringAttr>() : nullptr;
473     auto activation =
474         n > 1 ? op.fused_ops()[1].dyn_cast<mlir::StringAttr>() : nullptr;
475 
476     if ((n > 0 && !fusion) || (n > 1 && !activation)) return failure();
477 
478     // TODO(ezhulenev): Update fission pass to support more fusions and
479     // activations.
480 
481     // We only support BiasAdd fusion ...
482     if (fusion && fusion.getValue() != "BiasAdd") return failure();
483 
484     // ... with Relu activation.
485     if (activation && activation.getValue() != "Relu") return failure();
486 
487     return success();
488   }
489 };
490 
491 // -------------------------------------------------------------------------- //
492 // tf.Fill
493 // -------------------------------------------------------------------------- //
494 
495 class FillOpClusteringPolicy : public TensorflowOpClusteringPolicy<FillOp> {
MatchAndUpdateConstraints(FillOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const496   LogicalResult MatchAndUpdateConstraints(
497       FillOp op, const ValuesConstraintSet& results,
498       ValuesConstraintSet& operands) const final {
499     // Fill operation does not have any default constraints.
500     auto result_constraint = results.GetConstraint(op->getResult(0));
501     if (!result_constraint.has_value()) return success();
502 
503     // To know the result shape we need to know the shape operand value.
504     if (*result_constraint == ValueConstraint::kShape)
505       operands.Insert(op.dims(), ValueConstraint::kValue);
506 
507     // To know the result rank we need to know the shape operand shape.
508     if (*result_constraint == ValueConstraint::kRank)
509       operands.Insert(op.dims(), ValueConstraint::kShape);
510 
511     // Value constraint propagation is not supported.
512     if (*result_constraint == ValueConstraint::kValue) return failure();
513 
514     return success();
515   }
516 };
517 
518 // -------------------------------------------------------------------------- //
519 // tf.MatMul
520 // -------------------------------------------------------------------------- //
521 
522 class MatMulOpClusteringPolicy : public OpDefaultClusteringPolicy<MatMulOp> {};
523 
524 // -------------------------------------------------------------------------- //
525 // tf.OneHot
526 // -------------------------------------------------------------------------- //
527 
528 class OneHotOpClusteringPolicy : public TensorflowOpClusteringPolicy<OneHotOp> {
MatchAndUpdateConstraints(OneHotOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const529   LogicalResult MatchAndUpdateConstraints(
530       OneHotOp op, const ValuesConstraintSet& results,
531       ValuesConstraintSet& operands) const final {
532     // Value constraint propagation is not supported.
533     if (auto constraint = results.GetConstraint(op.getResult()))
534       if (*constraint == ValueConstraint::kValue) return failure();
535 
536     // MHLO lowering needs a static shape for the indices and a constant depth.
537     operands.Insert(op.indices(), ValueConstraint::kShape);
538     operands.Insert(op.depth(), ValueConstraint::kValue);
539 
540     return success();
541   }
542 };
543 
544 // -------------------------------------------------------------------------- //
545 // tf.Pack
546 // -------------------------------------------------------------------------- //
547 
548 class PackOpClusteringPolicy : public OpDefaultClusteringPolicy<PackOp> {};
549 
550 // -------------------------------------------------------------------------- //
551 // tf.Range
552 // -------------------------------------------------------------------------- //
553 
554 class RangeOpClusteringPolicy : public TensorflowOpClusteringPolicy<RangeOp> {
MatchAndUpdateConstraints(RangeOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const555   LogicalResult MatchAndUpdateConstraints(
556       RangeOp op, const ValuesConstraintSet& results,
557       ValuesConstraintSet& operands) const final {
558     // Range operation does not have any default constraints.
559     auto result_constraint = results.GetConstraint(op.getResult());
560     if (!result_constraint.has_value()) return success();
561 
562     // To know the result shape we need the input values.
563     if (*result_constraint == ValueConstraint::kShape) {
564       operands.Insert({op.start(), op.limit(), op.delta()},
565                       ValueConstraint::kValue);
566     }
567 
568     // Value constraint propagation is not supported.
569     if (*result_constraint == ValueConstraint::kValue) return failure();
570 
571     return success();
572   }
573 };
574 
575 // -------------------------------------------------------------------------- //
576 // tf.Reshape
577 // -------------------------------------------------------------------------- //
578 
579 class ReshapeOpClusteringPolicy
580     : public TensorflowOpClusteringPolicy<ReshapeOp> {
MatchAndUpdateConstraints(ReshapeOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const581   LogicalResult MatchAndUpdateConstraints(
582       ReshapeOp op, const ValuesConstraintSet& results,
583       ValuesConstraintSet& operands) const final {
584     // The runtime only supports ranked tensors.
585     operands.Insert(op.tensor(), ValueConstraint::kRank);
586 
587     // Reshape operation does not have any default constraints.
588     auto result_constraint = results.GetConstraint(op.getResult());
589     if (!result_constraint.has_value()) return success();
590 
591     // To know the result shape we need to know the shape operand value. We also
592     // require a static shape on the input in case there's a -1 in the shape.
593     if (*result_constraint == ValueConstraint::kShape) {
594       operands.Insert(op.shape(), ValueConstraint::kValue);
595       operands.Insert(op.tensor(), ValueConstraint::kShape);
596     }
597 
598     // To know the result rank we need to know the shape operand shape.
599     if (*result_constraint == ValueConstraint::kRank)
600       operands.Insert(op.shape(), ValueConstraint::kShape);
601 
602     // Value constraint propagation is not supported.
603     if (*result_constraint == ValueConstraint::kValue) return failure();
604 
605     return success();
606   }
607 };
608 
609 // -------------------------------------------------------------------------- //
610 // tf.Shape
611 // -------------------------------------------------------------------------- //
612 
613 class ShapeOpClusteringPolicy : public TensorflowOpClusteringPolicy<ShapeOp> {
MatchAndUpdateConstraints(ShapeOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const614   LogicalResult MatchAndUpdateConstraints(
615       ShapeOp op, const ValuesConstraintSet& results,
616       ValuesConstraintSet& operands) const final {
617     // Unranked inputs aren't supported by JitRt.
618     operands.Insert(op.input(), ValueConstraint::kRank);
619 
620     // Check constraint on the result value.
621     auto result_constraint = results.GetConstraint(op.getResult());
622     if (!result_constraint.has_value()) return success();
623 
624     // To know the result shape we need only the rank of the input.
625     if (*result_constraint == ValueConstraint::kShape)
626       operands.Insert(op.input(), ValueConstraint::kRank);
627 
628     // To know the result value we need to know the shape of the input.
629     if (*result_constraint == ValueConstraint::kValue)
630       operands.Insert(op.input(), ValueConstraint::kShape);
631 
632     return success();
633   }
634 };
635 
636 // -------------------------------------------------------------------------- //
637 // tf.Softmax
638 // -------------------------------------------------------------------------- //
639 
640 class SoftmaxOpClusteringPolicy : public DefaultClusteringPolicy {
641  public:
SoftmaxOpClusteringPolicy()642   SoftmaxOpClusteringPolicy()
643       : DefaultClusteringPolicy(IsSoftmaxOp(), ValueConstraint::kRank) {}
644 
645  private:
IsSoftmaxOp()646   std::function<bool(Operation* op)> IsSoftmaxOp() {
647     return [](Operation* op) {
648       return mlir::isa<mlir::TF::SoftmaxOp, mlir::TF::LogSoftmaxOp>(op);
649     };
650   }
651 };
652 
653 // -------------------------------------------------------------------------- //
654 // tf.Squeeze
655 // -------------------------------------------------------------------------- //
656 
657 class SqueezeOpClusteringPolicy
658     : public TensorflowOpClusteringPolicy<SqueezeOp> {
MatchAndUpdateConstraints(SqueezeOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const659   LogicalResult MatchAndUpdateConstraints(
660       SqueezeOp op, const ValuesConstraintSet& results,
661       ValuesConstraintSet& operands) const final {
662     // Propagate static shape constraints.
663     auto input_constraint = ValueConstraint::kRank;
664     if (auto result_constraint = results.GetConstraint(op.getResult())) {
665       if (*result_constraint == ValueConstraint::kValue) return failure();
666       input_constraint = *result_constraint;
667     }
668 
669     // If squeeze_dims is not present we need a static shape.
670     if (op.squeeze_dims().empty()) input_constraint = ValueConstraint::kShape;
671 
672     operands.Insert(op.input(), input_constraint);
673     return success();
674   }
675 };
676 
677 // -------------------------------------------------------------------------- //
678 // tf.StopGradient
679 // -------------------------------------------------------------------------- //
680 
681 class StopGradientOpClusteringPolicy
682     : public OpDefaultClusteringPolicy<StopGradientOp> {};
683 
684 // -------------------------------------------------------------------------- //
685 // tf.Transpose
686 // -------------------------------------------------------------------------- //
687 
688 class TransposeOpClusteringPolicy
689     : public TensorflowOpClusteringPolicy<TransposeOp> {
MatchAndUpdateConstraints(TransposeOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const690   LogicalResult MatchAndUpdateConstraints(
691       TransposeOp op, const ValuesConstraintSet& results,
692       ValuesConstraintSet& operands) const final {
693     // Propagate result constraints to the input, at minimum require known rank.
694     if (auto constraint = results.GetConstraint(op.getResult())) {
695       operands.Insert(op.x(), *constraint);
696     } else {
697       operands.Insert(op.x(), ValueConstraint::kRank);
698     }
699 
700     // Permutation must be always known at compile time.
701     operands.Insert(op.perm(), ValueConstraint::kValue);
702 
703     return success();
704   }
705 };
706 
707 // -------------------------------------------------------------------------- //
708 // tf.Slice
709 // -------------------------------------------------------------------------- //
710 
711 class SliceOpClusteringPolicy : public TensorflowOpClusteringPolicy<SliceOp> {
MatchAndUpdateConstraints(SliceOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const712   LogicalResult MatchAndUpdateConstraints(
713       SliceOp op, const ValuesConstraintSet& results,
714       ValuesConstraintSet& operands) const final {
715     // Value constraint propagation is not supported.
716     if (auto constraint = results.GetConstraint(op.getResult()))
717       if (*constraint == ValueConstraint::kValue) return failure();
718 
719     // We must know the shape of the input.
720     operands.Insert(op.input(), ValueConstraint::kShape);
721 
722     // Force begin and size to be constants. The restriction on begin could be
723     // lifted if we know that there are no `-1` sizes.
724     // TODO(kramerb): Revisit this when mhlo.real_dynamic_slice stabilizes.
725     operands.Insert({op.begin(), op.size()}, ValueConstraint::kValue);
726 
727     return success();
728   }
729 };
730 
731 // -------------------------------------------------------------------------- //
732 // tf.StridedSlice
733 // -------------------------------------------------------------------------- //
734 
735 class StridedSliceOpClusteringPolicy
736     : public TensorflowOpClusteringPolicy<StridedSliceOp> {
MatchAndUpdateConstraints(StridedSliceOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const737   LogicalResult MatchAndUpdateConstraints(
738       StridedSliceOp op, const ValuesConstraintSet& results,
739       ValuesConstraintSet& operands) const final {
740     // We must know the shape of the input.
741     operands.Insert(op.input(), ValueConstraint::kShape);
742 
743     // And values of operands that control the slice size.
744     operands.Insert({op.begin(), op.end(), op.strides()},
745                     ValueConstraint::kValue);
746 
747     return success();
748   }
749 };
750 
751 // -------------------------------------------------------------------------- //
752 // Gather Operations.
753 // -------------------------------------------------------------------------- //
754 
755 class GatherOpClusteringPolicy : public DefaultClusteringPolicy {
756  public:
GatherOpClusteringPolicy()757   GatherOpClusteringPolicy()
758       : DefaultClusteringPolicy(IsGatherOp(), ValueConstraint::kRank) {}
759 
760  private:
IsGatherOp()761   std::function<bool(Operation* op)> IsGatherOp() {
762     return [](Operation* op) {
763       return mlir::isa<mlir::TF::GatherNdOp, mlir::TF::GatherV2Op,
764                        mlir::TF::GatherOp>(op);
765     };
766   }
767 };
768 
769 // -------------------------------------------------------------------------- //
770 // Scatter Operations.
771 // -------------------------------------------------------------------------- //
772 
773 class ScatterOpClusteringPolicy : public DefaultClusteringPolicy {
774  public:
ScatterOpClusteringPolicy()775   ScatterOpClusteringPolicy()
776       : DefaultClusteringPolicy(IsScatterOp(), ValueConstraint::kRank) {}
777 
778  private:
IsScatterOp()779   std::function<bool(Operation* op)> IsScatterOp() {
780     return [](Operation* op) {
781       return mlir::isa<
782           mlir::TF::ScatterNdOp, mlir::TF::TensorScatterAddOp,
783           mlir::TF::TensorScatterMaxOp, mlir::TF::TensorScatterMinOp,
784           mlir::TF::TensorScatterSubOp, mlir::TF::TensorScatterUpdateOp>(op);
785     };
786   }
787 };
788 
789 }  // namespace
790 
populateTfJitRtClusteringPolicies(ClusteringPolicySet & policies,JitRtClusteringTier tier)791 void populateTfJitRtClusteringPolicies(ClusteringPolicySet& policies,
792                                        JitRtClusteringTier tier) {
793   // Returns true if the given jitrt compilation tier is enabled.
794   auto is_enabled = [&](JitRtClusteringTier requested) -> bool {
795     return (static_cast<uint8_t>(tier) & static_cast<uint8_t>(requested)) ==
796            static_cast<uint8_t>(requested);
797   };
798 
799   if (is_enabled(JitRtClusteringTier::kCwise)) {
800     policies.Add<CwiseBinaryOpClusteringPolicy,   //
801                  CwiseUnaryOpClusteringPolicy,    //
802                  CwiseTernaryOpClusteringPolicy,  //
803                  StopGradientOpClusteringPolicy>();
804   }
805 
806   if (is_enabled(JitRtClusteringTier::kTranspose)) {
807     policies.Add<TransposeOpClusteringPolicy>();
808   }
809 
810   if (is_enabled(JitRtClusteringTier::kReductions)) {
811     policies.Add<ReductionOpClusteringPolicy>();
812   }
813 
814   if (is_enabled(JitRtClusteringTier::kMetadata)) {
815     policies.Add<ExpandDimsOpClusteringPolicy,  //
816                  ReshapeOpClusteringPolicy,     //
817                  ShapeOpClusteringPolicy,       //
818                  SqueezeOpClusteringPolicy>();
819   }
820 
821   if (is_enabled(JitRtClusteringTier::kGatherScatter)) {
822     policies.Add<GatherOpClusteringPolicy,  //
823                  ScatterOpClusteringPolicy>();
824   }
825 
826   if (is_enabled(JitRtClusteringTier::kAll)) {
827     policies.Add<BatchMatMulV2OpClusteringPolicy,  //
828                  BroadcastToOpClusteringPolicy,    //
829                  ConcatV2OpClusteringPolicy,       //
830                  FillOpClusteringPolicy,           //
831                  FusedMatMulOpClusteringPolicy,    //
832                  MatMulOpClusteringPolicy,         //
833                  OneHotOpClusteringPolicy,         //
834                  PackOpClusteringPolicy,           //
835                  RangeOpClusteringPolicy,          //
836                  SliceOpClusteringPolicy,          //
837                  SoftmaxOpClusteringPolicy,        //
838                  StridedSliceOpClusteringPolicy>();
839   }
840 }
841 
populateTfJitRtConstraintsPolicies(ClusteringPolicySet & policies,JitRtClusteringTier tier)842 void populateTfJitRtConstraintsPolicies(ClusteringPolicySet& policies,
843                                         JitRtClusteringTier tier) {
844   populateTfJitRtClusteringPolicies(policies, tier);
845   policies.Add<ConstOpClusteringPolicy>();
846 }
847 
848 // -------------------------------------------------------------------------- //
849 // Helper functions.
850 // -------------------------------------------------------------------------- //
851 
IsCompilableConstant(mlir::ElementsAttr value)852 mlir::LogicalResult IsCompilableConstant(mlir::ElementsAttr value) {
853   return success(value.getNumElements() <= 16 &&
854                  value.getType().getElementType().isIntOrIndexOrFloat());
855 }
856 
IsI1Integer(Type type)857 static bool IsI1Integer(Type type) {
858   return mlir::getElementTypeOrSelf(type).isInteger(1);
859 }
860 
IsUnsignedInteger(Type type)861 static bool IsUnsignedInteger(Type type) {
862   return mlir::getElementTypeOrSelf(type).isUnsignedInteger();
863 }
864 
VerifyCluster(const Cluster & cluster)865 mlir::LogicalResult VerifyCluster(const Cluster& cluster) {
866   llvm::SmallDenseSet<Operation*> ops;
867   for (Operation* op : cluster.operations) {
868     auto inserted = ops.insert(op);
869     assert(inserted.second && "clustered operations must be unique");
870     (void)inserted;
871   }
872 
873   // TODO(ezhulenev): Too large clusters with dynamic shapes can take a very
874   // long time to compile. Skip them for now.
875   if (ops.size() > 20) return failure();
876 
877   // TODO(ezhulenev): This is a temporary workaround to disable forming clusters
878   // with known compilation problems.
879   for (Operation* op : ops) {
880     // TODO(b/205714705): Memory layout of `i1` data type is not defined, and
881     // when vectorization is enabled it can lead to crashes.
882     bool has_i1_integers = llvm::any_of(op->getOperandTypes(), IsI1Integer) ||
883                            llvm::any_of(op->getResultTypes(), IsI1Integer);
884     if (has_i1_integers && tensorflow::GetJitRtFlags().vectorize)
885       return failure();
886 
887     // TODO(b/205905286): Unsigned integers support has a lot of gaps, and
888     // similar to handling `i1` we need a type conversion to signless integers.
889     bool has_unsigned_integers =
890         llvm::any_of(op->getOperandTypes(), IsUnsignedInteger) ||
891         llvm::any_of(op->getResultTypes(), IsUnsignedInteger);
892     if (has_unsigned_integers) return failure();
893   }
894 
895   for (auto& pair : cluster.constraints) {
896     Value value = pair.getFirst();
897     ValueConstraint constraint = pair.getSecond();
898 
899     // We can satisfy shape and rank constraints on the compiled function
900     // operands.
901     if (constraint == ValueConstraint::kRank ||
902         constraint == ValueConstraint::kShape)
903       continue;
904 
905     if (constraint == ValueConstraint::kValue &&
906         xla::runtime::SupportsValueSpecialization(value.getType()))
907       continue;
908 
909     Operation* op = value.getDefiningOp();
910     if (!op) return failure();  // we do not support block arguments
911 
912     // Operations defined inside the cluster will be constant folded before the
913     // compilation. This property is guaranteed by the clustering policy.
914     if (ops.contains(op)) continue;
915 
916     // Small constants will be sunk into the compiled function body.
917     auto const_op = mlir::dyn_cast<mlir::TF::ConstOp>(op);
918     if (!const_op || failed(IsCompilableConstant(const_op.value())))
919       return failure();
920   }
921 
922   return success();
923 }
924 
925 }  // namespace tensorflow
926