1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h"
17
18 #include "llvm/Support/CommandLine.h"
19 #include "mlir/IR/Builders.h" // from @llvm-project
20 #include "mlir/IR/Location.h" // from @llvm-project
21 #include "mlir/Pass/Pass.h" // from @llvm-project
22 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
23 #include "tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.h"
24 #include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h"
25 #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h"
26 #include "tensorflow/core/common_runtime/graph_constructor.h"
27 #include "tensorflow/core/common_runtime/optimization_registry.h"
28 #include "tensorflow/core/framework/function.h"
29 #include "tensorflow/core/graph/graph.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/protobuf/graph_debug_info.pb.h"
32 #include "tensorflow/core/public/session_options.h"
33 #include "tensorflow/stream_executor/lib/statusor.h"
34
35 #define DEBUG_TYPE "run-tf-graph-optimization"
36
37 namespace tensorflow {
38 namespace {
39 // Creates a pass to convert MLIR to Graph, run user-specified Graph
40 // Optimization Passes and convert back to MLIR.
41 // Constraints: This pass expects that all operations in the MLIR module either
42 // belong to 'tf' or '_tf' dialect. The output is in '_tf' dialect.
43 class GraphOptPass
44 : public mlir::PassWrapper<GraphOptPass,
45 mlir::OperationPass<mlir::ModuleOp>> {
getDependentDialects(mlir::DialectRegistry & registry) const46 void getDependentDialects(mlir::DialectRegistry& registry) const override {
47 mlir::RegisterAllTensorFlowDialects(registry);
48 }
49
50 public:
51 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(GraphOptPass)
52
GraphOptPass(std::vector<tensorflow::GraphOptimizationPass * > passes)53 explicit GraphOptPass(std::vector<tensorflow::GraphOptimizationPass*> passes)
54 : passes_(std::move(passes)) {}
55
56 protected:
57 void runOnOperation() override;
58
59 // The passes to run on the module.
60 std::vector<GraphOptimizationPass*> passes_;
61 };
62 } // anonymous namespace
63
runOnOperation()64 void GraphOptPass::runOnOperation() {
65 mlir::ModuleOp module_in = getOperation();
66 mlir::MLIRContext& ctx = getContext();
67
68 // Convert MLIR to Graph
69 FunctionLibraryDefinition flib_def(OpRegistry::Global(),
70 FunctionDefLibrary());
71 GraphExportConfig confs;
72 auto graph = std::make_unique<Graph>(flib_def);
73 Status status = ConvertMlirToGraph(module_in, confs, &graph, &flib_def);
74 if (!status.ok()) {
75 mlir::emitError(mlir::UnknownLoc::get(&ctx)) << status.error_message();
76 return signalPassFailure();
77 }
78
79 // Run each of the passes that were selected.
80 GraphConstructorOptions opts;
81 opts.allow_internal_ops = true;
82 opts.expect_device_spec = false;
83
84 GraphOptimizationPassOptions options;
85 SessionOptions sess_options;
86 options.graph = &graph;
87 options.flib_def = &flib_def;
88 options.session_options = &sess_options;
89
90 for (auto pass : passes_) {
91 assert(pass != nullptr);
92 Status status = pass->Run(options);
93 if (!status.ok()) {
94 mlir::emitError(mlir::UnknownLoc::get(&ctx))
95 << pass->name() << ": " << status.error_message();
96 return signalPassFailure();
97 }
98 }
99
100 // Convert Graph to MLIR
101 GraphDebugInfo debug_info;
102 GraphImportConfig specs;
103 auto module_or_status =
104 ConvertGraphToMlir(**options.graph, debug_info, flib_def, specs, &ctx);
105 if (!module_or_status.ok()) {
106 mlir::emitError(mlir::UnknownLoc::get(&ctx))
107 << module_or_status.status().error_message();
108 return signalPassFailure();
109 }
110 auto module_out = std::move(module_or_status).ValueOrDie();
111
112 // We cannot replace the module in a ModulePass. So we simply copy the
113 // operation list from module_out to module_in.
114 auto& module_in_ops = module_in.getBody()->getOperations();
115 module_in_ops.clear();
116 module_in_ops.splice(module_in_ops.end(),
117 module_out->getBody()->getOperations());
118 }
119
120 // Returns a vector of passes from their names. If a pass is not found, then the
121 // corresponding return entry is null.
FindRegisteredPassesByName(const std::vector<std::string> & pass_names)122 static std::vector<GraphOptimizationPass*> FindRegisteredPassesByName(
123 const std::vector<std::string>& pass_names) {
124 std::vector<GraphOptimizationPass*> pass_ids(pass_names.size(), nullptr);
125
126 for (const auto& group : OptimizationPassRegistry::Global()->groups()) {
127 for (const auto& phase : group.second) {
128 for (const auto& pass : phase.second) {
129 // Iterate over the pass_names_ and insert the pass pointer at all the
130 // corresponding indices in the pass_ids vector.
131 auto iter = pass_names.begin();
132 while ((iter = std::find(iter, pass_names.end(), pass->name())) !=
133 pass_names.end()) {
134 pass_ids[std::distance(pass_names.begin(), iter)] = pass.get();
135 iter++;
136 }
137 }
138 }
139 }
140 return pass_ids;
141 }
142
143 // TODO(prakalps): Move these flags and pass registration to a header file so
144 // that it is clear that this is a generic pass library and command line is used
145 // for testing only.
146
147 // NOLINTNEXTLINE
148 static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options");
149
150 // NOLINTNEXTLINE
151 static llvm::cl::list<std::string> cl_pass_list(
152 "graph-passes", llvm::cl::value_desc("list"),
153 llvm::cl::desc("comma separated list of GraphOptimizationPass to run."),
154 llvm::cl::CommaSeparated, llvm::cl::cat(clOptionsCategory));
155
156 class GraphOptByNamePass : public GraphOptPass {
157 public:
GraphOptByNamePass()158 explicit GraphOptByNamePass() : GraphOptByNamePass(cl_pass_list) {}
GraphOptByNamePass(const std::vector<std::string> & pass_names)159 explicit GraphOptByNamePass(const std::vector<std::string>& pass_names)
160 : GraphOptPass(FindRegisteredPassesByName(pass_names)) {}
161
getArgument() const162 llvm::StringRef getArgument() const final {
163 return "run-tf-graph-optimization";
164 }
165
getDescription() const166 llvm::StringRef getDescription() const final {
167 return "runs passes registered as tensorflow::GraphOptimizationPass";
168 }
169
170 private:
runOnOperation()171 void runOnOperation() override {
172 // Verify all passes requested were registered/found.
173 for (auto pass_it : llvm::enumerate(passes_)) {
174 if (pass_it.value() == nullptr) {
175 mlir::emitError(mlir::UnknownLoc::get(&getContext()))
176 << "could not find pass " << cl_pass_list[pass_it.index()];
177 return signalPassFailure();
178 }
179 }
180 return GraphOptPass::runOnOperation();
181 }
182 };
183
184 } // namespace tensorflow
185
186 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateTensorFlowGraphOptimizationPass(std::vector<tensorflow::GraphOptimizationPass * > tf_passes)187 tensorflow::CreateTensorFlowGraphOptimizationPass(
188 std::vector<tensorflow::GraphOptimizationPass*> tf_passes) {
189 return std::make_unique<GraphOptPass>(std::move(tf_passes));
190 }
191
192 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateTensorFlowGraphOptimizationPass(const std::vector<std::string> & pass_names)193 tensorflow::CreateTensorFlowGraphOptimizationPass(
194 const std::vector<std::string>& pass_names) {
195 return std::make_unique<GraphOptByNamePass>(pass_names);
196 }
197
RegisterGraphOptimizationPasses()198 void tensorflow::RegisterGraphOptimizationPasses() {
199 ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
200 return std::make_unique<GraphOptByNamePass>();
201 });
202 }
203