1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <string> 17 18 #include "llvm/ADT/StringRef.h" 19 #include "mlir/Dialect/Quant/QuantTypes.h" // from @llvm-project 20 #include "mlir/IR/Block.h" // from @llvm-project 21 #include "mlir/IR/Builders.h" // from @llvm-project 22 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project 23 #include "mlir/IR/Operation.h" // from @llvm-project 24 #include "mlir/IR/Value.h" // from @llvm-project 25 #include "mlir/Pass/Pass.h" // from @llvm-project 26 #include "mlir/Pass/PassRegistry.h" // from @llvm-project 27 #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" 28 #include "tensorflow/compiler/mlir/tfr/passes/passes.h" 29 30 namespace mlir { 31 namespace TFR { 32 33 class RewriteQuantizedIOPass 34 : public PassWrapper<RewriteQuantizedIOPass, OperationPass<ModuleOp>> { 35 public: getArgument() const36 StringRef getArgument() const final { return "tfr-rewrite-quantized-io"; } 37 getDescription() const38 StringRef getDescription() const final { 39 return "Replaces operands and results that has quantized type with their " 40 "storage types."; 41 } 42 void runOnOperation() override; 43 }; 44 runOnOperation()45void RewriteQuantizedIOPass::runOnOperation() { 46 ModuleOp module = getOperation(); 47 OpBuilder builder(module); 48 module.walk([&](func::FuncOp func) { 49 Block& block = func.front(); 50 Operation* terminator = block.getTerminator(); 51 52 // Replace input_arg(tensor<quant_type>) -> tfr.cast 53 // with input_arg(tensor<storage_type>) -> tfr.cast 54 for (BlockArgument arg : block.getArguments()) { 55 Type arg_type = arg.getType(); 56 if (auto quant_type = arg_type.cast<TensorType>() 57 .getElementType() 58 .dyn_cast<quant::QuantizedType>()) { 59 if (arg.hasOneUse() && llvm::isa<TFR::CastOp>(*arg.user_begin())) { 60 arg.setType( 61 arg_type.cast<TensorType>().clone(quant_type.getStorageType())); 62 } else { 63 std::string error_message; 64 llvm::raw_string_ostream os{error_message}; 65 os << "The argument with type "; 66 arg.getType().print(os); 67 os << " should have one user, which should be tfr.cast."; 68 func->emitError(error_message); 69 return; 70 } 71 } 72 } 73 74 builder.setInsertionPoint(terminator); 75 // Replace tfr.cast(tensor<quant_type>) -> output 76 // with tfr.cast(tensor<storage_type>) -> output 77 for (OpOperand& returned_value : terminator->getOpOperands()) { 78 auto returned_type = 79 returned_value.get().getType().dyn_cast<TensorType>(); 80 if (!returned_type || 81 !returned_type.getElementType().isa<quant::QuantizedType>()) { 82 continue; 83 } 84 85 if (auto returned_op = 86 returned_value.get().getDefiningOp<TFR::CastOp>()) { 87 auto new_type = returned_type.clone(returned_type.getElementType() 88 .cast<quant::QuantizedType>() 89 .getStorageType()); 90 auto new_op = builder.create<TFR::CastOp>(returned_op->getLoc(), 91 new_type, returned_op.arg()); 92 returned_value.set(new_op.getResult()); 93 if (returned_op.use_empty()) { 94 returned_op.erase(); 95 } 96 } else { 97 returned_value.get().getDefiningOp()->emitError( 98 "The producer of quantized type result should be a tfr.cast op."); 99 return; 100 } 101 } 102 103 auto new_func_type = builder.getFunctionType(block.getArgumentTypes(), 104 terminator->getOperandTypes()); 105 func.setType(new_func_type); 106 }); 107 } 108 109 // Creates an instance of the pass to decompose the TF ops. CreateRewriteQuantizedIOPass()110std::unique_ptr<OperationPass<ModuleOp>> CreateRewriteQuantizedIOPass() { 111 return std::make_unique<RewriteQuantizedIOPass>(); 112 } 113 __anona3c06dd20202null114static PassRegistration<RewriteQuantizedIOPass> pass([] { 115 return CreateRewriteQuantizedIOPass(); 116 }); 117 118 } // namespace TFR 119 } // namespace mlir 120