xref: /aosp_15_r20/external/tensorflow/tensorflow/tools/optimization/optimization_pass_runner.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // This file creates a library that can run any registered optimization pass.
16 // The binary that uses this will be run in a form similar to:
17 // ./optimization_pass_runner  --input_file_path=/tmp/input.pbtxt
18 // --output_file_path=/tmp/output.pbtxt
19 // --optimization_pass=NameOfGraphOptimizationPass
20 #include "tensorflow/tools/optimization/optimization_pass_runner.h"
21 
22 #include <memory>
23 #include <string>
24 #include <vector>
25 
26 #include "tensorflow/core/common_runtime/device.h"
27 #include "tensorflow/core/common_runtime/device_set.h"
28 #include "tensorflow/core/common_runtime/graph_constructor.h"
29 #include "tensorflow/core/common_runtime/optimization_registry.h"
30 #include "tensorflow/core/framework/allocator.h"
31 #include "tensorflow/core/framework/device_attributes.pb.h"
32 #include "tensorflow/core/framework/function.h"
33 #include "tensorflow/core/framework/function.pb.h"
34 #include "tensorflow/core/framework/graph.pb.h"
35 #include "tensorflow/core/framework/op.h"
36 #include "tensorflow/core/framework/types.h"
37 #include "tensorflow/core/graph/graph.h"
38 #include "tensorflow/core/lib/core/errors.h"
39 #include "tensorflow/core/lib/core/status.h"
40 #include "tensorflow/core/platform/env.h"
41 #include "tensorflow/core/protobuf/config.pb.h"
42 #include "tensorflow/core/public/session_options.h"
43 
44 namespace tensorflow {
45 namespace {
46 // A fake device used to populate a DeviceSet.
47 class FakeDevice : public Device {
48  private:
FakeDevice(const DeviceAttributes & device_attributes)49   explicit FakeDevice(const DeviceAttributes& device_attributes)
50       : Device(nullptr, device_attributes) {}
51 
52  public:
53   Status Sync() override;
54   static std::unique_ptr<Device> Make(const string& name, const string& type);
55 };
56 
Sync()57 Status FakeDevice::Sync() {
58   return errors::Unimplemented("FakeDevice::Sync()");
59 }
60 
Make(const string & name,const string & type)61 std::unique_ptr<Device> FakeDevice::Make(const string& name,
62                                          const string& type) {
63   DeviceAttributes device_attributes;
64   device_attributes.set_name(name);
65   device_attributes.set_device_type(DeviceType(type).type());
66   return std::unique_ptr<Device>(new FakeDevice(device_attributes));
67 }
68 
FindPassWithName(absl::string_view name,GraphOptimizationPass ** result)69 Status FindPassWithName(absl::string_view name,
70                         GraphOptimizationPass** result) {
71   *result = nullptr;
72   // Run the optimization pass specified by the command line flag.
73   for (const auto& groups_and_passes :
74        OptimizationPassRegistry::Global()->groups()) {
75     for (const auto& phase_and_passes : groups_and_passes.second) {
76       for (const auto& pass : phase_and_passes.second) {
77         if (pass->name() == name) {
78           if (*result) {
79             return errors::Internal("Found more than one pass with name ",
80                                     name);
81           }
82           *result = pass.get();
83         }
84       }
85     }
86   }
87 
88   return *result == nullptr
89              ? errors::Internal("Could not find pass with name ", name)
90              : OkStatus();
91 }
92 }  // namespace
93 
Run(absl::string_view pass_to_run,GraphDef input,GraphDef * result)94 Status OptimizationPassRunner::Run(absl::string_view pass_to_run,
95                                    GraphDef input, GraphDef* result) {
96   auto session_options = absl::make_unique<SessionOptions>();
97   session_options->config.mutable_graph_options()
98       ->mutable_optimizer_options()
99       ->set_global_jit_level(jit_level_);
100   FunctionDefLibrary flib;
101   std::unique_ptr<Graph> graph = absl::make_unique<Graph>(OpRegistry::Global());
102 
103   GraphOptimizationPassOptions options;
104   options.session_options = session_options.get();
105   options.graph = &graph;
106   std::unique_ptr<FunctionLibraryDefinition> flib_def(
107       new FunctionLibraryDefinition((*options.graph)->op_registry(), flib));
108   options.flib_def = flib_def.get();
109 
110   // Grab the data
111   GraphConstructorOptions graph_opts;
112   graph_opts.expect_device_spec = true;
113   graph_opts.allow_internal_ops = true;
114   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(graph_opts, std::move(input),
115                                             options.graph->get()));
116 
117   // Add all devices that were previously configured with AddDevice.
118   DeviceSet device_set;
119   for (auto& device : devices_) {
120     device_set.AddDevice(device.get());
121   }
122   options.device_set = &device_set;
123 
124   GraphOptimizationPass* pass;
125   TF_RETURN_IF_ERROR(FindPassWithName(pass_to_run, &pass));
126   TF_RETURN_IF_ERROR(pass->Run(options));
127 
128   options.graph->get()->ToGraphDef(result);
129   return OkStatus();
130 }
131 
SetJitLevel(OptimizerOptions::GlobalJitLevel jit_level)132 Status OptimizationPassRunner::SetJitLevel(
133     OptimizerOptions::GlobalJitLevel jit_level) {
134   jit_level_ = jit_level;
135   return OkStatus();
136 }
137 
AddDevices(absl::string_view type,int count)138 Status OptimizationPassRunner::AddDevices(absl::string_view type, int count) {
139   for (int i = 0; i < count; i++) {
140     devices_.push_back(FakeDevice::Make(
141         absl::StrCat("/job:localhost/replica:0/task:0/device:", type, ":", i),
142         absl::StrCat(type)));
143     devices_.push_back(FakeDevice::Make(
144         absl::StrCat("/job:localhost/replica:0/task:0/device:XLA_", type, ":",
145                      i),
146         absl::StrCat(type)));
147   }
148 
149   return OkStatus();
150 }
151 }  // namespace tensorflow
152