1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // This file creates a library that can run any registered optimization pass.
16 // The binary that uses this will be run in a form similar to:
17 // ./optimization_pass_runner --input_file_path=/tmp/input.pbtxt
18 // --output_file_path=/tmp/output.pbtxt
19 // --optimization_pass=NameOfGraphOptimizationPass
20 #include "tensorflow/tools/optimization/optimization_pass_runner.h"
21
22 #include <memory>
23 #include <string>
24 #include <vector>
25
26 #include "tensorflow/core/common_runtime/device.h"
27 #include "tensorflow/core/common_runtime/device_set.h"
28 #include "tensorflow/core/common_runtime/graph_constructor.h"
29 #include "tensorflow/core/common_runtime/optimization_registry.h"
30 #include "tensorflow/core/framework/allocator.h"
31 #include "tensorflow/core/framework/device_attributes.pb.h"
32 #include "tensorflow/core/framework/function.h"
33 #include "tensorflow/core/framework/function.pb.h"
34 #include "tensorflow/core/framework/graph.pb.h"
35 #include "tensorflow/core/framework/op.h"
36 #include "tensorflow/core/framework/types.h"
37 #include "tensorflow/core/graph/graph.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/core/status.h"
40 #include "tensorflow/core/platform/env.h"
41 #include "tensorflow/core/protobuf/config.pb.h"
42 #include "tensorflow/core/public/session_options.h"
43
44 namespace tensorflow {
45 namespace {
46 // A fake device used to populate a DeviceSet.
47 class FakeDevice : public Device {
48 private:
FakeDevice(const DeviceAttributes & device_attributes)49 explicit FakeDevice(const DeviceAttributes& device_attributes)
50 : Device(nullptr, device_attributes) {}
51
52 public:
53 Status Sync() override;
54 static std::unique_ptr<Device> Make(const string& name, const string& type);
55 };
56
Sync()57 Status FakeDevice::Sync() {
58 return errors::Unimplemented("FakeDevice::Sync()");
59 }
60
Make(const string & name,const string & type)61 std::unique_ptr<Device> FakeDevice::Make(const string& name,
62 const string& type) {
63 DeviceAttributes device_attributes;
64 device_attributes.set_name(name);
65 device_attributes.set_device_type(DeviceType(type).type());
66 return std::unique_ptr<Device>(new FakeDevice(device_attributes));
67 }
68
FindPassWithName(absl::string_view name,GraphOptimizationPass ** result)69 Status FindPassWithName(absl::string_view name,
70 GraphOptimizationPass** result) {
71 *result = nullptr;
72 // Run the optimization pass specified by the command line flag.
73 for (const auto& groups_and_passes :
74 OptimizationPassRegistry::Global()->groups()) {
75 for (const auto& phase_and_passes : groups_and_passes.second) {
76 for (const auto& pass : phase_and_passes.second) {
77 if (pass->name() == name) {
78 if (*result) {
79 return errors::Internal("Found more than one pass with name ",
80 name);
81 }
82 *result = pass.get();
83 }
84 }
85 }
86 }
87
88 return *result == nullptr
89 ? errors::Internal("Could not find pass with name ", name)
90 : OkStatus();
91 }
92 } // namespace
93
Run(absl::string_view pass_to_run,GraphDef input,GraphDef * result)94 Status OptimizationPassRunner::Run(absl::string_view pass_to_run,
95 GraphDef input, GraphDef* result) {
96 auto session_options = absl::make_unique<SessionOptions>();
97 session_options->config.mutable_graph_options()
98 ->mutable_optimizer_options()
99 ->set_global_jit_level(jit_level_);
100 FunctionDefLibrary flib;
101 std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
102
103 GraphOptimizationPassOptions options;
104 options.session_options = session_options.get();
105 options.graph = &graph;
106 std::unique_ptr<FunctionLibraryDefinition> flib_def(
107 new FunctionLibraryDefinition((*options.graph)->op_registry(), flib));
108 options.flib_def = flib_def.get();
109
110 // Grab the data
111 GraphConstructorOptions graph_opts;
112 graph_opts.expect_device_spec = true;
113 graph_opts.allow_internal_ops = true;
114 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(graph_opts, std::move(input),
115 options.graph->get()));
116
117 // Add all devices that were previously configured with AddDevice.
118 DeviceSet device_set;
119 for (auto& device : devices_) {
120 device_set.AddDevice(device.get());
121 }
122 options.device_set = &device_set;
123
124 GraphOptimizationPass* pass;
125 TF_RETURN_IF_ERROR(FindPassWithName(pass_to_run, &pass));
126 TF_RETURN_IF_ERROR(pass->Run(options));
127
128 options.graph->get()->ToGraphDef(result);
129 return OkStatus();
130 }
131
SetJitLevel(OptimizerOptions::GlobalJitLevel jit_level)132 Status OptimizationPassRunner::SetJitLevel(
133 OptimizerOptions::GlobalJitLevel jit_level) {
134 jit_level_ = jit_level;
135 return OkStatus();
136 }
137
AddDevices(absl::string_view type,int count)138 Status OptimizationPassRunner::AddDevices(absl::string_view type, int count) {
139 for (int i = 0; i < count; i++) {
140 devices_.push_back(FakeDevice::Make(
141 absl::StrCat("/job:localhost/replica:0/task:0/device:", type, ":", i),
142 absl::StrCat(type)));
143 devices_.push_back(FakeDevice::Make(
144 absl::StrCat("/job:localhost/replica:0/task:0/device:XLA_", type, ":",
145 i),
146 absl::StrCat(type)));
147 }
148
149 return OkStatus();
150 }
151 } // namespace tensorflow
152