xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_clustering.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_clustering.h"
17 
18 #include <functional>
19 #include <utility>
20 
21 #include "llvm/ADT/DenseSet.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/StringSet.h"
24 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
25 #include "mlir/IR/Operation.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
29 #include "tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.h"
30 #include "tfrt/cpu/jit/cpurt_support.h"  // from @tf_runtime
31 
32 namespace tensorflow {
33 
34 using mlir::failure;
35 using mlir::LogicalResult;
36 using mlir::Operation;
37 using mlir::success;
38 using mlir::TensorType;
39 using mlir::Type;
40 using mlir::Value;
41 
42 using mlir::TFDevice::Cluster;
43 using mlir::TFDevice::ClusteringPolicy;
44 using mlir::TFDevice::ClusteringPolicySet;
45 using mlir::TFDevice::ValueConstraint;
46 using mlir::TFDevice::ValuesConstraintSet;
47 
48 using mlir::TF::_FusedMatMulOp;
49 using mlir::TF::BroadcastToOp;
50 using mlir::TF::ConcatV2Op;
51 using mlir::TF::ConstOp;
52 using mlir::TF::ExpandDimsOp;
53 using mlir::TF::FillOp;
54 using mlir::TF::MatMulOp;
55 using mlir::TF::PackOp;
56 using mlir::TF::RangeOp;
57 using mlir::TF::ReshapeOp;
58 using mlir::TF::ShapeOp;
59 using mlir::TF::StopGradientOp;
60 using mlir::TF::StridedSliceOp;
61 using mlir::TF::TransposeOp;
62 
63 namespace {
64 
65 // A set of clustering constraints that allow TF -> CPURT compilation pipeline
66 // to lower Tensorflow operations to MHLO and then to Linalg. Tensorflow
67 // dynamism is not fully representable at Linalg level, so by providing a
68 // clustering policy we ensure that we can successfully compile all clustered
69 // operations (we have enough static information to lower to MHLO, or build
70 // static Linalg indexing maps).
71 //
72 // Some of these constraints gets resolved at constant folding time, and
73 // operations are completely removed from the IR, and some constraints just
74 // enable TF->MHLO or MHLO->Linalg lowering.
75 
76 // Returns true if all types are supported by the Tensorflow -> CPURT
77 // compilation pipeline and TFRT JIT runtime integration (see cpurt.h).
78 template <typename TypeRange>
IsSupportedDataTypes(TypeRange && types)79 static bool IsSupportedDataTypes(TypeRange&& types) {
80   return llvm::all_of(types, [](Type type) -> bool {
81     if (auto tensor = type.dyn_cast<TensorType>()) {
82       auto elt_type = tensor.getElementType();
83       return elt_type.isF32() || elt_type.isInteger(1) ||
84              elt_type.isInteger(32) || elt_type.isInteger(64);
85     }
86     return false;
87   });
88 }
89 
IsSupportedOperandTypes(Operation * op)90 static bool IsSupportedOperandTypes(Operation* op) {
91   return IsSupportedDataTypes(op->getOperandTypes());
92 }
93 
IsSupportedResultTypes(Operation * op)94 static bool IsSupportedResultTypes(Operation* op) {
95   return IsSupportedDataTypes(op->getResultTypes());
96 }
97 
IsSupportedOperandAndResultTypes(Operation * op)98 static bool IsSupportedOperandAndResultTypes(Operation* op) {
99   return IsSupportedOperandTypes(op) && IsSupportedResultTypes(op);
100 }
101 
102 // Clustering policy for a specific Tensorflow operation type that verifies
103 // that operation operands and results data types are supported.
104 template <typename OpTy>
105 class TensorflowOpClusteringPolicy : public ClusteringPolicy {
106  public:
MatchAndUpdateConstraints(Operation * operation,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const107   LogicalResult MatchAndUpdateConstraints(
108       Operation* operation, const ValuesConstraintSet& results,
109       ValuesConstraintSet& operands) const final {
110     auto op = mlir::dyn_cast<OpTy>(operation);
111     if (op && IsSupportedOperandAndResultTypes(op))
112       return MatchAndUpdateConstraints(op, results, operands);
113     return failure();
114   }
115 
116   virtual LogicalResult MatchAndUpdateConstraints(
117       OpTy op, const ValuesConstraintSet& results,
118       ValuesConstraintSet& operands) const = 0;
119 };
120 
121 // -------------------------------------------------------------------------- //
122 // Default clustering policy for TF -> CPURT compilation.
123 // -------------------------------------------------------------------------- //
124 
125 // Default clustering policy for Tensorflow -> TFRT JIT compilation propagates
126 // the most restrictive constraint from the results to all operands. If results
127 // do not have any constraints it adds default constraint to all operands if it
128 // is provided, otherwise just returns `success` without adding any constraints.
129 class DefaultClusteringPolicy : public ClusteringPolicy {
130  public:
DefaultClusteringPolicy(std::function<bool (Operation *)> filter,llvm::Optional<ValueConstraint> default_constraint=llvm::None)131   explicit DefaultClusteringPolicy(
132       std::function<bool(Operation*)> filter,
133       llvm::Optional<ValueConstraint> default_constraint = llvm::None)
134       : filter_(std::move(filter)), default_constraint_(default_constraint) {}
135 
136   LogicalResult MatchAndUpdateConstraints(
137       Operation* op, const ValuesConstraintSet& results,
138       ValuesConstraintSet& operands) const final;
139 
140  private:
141   // A filter for operations that are supported.
142   std::function<bool(Operation*)> filter_;
143   // Default constraint for all operands.
144   llvm::Optional<ValueConstraint> default_constraint_;
145 };
146 
147 template <typename OpTy>
148 class OpDefaultClusteringPolicy : public DefaultClusteringPolicy {
149  public:
OpDefaultClusteringPolicy(llvm::Optional<ValueConstraint> default_constraint=llvm::None)150   explicit OpDefaultClusteringPolicy(
151       llvm::Optional<ValueConstraint> default_constraint = llvm::None)
152       : DefaultClusteringPolicy(
153             [](Operation* op) -> bool { return mlir::isa<OpTy>(op); },
154             default_constraint) {}
155 };
156 
MatchAndUpdateConstraints(Operation * op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const157 LogicalResult DefaultClusteringPolicy::MatchAndUpdateConstraints(
158     Operation* op, const ValuesConstraintSet& results,
159     ValuesConstraintSet& operands) const {
160   if (!filter_(op)) return failure();
161 
162   if (!IsSupportedOperandAndResultTypes(op)) return failure();
163 
164   // Find the most restrictive constraint from the operation results.
165   llvm::Optional<ValueConstraint> default_constraint = default_constraint_;
166 
167   for (mlir::Value result : op->getResults()) {
168     if (auto result_constraint = results.GetConstraint(result)) {
169       // TODO(ezhulenev): We can safely propagate value constraints if we know
170       // that the value is an integer scalar or a small vector, however in
171       // practice all values that we are interested in are defined by constant
172       // operations directly. Revisit if this becomes a problem.
173       if (*result_constraint == ValueConstraint::kValue) return failure();
174 
175       default_constraint = default_constraint.hasValue()
176                                ? Merge(*default_constraint, *result_constraint)
177                                : *result_constraint;
178     }
179   }
180 
181   // No constraints to propagate.
182   if (!default_constraint.hasValue()) return success();
183 
184   // Propage constraint to all operands.
185   for (unsigned i = 0; i < op->getNumOperands(); ++i)
186     operands.Insert(op->getOperand(i), *default_constraint);
187   return success();
188 }
189 
190 // -------------------------------------------------------------------------- //
191 // tf.BroadcastTo
192 // -------------------------------------------------------------------------- //
193 
194 class BroadcastToOpClusteringPolicy
195     : public TensorflowOpClusteringPolicy<BroadcastToOp> {
MatchAndUpdateConstraints(BroadcastToOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const196   LogicalResult MatchAndUpdateConstraints(
197       BroadcastToOp op, const ValuesConstraintSet& results,
198       ValuesConstraintSet& operands) const final {
199     // Only ranked inputs are supported.
200     operands.Insert(op.input(), ValueConstraint::kRank);
201 
202     if (auto result_constraint = results.GetConstraint(op.getResult())) {
203       if (*result_constraint == ValueConstraint::kValue) return failure();
204       // For a static output shape we need a constant shape operand.
205       if (*result_constraint == ValueConstraint::kShape) {
206         operands.Insert(op.shape(), ValueConstraint::kValue);
207         return success();
208       }
209     }
210 
211     // Producing a ranked output requires a known shape for the shape operand.
212     operands.Insert(op.shape(), ValueConstraint::kShape);
213 
214     return success();
215   }
216 };
217 
218 // -------------------------------------------------------------------------- //
219 // Cwise Binary Operations.
220 // -------------------------------------------------------------------------- //
221 
222 class CwiseBinaryOpClusteringPolicy : public DefaultClusteringPolicy {
223  public:
CwiseBinaryOpClusteringPolicy()224   CwiseBinaryOpClusteringPolicy()
225       : DefaultClusteringPolicy(IsBinaryOp(), ValueConstraint::kRank) {}
226 
227  private:
228   // TODO(ezhulenev): Use mlir::isa<>() to filter operations.
IsBinaryOp()229   std::function<bool(Operation* op)> IsBinaryOp() {
230     llvm::StringSet<> binary_ops = {
231         "tf.Add",
232         "tf.AddV2",
233         "tf.ApproximateEqual",
234         "tf.Atan2",
235         "tf.BiasAdd",
236         "tf.BitwiseAnd",
237         "tf.BitwiseOr",
238         "tf.BitwiseXor",
239         "tf.Div",
240         "tf.DivNoNan",
241         "tf.Equal",
242         "tf.FloorDiv",
243         "tf.FloorMod",
244         "tf.Greater",
245         "tf.GreaterEqual",
246         "tf.Less",
247         "tf.LessEqual",
248         "tf.LogicalAnd",
249         "tf.LogicalOr",
250         "tf.Maximum",
251         "tf.Minimum",
252         "tf.Mod",
253         "tf.Mul",
254         "tf.MulNoNan",
255         "tf.NotEqual",
256         "tf.Pow",
257         "tf.RealDiv",
258         "tf.SquaredDifference",
259         "tf.Sub",
260         "tf.Xdivy",
261         "tf.Xlogy",
262     };
263     return [binary_ops = std::move(binary_ops)](Operation* op) {
264       return binary_ops.contains(op->getName().getStringRef());
265     };
266   }
267 };
268 
269 // -------------------------------------------------------------------------- //
270 // Cwise Unary Operations.
271 // -------------------------------------------------------------------------- //
272 
273 class CwiseUnaryOpClusteringPolicy : public DefaultClusteringPolicy {
274  public:
CwiseUnaryOpClusteringPolicy()275   CwiseUnaryOpClusteringPolicy()
276       : DefaultClusteringPolicy(IsUnaryOp(), ValueConstraint::kRank) {}
277 
278  private:
IsUnaryOp()279   std::function<bool(Operation* op)> IsUnaryOp() {
280     // TODO(ezhulenev): Use mlir::isa<>() to filter operations.
281     llvm::StringSet<> unary_ops = {
282         "tf.Abs",      "tf.Acos",        "tf.Acosh",      "tf.Asin",
283         "tf.Asinh",    "tf.Atan",        "tf.Atanh",      "tf.Cast",
284         "tf.Ceil",     "tf.ClipByValue", "tf.ComplexAbs", "tf.Conj",
285         "tf.Cos",      "tf.Cosh",        "tf.Elu",        "tf.Erf",
286         "tf.Exp",      "tf.Floor",       "tf.Inv",        "tf.Invert",
287         "tf.IsFinite", "tf.IsInf",       "tf.IsNan",      "tf.LeakyRelu",
288         "tf.Log",      "tf.Log1p",       "tf.LogicalNot", "tf.Neg",
289         "tf.Real",     "tf.Reciprocal",  "tf.Relu",       "tf.Relu6",
290         "tf.Rint",     "tf.Round",       "tf.Rsqrt",      "tf.Selu",
291         "tf.Sigmoid",  "tf.Sign",        "tf.Sin",        "tf.Sinh",
292         "tf.Softplus", "tf.Softsign",    "tf.Sqrt",       "tf.Square",
293         "tf.Tan",      "tf.Tanh",        "tf.ZerosLike",
294     };
295     return [unary_ops = std::move(unary_ops)](Operation* op) {
296       return unary_ops.contains(op->getName().getStringRef());
297     };
298   }
299 };
300 
301 // -------------------------------------------------------------------------- //
302 // Cwise Ternary Operations.
303 // -------------------------------------------------------------------------- //
304 
305 class CwiseTernaryOpClusteringPolicy : public DefaultClusteringPolicy {
306  public:
CwiseTernaryOpClusteringPolicy()307   CwiseTernaryOpClusteringPolicy()
308       : DefaultClusteringPolicy(IsTernaryOp(), ValueConstraint::kRank) {}
309 
310  private:
IsTernaryOp()311   std::function<bool(Operation* op)> IsTernaryOp() {
312     return [](Operation* op) {
313       return mlir::isa<mlir::TF::SelectOp, mlir::TF::SelectV2Op>(op);
314     };
315   }
316 };
317 
318 // -------------------------------------------------------------------------- //
319 // Reduction Operations.
320 // -------------------------------------------------------------------------- //
321 
322 // Clustering policy for Tensorflow reduction operations:
323 //   - shape constraint can be propagated from the result to the input
324 //   - reduction indices value must be known at compile time
325 //
326 // All operations that use this policy must have two operands (input and
327 // reduction indices) and a single result.
328 class ReductionOpClusteringPolicy : public ClusteringPolicy {
329  public:
330   LogicalResult MatchAndUpdateConstraints(
331       Operation* op, const ValuesConstraintSet& results,
332       ValuesConstraintSet& operands) const final;
333 
334  private:
335   bool IsSupported(Operation* op) const;
336 };
337 
MatchAndUpdateConstraints(Operation * op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const338 LogicalResult ReductionOpClusteringPolicy::MatchAndUpdateConstraints(
339     Operation* op, const ValuesConstraintSet& results,
340     ValuesConstraintSet& operands) const {
341   // Verify that the operation is a reduction with supported operands
342   // and results data types.
343   if (!IsSupported(op) || !IsSupportedOperandAndResultTypes(op))
344     return failure();
345 
346   assert(op->getNumOperands() == 2 && "expected two operands");
347   assert(op->getNumResults() == 1 && "expected one result");
348 
349   // Propagate constraint from the result to the input.
350   if (auto result_constraint = results.GetConstraint(op->getResult(0))) {
351     if (*result_constraint == ValueConstraint::kValue) return failure();
352     operands.Insert(op->getOperand(0), *result_constraint);
353   } else {
354     operands.Insert(op->getOperand(0), ValueConstraint::kRank);
355   }
356 
357   // Reduction indices must be known at compile time.
358   operands.Insert(op->getOperand(1), ValueConstraint::kValue);
359 
360   return success();
361 }
362 
IsSupported(Operation * op) const363 bool ReductionOpClusteringPolicy::IsSupported(Operation* op) const {
364   return mlir::isa<mlir::TF::AllOp,   //
365                    mlir::TF::AnyOp,   //
366                    mlir::TF::MaxOp,   //
367                    mlir::TF::MeanOp,  //
368                    mlir::TF::MinOp,   //
369                    mlir::TF::ProdOp,  //
370                    mlir::TF::SumOp>(op);
371 }
372 
373 // -------------------------------------------------------------------------- //
374 // tf.ConcatV2
375 // -------------------------------------------------------------------------- //
376 
377 class ConcatV2OpClusteringPolicy
378     : public TensorflowOpClusteringPolicy<ConcatV2Op> {
MatchAndUpdateConstraints(ConcatV2Op op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const379   LogicalResult MatchAndUpdateConstraints(
380       ConcatV2Op op, const ValuesConstraintSet& results,
381       ValuesConstraintSet& operands) const final {
382     auto result_constraint = results.GetConstraint(op->getResult(0));
383     if (result_constraint && *result_constraint == ValueConstraint::kValue)
384       return failure();
385 
386     // Propagate constraint from the result to the input. All inputs always need
387     // a known rank.
388     for (auto value : op.values()) {
389       operands.Insert(value,
390                       result_constraint.getValueOr(ValueConstraint::kRank));
391     }
392 
393     // Force axis to be a constant.
394     operands.Insert(op.axis(), ValueConstraint::kValue);
395 
396     return success();
397   }
398 };
399 
400 // -------------------------------------------------------------------------- //
401 // tf.Const
402 // -------------------------------------------------------------------------- //
403 
404 class ConstOpClusteringPolicy : public TensorflowOpClusteringPolicy<ConstOp> {
MatchAndUpdateConstraints(ConstOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const405   LogicalResult MatchAndUpdateConstraints(
406       ConstOp op, const ValuesConstraintSet& results,
407       ValuesConstraintSet& operands) const final {
408     // We cluster constant operation only if it is required to resolve some of
409     // the constraints.
410     auto result_constraint = results.GetConstraint(op.getResult());
411     if (!result_constraint.hasValue()) return failure();
412 
413     return IsCompilableConstant(op.value());
414   }
415 };
416 
417 // -------------------------------------------------------------------------- //
418 // tf.ExpandDims
419 // -------------------------------------------------------------------------- //
420 
421 class ExpandDimsOpClusteringPolicy
422     : public TensorflowOpClusteringPolicy<ExpandDimsOp> {
MatchAndUpdateConstraints(ExpandDimsOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const423   LogicalResult MatchAndUpdateConstraints(
424       ExpandDimsOp op, const ValuesConstraintSet& results,
425       ValuesConstraintSet& operands) const final {
426     // Propagate constraint from the result to the input.
427     if (auto result_constraint = results.GetConstraint(op->getResult(0))) {
428       if (*result_constraint == ValueConstraint::kValue) return failure();
429       operands.Insert(op.input(), *result_constraint);
430     } else {
431       operands.Insert(op.input(), ValueConstraint::kRank);
432     }
433 
434     // The inserted dimension must be always known at compile time.
435     operands.Insert(op.dim(), ValueConstraint::kValue);
436 
437     return success();
438   }
439 };
440 
441 // -------------------------------------------------------------------------- //
442 // tf._FusedMatMul
443 // -------------------------------------------------------------------------- //
444 
445 class FusedMatMulOpClusteringPolicy
446     : public TensorflowOpClusteringPolicy<_FusedMatMulOp> {
MatchAndUpdateConstraints(_FusedMatMulOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const447   LogicalResult MatchAndUpdateConstraints(
448       _FusedMatMulOp op, const ValuesConstraintSet& results,
449       ValuesConstraintSet& operands) const final {
450     // Check if the default policy accepts the operation.
451     OpDefaultClusteringPolicy<_FusedMatMulOp> default_policy;
452     if (failed(default_policy.MatchAndUpdateConstraints(op, results, operands)))
453       return failure();
454 
455     // Check if we do support a set of fused operations.
456     size_t n = op.fused_ops().size();
457 
458     auto fusion =
459         n > 0 ? op.fused_ops()[0].dyn_cast<mlir::StringAttr>() : nullptr;
460     auto activation =
461         n > 1 ? op.fused_ops()[1].dyn_cast<mlir::StringAttr>() : nullptr;
462 
463     if ((n > 0 && !fusion) || (n > 1 && !activation)) return failure();
464 
465     // TODO(ezhulenev): Update fission pass to support more fusions and
466     // activations.
467 
468     // We only support BiasAdd fusion ...
469     if (fusion && fusion.getValue() != "BiasAdd") return failure();
470 
471     // ... with Relu activation.
472     if (activation && activation.getValue() != "Relu") return failure();
473 
474     return success();
475   }
476 };
477 
478 // -------------------------------------------------------------------------- //
479 // tf.Fill
480 // -------------------------------------------------------------------------- //
481 
482 class FillOpClusteringPolicy : public TensorflowOpClusteringPolicy<FillOp> {
MatchAndUpdateConstraints(FillOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const483   LogicalResult MatchAndUpdateConstraints(
484       FillOp op, const ValuesConstraintSet& results,
485       ValuesConstraintSet& operands) const final {
486     // Fill operation does not have any default constraints.
487     auto result_constraint = results.GetConstraint(op->getResult(0));
488     if (!result_constraint.hasValue()) return success();
489 
490     // To know the result shape we need to know the shape operand value.
491     if (*result_constraint == ValueConstraint::kShape)
492       operands.Insert(op.dims(), ValueConstraint::kValue);
493 
494     // To know the result rank we need to know the shape operand shape.
495     if (*result_constraint == ValueConstraint::kRank)
496       operands.Insert(op.dims(), ValueConstraint::kShape);
497 
498     // Value constraint propagation is not supported.
499     if (*result_constraint == ValueConstraint::kValue) return failure();
500 
501     return success();
502   }
503 };
504 
505 // -------------------------------------------------------------------------- //
506 // tf.MatMul
507 // -------------------------------------------------------------------------- //
508 
509 class MatMulOpClusteringPolicy : public OpDefaultClusteringPolicy<MatMulOp> {};
510 
511 // -------------------------------------------------------------------------- //
512 // tf.Pack
513 // -------------------------------------------------------------------------- //
514 
515 class PackOpClusteringPolicy : public OpDefaultClusteringPolicy<PackOp> {};
516 
517 // -------------------------------------------------------------------------- //
518 // tf.Range
519 // -------------------------------------------------------------------------- //
520 
521 class RangeOpClusteringPolicy : public TensorflowOpClusteringPolicy<RangeOp> {
MatchAndUpdateConstraints(RangeOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const522   LogicalResult MatchAndUpdateConstraints(
523       RangeOp op, const ValuesConstraintSet& results,
524       ValuesConstraintSet& operands) const final {
525     // Range operation does not have any default constraints.
526     auto result_constraint = results.GetConstraint(op.getResult());
527     if (!result_constraint.hasValue()) return success();
528 
529     // To know the result shape we need the input values.
530     if (*result_constraint == ValueConstraint::kShape) {
531       operands.Insert({op.start(), op.limit(), op.delta()},
532                       ValueConstraint::kValue);
533     }
534 
535     // Value constraint propagation is not supported.
536     if (*result_constraint == ValueConstraint::kValue) return failure();
537 
538     return success();
539   }
540 };
541 
542 // -------------------------------------------------------------------------- //
543 // tf.Reshape
544 // -------------------------------------------------------------------------- //
545 
546 class ReshapeOpClusteringPolicy
547     : public TensorflowOpClusteringPolicy<ReshapeOp> {
MatchAndUpdateConstraints(ReshapeOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const548   LogicalResult MatchAndUpdateConstraints(
549       ReshapeOp op, const ValuesConstraintSet& results,
550       ValuesConstraintSet& operands) const final {
551     // The runtime only supports ranked tensors.
552     operands.Insert(op.tensor(), ValueConstraint::kRank);
553 
554     // Reshape operation does not have any default constraints.
555     auto result_constraint = results.GetConstraint(op.getResult());
556     if (!result_constraint.hasValue()) return success();
557 
558     // To know the result shape we need to know the shape operand value. We also
559     // require a static shape on the input in case there's a -1 in the shape.
560     if (*result_constraint == ValueConstraint::kShape) {
561       operands.Insert(op.shape(), ValueConstraint::kValue);
562       operands.Insert(op.tensor(), ValueConstraint::kShape);
563     }
564 
565     // To know the result rank we need to know the shape operand shape.
566     if (*result_constraint == ValueConstraint::kRank)
567       operands.Insert(op.shape(), ValueConstraint::kShape);
568 
569     // Value constraint propagation is not supported.
570     if (*result_constraint == ValueConstraint::kValue) return failure();
571 
572     return success();
573   }
574 };
575 
576 // -------------------------------------------------------------------------- //
577 // tf.Shape
578 // -------------------------------------------------------------------------- //
579 
580 class ShapeOpClusteringPolicy : public TensorflowOpClusteringPolicy<ShapeOp> {
MatchAndUpdateConstraints(ShapeOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const581   LogicalResult MatchAndUpdateConstraints(
582       ShapeOp op, const ValuesConstraintSet& results,
583       ValuesConstraintSet& operands) const final {
584     // Check constraint on the result value.
585     auto result_constraint = results.GetConstraint(op.getResult());
586     if (!result_constraint.hasValue()) return success();
587 
588     // To know the result shape we need only the rank of the input.
589     if (*result_constraint == ValueConstraint::kShape)
590       operands.Insert(op.input(), ValueConstraint::kRank);
591 
592     // To know the result value we need to know the shape of the input.
593     if (*result_constraint == ValueConstraint::kValue)
594       operands.Insert(op.input(), ValueConstraint::kShape);
595 
596     return success();
597   }
598 };
599 
600 // -------------------------------------------------------------------------- //
601 // tf.Softmax
602 // -------------------------------------------------------------------------- //
603 
604 class SoftmaxOpClusteringPolicy : public DefaultClusteringPolicy {
605  public:
SoftmaxOpClusteringPolicy()606   SoftmaxOpClusteringPolicy()
607       : DefaultClusteringPolicy(IsSoftmaxOp(), ValueConstraint::kRank) {}
608 
609  private:
IsSoftmaxOp()610   std::function<bool(Operation* op)> IsSoftmaxOp() {
611     return [](Operation* op) {
612       return mlir::isa<mlir::TF::SoftmaxOp, mlir::TF::LogSoftmaxOp>(op);
613     };
614   }
615 };
616 
617 // -------------------------------------------------------------------------- //
618 // tf.StopGradient
619 // -------------------------------------------------------------------------- //
620 
621 class StopGradientOpClusteringPolicy
622     : public OpDefaultClusteringPolicy<StopGradientOp> {};
623 
624 // -------------------------------------------------------------------------- //
625 // tf.Transpose
626 // -------------------------------------------------------------------------- //
627 
628 class TransposeOpClusteringPolicy
629     : public TensorflowOpClusteringPolicy<TransposeOp> {
MatchAndUpdateConstraints(TransposeOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const630   LogicalResult MatchAndUpdateConstraints(
631       TransposeOp op, const ValuesConstraintSet& results,
632       ValuesConstraintSet& operands) const final {
633     // Propagate result constraints to the input, at minimum require known rank.
634     if (auto constraint = results.GetConstraint(op.getResult())) {
635       operands.Insert(op.x(), *constraint);
636     } else {
637       operands.Insert(op.x(), ValueConstraint::kRank);
638     }
639 
640     // Permutation must be always known at compile time.
641     operands.Insert(op.perm(), ValueConstraint::kValue);
642 
643     return success();
644   }
645 };
646 
647 // -------------------------------------------------------------------------- //
648 // tf.StridedSlice
649 // -------------------------------------------------------------------------- //
650 
651 class StridedSliceOpClusteringPolicy
652     : public TensorflowOpClusteringPolicy<StridedSliceOp> {
MatchAndUpdateConstraints(StridedSliceOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const653   LogicalResult MatchAndUpdateConstraints(
654       StridedSliceOp op, const ValuesConstraintSet& results,
655       ValuesConstraintSet& operands) const final {
656     // We must know the shape of the input.
657     operands.Insert(op.input(), ValueConstraint::kShape);
658 
659     // And values of operands that control the slice size.
660     operands.Insert({op.begin(), op.end(), op.strides()},
661                     ValueConstraint::kValue);
662 
663     return success();
664   }
665 };
666 
667 }  // namespace
668 
populateTfCpurtClusteringPolicies(ClusteringPolicySet & policies,CpurtClusteringTier tier)669 void populateTfCpurtClusteringPolicies(ClusteringPolicySet& policies,
670                                        CpurtClusteringTier tier) {
671   // Returns true if the given cpurt compilation tier is enabled.
672   auto is_enabled = [&](CpurtClusteringTier requested) -> bool {
673     return static_cast<uint8_t>(requested) <= static_cast<uint8_t>(tier);
674   };
675 
676   if (is_enabled(CpurtClusteringTier::kTier1)) {
677     policies.Add<CwiseBinaryOpClusteringPolicy,   //
678                  CwiseUnaryOpClusteringPolicy,    //
679                  CwiseTernaryOpClusteringPolicy,  //
680                  StopGradientOpClusteringPolicy,  //
681                  TransposeOpClusteringPolicy>();
682   }
683 
684   if (is_enabled(CpurtClusteringTier::kAll)) {
685     policies.Add<BroadcastToOpClusteringPolicy,  //
686                  ConcatV2OpClusteringPolicy,     //
687                  ExpandDimsOpClusteringPolicy,   //
688                  FillOpClusteringPolicy,         //
689                  FusedMatMulOpClusteringPolicy,  //
690                  MatMulOpClusteringPolicy,       //
691                  PackOpClusteringPolicy,         //
692                  RangeOpClusteringPolicy,        //
693                  ReductionOpClusteringPolicy,    //
694                  ReshapeOpClusteringPolicy,      //
695                  ShapeOpClusteringPolicy,        //
696                  SoftmaxOpClusteringPolicy,      //
697                  StridedSliceOpClusteringPolicy>();
698   }
699 }
700 
populateTfCpurtConstraintsPolicies(ClusteringPolicySet & policies,CpurtClusteringTier tier)701 void populateTfCpurtConstraintsPolicies(ClusteringPolicySet& policies,
702                                         CpurtClusteringTier tier) {
703   populateTfCpurtClusteringPolicies(policies, tier);
704   policies.Add<ConstOpClusteringPolicy>();
705 }
706 
707 // -------------------------------------------------------------------------- //
708 // Helper functions.
709 // -------------------------------------------------------------------------- //
710 
IsCompilableConstant(mlir::ElementsAttr value)711 mlir::LogicalResult IsCompilableConstant(mlir::ElementsAttr value) {
712   return success(value.getNumElements() <= 16 &&
713                  value.getType().getElementType().isIntOrIndexOrFloat());
714 }
715 
VerifyCluster(const Cluster & cluster)716 mlir::LogicalResult VerifyCluster(const Cluster& cluster) {
717   llvm::SmallDenseSet<Operation*> ops;
718   for (Operation* op : cluster.operations) {
719     auto inserted = ops.insert(op);
720     assert(inserted.second && "clustered operations must be unique");
721     (void)inserted;
722   }
723 
724   for (auto& pair : cluster.constraints) {
725     Value value = pair.getFirst();
726     ValueConstraint constraint = pair.getSecond();
727 
728     // We can satisfy shape and rank constraints on the compiled function
729     // operands.
730     if (constraint == ValueConstraint::kRank ||
731         constraint == ValueConstraint::kShape)
732       continue;
733 
734     if (constraint == ValueConstraint::kValue &&
735         tfrt::cpu::jit::SupportsValueSpecialization(value.getType()))
736       continue;
737 
738     Operation* op = value.getDefiningOp();
739     if (!op) return failure();  // we do not support block arguments
740 
741     // Operations defined inside the cluster will be constant folded before the
742     // compilation. This property is guaranteed by the clustering policy.
743     if (ops.contains(op)) continue;
744 
745     // Small constants will be sunk into the compiled function body.
746     auto const_op = mlir::dyn_cast<mlir::TF::ConstOp>(op);
747     if (!const_op || failed(IsCompilableConstant(const_op.value())))
748       return failure();
749   }
750 
751   return success();
752 }
753 
754 }  // namespace tensorflow
755