1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "absl/strings/string_view.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
19 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
20 #include "mlir/IR/Visitors.h" // from @llvm-project
21 #include "mlir/Support/LLVM.h" // from @llvm-project
22 #include "mlir/Support/LogicalResult.h" // from @llvm-project
23 #include "mlir/Transforms/Passes.h" // from @llvm-project
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
25 #include "tensorflow/dtensor/cc/dtensor_utils.h"
26 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
27 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
28 #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
29 #include "tensorflow/dtensor/mlir/layout_parsing.h"
30 #include "tensorflow/dtensor/mlir/spmd_expander_common.h"
31
32 namespace tensorflow {
33 namespace dtensor {
34 namespace {
35
36 // Extracts the reduction group size from the group_assignment operand of the
37 // reduce op. group_assignment is a 2-dimensional array where each element is
38 // the list of devices that are a part of the same reduction group.
39 template <class ReduceOpType>
GetAllReduceGroupSize(ReduceOpType reduce_op,int32 * group_size)40 mlir::LogicalResult GetAllReduceGroupSize(ReduceOpType reduce_op,
41 int32* group_size) {
42 mlir::DenseIntElementsAttr group_assignment_attr;
43 if (!matchPattern(reduce_op.group_assignment(),
44 m_Constant(&group_assignment_attr)))
45 return mlir::emitError(reduce_op.getLoc(),
46 "group_assigment must be a constant.");
47 if (group_assignment_attr.getType().getRank() != 2)
48 return mlir::emitError(reduce_op.getLoc(),
49 "group_assignment should have two dimensions.");
50
51 *group_size = group_assignment_attr.getType().getShape()[1];
52 return mlir::success();
53 }
54
55 // For large enough reduction groups, we compute reductions in a higher
56 // precision type to ensure accuracy is not lost with sequential addition
57 // of large numbers in a lower precision type. If the given reduce op meets the
58 // following criteria:
59 // - the tensors being reduced are of type bfloat16,
60 // - the reduction group is at least as large as the configurable env var
61 // DTENSOR_REDUCE_IN_BFLOAT16_MAX_GROUP_SIZE,
62 // then the tensors are upcasted to float32 for the reduction before being
63 // downcasted again.
64 template <class ReduceOpType>
MaybeUpcastForReduction(ReduceOpType reduce_op,bool * changed)65 mlir::LogicalResult MaybeUpcastForReduction(ReduceOpType reduce_op,
66 bool* changed) {
67 const mlir::RankedTensorType& input_type =
68 reduce_op.input().getType().template dyn_cast<mlir::RankedTensorType>();
69 if (!input_type.getElementType().isBF16()) {
70 // Upcast only applies for bfloat16 input.
71 return mlir::success();
72 }
73
74 mlir::OpBuilder builder(reduce_op);
75 const mlir::Location loc = reduce_op.getLoc();
76
77 int32 group_size;
78 if (mlir::failed(GetAllReduceGroupSize(reduce_op, &group_size)))
79 return mlir::failure();
80 if (group_size <= ReduceInBfloat16MaxGroupSize())
81 // Reduce group size is not sufficient, so we do not modify the ops.
82 return mlir::success();
83
84 const auto reduce_layout = ExtractRequiredSingleLayoutFromOp(reduce_op);
85 if (!reduce_layout.ok())
86 return reduce_op.emitOpError(llvm::formatv(
87 "Malformed layout specification for DTensor reduce op found: {0}",
88 reduce_layout.status().error_message()));
89
90 // The original output tensor type that would have been used by all users of
91 // the reduce op.
92 const mlir::RankedTensorType& output_type =
93 reduce_op.output().getType().template dyn_cast<mlir::RankedTensorType>();
94
95 mlir::TF::CastOp upcast = builder.create<mlir::TF::CastOp>(
96 loc,
97 mlir::RankedTensorType::get(input_type.getShape(), builder.getF32Type()),
98 reduce_op.input());
99 reduce_op->setOperand(0, upcast.y());
100 reduce_op.output().setType(upcast.y().getType());
101
102 builder.setInsertionPointAfter(reduce_op);
103 mlir::TF::CastOp downcast = builder.create<mlir::TF::CastOp>(
104 loc,
105 mlir::RankedTensorType::get(output_type.getShape(),
106 output_type.getElementType()),
107 reduce_op);
108 // Match the layout of the downcast with the reduce op, this is required for
109 // the later passes.
110 SetSingleLayoutOnOp(downcast, *reduce_layout);
111 reduce_op.output().replaceAllUsesExcept(downcast.y(), downcast);
112
113 *changed = true;
114 return mlir::success();
115 }
116
117 template <class ReduceOpType>
TryMixedPrecisionReduce(mlir::func::FuncOp function,absl::string_view opName)118 mlir::LogicalResult TryMixedPrecisionReduce(mlir::func::FuncOp function,
119 absl::string_view opName) {
120 int32_t reduceOpsCounter = 0;
121 int32_t changedReduceOpsCounter = 0;
122
123 mlir::WalkResult walk_result = function.walk([&](ReduceOpType reduce_op) {
124 if (reduce_op.reduce_op().str() == kReduceOpAdd) {
125 reduceOpsCounter += 1;
126 bool changed = false;
127 if (mlir::failed(MaybeUpcastForReduction(reduce_op, &changed)))
128 return mlir::WalkResult::interrupt();
129 if (changed) changedReduceOpsCounter += 1;
130 }
131 return mlir::WalkResult::advance();
132 });
133 if (walk_result.wasInterrupted()) return mlir::failure();
134
135 VLOG(2) << "Applied mixed precision to " << changedReduceOpsCounter << " of "
136 << reduceOpsCounter << " Add " << opName << " ops.";
137
138 return mlir::success();
139 }
140
141 // MLIR pass that enables tensor upcasting within mixed-precision reduction.
142 struct DTensorMixedPrecisionReducePass
143 : public DTensorMixedPrecisionReduceBase<DTensorMixedPrecisionReducePass> {
runOnOperationtensorflow::dtensor::__anoncee055470111::DTensorMixedPrecisionReducePass144 void runOnOperation() override {
145 mlir::func::FuncOp function = getOperation();
146
147 if (mlir::failed(TryMixedPrecisionReduce<mlir::TF::DTensorAllReduceOp>(
148 function, "DTensorAllReduce")))
149 return signalPassFailure();
150 if (mlir::failed(TryMixedPrecisionReduce<mlir::TF::DTensorReduceScatterOp>(
151 function, "DTensorReduceScatter")))
152 return signalPassFailure();
153 }
154 };
155
156 } // namespace
157
158 std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
CreateDTensorMixedPrecisionReducePass()159 CreateDTensorMixedPrecisionReducePass() {
160 return std::make_unique<DTensorMixedPrecisionReducePass>();
161 }
162
163 } // namespace dtensor
164 } // namespace tensorflow
165