1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/SmallVector.h"
17 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
18 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
19 #include "mlir/Dialect/Func/IR/FuncOps.h"
20 #include "mlir/Dialect/Shape/IR/Shape.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 #include "mlir/IR/Attributes.h"
23 #include "mlir/IR/Builders.h"
24 #include "mlir/IR/BuiltinTypes.h"
25 #include "mlir/IR/ImplicitLocOpBuilder.h"
26 #include "mlir/IR/MLIRContext.h"
27 #include "mlir/IR/PatternMatch.h"
28 #include "mlir/IR/Types.h"
29 #include "mlir/Transforms/DialectConversion.h"
30
31 namespace mlir {
32 namespace mhlo {
33
34 namespace {
35
36 // Broadcasts the 1D value tensor 'value_1d' to the shape of 'result_type'. If
37 // 'shape_value' is initialized, creates a dynamic broadcast, otherwise creates
38 // a static broadcast.
broadcastToFeatureDim(Location loc,RankedTensorType resultType,Value value1d,Value shapeValue,int64_t featureDim,PatternRewriter & rewriter)39 Value broadcastToFeatureDim(Location loc, RankedTensorType resultType,
40 Value value1d, Value shapeValue, int64_t featureDim,
41 PatternRewriter& rewriter) { // NOLINT
42 auto dimsType = RankedTensorType::get({1}, rewriter.getIntegerType(64));
43 auto dims = DenseIntElementsAttr::get(dimsType, {featureDim});
44 if (shapeValue) {
45 return rewriter.createOrFold<mhlo::DynamicBroadcastInDimOp>(
46 loc, resultType, value1d, shapeValue, dims);
47 }
48 assert(resultType.hasStaticShape());
49 return rewriter.create<mhlo::BroadcastInDimOp>(loc, resultType, value1d,
50 dims);
51 }
52
53 // Get the shape of operand, assuming it is a dynamic shape with static rank.
getShapeValue(Location loc,Value operand,PatternRewriter & rewriter)54 Value getShapeValue(Location loc, Value operand,
55 PatternRewriter &rewriter) { // NOLINT
56 RankedTensorType resultType = operand.getType().dyn_cast<RankedTensorType>();
57 return rewriter.create<mlir::shape::ShapeOfOp>(
58 loc,
59 RankedTensorType::get({resultType.getRank()}, rewriter.getIndexType()),
60 operand);
61 }
62
materializeEpsilon(Operation * op,FloatAttr epsilonAttr,FloatType fpType,Value broadcastTo,RankedTensorType broadcastToType,PatternRewriter & rewriter)63 Value materializeEpsilon(Operation *op, FloatAttr epsilonAttr, FloatType fpType,
64 Value broadcastTo, RankedTensorType broadcastToType,
65 PatternRewriter &rewriter) { // NOLINT
66 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
67 if (epsilonAttr.getType() != fpType) {
68 // Need to convert.
69 bool losesInfo;
70 APFloat epsilonFloat = epsilonAttr.getValue();
71 auto status = epsilonFloat.convert(
72 fpType.getFloatSemantics(), APFloat::rmNearestTiesToEven, &losesInfo);
73 if ((status & (~APFloat::opInexact)) != APFloat::opOK) {
74 op->emitWarning() << "Could not convert batch_norm epsilon to target fp "
75 "type: opStatus = "
76 << static_cast<int>(status);
77 return nullptr;
78 }
79 if (losesInfo) {
80 op->emitWarning("Conversion of epsilon loses precision");
81 }
82 epsilonAttr = b.getFloatAttr(fpType, epsilonFloat);
83 }
84
85 auto scalarType = RankedTensorType::get({}, fpType);
86 auto epsilonTensorAttr =
87 DenseElementsAttr::get(scalarType, {epsilonAttr.cast<Attribute>()});
88 Value epsilon = b.create<mhlo::ConstantOp>(epsilonTensorAttr);
89 auto dimsType = RankedTensorType::get({0}, b.getIntegerType(64));
90 auto dims = DenseIntElementsAttr::get(dimsType, SmallVector<int64_t, 1>{});
91 if (broadcastToType.hasStaticShape()) {
92 return b.create<mhlo::BroadcastInDimOp>(broadcastToType, epsilon,
93 /*broadcast_dims=*/dims);
94 }
95 Value shapeValue = getShapeValue(op->getLoc(), broadcastTo, rewriter);
96 return b.createOrFold<mhlo::DynamicBroadcastInDimOp>(broadcastToType, epsilon,
97 shapeValue,
98 /*broadcast_dims=*/dims);
99 }
100
101 class UnfuseBatchNormInferencePattern
102 : public OpRewritePattern<mhlo::BatchNormInferenceOp> {
103 public:
104 using OpRewritePattern<mhlo::BatchNormInferenceOp>::OpRewritePattern;
105
matchAndRewrite(mhlo::BatchNormInferenceOp bnOp,PatternRewriter & rewriter) const106 LogicalResult matchAndRewrite(mhlo::BatchNormInferenceOp bnOp,
107 PatternRewriter& rewriter) const override {
108 // Enforce type invariants.
109 // Note that we deduce the actual element type from the variance,
110 // which should not be subject to quantization at a higher level.
111 auto inputType = bnOp.operand().getType().dyn_cast<RankedTensorType>();
112 auto varianceType = bnOp.variance().getType().dyn_cast<RankedTensorType>();
113 if (!inputType || !varianceType) {
114 return failure();
115 }
116 auto fpType = varianceType.getElementType().dyn_cast<FloatType>();
117 if (!fpType) {
118 return failure();
119 }
120 int64_t featureDim = bnOp.feature_index();
121
122 // Add epsilon to the variance and sqrt to get stddev:
123 // stddev = sqrt(variance + epsilon)
124 auto epsilon =
125 materializeEpsilon(bnOp.getOperation(), bnOp.epsilonAttr(), fpType,
126 bnOp.variance(), varianceType, rewriter);
127 if (!epsilon) {
128 return failure();
129 }
130 Value stddev =
131 rewriter.create<mhlo::AddOp>(bnOp.getLoc(), bnOp.variance(), epsilon);
132 stddev = rewriter.create<mhlo::SqrtOp>(bnOp.getLoc(), stddev);
133
134 // Broadcast all terms.
135 Value shapeValue;
136 if (!inputType.hasStaticShape()) {
137 shapeValue = getShapeValue(bnOp.getLoc(), bnOp.operand(), rewriter);
138 }
139 auto broadcastScale =
140 broadcastToFeatureDim(bnOp.getLoc(), inputType, bnOp.scale(),
141 shapeValue, featureDim, rewriter);
142 auto broadcastOffset =
143 broadcastToFeatureDim(bnOp.getLoc(), inputType, bnOp.offset(),
144 shapeValue, featureDim, rewriter);
145 auto broadcastMean =
146 broadcastToFeatureDim(bnOp.getLoc(), inputType, bnOp.mean(), shapeValue,
147 featureDim, rewriter);
148 auto broadcastStddev = broadcastToFeatureDim(
149 bnOp.getLoc(), inputType, stddev, shapeValue, featureDim, rewriter);
150
151 // Compute:
152 // scale * (input - mean) / stddev + offset
153 Value result = rewriter.create<mhlo::SubtractOp>(
154 bnOp.getLoc(), bnOp.operand(), broadcastMean);
155 result =
156 rewriter.create<mhlo::MulOp>(bnOp.getLoc(), result, broadcastScale);
157 result =
158 rewriter.create<mhlo::DivOp>(bnOp.getLoc(), result, broadcastStddev);
159 rewriter.replaceOpWithNewOp<mhlo::AddOp>(bnOp, result, broadcastOffset);
160
161 return success();
162 }
163 };
164
165 // Create "mhlo.reduce", "operand" is reduce input and "zero" is init value,
166 // reduce sum from operand to operand[feature_index].
createReduce(Location loc,Value operand,Value zero,SmallVector<int64_t> & reduceDims,int64_t featureIndex,PatternRewriter & rewriter)167 Value createReduce(Location loc, Value operand, Value zero,
168 SmallVector<int64_t>& reduceDims, int64_t featureIndex,
169 PatternRewriter& rewriter) {
170 auto operandType = operand.getType().cast<RankedTensorType>();
171 Type reduceResultType = RankedTensorType::get(
172 {operandType.getDimSize(featureIndex)}, operandType.getElementType());
173 mhlo::ReduceOp reduce =
174 rewriter.create<mhlo::ReduceOp>(loc, reduceResultType, operand, zero,
175 rewriter.getI64TensorAttr(reduceDims));
176
177 // setup "mhlo.reduce"'s body
178 Region& region = reduce.body();
179 Block& block = region.emplaceBlock();
180 RankedTensorType blockArgumentType =
181 RankedTensorType::get({}, operandType.getElementType());
182 block.addArgument(blockArgumentType, loc);
183 block.addArgument(blockArgumentType, loc);
184 auto* firstArgument = block.args_begin();
185 auto secondArgument = block.args_rbegin();
186 {
187 OpBuilder::InsertionGuard guard(rewriter);
188 rewriter.setInsertionPointToStart(&block);
189 Value addResult =
190 rewriter.create<mhlo::AddOp>(loc, *firstArgument, *secondArgument);
191 rewriter.create<mhlo::ReturnOp>(loc, addResult);
192 }
193
194 return reduce.getResult(0);
195 }
196
197 // Calculate total reduce size, assuming it is a dynamic shape with static rank.
198 // Reduce from operand to operand[feature_index]/scale
calculateReduceSize(Operation * op,Value operand,RankedTensorType operandType,Value scale,RankedTensorType scaleType,int64_t featureIndex,PatternRewriter & rewriter)199 Value calculateReduceSize(Operation *op, Value operand,
200 RankedTensorType operandType, Value scale,
201 RankedTensorType scaleType, int64_t featureIndex,
202 PatternRewriter &rewriter) {
203 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
204 Type indexType = b.getIndexType();
205 if (!operandType.hasStaticShape()) {
206 // the "operand" has dynamic shape with static rank
207 Value operandShape = getShapeValue(op->getLoc(), operand, rewriter);
208 Value scaleShape = getShapeValue(op->getLoc(), scale, rewriter);
209 Value operandTotalSize =
210 b.create<shape::NumElementsOp>(indexType, operandShape);
211 Value scaleTotalSize =
212 b.create<shape::NumElementsOp>(indexType, scaleShape);
213 Value reduceSize =
214 b.create<shape::DivOp>(indexType, operandTotalSize, scaleTotalSize);
215 reduceSize = b.create<arith::IndexCastOp>(b.getI64Type(), reduceSize);
216 reduceSize = b.create<tensor::FromElementsOp>(reduceSize);
217 reduceSize = b.create<mhlo::ConvertOp>(
218 RankedTensorType::get({1}, operandType.getElementType()), reduceSize);
219 reduceSize = b.create<mhlo::ReshapeOp>(
220 RankedTensorType::get({}, operandType.getElementType()), reduceSize);
221 return b.createOrFold<mhlo::DynamicBroadcastInDimOp>(
222 scaleType, reduceSize, scaleShape, b.getI64TensorAttr({}));
223 }
224
225 // the "operand" has static shape
226 int64_t reduceDimsSize = 1;
227 for (int64_t i = 0, e = operandType.getRank(); i < e; i++) {
228 if (i != featureIndex) {
229 reduceDimsSize *= operandType.getDimSize(i);
230 }
231 }
232 llvm::APFloat floatValue(static_cast<double>(reduceDimsSize));
233 bool losesInfo;
234 floatValue.convert(
235 scaleType.getElementType().cast<FloatType>().getFloatSemantics(),
236 APFloat::rmNearestTiesToEven, &losesInfo);
237 if (losesInfo) {
238 op->emitWarning("Conversion of reduce_dims_size loses precision");
239 }
240 Value reduceSize = b.create<mhlo::ConstantOp>(
241 DenseFPElementsAttr::get(scaleType, floatValue));
242 return reduceSize;
243 }
244
245 // BatchNormTraining(X, scale, offset) =
246 // ((X - E[X]) / Sqrt(Var[X] + epsilon)) * scale + offset.
247 class UnfuseBatchNormTrainingPattern
248 : public OpRewritePattern<mhlo::BatchNormTrainingOp> {
249 public:
250 using OpRewritePattern<mhlo::BatchNormTrainingOp>::OpRewritePattern;
251
matchAndRewrite(mhlo::BatchNormTrainingOp bnOp,PatternRewriter & rewriter) const252 LogicalResult matchAndRewrite(mhlo::BatchNormTrainingOp bnOp,
253 PatternRewriter& rewriter) const override {
254 auto operandType = bnOp.operand().getType().dyn_cast<RankedTensorType>();
255 auto scaleType = bnOp.scale().getType().dyn_cast<RankedTensorType>();
256 if (!operandType || !scaleType) {
257 return failure();
258 }
259 auto fpType = operandType.getElementType().dyn_cast<FloatType>();
260 if (!fpType) {
261 return failure();
262 }
263 int64_t featureIndex = bnOp.feature_index();
264 SmallVector<int64_t> dimensionsWithoutFeature;
265 for (int64_t i = 0, e = operandType.getRank(); i < e; i++) {
266 if (i != featureIndex) {
267 dimensionsWithoutFeature.push_back(i);
268 }
269 }
270
271 // zero constant
272 Value constZero = rewriter.create<mhlo::ConstantOp>(
273 bnOp.getLoc(),
274 DenseFPElementsAttr::get(RankedTensorType::get({}, fpType),
275 APFloat::getZero(fpType.getFloatSemantics())));
276 // epsilon
277 auto epsilon =
278 materializeEpsilon(bnOp.getOperation(), bnOp.epsilonAttr(), fpType,
279 bnOp.scale(), scaleType, rewriter);
280 if (!epsilon) {
281 return failure();
282 }
283 // reduce size constant
284 Value reduceSize =
285 calculateReduceSize(bnOp.getOperation(), bnOp.operand(), operandType,
286 bnOp.scale(), scaleType, featureIndex, rewriter);
287 if (!reduceSize) {
288 return failure();
289 }
290 // Sum[X]
291 Value sum = createReduce(bnOp.getLoc(), bnOp.operand(), constZero,
292 dimensionsWithoutFeature, featureIndex, rewriter);
293 // X^2
294 Value operandSquare = rewriter.create<mhlo::MulOp>(
295 bnOp.getLoc(), bnOp.operand(), bnOp.operand());
296 // Sum[X^2]
297 Value squareSum =
298 createReduce(bnOp.getLoc(), operandSquare, constZero,
299 dimensionsWithoutFeature, featureIndex, rewriter);
300 // E[X]
301 Value mean = rewriter.create<mhlo::DivOp>(bnOp.getLoc(), sum, reduceSize);
302 // E[X^2]
303 Value squareMean =
304 rewriter.create<mhlo::DivOp>(bnOp.getLoc(), squareSum, reduceSize);
305 // E^2[X]
306 Value meanSquare = rewriter.create<mhlo::MulOp>(bnOp.getLoc(), mean, mean);
307 // Var[X]
308 Value var = rewriter.create<mhlo::SubtractOp>(bnOp.getLoc(), squareMean,
309 meanSquare);
310 // Var[X] + epsilon
311 Value varAddEpsilon =
312 rewriter.create<mhlo::AddOp>(bnOp.getLoc(), var, epsilon);
313 // Sqrt(Var[X] + epsilon)
314 Value sqrtVar = rewriter.create<mhlo::SqrtOp>(bnOp.getLoc(), varAddEpsilon);
315
316 Value shapeValue;
317 if (!operandType.hasStaticShape()) {
318 shapeValue = getShapeValue(bnOp.getLoc(), bnOp.operand(), rewriter);
319 }
320 // X - E[X]
321 Value meanBroadcast = broadcastToFeatureDim(
322 bnOp.getLoc(), operandType, mean, shapeValue, featureIndex, rewriter);
323 Value operandMinusMean = rewriter.create<mhlo::SubtractOp>(
324 bnOp.getLoc(), bnOp.operand(), meanBroadcast);
325 // (X - E[X]) / Sqrt(Var[X] + epsilon)
326 Value sqrtVarBroadcast =
327 broadcastToFeatureDim(bnOp.getLoc(), operandType, sqrtVar, shapeValue,
328 featureIndex, rewriter);
329 Value normalized = rewriter.create<mhlo::DivOp>(
330 bnOp.getLoc(), operandMinusMean, sqrtVarBroadcast);
331
332 // ((X - E[X]) / Sqrt(Var[X] + epsilon)) * scale
333 Value scaleBroadcast =
334 broadcastToFeatureDim(bnOp.getLoc(), operandType, bnOp.scale(),
335 shapeValue, featureIndex, rewriter);
336 Value scaledNormalized =
337 rewriter.create<mhlo::MulOp>(bnOp.getLoc(), normalized, scaleBroadcast);
338 // ((X - E[X]) / Sqrt(Var[X] + epsilon)) * scale + offset.
339 Value offsetBroadcast =
340 broadcastToFeatureDim(bnOp.getLoc(), operandType, bnOp.offset(),
341 shapeValue, featureIndex, rewriter);
342 Value shiftedNormalized = rewriter.create<mhlo::AddOp>(
343 bnOp.getLoc(), scaledNormalized, offsetBroadcast);
344
345 // results
346 SmallVector<Value> results = {shiftedNormalized, mean, var};
347 rewriter.replaceOp(bnOp, results);
348
349 return success();
350 }
351 };
352
353 } // namespace
354
355 // Populates conversion patterns to unfuse batch normalization operations.
356 // In combination with marking such ops as illegal, this allows backends that
357 // do not have special support for fused batchnorm to use simpler arithmetic
358 // primitives.
populateUnfuseBatchNormInferencePattern(MLIRContext * context,RewritePatternSet * patterns)359 void populateUnfuseBatchNormInferencePattern(MLIRContext *context,
360 RewritePatternSet *patterns) {
361 patterns->add<UnfuseBatchNormInferencePattern>(context);
362 }
363
populateUnfuseBatchNormTrainingPattern(MLIRContext * context,RewritePatternSet * patterns)364 void populateUnfuseBatchNormTrainingPattern(MLIRContext *context,
365 RewritePatternSet *patterns) {
366 patterns->add<UnfuseBatchNormTrainingPattern>(context);
367 }
368
369 } // namespace mhlo
370 } // namespace mlir
371