xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tfrt/fallback/op_kernel_runner.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tfrt/fallback/op_kernel_runner.h"
16 
17 #include "tensorflow/core/platform/errors.h"
18 
19 namespace tensorflow {
20 namespace tfrt_stub {
21 namespace {
22 
CheckOpDefCompatibility(const tensorflow::OpDef & op_def)23 Status CheckOpDefCompatibility(const tensorflow::OpDef& op_def) {
24   auto check_arg_def = [&](const auto& arg_def) {
25     if (arg_def.is_ref())
26       return tensorflow::errors::Internal(
27           "TFRT kernel fallback error: Unsupported ref args in ",
28           op_def.name());
29     return OkStatus();
30   };
31 
32   for (const auto& arg_def : op_def.input_arg())
33     TF_RETURN_IF_ERROR(check_arg_def(arg_def));
34   for (const auto& arg_def : op_def.output_arg())
35     TF_RETURN_IF_ERROR(check_arg_def(arg_def));
36 
37   return OkStatus();
38 }
39 
40 // Create a tensorflow::NodeDef from the tensorflow::OpDef and the attributes.
BuildNodeDef(const tensorflow::OpDef & op_def,int num_args,const std::function<Status (tensorflow::AttrValueMap *)> & attr_builder)41 StatusOr<tensorflow::NodeDef> BuildNodeDef(
42     const tensorflow::OpDef& op_def, int num_args,
43     const std::function<Status(tensorflow::AttrValueMap*)>& attr_builder) {
44   tensorflow::NodeDef node_def;
45   node_def.set_name(op_def.name());
46   node_def.set_op(op_def.name());
47   for (int i = 0; i < num_args; ++i) {
48     node_def.add_input("dummy_input");
49   }
50 
51   auto* attr_value_map = node_def.mutable_attr();
52   TF_RETURN_IF_ERROR(attr_builder(attr_value_map));
53 
54   // For any attr-value pairs that exist in the op def (from op registry)
55   // but not in `attr_value_map`, fill them into `attr_value_map`, so that we
56   // can run a TFE_Op without having to specify all the default attr values
57   // (e.g. for matmul, the `transpose_a` attr defaults to false).
58   for (const auto& attr_def : op_def.attr()) {
59     if (attr_def.has_default_value()) {
60       // Insertion will fail if this attribute already has a value.
61       attr_value_map->insert({attr_def.name(), attr_def.default_value()});
62     }
63   }
64   return node_def;
65 }
66 
CreateOpKernel(tensorflow::FunctionLibraryRuntime * flr,tensorflow::NodeDef ndef,std::unique_ptr<tensorflow::OpKernel> * result)67 tensorflow::Status CreateOpKernel(
68     tensorflow::FunctionLibraryRuntime* flr, tensorflow::NodeDef ndef,
69     std::unique_ptr<tensorflow::OpKernel>* result) {
70   std::shared_ptr<const tensorflow::NodeProperties> props;
71   TF_RETURN_IF_ERROR(tensorflow::NodeProperties::CreateFromNodeDef(
72       ndef, flr->GetFunctionLibraryDefinition(), &props));
73   tensorflow::OpKernel* k = nullptr;
74   TF_RETURN_IF_ERROR(flr->CreateKernel(props, &k));
75   result->reset(k);
76   return OkStatus();
77 }
78 
79 }  // namespace
80 
Create(absl::string_view op_name,absl::string_view device_name,int num_args,const std::function<Status (tensorflow::AttrValueMap *)> & attr_builder,const tensorflow::DeviceMgr & device_manager,const tensorflow::ProcessFunctionLibraryRuntime & process_function_library_runtime)81 StatusOr<OpKernelRunner> OpKernelRunner::Create(
82     absl::string_view op_name, absl::string_view device_name, int num_args,
83     const std::function<Status(tensorflow::AttrValueMap*)>& attr_builder,
84     const tensorflow::DeviceMgr& device_manager,
85     const tensorflow::ProcessFunctionLibraryRuntime&
86         process_function_library_runtime) {
87   tensorflow::Device* device = nullptr;
88   Status s = device_manager.LookupDevice(device_name, &device);
89 
90   // Fall back to host device if it fails to find the specified device.
91   if (!s.ok()) {
92     LOG(WARNING) << "Failed to find device " << device_name
93                  << " when creating OpKernel: " << op_name << ". Error: " << s;
94     LOG(WARNING) << "Fallback to host device instead";
95     device = device_manager.HostCPU();
96   }
97 
98   return Create(op_name, num_args, attr_builder,
99                 process_function_library_runtime, device);
100 }
101 
Create(absl::string_view op_name,int num_args,const std::function<Status (tensorflow::AttrValueMap *)> & attr_builder,const tensorflow::ProcessFunctionLibraryRuntime & process_function_library_runtime,tensorflow::Device * device)102 StatusOr<OpKernelRunner> OpKernelRunner::Create(
103     absl::string_view op_name, int num_args,
104     const std::function<Status(tensorflow::AttrValueMap*)>& attr_builder,
105     const tensorflow::ProcessFunctionLibraryRuntime&
106         process_function_library_runtime,
107     tensorflow::Device* device) {
108   const OpDef* op_def = nullptr;
109   TF_RETURN_IF_ERROR(tensorflow::OpRegistry::Global()->LookUpOpDef(
110       std::string(op_name), &op_def));
111   TF_RETURN_IF_ERROR(CheckOpDefCompatibility(*op_def));
112   VLOG(1) << "KernelFallbackExecuteCompat creating op from OpDef: "
113           << op_def->DebugString();
114 
115   TF_ASSIGN_OR_RETURN(auto node_def,
116                       BuildNodeDef(*op_def, num_args, attr_builder));
117 
118   VLOG(1) << "KernelFallbackExecuteCompat created NodeDef: "
119           << node_def.DebugString();
120 
121   tensorflow::FunctionLibraryRuntime* function_library_runtime = nullptr;
122 
123   function_library_runtime =
124       process_function_library_runtime.GetFLR(device->name());
125 
126   std::unique_ptr<OpKernel> op_kernel;
127   TF_RETURN_IF_ERROR(CreateOpKernel(function_library_runtime,
128                                     std::move(node_def), &op_kernel));
129   return OpKernelRunner(device, function_library_runtime, std::move(op_kernel));
130 }
131 
OpKernelRunner(tensorflow::Device * device,tensorflow::FunctionLibraryRuntime * function_library_runtime,std::unique_ptr<tensorflow::OpKernel> op_kernel)132 OpKernelRunner::OpKernelRunner(
133     tensorflow::Device* device,
134     tensorflow::FunctionLibraryRuntime* function_library_runtime,
135     std::unique_ptr<tensorflow::OpKernel> op_kernel)
136     : device_(device),
137       function_library_runtime_(function_library_runtime),
138       resource_manager_(device->resource_manager()),
139       op_kernel_(std::move(op_kernel)),
140       is_async_(op_kernel_->AsAsync() != nullptr) {
141   DCHECK(device_);
142   DCHECK(function_library_runtime_);
143 
144   const auto& input_memory_types = op_kernel_->input_memory_types();
145   input_alloc_attrs_.resize(op_kernel_->num_inputs());
146   for (size_t i = 0, e = op_kernel_->num_inputs(); i < e; ++i) {
147     input_alloc_attrs_[i].set_on_host(input_memory_types[i] ==
148                                       tensorflow::HOST_MEMORY);
149   }
150   const auto& output_memory_types = op_kernel_->output_memory_types();
151   output_alloc_attrs_.resize(op_kernel_->num_outputs());
152   for (size_t i = 0, e = output_alloc_attrs_.size(); i < e; ++i) {
153     output_alloc_attrs_[i].set_on_host(output_memory_types[i] ==
154                                        tensorflow::HOST_MEMORY);
155   }
156 }
157 
Run(OpKernelContext * context) const158 void OpKernelRunner::Run(OpKernelContext* context) const {
159   DVLOG(1) << "KernelFallbackExecuteCompat Running Op: "
160            << op_kernel_->def().DebugString()
161            << ", on Device: " << context->device()->name();
162 
163   static_cast<tensorflow::Device*>(context->device())
164       ->Compute(op_kernel_.get(), context);
165 }
166 
RunAsync(OpKernelContext * context,AsyncOpKernel::DoneCallback done_callback) const167 void OpKernelRunner::RunAsync(OpKernelContext* context,
168                               AsyncOpKernel::DoneCallback done_callback) const {
169   DVLOG(1) << "KernelFallbackExecuteCompat Running Async Op: "
170            << op_kernel_->def().DebugString()
171            << ", on Device: " << context->device()->name();
172 
173   AsyncOpKernel* async = op_kernel_->AsAsync();
174   DCHECK(async);
175 
176   static_cast<tensorflow::Device*>(context->device())
177       ->ComputeAsync(async, context, std::move(done_callback));
178 }
179 
180 }  // namespace tfrt_stub
181 }  // namespace tensorflow
182