1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tfrt/fallback/op_kernel_runner.h"
16
17 #include "tensorflow/core/platform/errors.h"
18
19 namespace tensorflow {
20 namespace tfrt_stub {
21 namespace {
22
CheckOpDefCompatibility(const tensorflow::OpDef & op_def)23 Status CheckOpDefCompatibility(const tensorflow::OpDef& op_def) {
24 auto check_arg_def = [&](const auto& arg_def) {
25 if (arg_def.is_ref())
26 return tensorflow::errors::Internal(
27 "TFRT kernel fallback error: Unsupported ref args in ",
28 op_def.name());
29 return OkStatus();
30 };
31
32 for (const auto& arg_def : op_def.input_arg())
33 TF_RETURN_IF_ERROR(check_arg_def(arg_def));
34 for (const auto& arg_def : op_def.output_arg())
35 TF_RETURN_IF_ERROR(check_arg_def(arg_def));
36
37 return OkStatus();
38 }
39
40 // Create a tensorflow::NodeDef from the tensorflow::OpDef and the attributes.
BuildNodeDef(const tensorflow::OpDef & op_def,int num_args,const std::function<Status (tensorflow::AttrValueMap *)> & attr_builder)41 StatusOr<tensorflow::NodeDef> BuildNodeDef(
42 const tensorflow::OpDef& op_def, int num_args,
43 const std::function<Status(tensorflow::AttrValueMap*)>& attr_builder) {
44 tensorflow::NodeDef node_def;
45 node_def.set_name(op_def.name());
46 node_def.set_op(op_def.name());
47 for (int i = 0; i < num_args; ++i) {
48 node_def.add_input("dummy_input");
49 }
50
51 auto* attr_value_map = node_def.mutable_attr();
52 TF_RETURN_IF_ERROR(attr_builder(attr_value_map));
53
54 // For any attr-value pairs that exist in the op def (from op registry)
55 // but not in `attr_value_map`, fill them into `attr_value_map`, so that we
56 // can run a TFE_Op without having to specify all the default attr values
57 // (e.g. for matmul, the `transpose_a` attr defaults to false).
58 for (const auto& attr_def : op_def.attr()) {
59 if (attr_def.has_default_value()) {
60 // Insertion will fail if this attribute already has a value.
61 attr_value_map->insert({attr_def.name(), attr_def.default_value()});
62 }
63 }
64 return node_def;
65 }
66
CreateOpKernel(tensorflow::FunctionLibraryRuntime * flr,tensorflow::NodeDef ndef,std::unique_ptr<tensorflow::OpKernel> * result)67 tensorflow::Status CreateOpKernel(
68 tensorflow::FunctionLibraryRuntime* flr, tensorflow::NodeDef ndef,
69 std::unique_ptr<tensorflow::OpKernel>* result) {
70 std::shared_ptr<const tensorflow::NodeProperties> props;
71 TF_RETURN_IF_ERROR(tensorflow::NodeProperties::CreateFromNodeDef(
72 ndef, flr->GetFunctionLibraryDefinition(), &props));
73 tensorflow::OpKernel* k = nullptr;
74 TF_RETURN_IF_ERROR(flr->CreateKernel(props, &k));
75 result->reset(k);
76 return OkStatus();
77 }
78
79 } // namespace
80
Create(absl::string_view op_name,absl::string_view device_name,int num_args,const std::function<Status (tensorflow::AttrValueMap *)> & attr_builder,const tensorflow::DeviceMgr & device_manager,const tensorflow::ProcessFunctionLibraryRuntime & process_function_library_runtime)81 StatusOr<OpKernelRunner> OpKernelRunner::Create(
82 absl::string_view op_name, absl::string_view device_name, int num_args,
83 const std::function<Status(tensorflow::AttrValueMap*)>& attr_builder,
84 const tensorflow::DeviceMgr& device_manager,
85 const tensorflow::ProcessFunctionLibraryRuntime&
86 process_function_library_runtime) {
87 tensorflow::Device* device = nullptr;
88 Status s = device_manager.LookupDevice(device_name, &device);
89
90 // Fall back to host device if it fails to find the specified device.
91 if (!s.ok()) {
92 LOG(WARNING) << "Failed to find device " << device_name
93 << " when creating OpKernel: " << op_name << ". Error: " << s;
94 LOG(WARNING) << "Fallback to host device instead";
95 device = device_manager.HostCPU();
96 }
97
98 return Create(op_name, num_args, attr_builder,
99 process_function_library_runtime, device);
100 }
101
Create(absl::string_view op_name,int num_args,const std::function<Status (tensorflow::AttrValueMap *)> & attr_builder,const tensorflow::ProcessFunctionLibraryRuntime & process_function_library_runtime,tensorflow::Device * device)102 StatusOr<OpKernelRunner> OpKernelRunner::Create(
103 absl::string_view op_name, int num_args,
104 const std::function<Status(tensorflow::AttrValueMap*)>& attr_builder,
105 const tensorflow::ProcessFunctionLibraryRuntime&
106 process_function_library_runtime,
107 tensorflow::Device* device) {
108 const OpDef* op_def = nullptr;
109 TF_RETURN_IF_ERROR(tensorflow::OpRegistry::Global()->LookUpOpDef(
110 std::string(op_name), &op_def));
111 TF_RETURN_IF_ERROR(CheckOpDefCompatibility(*op_def));
112 VLOG(1) << "KernelFallbackExecuteCompat creating op from OpDef: "
113 << op_def->DebugString();
114
115 TF_ASSIGN_OR_RETURN(auto node_def,
116 BuildNodeDef(*op_def, num_args, attr_builder));
117
118 VLOG(1) << "KernelFallbackExecuteCompat created NodeDef: "
119 << node_def.DebugString();
120
121 tensorflow::FunctionLibraryRuntime* function_library_runtime = nullptr;
122
123 function_library_runtime =
124 process_function_library_runtime.GetFLR(device->name());
125
126 std::unique_ptr<OpKernel> op_kernel;
127 TF_RETURN_IF_ERROR(CreateOpKernel(function_library_runtime,
128 std::move(node_def), &op_kernel));
129 return OpKernelRunner(device, function_library_runtime, std::move(op_kernel));
130 }
131
OpKernelRunner(tensorflow::Device * device,tensorflow::FunctionLibraryRuntime * function_library_runtime,std::unique_ptr<tensorflow::OpKernel> op_kernel)132 OpKernelRunner::OpKernelRunner(
133 tensorflow::Device* device,
134 tensorflow::FunctionLibraryRuntime* function_library_runtime,
135 std::unique_ptr<tensorflow::OpKernel> op_kernel)
136 : device_(device),
137 function_library_runtime_(function_library_runtime),
138 resource_manager_(device->resource_manager()),
139 op_kernel_(std::move(op_kernel)),
140 is_async_(op_kernel_->AsAsync() != nullptr) {
141 DCHECK(device_);
142 DCHECK(function_library_runtime_);
143
144 const auto& input_memory_types = op_kernel_->input_memory_types();
145 input_alloc_attrs_.resize(op_kernel_->num_inputs());
146 for (size_t i = 0, e = op_kernel_->num_inputs(); i < e; ++i) {
147 input_alloc_attrs_[i].set_on_host(input_memory_types[i] ==
148 tensorflow::HOST_MEMORY);
149 }
150 const auto& output_memory_types = op_kernel_->output_memory_types();
151 output_alloc_attrs_.resize(op_kernel_->num_outputs());
152 for (size_t i = 0, e = output_alloc_attrs_.size(); i < e; ++i) {
153 output_alloc_attrs_[i].set_on_host(output_memory_types[i] ==
154 tensorflow::HOST_MEMORY);
155 }
156 }
157
Run(OpKernelContext * context) const158 void OpKernelRunner::Run(OpKernelContext* context) const {
159 DVLOG(1) << "KernelFallbackExecuteCompat Running Op: "
160 << op_kernel_->def().DebugString()
161 << ", on Device: " << context->device()->name();
162
163 static_cast<tensorflow::Device*>(context->device())
164 ->Compute(op_kernel_.get(), context);
165 }
166
RunAsync(OpKernelContext * context,AsyncOpKernel::DoneCallback done_callback) const167 void OpKernelRunner::RunAsync(OpKernelContext* context,
168 AsyncOpKernel::DoneCallback done_callback) const {
169 DVLOG(1) << "KernelFallbackExecuteCompat Running Async Op: "
170 << op_kernel_->def().DebugString()
171 << ", on Device: " << context->device()->name();
172
173 AsyncOpKernel* async = op_kernel_->AsAsync();
174 DCHECK(async);
175
176 static_cast<tensorflow::Device*>(context->device())
177 ->ComputeAsync(async, context, std::move(done_callback));
178 }
179
180 } // namespace tfrt_stub
181 } // namespace tensorflow
182