1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/container/flat_hash_set.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
19 #include "mlir/IR/Attributes.h"  // from @llvm-project
20 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
21 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
22 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
23 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
24 #include "mlir/IR/Region.h"  // from @llvm-project
25 #include "mlir/IR/Value.h"  // from @llvm-project
26 #include "mlir/IR/Visitors.h"  // from @llvm-project
27 #include "mlir/Pass/Pass.h"  // from @llvm-project
28 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
29 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
33 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
34 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
35 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
36 
37 namespace mlir {
38 namespace TF {
39 namespace {
40 
41 // Returns true if the given op is TF/XLA communication op in the old bridge.
IsCommunicationOp(Operation * op)42 bool IsCommunicationOp(Operation* op) {
43   return isa<TF::XlaHostComputeOp, TF::XlaSendToHostOp, TF::XlaRecvFromHostOp>(
44       op);
45 }
46 
47 // Returns true if the given op is one of ops supported to have communication
48 // subcomputation in the TF/XLA bridge.
SupportsCommunicationComputation(Operation * op)49 bool SupportsCommunicationComputation(Operation* op) {
50   return isa<TF::IfRegionOp, TF::WhileRegionOp, TF::CaseRegionOp,
51              TF::StatefulPartitionedCallOp, TF::PartitionedCallOp,
52              TF::LegacyCallOp>(op);
53 }
54 
55 class PrepareTpuComputationForTfExportPass
56     : public PrepareTpuComputationForTfExportPassBase<
57           PrepareTpuComputationForTfExportPass> {
58   void runOnOperation() override;
59 };
60 
61 class RewriteXlaHostComputeMlir
62     : public OpRewritePattern<TF::_XlaHostComputeMlirOp> {
63  public:
64   using OpRewritePattern<TF::_XlaHostComputeMlirOp>::OpRewritePattern;
65 
matchAndRewrite(TF::_XlaHostComputeMlirOp op,PatternRewriter & rewriter) const66   LogicalResult matchAndRewrite(TF::_XlaHostComputeMlirOp op,
67                                 PatternRewriter& rewriter) const override {
68     llvm::SmallVector<Attribute> shape_attrs;
69     shape_attrs.reserve(op.getNumResults());
70     for (Type ty : op.getResultTypes()) {
71       shape_attrs.push_back(
72           TF::ShapeAttr::get(rewriter.getContext(), ty.cast<ShapedType>()));
73     }
74 
75     // Clone the `host_func` in the `host_mlir_module` attribute if it exists
76     // and use it for `shape_inference_graph` attribute on XlaHostCompute.
77     func::FuncOp cloned_func;
78     SymbolTable manager(op->getParentOfType<ModuleOp>());
79     StringRef host_module = op.host_mlir_module();
80     if (!host_module.empty()) {
81       mlir::OwningOpRef<mlir::ModuleOp> module_for_func;
82 
83       func::FuncOp func = op.GetHostFunc(&module_for_func);
84 
85       OpBuilder::InsertionGuard guard(rewriter);
86       rewriter.setInsertionPointAfter(op->getParentOfType<func::FuncOp>());
87       cloned_func = llvm::dyn_cast_or_null<func::FuncOp>(
88           rewriter.clone(*func.getOperation()));
89       manager.insert(cloned_func);
90       rewriter.setInsertionPointToStart(&cloned_func.getBody().front());
91       auto result_type =
92           RankedTensorType::get({3}, rewriter.getType<TF::StringType>());
93       auto dynamic_key =
94           rewriter.create<TF::_TPUCompileMlirPlaceholderProgramKeyOp>(
95               func.getLoc(), /*program=*/result_type, llvm::ArrayRef<Value>{});
96 
97       auto recv_at_host = rewriter.create<TF::_XlaRecvAtHostOp>(
98           func.getLoc(), op.getOperandTypes(), /*dynamic_key=*/dynamic_key,
99           op.send_keyAttr(),
100           /*device_ordinal=*/rewriter.getI64IntegerAttr(0));
101       for (auto result :
102            llvm::zip(cloned_func.getArguments(), recv_at_host->getResults())) {
103         std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
104       }
105 
106       rewriter.setInsertionPoint(cloned_func.getBody().front().getTerminator());
107       rewriter.create<TF::_XlaSendFromHostOp>(
108           func.getLoc(),
109           cloned_func.getBody().front().getTerminator()->getOperands(),
110           /*dynamic_key=*/dynamic_key, op.recv_keyAttr(),
111           /*device_ordinal=*/rewriter.getI64IntegerAttr(0));
112     }
113 
114     constexpr int64_t kDefaultCostEstimate = 1000000;
115     rewriter.replaceOpWithNewOp<TF::XlaHostComputeOp>(
116         op, op.getResultTypes(), op.inputs(),
117         /*ancestors=*/rewriter.getArrayAttr({}),
118         rewriter.getArrayAttr(shape_attrs),
119         /*shape_inference_graph=*/
120         cloned_func ? SymbolRefAttr::get(cloned_func) : SymbolRefAttr(),
121         /*key=*/rewriter.getStringAttr(""), op.send_keyAttr(),
122         op.recv_keyAttr(),
123         /*cost_estimate_ns=*/rewriter.getI64IntegerAttr(kDefaultCostEstimate),
124         /*tpu_core=*/rewriter.getI64IntegerAttr(0));
125     return success();
126   }
127 };
128 
UpdateArgAttributes(mlir::func::FuncOp func)129 void UpdateArgAttributes(mlir::func::FuncOp func) {
130   OpBuilder builder(func.getBody());
131   for (int i = 0; i < func.getNumArguments(); ++i) {
132     constexpr char kShardingAttr[] = "mhlo.sharding";
133     if (auto sharding =
134             func.getArgAttrOfType<mlir::StringAttr>(i, kShardingAttr)) {
135       if (!sharding.getValue().empty()) {
136         BlockArgument arg = func.getArgument(i);
137         // TODO(hinsu): Instead of setting both 'sharding' and '_XlaSharding'
138         // attributes, only set the 'sharding' attribute. Both attributes are
139         // currently required as the XlaSharding xla op kernel doesn't use the
140         // 'sharding' attribute.
141         auto updated_arg = builder.create<TF::XlaShardingOp>(
142             func.getLoc(), arg.getType(), arg, sharding, sharding);
143         func.getArgument(i).replaceAllUsesExcept(
144             updated_arg, llvm::SmallPtrSet<Operation*, 1>({updated_arg}));
145       }
146 
147       func.removeArgAttr(i, builder.getStringAttr(kShardingAttr));
148     }
149   }
150 }
151 
RewriteCommunicationOps(ModuleOp module)152 LogicalResult RewriteCommunicationOps(ModuleOp module) {
153   MLIRContext* ctx = module.getContext();
154   mlir::RewritePatternSet patterns(ctx);
155   patterns.add<RewriteXlaHostComputeMlir>(ctx);
156   if (failed(mlir::applyPatternsAndFoldGreedily(module, std::move(patterns)))) {
157     return module.emitError("failed to apply tf export preparation patterns");
158   }
159 
160   // TODO(hinsu): Investigate if the semantics of keys for these communication
161   // ops between the old bridge and new bridge can be reconciled.
162   module.walk([&](Operation* op) {
163     if (isa<TF::XlaSendToHostOp>(op)) {
164       StringRef old_key = op->getAttrOfType<StringAttr>("key").getValue();
165       auto new_key = StringAttr::get(ctx, old_key.str() + "_dtoh_0");
166       op->setAttr("key", new_key);
167     } else if (isa<TF::XlaRecvFromHostOp>(op)) {
168       StringRef old_key = op->getAttrOfType<StringAttr>("key").getValue();
169       auto new_key = StringAttr::get(ctx, old_key.str() + "_htod_0");
170       op->setAttr("key", new_key);
171     }
172   });
173   return success();
174 }
175 
176 // Sets token input node names attribute and their corresponding original node
177 // names for tf/xla communication related ops. These attributes are used to
178 // order operations on device. First op in the region should have a special
179 // argument token and then remaining operations should have node name of the
180 // previous communication ops.
SetTokenInputAttrs(ModuleOp module)181 LogicalResult SetTokenInputAttrs(ModuleOp module) {
182   // Collect all the ops that needs to have token input names attributes. These
183   // ops are communication ops and all their parent ops via nesting or function
184   // calls. For example, IfRegion op and PartitionedCall op.
185   std::vector<Operation*> worklist;
186   absl::flat_hash_set<Operation*> ops_with_tokens;
187   module.walk([&](Operation* op) {
188     if (IsCommunicationOp(op)) {
189       ops_with_tokens.insert(op);
190       worklist.push_back(op);
191     }
192   });
193 
194   SymbolTableCollection table;
195   SymbolUserMap symbol_map(table, module);
196 
197   // Regions that contains ops requiring token input attributes.
198   absl::flat_hash_set<Region*> regions_with_token;
199   while (!worklist.empty()) {
200     Operation* op = worklist.back();
201     worklist.pop_back();
202 
203     Region* region = op->getParentRegion();
204     regions_with_token.insert(region);
205 
206     // If the parent is not a FuncOp, then add the parent op containing a region
207     // to worklist.
208     Operation* parent = region->getParentOp();
209     if (!isa<func::FuncOp>(parent)) {
210       if (ops_with_tokens.insert(parent).second) {
211         worklist.push_back(parent);
212       }
213       continue;
214     }
215 
216     // For functions, get all the users and add them to the worklist.
217     for (auto& user : symbol_map.getUsers(parent)) {
218       if (ops_with_tokens.insert(user).second) {
219         worklist.push_back(user);
220       }
221     }
222   }
223 
224   // Use name mapper to uniquely name all ops in the module as export to
225   // TensorFlow graph may change node names. These op names here doesn't need to
226   // match the actual names in the graph as this sets original node name
227   // attribute for all the relevant nodes.
228   tensorflow::OpOrArgLocNameMapper name_mapper;
229   MLIRContext* ctx = module.getContext();
230   for (Region* region : regions_with_token) {
231     // Initialize the token with the special argument token. This gets mapped to
232     // input token in the parent op or a new token for the entry computation.
233     auto token = StringAttr::get(ctx, tensorflow::kXlaTokenArgNodeName);
234     for (Operation& op : region->getOps()) {
235       // Only communication related ops that needs to have token should have the
236       // extra attribute.
237       if (!ops_with_tokens.contains(&op)) continue;
238 
239       if (!IsCommunicationOp(&op) && !SupportsCommunicationComputation(&op)) {
240         return op.emitOpError(
241             "does not support subcomputations with tf/xla communication ops");
242       }
243 
244       op.setAttr(tensorflow::kXlaTokenInputNodesAttrName,
245                  ArrayAttr::get(ctx, {token}));
246 
247       auto node_name = StringAttr::get(ctx, name_mapper.GetUniqueName(&op));
248       op.setAttr(tensorflow::kXlaOriginalOutsideCompilationNodeName, node_name);
249       token = node_name;
250     }
251   }
252   return success();
253 }
254 
runOnOperation()255 void PrepareTpuComputationForTfExportPass::runOnOperation() {
256   ModuleOp module = getOperation();
257 
258   for (func::FuncOp func : module.getOps<func::FuncOp>()) {
259     UpdateArgAttributes(func);
260   }
261 
262   // First rewrite communication ops used in the new bridge to match old bridge
263   // semantics and then set token input node names attributes on the supported
264   // ops.
265   if (failed(RewriteCommunicationOps(module)) ||
266       failed(SetTokenInputAttrs(module))) {
267     signalPassFailure();
268     return;
269   }
270 }
271 
272 }  // namespace
273 
274 std::unique_ptr<OperationPass<ModuleOp>>
CreatePrepareTpuComputationForTfExportPass()275 CreatePrepareTpuComputationForTfExportPass() {
276   return std::make_unique<PrepareTpuComputationForTfExportPass>();
277 }
278 
279 }  // namespace TF
280 }  // namespace mlir
281