xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/tfr/passes/rewrite_quantized_io.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 
18 #include "llvm/ADT/StringRef.h"
19 #include "mlir/Dialect/Quant/QuantTypes.h"  // from @llvm-project
20 #include "mlir/IR/Block.h"  // from @llvm-project
21 #include "mlir/IR/Builders.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
23 #include "mlir/IR/Operation.h"  // from @llvm-project
24 #include "mlir/IR/Value.h"  // from @llvm-project
25 #include "mlir/Pass/Pass.h"  // from @llvm-project
26 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
27 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h"
28 #include "tensorflow/compiler/mlir/tfr/passes/passes.h"
29 
30 namespace mlir {
31 namespace TFR {
32 
33 class RewriteQuantizedIOPass
34     : public PassWrapper<RewriteQuantizedIOPass, OperationPass<ModuleOp>> {
35  public:
getArgument() const36   StringRef getArgument() const final { return "tfr-rewrite-quantized-io"; }
37 
getDescription() const38   StringRef getDescription() const final {
39     return "Replaces operands and results that has quantized type with their "
40            "storage types.";
41   }
42   void runOnOperation() override;
43 };
44 
runOnOperation()45 void RewriteQuantizedIOPass::runOnOperation() {
46   ModuleOp module = getOperation();
47   OpBuilder builder(module);
48   module.walk([&](func::FuncOp func) {
49     Block& block = func.front();
50     Operation* terminator = block.getTerminator();
51 
52     // Replace input_arg(tensor<quant_type>) -> tfr.cast
53     // with input_arg(tensor<storage_type>) -> tfr.cast
54     for (BlockArgument arg : block.getArguments()) {
55       Type arg_type = arg.getType();
56       if (auto quant_type = arg_type.cast<TensorType>()
57                                 .getElementType()
58                                 .dyn_cast<quant::QuantizedType>()) {
59         if (arg.hasOneUse() && llvm::isa<TFR::CastOp>(*arg.user_begin())) {
60           arg.setType(
61               arg_type.cast<TensorType>().clone(quant_type.getStorageType()));
62         } else {
63           std::string error_message;
64           llvm::raw_string_ostream os{error_message};
65           os << "The argument with type ";
66           arg.getType().print(os);
67           os << " should have one user, which should be tfr.cast.";
68           func->emitError(error_message);
69           return;
70         }
71       }
72     }
73 
74     builder.setInsertionPoint(terminator);
75     // Replace tfr.cast(tensor<quant_type>) -> output
76     // with tfr.cast(tensor<storage_type>) -> output
77     for (OpOperand& returned_value : terminator->getOpOperands()) {
78       auto returned_type =
79           returned_value.get().getType().dyn_cast<TensorType>();
80       if (!returned_type ||
81           !returned_type.getElementType().isa<quant::QuantizedType>()) {
82         continue;
83       }
84 
85       if (auto returned_op =
86               returned_value.get().getDefiningOp<TFR::CastOp>()) {
87         auto new_type = returned_type.clone(returned_type.getElementType()
88                                                 .cast<quant::QuantizedType>()
89                                                 .getStorageType());
90         auto new_op = builder.create<TFR::CastOp>(returned_op->getLoc(),
91                                                   new_type, returned_op.arg());
92         returned_value.set(new_op.getResult());
93         if (returned_op.use_empty()) {
94           returned_op.erase();
95         }
96       } else {
97         returned_value.get().getDefiningOp()->emitError(
98             "The producer of quantized type result should be a tfr.cast op.");
99         return;
100       }
101     }
102 
103     auto new_func_type = builder.getFunctionType(block.getArgumentTypes(),
104                                                  terminator->getOperandTypes());
105     func.setType(new_func_type);
106   });
107 }
108 
109 // Creates an instance of the pass to decompose the TF ops.
CreateRewriteQuantizedIOPass()110 std::unique_ptr<OperationPass<ModuleOp>> CreateRewriteQuantizedIOPass() {
111   return std::make_unique<RewriteQuantizedIOPass>();
112 }
113 
__anona3c06dd20202null114 static PassRegistration<RewriteQuantizedIOPass> pass([] {
115   return CreateRewriteQuantizedIOPass();
116 });
117 
118 }  // namespace TFR
119 }  // namespace mlir
120