1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_executor.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <memory>
21 #include <string>
22 #include <utility>
23
24 #include "llvm/ADT/ArrayRef.h"
25 #include "llvm/ADT/SmallVector.h"
26 #include "llvm/ADT/StringRef.h"
27 #include "llvm/Support/SourceMgr.h"
28 #include "mlir/Parser/Parser.h" // from @llvm-project
29 #include "mlir/Pass/PassManager.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h"
31 #include "tensorflow/compiler/mlir/tfrt/transforms/passes.h"
32 #include "tensorflow/compiler/mlir/tfrt/utils/host_context.h"
33 #include "tensorflow/core/framework/tensor.h"
34 #include "tensorflow/core/platform/threadpool.h"
35 #include "tensorflow/core/platform/threadpool_interface.h"
36 #include "tensorflow/core/runtime_fallback/kernel/kernel_fallback_execute_compat.h"
37 #include "tensorflow/core/runtime_fallback/runtime/kernel_utils.h"
38 #include "tensorflow/core/tfrt/utils/fallback_tensor.h"
39 #include "tfrt/bef/bef_buffer.h" // from @tf_runtime
40 #include "tfrt/bef_converter/mlir_to_bef.h" // from @tf_runtime
41 #include "tfrt/bef_executor/bef_file.h" // from @tf_runtime
42 #include "tfrt/host_context/async_value.h" // from @tf_runtime
43 #include "tfrt/host_context/execution_context.h" // from @tf_runtime
44 #include "tfrt/host_context/function.h" // from @tf_runtime
45 #include "tfrt/host_context/host_context.h" // from @tf_runtime
46 #include "tfrt/host_context/resource_context.h" // from @tf_runtime
47 #include "tfrt/support/ref_count.h" // from @tf_runtime
48
49 namespace tensorflow {
50
51 using ::tfrt::AsyncValue;
52 using ::tfrt::BEFFile;
53 using ::tfrt::ExecutionContext;
54 using ::tfrt::Function;
55 using ::tfrt::HostContext;
56 using ::tfrt::MakeAvailableAsyncValueRef;
57 using ::tfrt::RCReference;
58 using ::tfrt::RequestContext;
59 using ::tfrt::RequestContextBuilder;
60 using ::tfrt::ResourceContext;
61
62 using ::tensorflow::Env;
63 using ::tensorflow::thread::ThreadPool;
64 using ::tensorflow::thread::ThreadPoolInterface;
65
66 using ::tensorflow::tfrt_stub::FallbackTensor;
67
68 // -------------------------------------------------------------------------- //
69 // Run function via the TF->TFRT fallback lowering.
70 // -------------------------------------------------------------------------- //
71
72 namespace {
73 // Thread pool for running `intra-op` tasks scheduled by the fallback kernels.
74 class IntraOpThreadPool : public ThreadPoolInterface {
75 public:
IntraOpThreadPool(int64_t num_threads)76 explicit IntraOpThreadPool(int64_t num_threads)
77 : tpool_(Env::Default(), "intra-op",
78 std::max(1, static_cast<int32_t>(num_threads))) {}
79
Schedule(std::function<void ()> fn)80 void Schedule(std::function<void()> fn) override {
81 tpool_.Schedule(std::move(fn));
82 }
83
NumThreads() const84 int NumThreads() const override { return tpool_.NumThreads(); }
CurrentThreadId() const85 int CurrentThreadId() const override { return tpool_.CurrentThreadId(); }
Cancel()86 void Cancel() override {}
87
88 private:
89 ThreadPool tpool_;
90 };
91 } // namespace
92
RuntimeFallbackExecutor(int64_t num_threads)93 RuntimeFallbackExecutor::RuntimeFallbackExecutor(int64_t num_threads)
94 : intra_op_(std::make_unique<IntraOpThreadPool>(num_threads)) {
95 // Create a HostContext for running TFRT functions. Concurrent work queue acts
96 // similar to the Tensorflow `inter-op` thread pool, so we'll match the size.
97 host_context_ = num_threads ? CreateMultiThreadedHostContext(num_threads)
98 : CreateSingleThreadedHostContext();
99 tfrt::RegisterStaticKernels(host_context_->GetMutableRegistry());
100
101 // Build an ExecutionContext from the HostContext.
102 auto builder = RequestContextBuilder(host_context_.get(), &resource_context_);
103
104 // Get tensorflow::EagerContext for the kernel fallback.
105 auto* eager_context_resource =
106 resource_context_
107 .GetOrCreateResource<tensorflow::tfd::EagerContextResource>(
108 tensorflow::tfd::kEagerContextResourceName);
109 auto expected_eager_context = eager_context_resource->GetTFEagerContext();
110 auto* eager_context = expected_eager_context.get();
111
112 // Initialize fallback kernels state with a custom intra-op thread pool.
113 auto status = tensorflow::tfd::SetUpKernelFallbackCompatRequestContext(
114 &builder, /*runner_table=*/nullptr, eager_context, intra_op_.get());
115 CHECK(status.ok()) << "Failed to setup request context: "
116 << status.error_message();
117
118 auto req_ctx = std::move(builder).build();
119 if (auto err = req_ctx.takeError())
120 LOG(FATAL) << "Failed to build a request context";
121
122 exec_ctx_ = std::make_unique<tfrt::ExecutionContext>(std::move(*req_ctx));
123 }
124
Prepare(llvm::StringRef mlir_input)125 void RuntimeFallbackExecutor::Prepare(llvm::StringRef mlir_input) {
126 // We only support IR written in the Tensorflow dialect.
127 mlir::DialectRegistry registry;
128 mlir::RegisterAllTensorFlowDialects(registry);
129 mlir::MLIRContext context(registry);
130
131 llvm::SourceMgr source_mgr;
132 source_mgr.AddNewSourceBuffer(
133 llvm::MemoryBuffer::getMemBuffer(mlir_input, "test_ir"), llvm::SMLoc());
134
135 // Parse a kernel source code into the MLIR Module.
136 mlir::OwningOpRef<mlir::ModuleOp> module(
137 mlir::parseSourceFile<mlir::ModuleOp>(source_mgr, &context));
138 CHECK(module) << "failed to parse mlir module";
139
140 // Collect all diagnostics emitted while lowering parsed kernel module.
141 std::string diagnostic_str;
142 llvm::raw_string_ostream os(diagnostic_str);
143 mlir::SourceMgrDiagnosticHandler handler(source_mgr, module->getContext(),
144 os);
145
146 // Convert TF to TFRT fallback dialect.
147 TfrtPipelineOptions pipeline_opts;
148 pipeline_opts.default_device = kDefaultHostDeviceName;
149 pipeline_opts.hoist_invariant_ops = true;
150 pipeline_opts.enable_native_ops = false;
151 pipeline_opts.cost_threshold = 1024;
152 pipeline_opts.upper_cost_threshold = 100000;
153 pipeline_opts.merge_inter_dependent_streams = true;
154 pipeline_opts.func_use_fallback_tensor = true;
155
156 mlir::PassManager pm(module->getContext());
157 pm.addPass(CreateTfToTfrtConversionPass(pipeline_opts));
158
159 CHECK(mlir::succeeded(pm.run(*module)))
160 << "Failed to lower module to TFRT: " << os.str();
161
162 // Convert module to BEF.
163 bef_buffer_ =
164 tfrt::ConvertMLIRToBEF(*module, /*disable_optional_sections=*/false);
165 CHECK(!bef_buffer_.empty()) << "Failed to convert module to BEF";
166
167 bef_file_ =
168 BEFFile::Open(bef_buffer_, host_context_->GetKernelRegistry(),
169 host_context_->diag_handler(), host_context_->allocator());
170 CHECK(bef_file_) << "Failed to open BEF";
171
172 // Run TFRT initialization function to pre-instantiate fallback kernels.
173 RunTfrtInitializer();
174 }
175
Execute(llvm::StringRef function_name,llvm::ArrayRef<Tensor> arguments)176 llvm::SmallVector<Tensor> RuntimeFallbackExecutor::Execute(
177 llvm::StringRef function_name, llvm::ArrayRef<Tensor> arguments) {
178 // Get the kernel entrypoint function.
179 const Function* compute = bef_file_->GetFunction(function_name);
180 CHECK(compute) << "Entrypoint function not found";
181 CHECK_EQ(arguments.size() + 1, compute->num_arguments())
182 << "Wrong number of arguments for function " << function_name.str();
183
184 // Prepare function arguments from ready Chain and input Tensors.
185 llvm::SmallVector<tfrt::AsyncValue*> exec_arguments;
186 exec_arguments.reserve(compute->num_arguments());
187 exec_arguments.push_back(tfrt::GetReadyChain().release());
188 for (const Tensor& input_tensor : arguments) {
189 auto av = MakeAvailableAsyncValueRef<FallbackTensor>(input_tensor);
190 exec_arguments.push_back(av.release());
191 }
192
193 // Space for returned values.
194 llvm::SmallVector<RCReference<AsyncValue>> results(compute->num_results());
195
196 compute->Execute(*exec_ctx_, exec_arguments, results);
197
198 // Wait for the function execution to finish, as well as the side-effects.
199 host_context_->Await(results);
200
201 // Check that all results are available.
202 llvm::SmallVector<Tensor> ret_values;
203 for (unsigned i = 1; i < results.size(); ++i) {
204 if (auto* error = results[i]->GetErrorIfPresent())
205 LOG(FATAL) << "Failed to execute a function: " << StrCat(*error);
206 ret_values.push_back(results[i]->get<tfrt_stub::FallbackTensor>().tensor());
207 }
208
209 // Deallocate arguments.
210 for (auto* argument : exec_arguments) argument->DropRef();
211 return ret_values;
212 }
213
214 // Run TFRT fallback initialization function to instantiate all fallback
215 // kernels ahead of executing the compute function.
RunTfrtInitializer()216 void RuntimeFallbackExecutor::RunTfrtInitializer() {
217 const Function* func = bef_file_->GetFunction("_tfrt_fallback_init");
218 CHECK(func) << "TFRT initialization function was not found";
219 CHECK_EQ(func->argument_types().size(), 1);
220
221 llvm::SmallVector<RCReference<AsyncValue>, 1> results;
222 results.resize(func->result_types().size());
223 CHECK_EQ(results.size(), 1);
224
225 func->Execute(*exec_ctx_, tfrt::GetReadyChain().GetAsyncValue(), results);
226
227 host_context_->Await(results);
228
229 CHECK(!results[0]->IsError()) << "Failed to run TFRT initialization function";
230 }
231
232 } // namespace tensorflow
233