1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering HLO dialect to LHLO dialect.
17
18 #include <memory>
19 #include <utility>
20
21 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
25 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
26 #include "mlir/Dialect/MemRef/IR/MemRef.h"
27 #include "mlir/Dialect/Tensor/IR/Tensor.h"
28 #include "mlir/IR/BuiltinDialect.h"
29 #include "mlir/Pass/Pass.h"
30 #include "mlir/Transforms/DialectConversion.h"
31
32 namespace mlir {
33 namespace mhlo {
34 namespace {
35
36 struct RngGetAndUpdateStatePattern
37 : public OpConversionPattern<mhlo::XlaRngGetAndUpdateStateOp> {
38 using OpConversionPattern<
39 mhlo::XlaRngGetAndUpdateStateOp>::OpConversionPattern;
40
matchAndRewritemlir::mhlo::__anon2d40a8040111::RngGetAndUpdateStatePattern41 LogicalResult matchAndRewrite(
42 mhlo::XlaRngGetAndUpdateStateOp op,
43 XlaRngGetAndUpdateStateOpAdaptor adaptor,
44 ConversionPatternRewriter& rewriter) const final {
45 // Get various type related information
46 auto loc = op->getLoc();
47
48 const auto globalName = rewriter.getStringAttr("rng_state");
49 constexpr auto initialSeed = 0x7012395ull;
50 auto seedType = rewriter.getIntegerType(128);
51 auto memrefType = MemRefType::get({}, seedType);
52
53 auto resultType = op.getType();
54 auto wordSize = resultType.getElementType().getIntOrFloatBitWidth();
55 auto smallerIntType = rewriter.getIntegerType(wordSize);
56 auto numElements = resultType.getNumElements();
57
58 // Get or define the global variable
59 auto* globalOp = mlir::SymbolTable::lookupNearestSymbolFrom(op, globalName);
60 if (!globalOp) {
61 auto* parent = mlir::SymbolTable::getNearestSymbolTable(op);
62 OpBuilder::InsertionGuard g(rewriter);
63 rewriter.setInsertionPointToStart(&parent->getRegions().front().front());
64
65 const auto priv = rewriter.getStringAttr("private");
66 auto initialValue = mlir::DenseElementsAttr::get(
67 mlir::RankedTensorType::get({}, seedType),
68 rewriter.getIntegerAttr(seedType, initialSeed));
69 globalOp = rewriter.create<memref::GlobalOp>(
70 loc, globalName, priv, memrefType, initialValue, /*constant=*/false,
71 /*alignment=*/IntegerAttr());
72 }
73 assert(isa<memref::GlobalOp>(globalOp) &&
74 "rng_state was defined somewhere else, not as a global op");
75
76 // Get and update
77 Value rngState =
78 rewriter.create<memref::GetGlobalOp>(loc, memrefType, globalName);
79 Value oldVal = rewriter.create<memref::LoadOp>(loc, rngState);
80 Value delta = rewriter.create<arith::ConstantOp>(
81 loc, rewriter.getIntegerAttr(seedType,
82 static_cast<int64_t>(adaptor.delta())));
83 Value newVal = rewriter.create<arith::AddIOp>(loc, oldVal, delta);
84 (void)rewriter.create<memref::StoreOp>(loc, newVal, rngState);
85
86 // Create the proper return type by packing the old seed into a tensor
87 SmallVector<Value> pieces;
88 for (int i = (numElements - 1) * wordSize; i >= 0; i -= wordSize) {
89 Value shiftDistance = rewriter.create<arith::ConstantOp>(
90 loc, rewriter.getIntegerAttr(seedType, i));
91 pieces.push_back(rewriter.create<arith::TruncIOp>(
92 loc, smallerIntType,
93 rewriter.create<arith::ShRUIOp>(loc, oldVal, shiftDistance)));
94 }
95
96 // Obtain a tensor with the correct shape and bit widths but the incorrect
97 // integer signedness, then cast the tensor to the correct signedness to
98 // ensure that unrealized casts will successfully lower later.
99 Value resultTensor = rewriter.create<tensor::FromElementsOp>(
100 loc, mlir::RankedTensorType::get(resultType.getShape(), smallerIntType),
101 pieces);
102 rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(op, resultType,
103 resultTensor);
104 return success();
105 }
106 };
107
108 struct HloLegalizeToArithmeticPass
109 : public HloLegalizeToArithmeticPassBase<HloLegalizeToArithmeticPass> {
getDependentDialectsmlir::mhlo::__anon2d40a8040111::HloLegalizeToArithmeticPass110 void getDependentDialects(DialectRegistry& registry) const override {
111 registry.insert<arith::ArithmeticDialect, memref::MemRefDialect,
112 tensor::TensorDialect>();
113 }
114
115 public:
runOnOperationmlir::mhlo::__anon2d40a8040111::HloLegalizeToArithmeticPass116 void runOnOperation() override {
117 auto& context = getContext();
118 RewritePatternSet patterns(&context);
119 ConversionTarget target(context);
120
121 populateHloToArithmeticConversionPatterns(&patterns);
122
123 target.addIllegalOp<XlaRngGetAndUpdateStateOp>();
124 target.addLegalDialect<arith::ArithmeticDialect, BuiltinDialect,
125 memref::MemRefDialect, tensor::TensorDialect>();
126
127 auto module = getOperation();
128 if (failed(applyPartialConversion(module, target, std::move(patterns))))
129 signalPassFailure();
130 }
131 };
132
133 } // namespace
134
populateHloToArithmeticConversionPatterns(RewritePatternSet * patterns)135 void populateHloToArithmeticConversionPatterns(RewritePatternSet* patterns) {
136 patterns->add<RngGetAndUpdateStatePattern>(patterns->getContext());
137 }
138
createLegalizeToArithmeticPass()139 std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToArithmeticPass() {
140 return std::make_unique<HloLegalizeToArithmeticPass>();
141 }
142
143 } // namespace mhlo
144 } // namespace mlir
145