1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <cstdint>
18 #include <memory>
19 #include <numeric>
20 #include <vector>
21
22 #include "absl/types/span.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
25 #include "mlir/IR/Attributes.h" // from @llvm-project
26 #include "mlir/IR/Builders.h" // from @llvm-project
27 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
28 #include "mlir/IR/Dialect.h" // from @llvm-project
29 #include "mlir/IR/OpDefinition.h" // from @llvm-project
30 #include "mlir/IR/Operation.h" // from @llvm-project
31 #include "mlir/IR/Types.h" // from @llvm-project
32 #include "mlir/Pass/Pass.h" // from @llvm-project
33 #include "mlir/Support/LogicalResult.h" // from @llvm-project
34 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
35 #include "tensorflow/compiler/xla/layout.h"
36 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
37 #include "tensorflow/compiler/xla/shape.h"
38 #include "tensorflow/core/tpu/tpu_api.h"
39 #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
40
41 namespace mlir {
42 namespace mhlo {
GetTPUInfeedLayoutFromAPI(RankedTensorType t)43 static FailureOr<std::vector<int64_t>> GetTPUInfeedLayoutFromAPI(
44 RankedTensorType t) {
45 // Call the TPU API to determine the right infeed layout. Note that
46 // this can fail if we're not running on a TPU-enabled node.
47 // TODO(kramm): Move this into a separate pass. See b/184944903
48 xla::Shape old_shape = xla::TypeToShape(t);
49 XLA_Shape old_shape_c = {};
50 XLA_Shape new_shape_c = {};
51 TfTpu_ExecutorApiFn *executor = tensorflow::tpu::ExecutorApiFn();
52 if (!tensorflow::tpu::IsInitialized(executor)) {
53 return failure();
54 }
55 ApiConverter::ToC(old_shape, &old_shape_c);
56 executor->TpuTransferManager_GetInfeedLayoutFn(&old_shape_c, &new_shape_c);
57 xla::Shape new_shape = ApiConverter::FromC(&new_shape_c);
58 ApiConverter::Destroy(&old_shape_c);
59 ApiConverter::Destroy(&new_shape_c);
60
61 auto minor_to_major = new_shape.layout().minor_to_major();
62 return std::vector<int64_t>(minor_to_major.begin(), minor_to_major.end());
63 }
64
GetTPUInfeedLayout(const ArrayRef<Type> types,OpBuilder & rewriter)65 FailureOr<Attribute> GetTPUInfeedLayout(const ArrayRef<Type> types,
66 OpBuilder &rewriter) {
67 auto i64_type = rewriter.getIntegerType(64);
68 if (types.size() > 1) {
69 llvm::SmallVector<mlir::Attribute> v;
70 v.reserve(types.size());
71 for (const mlir::Type &t : types) {
72 if (t.isa<TokenType>()) continue;
73 auto layout = GetTPUInfeedLayout({t}, rewriter);
74 if (failed(layout)) return failure();
75 v.push_back(layout.getValue());
76 }
77 ArrayRef<Attribute> shape(v);
78 return rewriter.getArrayAttr(shape);
79 } else if (types[0].isa<TupleType>()) {
80 auto tuple_type = types[0].dyn_cast<TupleType>();
81 const auto &types = tuple_type.getTypes();
82 llvm::SmallVector<mlir::Attribute> v;
83 v.reserve(types.size());
84 for (const mlir::Type &t : types) {
85 if (t.isa<TokenType>()) continue;
86 auto layout = GetTPUInfeedLayout({t}, rewriter);
87 if (failed(layout)) return failure();
88 v.push_back(layout.getValue());
89 }
90 ArrayRef<Attribute> shape(v);
91 return rewriter.getArrayAttr(shape);
92 } else if (auto t = types[0].dyn_cast<RankedTensorType>()) {
93 if (!t.hasStaticShape()) return failure();
94 auto layout = GetTPUInfeedLayoutFromAPI(t);
95 std::vector<int64_t> minor_to_major;
96 if (succeeded(layout)) {
97 minor_to_major = layout.getValue();
98 } else {
99 /* If we're not running on a TPU node, we might not be able to
100 * actually call the part of the TPU API that gives us layout.
101 * This happens e.g. for unit tests. Below we just create a reasonable
102 * layout. We sort by dimension size, which makes the layout agree with
103 * the "correct" TPU layout in surprisingly many cases.
104 * Note that the corresponding InfeedEnqueue op will be generated
105 * through another path, and might still generate an (incompatible)
106 * layout using the TPU API. Running legalize_tf.cc on non-TPU nodes
107 * thus is a potential source of bugs.
108 */
109 minor_to_major.resize(t.getRank());
110 std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
111 std::sort(minor_to_major.begin(), minor_to_major.end(),
112 [=](int64_t a, int64_t b) {
113 int64_t da = t.getDimSize(a);
114 int64_t db = t.getDimSize(b);
115 return da > db || (da == db && a > b);
116 });
117 }
118 std::vector<Attribute> elements;
119 elements.reserve(minor_to_major.size());
120 for (auto e : minor_to_major) {
121 elements.push_back(rewriter.getIntegerAttr(i64_type, e));
122 }
123 return rewriter.getArrayAttr(elements);
124 } else {
125 // types.size() == 1 and types[0] == TokenType
126 // For this case, we return an empty array attribute.
127 return rewriter.getArrayAttr({});
128 }
129 }
130
131 namespace {
132 class AdjustLayout
133 : public PassWrapper<AdjustLayout, OperationPass<func::FuncOp>> {
getDependentDialects(DialectRegistry & registry) const134 void getDependentDialects(DialectRegistry ®istry) const override {
135 registry.insert<mhlo::MhloDialect>();
136 }
137
138 public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AdjustLayout)139 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AdjustLayout)
140
141 StringRef getArgument() const final { return "xla-adjust-layout"; }
getDescription() const142 StringRef getDescription() const final {
143 return "Adjust layouts so infeed send & receive use the same format.";
144 }
145
runOnInfeedOp(::mlir::mhlo::InfeedOp op)146 static void runOnInfeedOp(::mlir::mhlo::InfeedOp op) {
147 OpBuilder builder(op.getContext());
148 SmallVector<Type> result_types(op.getResultTypes().begin(),
149 op.getResultTypes().end());
150 if (!op->getAttr("layout")) {
151 auto layout = GetTPUInfeedLayout(result_types, builder);
152 if (failed(layout)) return;
153
154 op->setAttr("layout", layout.getValue());
155 }
156 }
157
runOnOperation()158 void runOnOperation() override { getOperation().walk(runOnInfeedOp); }
159 };
160 } // anonymous namespace
161
162 // Header for this is in passes.h, which pulls into many deps. NOLINTNEXTLINE
CreateAdjustLayoutPass()163 std::unique_ptr<Pass> CreateAdjustLayoutPass() {
164 return std::make_unique<AdjustLayout>();
165 }
166
RegisterAdjustLayoutPass()167 void RegisterAdjustLayoutPass() { static PassRegistration<AdjustLayout> pass; }
168
169 } // namespace mhlo
170
171 } // namespace mlir
172