1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tfrt/jit/tf_cpurt_clustering.h"
17
18 #include <functional>
19 #include <utility>
20
21 #include "llvm/ADT/DenseSet.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/StringSet.h"
24 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
25 #include "mlir/IR/Operation.h" // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
29 #include "tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.h"
30 #include "tfrt/cpu/jit/cpurt_support.h" // from @tf_runtime
31
32 namespace tensorflow {
33
34 using mlir::failure;
35 using mlir::LogicalResult;
36 using mlir::Operation;
37 using mlir::success;
38 using mlir::TensorType;
39 using mlir::Type;
40 using mlir::Value;
41
42 using mlir::TFDevice::Cluster;
43 using mlir::TFDevice::ClusteringPolicy;
44 using mlir::TFDevice::ClusteringPolicySet;
45 using mlir::TFDevice::ValueConstraint;
46 using mlir::TFDevice::ValuesConstraintSet;
47
48 using mlir::TF::_FusedMatMulOp;
49 using mlir::TF::BroadcastToOp;
50 using mlir::TF::ConcatV2Op;
51 using mlir::TF::ConstOp;
52 using mlir::TF::ExpandDimsOp;
53 using mlir::TF::FillOp;
54 using mlir::TF::MatMulOp;
55 using mlir::TF::PackOp;
56 using mlir::TF::RangeOp;
57 using mlir::TF::ReshapeOp;
58 using mlir::TF::ShapeOp;
59 using mlir::TF::StopGradientOp;
60 using mlir::TF::StridedSliceOp;
61 using mlir::TF::TransposeOp;
62
63 namespace {
64
65 // A set of clustering constraints that allow TF -> CPURT compilation pipeline
66 // to lower Tensorflow operations to MHLO and then to Linalg. Tensorflow
67 // dynamism is not fully representable at Linalg level, so by providing a
68 // clustering policy we ensure that we can successfully compile all clustered
69 // operations (we have enough static information to lower to MHLO, or build
70 // static Linalg indexing maps).
71 //
72 // Some of these constraints gets resolved at constant folding time, and
73 // operations are completely removed from the IR, and some constraints just
74 // enable TF->MHLO or MHLO->Linalg lowering.
75
76 // Returns true if all types are supported by the Tensorflow -> CPURT
77 // compilation pipeline and TFRT JIT runtime integration (see cpurt.h).
78 template <typename TypeRange>
IsSupportedDataTypes(TypeRange && types)79 static bool IsSupportedDataTypes(TypeRange&& types) {
80 return llvm::all_of(types, [](Type type) -> bool {
81 if (auto tensor = type.dyn_cast<TensorType>()) {
82 auto elt_type = tensor.getElementType();
83 return elt_type.isF32() || elt_type.isInteger(1) ||
84 elt_type.isInteger(32) || elt_type.isInteger(64);
85 }
86 return false;
87 });
88 }
89
IsSupportedOperandTypes(Operation * op)90 static bool IsSupportedOperandTypes(Operation* op) {
91 return IsSupportedDataTypes(op->getOperandTypes());
92 }
93
IsSupportedResultTypes(Operation * op)94 static bool IsSupportedResultTypes(Operation* op) {
95 return IsSupportedDataTypes(op->getResultTypes());
96 }
97
IsSupportedOperandAndResultTypes(Operation * op)98 static bool IsSupportedOperandAndResultTypes(Operation* op) {
99 return IsSupportedOperandTypes(op) && IsSupportedResultTypes(op);
100 }
101
102 // Clustering policy for a specific Tensorflow operation type that verifies
103 // that operation operands and results data types are supported.
104 template <typename OpTy>
105 class TensorflowOpClusteringPolicy : public ClusteringPolicy {
106 public:
MatchAndUpdateConstraints(Operation * operation,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const107 LogicalResult MatchAndUpdateConstraints(
108 Operation* operation, const ValuesConstraintSet& results,
109 ValuesConstraintSet& operands) const final {
110 auto op = mlir::dyn_cast<OpTy>(operation);
111 if (op && IsSupportedOperandAndResultTypes(op))
112 return MatchAndUpdateConstraints(op, results, operands);
113 return failure();
114 }
115
116 virtual LogicalResult MatchAndUpdateConstraints(
117 OpTy op, const ValuesConstraintSet& results,
118 ValuesConstraintSet& operands) const = 0;
119 };
120
121 // -------------------------------------------------------------------------- //
122 // Default clustering policy for TF -> CPURT compilation.
123 // -------------------------------------------------------------------------- //
124
125 // Default clustering policy for Tensorflow -> TFRT JIT compilation propagates
126 // the most restrictive constraint from the results to all operands. If results
127 // do not have any constraints it adds default constraint to all operands if it
128 // is provided, otherwise just returns `success` without adding any constraints.
129 class DefaultClusteringPolicy : public ClusteringPolicy {
130 public:
DefaultClusteringPolicy(std::function<bool (Operation *)> filter,llvm::Optional<ValueConstraint> default_constraint=llvm::None)131 explicit DefaultClusteringPolicy(
132 std::function<bool(Operation*)> filter,
133 llvm::Optional<ValueConstraint> default_constraint = llvm::None)
134 : filter_(std::move(filter)), default_constraint_(default_constraint) {}
135
136 LogicalResult MatchAndUpdateConstraints(
137 Operation* op, const ValuesConstraintSet& results,
138 ValuesConstraintSet& operands) const final;
139
140 private:
141 // A filter for operations that are supported.
142 std::function<bool(Operation*)> filter_;
143 // Default constraint for all operands.
144 llvm::Optional<ValueConstraint> default_constraint_;
145 };
146
147 template <typename OpTy>
148 class OpDefaultClusteringPolicy : public DefaultClusteringPolicy {
149 public:
OpDefaultClusteringPolicy(llvm::Optional<ValueConstraint> default_constraint=llvm::None)150 explicit OpDefaultClusteringPolicy(
151 llvm::Optional<ValueConstraint> default_constraint = llvm::None)
152 : DefaultClusteringPolicy(
153 [](Operation* op) -> bool { return mlir::isa<OpTy>(op); },
154 default_constraint) {}
155 };
156
MatchAndUpdateConstraints(Operation * op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const157 LogicalResult DefaultClusteringPolicy::MatchAndUpdateConstraints(
158 Operation* op, const ValuesConstraintSet& results,
159 ValuesConstraintSet& operands) const {
160 if (!filter_(op)) return failure();
161
162 if (!IsSupportedOperandAndResultTypes(op)) return failure();
163
164 // Find the most restrictive constraint from the operation results.
165 llvm::Optional<ValueConstraint> default_constraint = default_constraint_;
166
167 for (mlir::Value result : op->getResults()) {
168 if (auto result_constraint = results.GetConstraint(result)) {
169 // TODO(ezhulenev): We can safely propagate value constraints if we know
170 // that the value is an integer scalar or a small vector, however in
171 // practice all values that we are interested in are defined by constant
172 // operations directly. Revisit if this becomes a problem.
173 if (*result_constraint == ValueConstraint::kValue) return failure();
174
175 default_constraint = default_constraint.hasValue()
176 ? Merge(*default_constraint, *result_constraint)
177 : *result_constraint;
178 }
179 }
180
181 // No constraints to propagate.
182 if (!default_constraint.hasValue()) return success();
183
184 // Propage constraint to all operands.
185 for (unsigned i = 0; i < op->getNumOperands(); ++i)
186 operands.Insert(op->getOperand(i), *default_constraint);
187 return success();
188 }
189
190 // -------------------------------------------------------------------------- //
191 // tf.BroadcastTo
192 // -------------------------------------------------------------------------- //
193
194 class BroadcastToOpClusteringPolicy
195 : public TensorflowOpClusteringPolicy<BroadcastToOp> {
MatchAndUpdateConstraints(BroadcastToOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const196 LogicalResult MatchAndUpdateConstraints(
197 BroadcastToOp op, const ValuesConstraintSet& results,
198 ValuesConstraintSet& operands) const final {
199 // Only ranked inputs are supported.
200 operands.Insert(op.input(), ValueConstraint::kRank);
201
202 if (auto result_constraint = results.GetConstraint(op.getResult())) {
203 if (*result_constraint == ValueConstraint::kValue) return failure();
204 // For a static output shape we need a constant shape operand.
205 if (*result_constraint == ValueConstraint::kShape) {
206 operands.Insert(op.shape(), ValueConstraint::kValue);
207 return success();
208 }
209 }
210
211 // Producing a ranked output requires a known shape for the shape operand.
212 operands.Insert(op.shape(), ValueConstraint::kShape);
213
214 return success();
215 }
216 };
217
218 // -------------------------------------------------------------------------- //
219 // Cwise Binary Operations.
220 // -------------------------------------------------------------------------- //
221
222 class CwiseBinaryOpClusteringPolicy : public DefaultClusteringPolicy {
223 public:
CwiseBinaryOpClusteringPolicy()224 CwiseBinaryOpClusteringPolicy()
225 : DefaultClusteringPolicy(IsBinaryOp(), ValueConstraint::kRank) {}
226
227 private:
228 // TODO(ezhulenev): Use mlir::isa<>() to filter operations.
IsBinaryOp()229 std::function<bool(Operation* op)> IsBinaryOp() {
230 llvm::StringSet<> binary_ops = {
231 "tf.Add",
232 "tf.AddV2",
233 "tf.ApproximateEqual",
234 "tf.Atan2",
235 "tf.BiasAdd",
236 "tf.BitwiseAnd",
237 "tf.BitwiseOr",
238 "tf.BitwiseXor",
239 "tf.Div",
240 "tf.DivNoNan",
241 "tf.Equal",
242 "tf.FloorDiv",
243 "tf.FloorMod",
244 "tf.Greater",
245 "tf.GreaterEqual",
246 "tf.Less",
247 "tf.LessEqual",
248 "tf.LogicalAnd",
249 "tf.LogicalOr",
250 "tf.Maximum",
251 "tf.Minimum",
252 "tf.Mod",
253 "tf.Mul",
254 "tf.MulNoNan",
255 "tf.NotEqual",
256 "tf.Pow",
257 "tf.RealDiv",
258 "tf.SquaredDifference",
259 "tf.Sub",
260 "tf.Xdivy",
261 "tf.Xlogy",
262 };
263 return [binary_ops = std::move(binary_ops)](Operation* op) {
264 return binary_ops.contains(op->getName().getStringRef());
265 };
266 }
267 };
268
269 // -------------------------------------------------------------------------- //
270 // Cwise Unary Operations.
271 // -------------------------------------------------------------------------- //
272
273 class CwiseUnaryOpClusteringPolicy : public DefaultClusteringPolicy {
274 public:
CwiseUnaryOpClusteringPolicy()275 CwiseUnaryOpClusteringPolicy()
276 : DefaultClusteringPolicy(IsUnaryOp(), ValueConstraint::kRank) {}
277
278 private:
IsUnaryOp()279 std::function<bool(Operation* op)> IsUnaryOp() {
280 // TODO(ezhulenev): Use mlir::isa<>() to filter operations.
281 llvm::StringSet<> unary_ops = {
282 "tf.Abs", "tf.Acos", "tf.Acosh", "tf.Asin",
283 "tf.Asinh", "tf.Atan", "tf.Atanh", "tf.Cast",
284 "tf.Ceil", "tf.ClipByValue", "tf.ComplexAbs", "tf.Conj",
285 "tf.Cos", "tf.Cosh", "tf.Elu", "tf.Erf",
286 "tf.Exp", "tf.Floor", "tf.Inv", "tf.Invert",
287 "tf.IsFinite", "tf.IsInf", "tf.IsNan", "tf.LeakyRelu",
288 "tf.Log", "tf.Log1p", "tf.LogicalNot", "tf.Neg",
289 "tf.Real", "tf.Reciprocal", "tf.Relu", "tf.Relu6",
290 "tf.Rint", "tf.Round", "tf.Rsqrt", "tf.Selu",
291 "tf.Sigmoid", "tf.Sign", "tf.Sin", "tf.Sinh",
292 "tf.Softplus", "tf.Softsign", "tf.Sqrt", "tf.Square",
293 "tf.Tan", "tf.Tanh", "tf.ZerosLike",
294 };
295 return [unary_ops = std::move(unary_ops)](Operation* op) {
296 return unary_ops.contains(op->getName().getStringRef());
297 };
298 }
299 };
300
301 // -------------------------------------------------------------------------- //
302 // Cwise Ternary Operations.
303 // -------------------------------------------------------------------------- //
304
305 class CwiseTernaryOpClusteringPolicy : public DefaultClusteringPolicy {
306 public:
CwiseTernaryOpClusteringPolicy()307 CwiseTernaryOpClusteringPolicy()
308 : DefaultClusteringPolicy(IsTernaryOp(), ValueConstraint::kRank) {}
309
310 private:
IsTernaryOp()311 std::function<bool(Operation* op)> IsTernaryOp() {
312 return [](Operation* op) {
313 return mlir::isa<mlir::TF::SelectOp, mlir::TF::SelectV2Op>(op);
314 };
315 }
316 };
317
318 // -------------------------------------------------------------------------- //
319 // Reduction Operations.
320 // -------------------------------------------------------------------------- //
321
322 // Clustering policy for Tensorflow reduction operations:
323 // - shape constraint can be propagated from the result to the input
324 // - reduction indices value must be known at compile time
325 //
326 // All operations that use this policy must have two operands (input and
327 // reduction indices) and a single result.
328 class ReductionOpClusteringPolicy : public ClusteringPolicy {
329 public:
330 LogicalResult MatchAndUpdateConstraints(
331 Operation* op, const ValuesConstraintSet& results,
332 ValuesConstraintSet& operands) const final;
333
334 private:
335 bool IsSupported(Operation* op) const;
336 };
337
MatchAndUpdateConstraints(Operation * op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const338 LogicalResult ReductionOpClusteringPolicy::MatchAndUpdateConstraints(
339 Operation* op, const ValuesConstraintSet& results,
340 ValuesConstraintSet& operands) const {
341 // Verify that the operation is a reduction with supported operands
342 // and results data types.
343 if (!IsSupported(op) || !IsSupportedOperandAndResultTypes(op))
344 return failure();
345
346 assert(op->getNumOperands() == 2 && "expected two operands");
347 assert(op->getNumResults() == 1 && "expected one result");
348
349 // Propagate constraint from the result to the input.
350 if (auto result_constraint = results.GetConstraint(op->getResult(0))) {
351 if (*result_constraint == ValueConstraint::kValue) return failure();
352 operands.Insert(op->getOperand(0), *result_constraint);
353 } else {
354 operands.Insert(op->getOperand(0), ValueConstraint::kRank);
355 }
356
357 // Reduction indices must be known at compile time.
358 operands.Insert(op->getOperand(1), ValueConstraint::kValue);
359
360 return success();
361 }
362
IsSupported(Operation * op) const363 bool ReductionOpClusteringPolicy::IsSupported(Operation* op) const {
364 return mlir::isa<mlir::TF::AllOp, //
365 mlir::TF::AnyOp, //
366 mlir::TF::MaxOp, //
367 mlir::TF::MeanOp, //
368 mlir::TF::MinOp, //
369 mlir::TF::ProdOp, //
370 mlir::TF::SumOp>(op);
371 }
372
373 // -------------------------------------------------------------------------- //
374 // tf.ConcatV2
375 // -------------------------------------------------------------------------- //
376
377 class ConcatV2OpClusteringPolicy
378 : public TensorflowOpClusteringPolicy<ConcatV2Op> {
MatchAndUpdateConstraints(ConcatV2Op op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const379 LogicalResult MatchAndUpdateConstraints(
380 ConcatV2Op op, const ValuesConstraintSet& results,
381 ValuesConstraintSet& operands) const final {
382 auto result_constraint = results.GetConstraint(op->getResult(0));
383 if (result_constraint && *result_constraint == ValueConstraint::kValue)
384 return failure();
385
386 // Propagate constraint from the result to the input. All inputs always need
387 // a known rank.
388 for (auto value : op.values()) {
389 operands.Insert(value,
390 result_constraint.getValueOr(ValueConstraint::kRank));
391 }
392
393 // Force axis to be a constant.
394 operands.Insert(op.axis(), ValueConstraint::kValue);
395
396 return success();
397 }
398 };
399
400 // -------------------------------------------------------------------------- //
401 // tf.Const
402 // -------------------------------------------------------------------------- //
403
404 class ConstOpClusteringPolicy : public TensorflowOpClusteringPolicy<ConstOp> {
MatchAndUpdateConstraints(ConstOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const405 LogicalResult MatchAndUpdateConstraints(
406 ConstOp op, const ValuesConstraintSet& results,
407 ValuesConstraintSet& operands) const final {
408 // We cluster constant operation only if it is required to resolve some of
409 // the constraints.
410 auto result_constraint = results.GetConstraint(op.getResult());
411 if (!result_constraint.hasValue()) return failure();
412
413 return IsCompilableConstant(op.value());
414 }
415 };
416
417 // -------------------------------------------------------------------------- //
418 // tf.ExpandDims
419 // -------------------------------------------------------------------------- //
420
421 class ExpandDimsOpClusteringPolicy
422 : public TensorflowOpClusteringPolicy<ExpandDimsOp> {
MatchAndUpdateConstraints(ExpandDimsOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const423 LogicalResult MatchAndUpdateConstraints(
424 ExpandDimsOp op, const ValuesConstraintSet& results,
425 ValuesConstraintSet& operands) const final {
426 // Propagate constraint from the result to the input.
427 if (auto result_constraint = results.GetConstraint(op->getResult(0))) {
428 if (*result_constraint == ValueConstraint::kValue) return failure();
429 operands.Insert(op.input(), *result_constraint);
430 } else {
431 operands.Insert(op.input(), ValueConstraint::kRank);
432 }
433
434 // The inserted dimension must be always known at compile time.
435 operands.Insert(op.dim(), ValueConstraint::kValue);
436
437 return success();
438 }
439 };
440
441 // -------------------------------------------------------------------------- //
442 // tf._FusedMatMul
443 // -------------------------------------------------------------------------- //
444
445 class FusedMatMulOpClusteringPolicy
446 : public TensorflowOpClusteringPolicy<_FusedMatMulOp> {
MatchAndUpdateConstraints(_FusedMatMulOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const447 LogicalResult MatchAndUpdateConstraints(
448 _FusedMatMulOp op, const ValuesConstraintSet& results,
449 ValuesConstraintSet& operands) const final {
450 // Check if the default policy accepts the operation.
451 OpDefaultClusteringPolicy<_FusedMatMulOp> default_policy;
452 if (failed(default_policy.MatchAndUpdateConstraints(op, results, operands)))
453 return failure();
454
455 // Check if we do support a set of fused operations.
456 size_t n = op.fused_ops().size();
457
458 auto fusion =
459 n > 0 ? op.fused_ops()[0].dyn_cast<mlir::StringAttr>() : nullptr;
460 auto activation =
461 n > 1 ? op.fused_ops()[1].dyn_cast<mlir::StringAttr>() : nullptr;
462
463 if ((n > 0 && !fusion) || (n > 1 && !activation)) return failure();
464
465 // TODO(ezhulenev): Update fission pass to support more fusions and
466 // activations.
467
468 // We only support BiasAdd fusion ...
469 if (fusion && fusion.getValue() != "BiasAdd") return failure();
470
471 // ... with Relu activation.
472 if (activation && activation.getValue() != "Relu") return failure();
473
474 return success();
475 }
476 };
477
478 // -------------------------------------------------------------------------- //
479 // tf.Fill
480 // -------------------------------------------------------------------------- //
481
482 class FillOpClusteringPolicy : public TensorflowOpClusteringPolicy<FillOp> {
MatchAndUpdateConstraints(FillOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const483 LogicalResult MatchAndUpdateConstraints(
484 FillOp op, const ValuesConstraintSet& results,
485 ValuesConstraintSet& operands) const final {
486 // Fill operation does not have any default constraints.
487 auto result_constraint = results.GetConstraint(op->getResult(0));
488 if (!result_constraint.hasValue()) return success();
489
490 // To know the result shape we need to know the shape operand value.
491 if (*result_constraint == ValueConstraint::kShape)
492 operands.Insert(op.dims(), ValueConstraint::kValue);
493
494 // To know the result rank we need to know the shape operand shape.
495 if (*result_constraint == ValueConstraint::kRank)
496 operands.Insert(op.dims(), ValueConstraint::kShape);
497
498 // Value constraint propagation is not supported.
499 if (*result_constraint == ValueConstraint::kValue) return failure();
500
501 return success();
502 }
503 };
504
505 // -------------------------------------------------------------------------- //
506 // tf.MatMul
507 // -------------------------------------------------------------------------- //
508
509 class MatMulOpClusteringPolicy : public OpDefaultClusteringPolicy<MatMulOp> {};
510
511 // -------------------------------------------------------------------------- //
512 // tf.Pack
513 // -------------------------------------------------------------------------- //
514
515 class PackOpClusteringPolicy : public OpDefaultClusteringPolicy<PackOp> {};
516
517 // -------------------------------------------------------------------------- //
518 // tf.Range
519 // -------------------------------------------------------------------------- //
520
521 class RangeOpClusteringPolicy : public TensorflowOpClusteringPolicy<RangeOp> {
MatchAndUpdateConstraints(RangeOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const522 LogicalResult MatchAndUpdateConstraints(
523 RangeOp op, const ValuesConstraintSet& results,
524 ValuesConstraintSet& operands) const final {
525 // Range operation does not have any default constraints.
526 auto result_constraint = results.GetConstraint(op.getResult());
527 if (!result_constraint.hasValue()) return success();
528
529 // To know the result shape we need the input values.
530 if (*result_constraint == ValueConstraint::kShape) {
531 operands.Insert({op.start(), op.limit(), op.delta()},
532 ValueConstraint::kValue);
533 }
534
535 // Value constraint propagation is not supported.
536 if (*result_constraint == ValueConstraint::kValue) return failure();
537
538 return success();
539 }
540 };
541
542 // -------------------------------------------------------------------------- //
543 // tf.Reshape
544 // -------------------------------------------------------------------------- //
545
546 class ReshapeOpClusteringPolicy
547 : public TensorflowOpClusteringPolicy<ReshapeOp> {
MatchAndUpdateConstraints(ReshapeOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const548 LogicalResult MatchAndUpdateConstraints(
549 ReshapeOp op, const ValuesConstraintSet& results,
550 ValuesConstraintSet& operands) const final {
551 // The runtime only supports ranked tensors.
552 operands.Insert(op.tensor(), ValueConstraint::kRank);
553
554 // Reshape operation does not have any default constraints.
555 auto result_constraint = results.GetConstraint(op.getResult());
556 if (!result_constraint.hasValue()) return success();
557
558 // To know the result shape we need to know the shape operand value. We also
559 // require a static shape on the input in case there's a -1 in the shape.
560 if (*result_constraint == ValueConstraint::kShape) {
561 operands.Insert(op.shape(), ValueConstraint::kValue);
562 operands.Insert(op.tensor(), ValueConstraint::kShape);
563 }
564
565 // To know the result rank we need to know the shape operand shape.
566 if (*result_constraint == ValueConstraint::kRank)
567 operands.Insert(op.shape(), ValueConstraint::kShape);
568
569 // Value constraint propagation is not supported.
570 if (*result_constraint == ValueConstraint::kValue) return failure();
571
572 return success();
573 }
574 };
575
576 // -------------------------------------------------------------------------- //
577 // tf.Shape
578 // -------------------------------------------------------------------------- //
579
580 class ShapeOpClusteringPolicy : public TensorflowOpClusteringPolicy<ShapeOp> {
MatchAndUpdateConstraints(ShapeOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const581 LogicalResult MatchAndUpdateConstraints(
582 ShapeOp op, const ValuesConstraintSet& results,
583 ValuesConstraintSet& operands) const final {
584 // Check constraint on the result value.
585 auto result_constraint = results.GetConstraint(op.getResult());
586 if (!result_constraint.hasValue()) return success();
587
588 // To know the result shape we need only the rank of the input.
589 if (*result_constraint == ValueConstraint::kShape)
590 operands.Insert(op.input(), ValueConstraint::kRank);
591
592 // To know the result value we need to know the shape of the input.
593 if (*result_constraint == ValueConstraint::kValue)
594 operands.Insert(op.input(), ValueConstraint::kShape);
595
596 return success();
597 }
598 };
599
600 // -------------------------------------------------------------------------- //
601 // tf.Softmax
602 // -------------------------------------------------------------------------- //
603
604 class SoftmaxOpClusteringPolicy : public DefaultClusteringPolicy {
605 public:
SoftmaxOpClusteringPolicy()606 SoftmaxOpClusteringPolicy()
607 : DefaultClusteringPolicy(IsSoftmaxOp(), ValueConstraint::kRank) {}
608
609 private:
IsSoftmaxOp()610 std::function<bool(Operation* op)> IsSoftmaxOp() {
611 return [](Operation* op) {
612 return mlir::isa<mlir::TF::SoftmaxOp, mlir::TF::LogSoftmaxOp>(op);
613 };
614 }
615 };
616
617 // -------------------------------------------------------------------------- //
618 // tf.StopGradient
619 // -------------------------------------------------------------------------- //
620
621 class StopGradientOpClusteringPolicy
622 : public OpDefaultClusteringPolicy<StopGradientOp> {};
623
624 // -------------------------------------------------------------------------- //
625 // tf.Transpose
626 // -------------------------------------------------------------------------- //
627
628 class TransposeOpClusteringPolicy
629 : public TensorflowOpClusteringPolicy<TransposeOp> {
MatchAndUpdateConstraints(TransposeOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const630 LogicalResult MatchAndUpdateConstraints(
631 TransposeOp op, const ValuesConstraintSet& results,
632 ValuesConstraintSet& operands) const final {
633 // Propagate result constraints to the input, at minimum require known rank.
634 if (auto constraint = results.GetConstraint(op.getResult())) {
635 operands.Insert(op.x(), *constraint);
636 } else {
637 operands.Insert(op.x(), ValueConstraint::kRank);
638 }
639
640 // Permutation must be always known at compile time.
641 operands.Insert(op.perm(), ValueConstraint::kValue);
642
643 return success();
644 }
645 };
646
647 // -------------------------------------------------------------------------- //
648 // tf.StridedSlice
649 // -------------------------------------------------------------------------- //
650
651 class StridedSliceOpClusteringPolicy
652 : public TensorflowOpClusteringPolicy<StridedSliceOp> {
MatchAndUpdateConstraints(StridedSliceOp op,const ValuesConstraintSet & results,ValuesConstraintSet & operands) const653 LogicalResult MatchAndUpdateConstraints(
654 StridedSliceOp op, const ValuesConstraintSet& results,
655 ValuesConstraintSet& operands) const final {
656 // We must know the shape of the input.
657 operands.Insert(op.input(), ValueConstraint::kShape);
658
659 // And values of operands that control the slice size.
660 operands.Insert({op.begin(), op.end(), op.strides()},
661 ValueConstraint::kValue);
662
663 return success();
664 }
665 };
666
667 } // namespace
668
populateTfCpurtClusteringPolicies(ClusteringPolicySet & policies,CpurtClusteringTier tier)669 void populateTfCpurtClusteringPolicies(ClusteringPolicySet& policies,
670 CpurtClusteringTier tier) {
671 // Returns true if the given cpurt compilation tier is enabled.
672 auto is_enabled = [&](CpurtClusteringTier requested) -> bool {
673 return static_cast<uint8_t>(requested) <= static_cast<uint8_t>(tier);
674 };
675
676 if (is_enabled(CpurtClusteringTier::kTier1)) {
677 policies.Add<CwiseBinaryOpClusteringPolicy, //
678 CwiseUnaryOpClusteringPolicy, //
679 CwiseTernaryOpClusteringPolicy, //
680 StopGradientOpClusteringPolicy, //
681 TransposeOpClusteringPolicy>();
682 }
683
684 if (is_enabled(CpurtClusteringTier::kAll)) {
685 policies.Add<BroadcastToOpClusteringPolicy, //
686 ConcatV2OpClusteringPolicy, //
687 ExpandDimsOpClusteringPolicy, //
688 FillOpClusteringPolicy, //
689 FusedMatMulOpClusteringPolicy, //
690 MatMulOpClusteringPolicy, //
691 PackOpClusteringPolicy, //
692 RangeOpClusteringPolicy, //
693 ReductionOpClusteringPolicy, //
694 ReshapeOpClusteringPolicy, //
695 ShapeOpClusteringPolicy, //
696 SoftmaxOpClusteringPolicy, //
697 StridedSliceOpClusteringPolicy>();
698 }
699 }
700
populateTfCpurtConstraintsPolicies(ClusteringPolicySet & policies,CpurtClusteringTier tier)701 void populateTfCpurtConstraintsPolicies(ClusteringPolicySet& policies,
702 CpurtClusteringTier tier) {
703 populateTfCpurtClusteringPolicies(policies, tier);
704 policies.Add<ConstOpClusteringPolicy>();
705 }
706
707 // -------------------------------------------------------------------------- //
708 // Helper functions.
709 // -------------------------------------------------------------------------- //
710
IsCompilableConstant(mlir::ElementsAttr value)711 mlir::LogicalResult IsCompilableConstant(mlir::ElementsAttr value) {
712 return success(value.getNumElements() <= 16 &&
713 value.getType().getElementType().isIntOrIndexOrFloat());
714 }
715
VerifyCluster(const Cluster & cluster)716 mlir::LogicalResult VerifyCluster(const Cluster& cluster) {
717 llvm::SmallDenseSet<Operation*> ops;
718 for (Operation* op : cluster.operations) {
719 auto inserted = ops.insert(op);
720 assert(inserted.second && "clustered operations must be unique");
721 (void)inserted;
722 }
723
724 for (auto& pair : cluster.constraints) {
725 Value value = pair.getFirst();
726 ValueConstraint constraint = pair.getSecond();
727
728 // We can satisfy shape and rank constraints on the compiled function
729 // operands.
730 if (constraint == ValueConstraint::kRank ||
731 constraint == ValueConstraint::kShape)
732 continue;
733
734 if (constraint == ValueConstraint::kValue &&
735 tfrt::cpu::jit::SupportsValueSpecialization(value.getType()))
736 continue;
737
738 Operation* op = value.getDefiningOp();
739 if (!op) return failure(); // we do not support block arguments
740
741 // Operations defined inside the cluster will be constant folded before the
742 // compilation. This property is guaranteed by the clustering policy.
743 if (ops.contains(op)) continue;
744
745 // Small constants will be sunk into the compiled function body.
746 auto const_op = mlir::dyn_cast<mlir::TF::ConstOp>(op);
747 if (!const_op || failed(IsCompilableConstant(const_op.value())))
748 return failure();
749 }
750
751 return success();
752 }
753
754 } // namespace tensorflow
755