xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/adjust_layout.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <cstdint>
18 #include <memory>
19 #include <numeric>
20 #include <vector>
21 
22 #include "absl/types/span.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
25 #include "mlir/IR/Attributes.h"  // from @llvm-project
26 #include "mlir/IR/Builders.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
28 #include "mlir/IR/Dialect.h"  // from @llvm-project
29 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
30 #include "mlir/IR/Operation.h"  // from @llvm-project
31 #include "mlir/IR/Types.h"  // from @llvm-project
32 #include "mlir/Pass/Pass.h"  // from @llvm-project
33 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
34 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
35 #include "tensorflow/compiler/xla/layout.h"
36 #include "tensorflow/compiler/xla/mlir_hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
37 #include "tensorflow/compiler/xla/shape.h"
38 #include "tensorflow/core/tpu/tpu_api.h"
39 #include "tensorflow/stream_executor/tpu/c_api_conversions.h"
40 
41 namespace mlir {
42 namespace mhlo {
GetTPUInfeedLayoutFromAPI(RankedTensorType t)43 static FailureOr<std::vector<int64_t>> GetTPUInfeedLayoutFromAPI(
44     RankedTensorType t) {
45   // Call the TPU API to determine the right infeed layout. Note that
46   // this can fail if we're not running on a TPU-enabled node.
47   // TODO(kramm): Move this into a separate pass. See b/184944903
48   xla::Shape old_shape = xla::TypeToShape(t);
49   XLA_Shape old_shape_c = {};
50   XLA_Shape new_shape_c = {};
51   TfTpu_ExecutorApiFn *executor = tensorflow::tpu::ExecutorApiFn();
52   if (!tensorflow::tpu::IsInitialized(executor)) {
53     return failure();
54   }
55   ApiConverter::ToC(old_shape, &old_shape_c);
56   executor->TpuTransferManager_GetInfeedLayoutFn(&old_shape_c, &new_shape_c);
57   xla::Shape new_shape = ApiConverter::FromC(&new_shape_c);
58   ApiConverter::Destroy(&old_shape_c);
59   ApiConverter::Destroy(&new_shape_c);
60 
61   auto minor_to_major = new_shape.layout().minor_to_major();
62   return std::vector<int64_t>(minor_to_major.begin(), minor_to_major.end());
63 }
64 
GetTPUInfeedLayout(const ArrayRef<Type> types,OpBuilder & rewriter)65 FailureOr<Attribute> GetTPUInfeedLayout(const ArrayRef<Type> types,
66                                         OpBuilder &rewriter) {
67   auto i64_type = rewriter.getIntegerType(64);
68   if (types.size() > 1) {
69     llvm::SmallVector<mlir::Attribute> v;
70     v.reserve(types.size());
71     for (const mlir::Type &t : types) {
72       if (t.isa<TokenType>()) continue;
73       auto layout = GetTPUInfeedLayout({t}, rewriter);
74       if (failed(layout)) return failure();
75       v.push_back(layout.getValue());
76     }
77     ArrayRef<Attribute> shape(v);
78     return rewriter.getArrayAttr(shape);
79   } else if (types[0].isa<TupleType>()) {
80     auto tuple_type = types[0].dyn_cast<TupleType>();
81     const auto &types = tuple_type.getTypes();
82     llvm::SmallVector<mlir::Attribute> v;
83     v.reserve(types.size());
84     for (const mlir::Type &t : types) {
85       if (t.isa<TokenType>()) continue;
86       auto layout = GetTPUInfeedLayout({t}, rewriter);
87       if (failed(layout)) return failure();
88       v.push_back(layout.getValue());
89     }
90     ArrayRef<Attribute> shape(v);
91     return rewriter.getArrayAttr(shape);
92   } else if (auto t = types[0].dyn_cast<RankedTensorType>()) {
93     if (!t.hasStaticShape()) return failure();
94     auto layout = GetTPUInfeedLayoutFromAPI(t);
95     std::vector<int64_t> minor_to_major;
96     if (succeeded(layout)) {
97       minor_to_major = layout.getValue();
98     } else {
99       /* If we're not running on a TPU node, we might not be able to
100        * actually call the part of the TPU API that gives us layout.
101        * This happens e.g. for unit tests. Below we just create a reasonable
102        * layout.  We sort by dimension size, which makes the layout agree with
103        * the "correct" TPU layout in surprisingly many cases.
104        * Note that the corresponding InfeedEnqueue op will be generated
105        * through another path, and might still generate an (incompatible)
106        * layout using the TPU API. Running legalize_tf.cc on non-TPU nodes
107        * thus is a potential source of bugs.
108        */
109       minor_to_major.resize(t.getRank());
110       std::iota(minor_to_major.begin(), minor_to_major.end(), 0);
111       std::sort(minor_to_major.begin(), minor_to_major.end(),
112                 [=](int64_t a, int64_t b) {
113                   int64_t da = t.getDimSize(a);
114                   int64_t db = t.getDimSize(b);
115                   return da > db || (da == db && a > b);
116                 });
117     }
118     std::vector<Attribute> elements;
119     elements.reserve(minor_to_major.size());
120     for (auto e : minor_to_major) {
121       elements.push_back(rewriter.getIntegerAttr(i64_type, e));
122     }
123     return rewriter.getArrayAttr(elements);
124   } else {
125     // types.size() == 1 and types[0] == TokenType
126     // For this case, we return an empty array attribute.
127     return rewriter.getArrayAttr({});
128   }
129 }
130 
131 namespace {
132 class AdjustLayout
133     : public PassWrapper<AdjustLayout, OperationPass<func::FuncOp>> {
getDependentDialects(DialectRegistry & registry) const134   void getDependentDialects(DialectRegistry &registry) const override {
135     registry.insert<mhlo::MhloDialect>();
136   }
137 
138  public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AdjustLayout)139   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AdjustLayout)
140 
141   StringRef getArgument() const final { return "xla-adjust-layout"; }
getDescription() const142   StringRef getDescription() const final {
143     return "Adjust layouts so infeed send & receive use the same format.";
144   }
145 
runOnInfeedOp(::mlir::mhlo::InfeedOp op)146   static void runOnInfeedOp(::mlir::mhlo::InfeedOp op) {
147     OpBuilder builder(op.getContext());
148     SmallVector<Type> result_types(op.getResultTypes().begin(),
149                                    op.getResultTypes().end());
150     if (!op->getAttr("layout")) {
151       auto layout = GetTPUInfeedLayout(result_types, builder);
152       if (failed(layout)) return;
153 
154       op->setAttr("layout", layout.getValue());
155     }
156   }
157 
runOnOperation()158   void runOnOperation() override { getOperation().walk(runOnInfeedOp); }
159 };
160 }  // anonymous namespace
161 
162 // Header for this is in passes.h, which pulls into many deps. NOLINTNEXTLINE
CreateAdjustLayoutPass()163 std::unique_ptr<Pass> CreateAdjustLayoutPass() {
164   return std::make_unique<AdjustLayout>();
165 }
166 
RegisterAdjustLayoutPass()167 void RegisterAdjustLayoutPass() { static PassRegistration<AdjustLayout> pass; }
168 
169 }  // namespace mhlo
170 
171 }  // namespace mlir
172