xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/dtensor_mixed_precision_reduce.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/strings/string_view.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
19 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
20 #include "mlir/IR/Visitors.h"  // from @llvm-project
21 #include "mlir/Support/LLVM.h"  // from @llvm-project
22 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
23 #include "mlir/Transforms/Passes.h"  // from @llvm-project
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
25 #include "tensorflow/dtensor/cc/dtensor_utils.h"
26 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
27 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
28 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
29 #include "tensorflow/dtensor/mlir/layout_parsing.h"
30 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
31 
32 namespace tensorflow {
33 namespace dtensor {
34 namespace {
35 
36 // Extracts the reduction group size from the group_assignment operand of the
37 // reduce op. group_assignment is a 2-dimensional array where each element is
38 // the list of devices that are a part of the same reduction group.
39 template <class ReduceOpType>
GetAllReduceGroupSize(ReduceOpType reduce_op,int32 * group_size)40 mlir::LogicalResult GetAllReduceGroupSize(ReduceOpType reduce_op,
41                                           int32* group_size) {
42   mlir::DenseIntElementsAttr group_assignment_attr;
43   if (!matchPattern(reduce_op.group_assignment(),
44                     m_Constant(&group_assignment_attr)))
45     return mlir::emitError(reduce_op.getLoc(),
46                            "group_assigment must be a constant.");
47   if (group_assignment_attr.getType().getRank() != 2)
48     return mlir::emitError(reduce_op.getLoc(),
49                            "group_assignment should have two dimensions.");
50 
51   *group_size = group_assignment_attr.getType().getShape()[1];
52   return mlir::success();
53 }
54 
55 // For large enough reduction groups, we compute reductions in a higher
56 // precision type to ensure accuracy is not lost with sequential addition
57 // of large numbers in a lower precision type. If the given reduce op meets the
58 // following criteria:
59 //   - the tensors being reduced are of type bfloat16,
60 //   - the reduction group is at least as large as the configurable env var
61 //     DTENSOR_REDUCE_IN_BFLOAT16_MAX_GROUP_SIZE,
62 // then the tensors are upcasted to float32 for the reduction before being
63 // downcasted again.
64 template <class ReduceOpType>
MaybeUpcastForReduction(ReduceOpType reduce_op,bool * changed)65 mlir::LogicalResult MaybeUpcastForReduction(ReduceOpType reduce_op,
66                                             bool* changed) {
67   const mlir::RankedTensorType& input_type =
68       reduce_op.input().getType().template dyn_cast<mlir::RankedTensorType>();
69   if (!input_type.getElementType().isBF16()) {
70     // Upcast only applies for bfloat16 input.
71     return mlir::success();
72   }
73 
74   mlir::OpBuilder builder(reduce_op);
75   const mlir::Location loc = reduce_op.getLoc();
76 
77   int32 group_size;
78   if (mlir::failed(GetAllReduceGroupSize(reduce_op, &group_size)))
79     return mlir::failure();
80   if (group_size <= ReduceInBfloat16MaxGroupSize())
81     // Reduce group size is not sufficient, so we do not modify the ops.
82     return mlir::success();
83 
84   const auto reduce_layout = ExtractRequiredSingleLayoutFromOp(reduce_op);
85   if (!reduce_layout.ok())
86     return reduce_op.emitOpError(llvm::formatv(
87         "Malformed layout specification for DTensor reduce op found: {0}",
88         reduce_layout.status().error_message()));
89 
90   // The original output tensor type that would have been used by all users of
91   // the reduce op.
92   const mlir::RankedTensorType& output_type =
93       reduce_op.output().getType().template dyn_cast<mlir::RankedTensorType>();
94 
95   mlir::TF::CastOp upcast = builder.create<mlir::TF::CastOp>(
96       loc,
97       mlir::RankedTensorType::get(input_type.getShape(), builder.getF32Type()),
98       reduce_op.input());
99   reduce_op->setOperand(0, upcast.y());
100   reduce_op.output().setType(upcast.y().getType());
101 
102   builder.setInsertionPointAfter(reduce_op);
103   mlir::TF::CastOp downcast = builder.create<mlir::TF::CastOp>(
104       loc,
105       mlir::RankedTensorType::get(output_type.getShape(),
106                                   output_type.getElementType()),
107       reduce_op);
108   // Match the layout of the downcast with the reduce op, this is required for
109   // the later passes.
110   SetSingleLayoutOnOp(downcast, *reduce_layout);
111   reduce_op.output().replaceAllUsesExcept(downcast.y(), downcast);
112 
113   *changed = true;
114   return mlir::success();
115 }
116 
117 template <class ReduceOpType>
TryMixedPrecisionReduce(mlir::func::FuncOp function,absl::string_view opName)118 mlir::LogicalResult TryMixedPrecisionReduce(mlir::func::FuncOp function,
119                                             absl::string_view opName) {
120   int32_t reduceOpsCounter = 0;
121   int32_t changedReduceOpsCounter = 0;
122 
123   mlir::WalkResult walk_result = function.walk([&](ReduceOpType reduce_op) {
124     if (reduce_op.reduce_op().str() == kReduceOpAdd) {
125       reduceOpsCounter += 1;
126       bool changed = false;
127       if (mlir::failed(MaybeUpcastForReduction(reduce_op, &changed)))
128         return mlir::WalkResult::interrupt();
129       if (changed) changedReduceOpsCounter += 1;
130     }
131     return mlir::WalkResult::advance();
132   });
133   if (walk_result.wasInterrupted()) return mlir::failure();
134 
135   VLOG(2) << "Applied mixed precision to " << changedReduceOpsCounter << " of "
136           << reduceOpsCounter << " Add " << opName << " ops.";
137 
138   return mlir::success();
139 }
140 
141 // MLIR pass that enables tensor upcasting within mixed-precision reduction.
142 struct DTensorMixedPrecisionReducePass
143     : public DTensorMixedPrecisionReduceBase<DTensorMixedPrecisionReducePass> {
runOnOperationtensorflow::dtensor::__anoncee055470111::DTensorMixedPrecisionReducePass144   void runOnOperation() override {
145     mlir::func::FuncOp function = getOperation();
146 
147     if (mlir::failed(TryMixedPrecisionReduce<mlir::TF::DTensorAllReduceOp>(
148             function, "DTensorAllReduce")))
149       return signalPassFailure();
150     if (mlir::failed(TryMixedPrecisionReduce<mlir::TF::DTensorReduceScatterOp>(
151             function, "DTensorReduceScatter")))
152       return signalPassFailure();
153   }
154 };
155 
156 }  // namespace
157 
158 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateDTensorMixedPrecisionReducePass()159 CreateDTensorMixedPrecisionReducePass() {
160   return std::make_unique<DTensorMixedPrecisionReducePass>();
161 }
162 
163 }  // namespace dtensor
164 }  // namespace tensorflow
165