xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/legalize_tf_types.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // The TF dialect uses some TF types that are illegal in the MHLO dialect and
17 // some generic types that are legal in MHLO. This pass legalizes TF types into
18 // types that are legal in MHLO. For example, TF::Qint8Type is converted to i8.
19 // Rewrites here should run before TF to MHLO op legalizations are run.
20 // TODO(b/180234029): The rewrite here should be part of the LegalizeTF pass
21 // rather than its own pass.
22 
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
28 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
29 #include "mlir/Pass/Pass.h"  // from @llvm-project
30 #include "mlir/Support/LLVM.h"  // from @llvm-project
31 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
32 #include "mlir/Transforms/DialectConversion.h"  // from @llvm-project
33 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
34 #include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes_detail.h"
35 
36 #define DEBUG_TYPE "xla-legalize-tf-types"
37 
38 namespace mlir {
39 namespace mhlo {
40 namespace {
41 
IsIllegalElementType(Type type)42 bool IsIllegalElementType(Type type) {
43   return type
44       .isa<mlir::TF::Qint8Type, mlir::TF::Qint16Type, mlir::TF::Qint32Type,
45            mlir::TF::Quint8Type, mlir::TF::Quint16Type>();
46 }
47 
ToLegalElementType(Type type)48 Type ToLegalElementType(Type type) {
49   return TypeSwitch<Type, Type>(type)
50       .Case<mlir::TF::Qint8Type>([&type](Type) {
51         return mlir::IntegerType::get(type.getContext(), 8);
52       })
53       .Case<mlir::TF::Qint16Type>([&type](Type) {
54         return mlir::IntegerType::get(type.getContext(), 16);
55       })
56       .Case<mlir::TF::Qint32Type>([&type](Type) {
57         return mlir::IntegerType::get(type.getContext(), 32);
58       })
59       .Case<mlir::TF::Quint8Type>([&type](Type) {
60         return mlir::IntegerType::get(
61             type.getContext(), 8,
62             mlir::IntegerType::SignednessSemantics::Unsigned);
63       })
64       .Case<mlir::TF::Quint16Type>([&type](Type) {
65         return mlir::IntegerType::get(
66             type.getContext(), 16,
67             mlir::IntegerType::SignednessSemantics::Unsigned);
68       })
69       .Default([&type](Type) { return type; });
70 }
71 
72 // TODO(b/180234863): What's below this line is generic so convert it to a
73 // utility.
74 
IsIllegalType(Type type)75 bool IsIllegalType(Type type) {
76   return IsIllegalElementType(getElementTypeOrSelf(type));
77 }
78 
ToLegalType(Type type)79 Type ToLegalType(Type type) {
80   if (IsIllegalElementType(type)) return ToLegalElementType(type);
81   if (auto shaped = type.dyn_cast<ShapedType>()) {
82     Type elem = shaped.getElementType();
83     if (IsIllegalType(elem)) return shaped.clone(ToLegalType(elem));
84   }
85   return type;
86 }
87 
88 class TfTypeConverter : public TypeConverter {
89  public:
TfTypeConverter()90   TfTypeConverter() {
91     addConversion([](Type type) -> Type {
92       return IsIllegalType(type) ? ToLegalType(type) : type;
93     });
94   }
95 };
96 
97 // An Op is illegal iff it contains an illegalType.
98 class TfTypeConversionTarget : public ConversionTarget {
99  public:
TfTypeConversionTarget(MLIRContext & ctx,TfTypeConverter & converter)100   explicit TfTypeConversionTarget(MLIRContext &ctx, TfTypeConverter &converter)
101       : ConversionTarget(ctx), converter_(converter) {
102     markUnknownOpDynamicallyLegal([this](Operation *op) {
103       // The FuncOp type can contain types that the op's operand and result
104       // types do not contain.
105       if (auto func = dyn_cast<func::FuncOp>(op)) {
106         if (!converter_.isSignatureLegal(func.getFunctionType())) return false;
107       }
108       return converter_.isLegal(op);
109     });
110   }
111 
112  private:
113   TfTypeConverter &converter_;
114 };
115 
116 class TfTypePattern : public ConversionPattern {
117  public:
TfTypePattern(MLIRContext * ctx,TypeConverter & converter)118   TfTypePattern(MLIRContext *ctx, TypeConverter &converter)
119       : ConversionPattern(converter, MatchAnyOpTypeTag(), 1, ctx) {}
120 
121   // The dialect conversion framework will call this matchAndRewrite on each
122   // Operation in the IR tree. This call matchAndRewrite needs to update the
123   // Operation's results and child regions.
matchAndRewrite(Operation * op,ArrayRef<Value> operands,ConversionPatternRewriter & rewriter) const124   LogicalResult matchAndRewrite(
125       Operation *op, ArrayRef<Value> operands,
126       ConversionPatternRewriter &rewriter) const override {
127     // Update the results.
128     llvm::SmallVector<Type, 4> new_results;
129     if (failed(getTypeConverter()->convertTypes(op->getResultTypes(),
130                                                 new_results)))
131       return failure();
132 
133     // Update the regions. The dialect conversion framework wants new regions to
134     // be created and updated, rather than updating the old op. Thus we use an
135     // OperationState so we can add regions to the new up.
136     OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
137                          new_results, op->getAttrs(), op->getSuccessors());
138     for (Region &region : op->getRegions()) {
139       Region &new_region = *state.addRegion();
140       rewriter.inlineRegionBefore(region, new_region, new_region.begin());
141       if (failed(rewriter.convertRegionTypes(&new_region, *getTypeConverter())))
142         return failure();
143     }
144     rewriter.replaceOp(op, rewriter.create(state)->getResults());
145 
146     return success();
147   }
148 };
149 
150 struct LegalizeTfTypesPass
151     : public LegalizeTfTypesPassBase<LegalizeTfTypesPass> {
152   void runOnOperation() override;
153 };
154 
runOnOperation()155 void LegalizeTfTypesPass::runOnOperation() {
156   TfTypeConverter converter;
157   RewritePatternSet patterns(&getContext());
158   patterns.add<TfTypePattern>(&getContext(), converter);
159   populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,
160                                                                  converter);
161   TfTypeConversionTarget target(getContext(), converter);
162   if (failed(applyFullConversion(getOperation(), target, std::move(patterns))))
163     return signalPassFailure();
164 }
165 
166 }  // namespace
167 
CreateLegalizeTfTypesPass()168 std::unique_ptr<OperationPass<>> CreateLegalizeTfTypesPass() {
169   return std::make_unique<LegalizeTfTypesPass>();
170 }
171 
172 }  // namespace mhlo
173 }  // namespace mlir
174