1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_MLIR_GRAPPLER_GRAPPLER_HOOK_H_ 17 #define TENSORFLOW_CORE_MLIR_GRAPPLER_GRAPPLER_HOOK_H_ 18 19 #include <functional> 20 #include <string> 21 22 #include "tensorflow/core/grappler/optimizers/graph_optimizer.h" 23 24 namespace mlir { 25 class PassManager; 26 27 namespace tfg { 28 29 // A function that builds the TFG pass pipeline. 30 using TFGPassPipelineBuilder = std::function<void(PassManager& pm)>; 31 32 // This class implements a Grappler optimizer wrapping a pipeline of passes 33 // implemented with TFG. 34 class TFGGrapplerOptimizer : public tensorflow::grappler::GraphOptimizer { 35 public: 36 // Constructs a TFG optimizer using the provided pipeline builder. By default, 37 // the optimizer will not use multi-threading. If `num_tfg_threads` is 38 // non-zero, then TFG will use threading with the specified number of threads. 39 explicit TFGGrapplerOptimizer(TFGPassPipelineBuilder builder, 40 unsigned num_tfg_threads = 0); 41 // Explicit destructor to defer instantiation of Impl. 42 ~TFGGrapplerOptimizer() override; 43 44 // Constructs a name for the optimizer using the registered passes. 45 std::string name() const override; 46 // The TFG optimizer requires access to the function library. UsesFunctionLibrary()47 bool UsesFunctionLibrary() const override { return true; } 48 49 // Runs the optimizer on the GraphDef. The optimizer converts the GraphDef to 50 // TFG using the importer, runs the passes on the MLIR, and exports back to 51 // GraphDef. The result is stored in `optimized_graph`. 52 tensorflow::Status Optimize(tensorflow::grappler::Cluster* cluster, 53 const tensorflow::grappler::GrapplerItem& item, 54 tensorflow::GraphDef* optimized_graph) override; 55 56 private: 57 // Hide the implementation details. 58 class Impl; 59 std::unique_ptr<Impl> impl_; 60 }; 61 62 } // end namespace tfg 63 } // end namespace mlir 64 65 #endif // TENSORFLOW_CORE_MLIR_GRAPPLER_GRAPPLER_HOOK_H_ 66