1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // The TF dialect uses some TF types that are illegal in the MHLO dialect and
17 // some generic types that are legal in MHLO. This pass legalizes TF types into
18 // types that are legal in MHLO. For example, TF::Qint8Type is converted to i8.
19 // Rewrites here should run before TF to MHLO op legalizations are run.
20 // TODO(b/180234029): The rewrite here should be part of the LegalizeTF pass
21 // rather than its own pass.
22
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
27 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
28 #include "mlir/IR/PatternMatch.h" // from @llvm-project
29 #include "mlir/Pass/Pass.h" // from @llvm-project
30 #include "mlir/Support/LLVM.h" // from @llvm-project
31 #include "mlir/Support/LogicalResult.h" // from @llvm-project
32 #include "mlir/Transforms/DialectConversion.h" // from @llvm-project
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
34 #include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes_detail.h"
35
36 #define DEBUG_TYPE "xla-legalize-tf-types"
37
38 namespace mlir {
39 namespace mhlo {
40 namespace {
41
IsIllegalElementType(Type type)42 bool IsIllegalElementType(Type type) {
43 return type
44 .isa<mlir::TF::Qint8Type, mlir::TF::Qint16Type, mlir::TF::Qint32Type,
45 mlir::TF::Quint8Type, mlir::TF::Quint16Type>();
46 }
47
ToLegalElementType(Type type)48 Type ToLegalElementType(Type type) {
49 return TypeSwitch<Type, Type>(type)
50 .Case<mlir::TF::Qint8Type>([&type](Type) {
51 return mlir::IntegerType::get(type.getContext(), 8);
52 })
53 .Case<mlir::TF::Qint16Type>([&type](Type) {
54 return mlir::IntegerType::get(type.getContext(), 16);
55 })
56 .Case<mlir::TF::Qint32Type>([&type](Type) {
57 return mlir::IntegerType::get(type.getContext(), 32);
58 })
59 .Case<mlir::TF::Quint8Type>([&type](Type) {
60 return mlir::IntegerType::get(
61 type.getContext(), 8,
62 mlir::IntegerType::SignednessSemantics::Unsigned);
63 })
64 .Case<mlir::TF::Quint16Type>([&type](Type) {
65 return mlir::IntegerType::get(
66 type.getContext(), 16,
67 mlir::IntegerType::SignednessSemantics::Unsigned);
68 })
69 .Default([&type](Type) { return type; });
70 }
71
72 // TODO(b/180234863): What's below this line is generic so convert it to a
73 // utility.
74
IsIllegalType(Type type)75 bool IsIllegalType(Type type) {
76 return IsIllegalElementType(getElementTypeOrSelf(type));
77 }
78
ToLegalType(Type type)79 Type ToLegalType(Type type) {
80 if (IsIllegalElementType(type)) return ToLegalElementType(type);
81 if (auto shaped = type.dyn_cast<ShapedType>()) {
82 Type elem = shaped.getElementType();
83 if (IsIllegalType(elem)) return shaped.clone(ToLegalType(elem));
84 }
85 return type;
86 }
87
88 class TfTypeConverter : public TypeConverter {
89 public:
TfTypeConverter()90 TfTypeConverter() {
91 addConversion([](Type type) -> Type {
92 return IsIllegalType(type) ? ToLegalType(type) : type;
93 });
94 }
95 };
96
97 // An Op is illegal iff it contains an illegalType.
98 class TfTypeConversionTarget : public ConversionTarget {
99 public:
TfTypeConversionTarget(MLIRContext & ctx,TfTypeConverter & converter)100 explicit TfTypeConversionTarget(MLIRContext &ctx, TfTypeConverter &converter)
101 : ConversionTarget(ctx), converter_(converter) {
102 markUnknownOpDynamicallyLegal([this](Operation *op) {
103 // The FuncOp type can contain types that the op's operand and result
104 // types do not contain.
105 if (auto func = dyn_cast<func::FuncOp>(op)) {
106 if (!converter_.isSignatureLegal(func.getFunctionType())) return false;
107 }
108 return converter_.isLegal(op);
109 });
110 }
111
112 private:
113 TfTypeConverter &converter_;
114 };
115
116 class TfTypePattern : public ConversionPattern {
117 public:
TfTypePattern(MLIRContext * ctx,TypeConverter & converter)118 TfTypePattern(MLIRContext *ctx, TypeConverter &converter)
119 : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, ctx) {}
120
121 // The dialect conversion framework will call this matchAndRewrite on each
122 // Operation in the IR tree. This call matchAndRewrite needs to update the
123 // Operation's results and child regions.
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const124 LogicalResult matchAndRewrite(
125 Operation *op, ArrayRef<Value> operands,
126 ConversionPatternRewriter &rewriter) const override {
127 // Update the results.
128 llvm::SmallVector<Type, 4> new_results;
129 if (failed(getTypeConverter()->convertTypes(op->getResultTypes(),
130 new_results)))
131 return failure();
132
133 // Update the regions. The dialect conversion framework wants new regions to
134 // be created and updated, rather than updating the old op. Thus we use an
135 // OperationState so we can add regions to the new up.
136 OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
137 new_results, op->getAttrs(), op->getSuccessors());
138 for (Region ®ion : op->getRegions()) {
139 Region &new_region = *state.addRegion();
140 rewriter.inlineRegionBefore(region, new_region, new_region.begin());
141 if (failed(rewriter.convertRegionTypes(&new_region, *getTypeConverter())))
142 return failure();
143 }
144 rewriter.replaceOp(op, rewriter.create(state)->getResults());
145
146 return success();
147 }
148 };
149
150 struct LegalizeTfTypesPass
151 : public LegalizeTfTypesPassBase<LegalizeTfTypesPass> {
152 void runOnOperation() override;
153 };
154
runOnOperation()155 void LegalizeTfTypesPass::runOnOperation() {
156 TfTypeConverter converter;
157 RewritePatternSet patterns(&getContext());
158 patterns.add<TfTypePattern>(&getContext(), converter);
159 populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
160 converter);
161 TfTypeConversionTarget target(getContext(), converter);
162 if (failed(applyFullConversion(getOperation(), target, std::move(patterns))))
163 return signalPassFailure();
164 }
165
166 } // namespace
167
CreateLegalizeTfTypesPass()168 std::unique_ptr<OperationPass<>> CreateLegalizeTfTypesPass() {
169 return std::make_unique<LegalizeTfTypesPass>();
170 }
171
172 } // namespace mhlo
173 } // namespace mlir
174