1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_executor.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/SourceMgr.h"
28 #include "mlir/Parser/Parser.h"  // from @llvm-project
29 #include "mlir/Pass/PassManager.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
31 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
32 #include "tensorflow/compiler/mlir/tfrt/utils/host_context.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/platform/threadpool.h"
35 #include "tensorflow/core/platform/threadpool_interface.h"
36 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.h"
37 #include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h"
38 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
39 #include "tfrt/bef/bef_buffer.h"  // from @tf_runtime
40 #include "tfrt/bef_converter/mlir_to_bef.h"  // from @tf_runtime
41 #include "tfrt/bef_executor/bef_file.h"  // from @tf_runtime
42 #include "tfrt/host_context/async_value.h"  // from @tf_runtime
43 #include "tfrt/host_context/execution_context.h"  // from @tf_runtime
44 #include "tfrt/host_context/function.h"  // from @tf_runtime
45 #include "tfrt/host_context/host_context.h"  // from @tf_runtime
46 #include "tfrt/host_context/resource_context.h"  // from @tf_runtime
47 #include "tfrt/support/ref_count.h"  // from @tf_runtime
48 
49 namespace tensorflow {
50 
51 using ::tfrt::AsyncValue;
52 using ::tfrt::BEFFile;
53 using ::tfrt::ExecutionContext;
54 using ::tfrt::Function;
55 using ::tfrt::HostContext;
56 using ::tfrt::MakeAvailableAsyncValueRef;
57 using ::tfrt::RCReference;
58 using ::tfrt::RequestContext;
59 using ::tfrt::RequestContextBuilder;
60 using ::tfrt::ResourceContext;
61 
62 using ::tensorflow::Env;
63 using ::tensorflow::thread::ThreadPool;
64 using ::tensorflow::thread::ThreadPoolInterface;
65 
66 using ::tensorflow::tfrt_stub::FallbackTensor;
67 
68 // -------------------------------------------------------------------------- //
69 // Run function via the TF->TFRT fallback lowering.
70 // -------------------------------------------------------------------------- //
71 
72 namespace {
73 // Thread pool for running `intra-op` tasks scheduled by the fallback kernels.
74 class IntraOpThreadPool : public ThreadPoolInterface {
75  public:
IntraOpThreadPool(int64_t num_threads)76   explicit IntraOpThreadPool(int64_t num_threads)
77       : tpool_(Env::Default(), "intra-op",
78                std::max(1, static_cast<int32_t>(num_threads))) {}
79 
Schedule(std::function<void ()> fn)80   void Schedule(std::function<void()> fn) override {
81     tpool_.Schedule(std::move(fn));
82   }
83 
NumThreads() const84   int NumThreads() const override { return tpool_.NumThreads(); }
CurrentThreadId() const85   int CurrentThreadId() const override { return tpool_.CurrentThreadId(); }
Cancel()86   void Cancel() override {}
87 
88  private:
89   ThreadPool tpool_;
90 };
91 }  // namespace
92 
RuntimeFallbackExecutor(int64_t num_threads)93 RuntimeFallbackExecutor::RuntimeFallbackExecutor(int64_t num_threads)
94     : intra_op_(std::make_unique<IntraOpThreadPool>(num_threads)) {
95   // Create a HostContext for running TFRT functions. Concurrent work queue acts
96   // similar to the Tensorflow `inter-op` thread pool, so we'll match the size.
97   host_context_ = num_threads ? CreateMultiThreadedHostContext(num_threads)
98                               : CreateSingleThreadedHostContext();
99   tfrt::RegisterStaticKernels(host_context_->GetMutableRegistry());
100 
101   // Build an ExecutionContext from the HostContext.
102   auto builder = RequestContextBuilder(host_context_.get(), &resource_context_);
103 
104   // Get tensorflow::EagerContext for the kernel fallback.
105   auto* eager_context_resource =
106       resource_context_
107           .GetOrCreateResource<tensorflow::tfd::EagerContextResource>(
108               tensorflow::tfd::kEagerContextResourceName);
109   auto expected_eager_context = eager_context_resource->GetTFEagerContext();
110   auto* eager_context = expected_eager_context.get();
111 
112   // Initialize fallback kernels state with a custom intra-op thread pool.
113   auto status = tensorflow::tfd::SetUpKernelFallbackCompatRequestContext(
114       &builder, /*runner_table=*/nullptr, eager_context, intra_op_.get());
115   CHECK(status.ok()) << "Failed to setup request context: "
116                      << status.error_message();
117 
118   auto req_ctx = std::move(builder).build();
119   if (auto err = req_ctx.takeError())
120     LOG(FATAL) << "Failed to build a request context";
121 
122   exec_ctx_ = std::make_unique<tfrt::ExecutionContext>(std::move(*req_ctx));
123 }
124 
Prepare(llvm::StringRef mlir_input)125 void RuntimeFallbackExecutor::Prepare(llvm::StringRef mlir_input) {
126   // We only support IR written in the Tensorflow dialect.
127   mlir::DialectRegistry registry;
128   mlir::RegisterAllTensorFlowDialects(registry);
129   mlir::MLIRContext context(registry);
130 
131   llvm::SourceMgr source_mgr;
132   source_mgr.AddNewSourceBuffer(
133       llvm::MemoryBuffer::getMemBuffer(mlir_input, "test_ir"), llvm::SMLoc());
134 
135   // Parse a kernel source code into the MLIR Module.
136   mlir::OwningOpRef<mlir::ModuleOp> module(
137       mlir::parseSourceFile<mlir::ModuleOp>(source_mgr, &context));
138   CHECK(module) << "failed to parse mlir module";
139 
140   // Collect all diagnostics emitted while lowering parsed kernel module.
141   std::string diagnostic_str;
142   llvm::raw_string_ostream os(diagnostic_str);
143   mlir::SourceMgrDiagnosticHandler handler(source_mgr, module->getContext(),
144                                            os);
145 
146   // Convert TF to TFRT fallback dialect.
147   TfrtPipelineOptions pipeline_opts;
148   pipeline_opts.default_device = kDefaultHostDeviceName;
149   pipeline_opts.hoist_invariant_ops = true;
150   pipeline_opts.enable_native_ops = false;
151   pipeline_opts.cost_threshold = 1024;
152   pipeline_opts.upper_cost_threshold = 100000;
153   pipeline_opts.merge_inter_dependent_streams = true;
154   pipeline_opts.func_use_fallback_tensor = true;
155 
156   mlir::PassManager pm(module->getContext());
157   pm.addPass(CreateTfToTfrtConversionPass(pipeline_opts));
158 
159   CHECK(mlir::succeeded(pm.run(*module)))
160       << "Failed to lower module to TFRT: " << os.str();
161 
162   // Convert module to BEF.
163   bef_buffer_ =
164       tfrt::ConvertMLIRToBEF(*module, /*disable_optional_sections=*/false);
165   CHECK(!bef_buffer_.empty()) << "Failed to convert module to BEF";
166 
167   bef_file_ =
168       BEFFile::Open(bef_buffer_, host_context_->GetKernelRegistry(),
169                     host_context_->diag_handler(), host_context_->allocator());
170   CHECK(bef_file_) << "Failed to open BEF";
171 
172   // Run TFRT initialization function to pre-instantiate fallback kernels.
173   RunTfrtInitializer();
174 }
175 
Execute(llvm::StringRef function_name,llvm::ArrayRef<Tensor> arguments)176 llvm::SmallVector<Tensor> RuntimeFallbackExecutor::Execute(
177     llvm::StringRef function_name, llvm::ArrayRef<Tensor> arguments) {
178   // Get the kernel entrypoint function.
179   const Function* compute = bef_file_->GetFunction(function_name);
180   CHECK(compute) << "Entrypoint function not found";
181   CHECK_EQ(arguments.size() + 1, compute->num_arguments())
182       << "Wrong number of arguments for function " << function_name.str();
183 
184   // Prepare function arguments from ready Chain and input Tensors.
185   llvm::SmallVector<tfrt::AsyncValue*> exec_arguments;
186   exec_arguments.reserve(compute->num_arguments());
187   exec_arguments.push_back(tfrt::GetReadyChain().release());
188   for (const Tensor& input_tensor : arguments) {
189     auto av = MakeAvailableAsyncValueRef<FallbackTensor>(input_tensor);
190     exec_arguments.push_back(av.release());
191   }
192 
193   // Space for returned values.
194   llvm::SmallVector<RCReference<AsyncValue>> results(compute->num_results());
195 
196   compute->Execute(*exec_ctx_, exec_arguments, results);
197 
198   // Wait for the function execution to finish, as well as the side-effects.
199   host_context_->Await(results);
200 
201   // Check that all results are available.
202   llvm::SmallVector<Tensor> ret_values;
203   for (unsigned i = 1; i < results.size(); ++i) {
204     if (auto* error = results[i]->GetErrorIfPresent())
205       LOG(FATAL) << "Failed to execute a function: " << StrCat(*error);
206     ret_values.push_back(results[i]->get<tfrt_stub::FallbackTensor>().tensor());
207   }
208 
209   // Deallocate arguments.
210   for (auto* argument : exec_arguments) argument->DropRef();
211   return ret_values;
212 }
213 
214 // Run TFRT fallback initialization function to instantiate all fallback
215 // kernels ahead of executing the compute function.
RunTfrtInitializer()216 void RuntimeFallbackExecutor::RunTfrtInitializer() {
217   const Function* func = bef_file_->GetFunction("_tfrt_fallback_init");
218   CHECK(func) << "TFRT initialization function was not found";
219   CHECK_EQ(func->argument_types().size(), 1);
220 
221   llvm::SmallVector<RCReference<AsyncValue>, 1> results;
222   results.resize(func->result_types().size());
223   CHECK_EQ(results.size(), 1);
224 
225   func->Execute(*exec_ctx_, tfrt::GetReadyChain().GetAsyncValue(), results);
226 
227   host_context_->Await(results);
228 
229   CHECK(!results[0]->IsError()) << "Failed to run TFRT initialization function";
230 }
231 
232 }  // namespace tensorflow
233