xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
17 
18 #include <cassert>
19 
20 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
21 
22 namespace tensorflow {
23 
XlaCompiledCpuFunction(const StaticData & static_data,AllocMode alloc_mode)24 XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
25                                                AllocMode alloc_mode)
26     : raw_function_(static_data.raw_function_),
27       result_index_(static_data.result_index_),
28       buffer_table_(new void*[static_data.num_buffers_]),
29       buffer_infos_(static_data.buffer_infos_),
30       arg_index_table_(static_data.arg_index_table_),
31       num_args_(static_data.num_args_),
32       num_variables_(static_data.num_variables_),
33       arg_names_(static_data.arg_names_),
34       variable_names_(static_data.variable_names_),
35       result_names_(static_data.result_names_),
36       program_shape_(static_data.program_shape_),
37       hlo_profile_printer_data_(static_data.hlo_profile_printer_data_) {
38   bool allocate_entry_params =
39       alloc_mode == AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS;
40   // Allocate arg and temp buffers.
41   alloc_buffer_table_ = xla::cpu_function_runtime::MallocContiguousBuffers(
42       static_data.buffer_infos_, static_data.num_buffers_,
43       /*allocate_entry_params=*/allocate_entry_params, buffer_table_,
44       /*annotate_initialized=*/true);
45   // If Hlo profiling is enabled the generated code expects an appropriately
46   // sized buffer to be passed in as the last argument.  If Hlo profiling is
47   // disabled the last function argument is still present in the function
48   // signature, but it is ignored by the generated code and we pass in null for
49   // it.
50   if (hlo_profiling_enabled()) {
51     profile_counters_ = new int64_t[static_data.profile_counters_size_]();
52   }
53 }
54 
Run()55 bool XlaCompiledCpuFunction::Run() {
56   XlaCustomCallStatus status;
57   raw_function_(buffer_table_[result_index_], &run_options_, nullptr,
58                 buffer_table_, &status, profile_counters_);
59   return !xla::CustomCallStatusGetMessage(&status).has_value();
60 }
61 
~XlaCompiledCpuFunction()62 XlaCompiledCpuFunction::~XlaCompiledCpuFunction() {
63   xla::cpu_function_runtime::FreeContiguous(alloc_buffer_table_);
64   delete[] buffer_table_;
65   delete[] profile_counters_;
66 }
67 
68 namespace {
69 
70 constexpr int kNotFound = -1;
71 
72 // Linear search through `names` looking for a match with `name`. Returns -1 if
73 // the name isn't found, or is empty.
74 //
75 // REQUIRES: `names` is a nullptr-terminated array.
LookupNameIndex(const string & name,const char ** names)76 int LookupNameIndex(const string& name, const char** names) {
77   // Hitting this assert means that there is no name-to-index data available;
78   // for AOT try the setting the tfcompile --gen_name_to_index flag.
79   assert(names != nullptr);
80 
81   if (name.empty()) {
82     return kNotFound;
83   }
84   for (int index = 0; names[index] != nullptr; ++index) {
85     if (name == names[index]) {
86       return index;
87     }
88   }
89   return kNotFound;
90 }
91 
92 }  // namespace
93 
LookupArgIndex(const string & name) const94 int XlaCompiledCpuFunction::LookupArgIndex(const string& name) const {
95   return LookupNameIndex(name, arg_names_);
96 }
97 
LookupVariableIndex(const string & name) const98 int XlaCompiledCpuFunction::LookupVariableIndex(const string& name) const {
99   int index = LookupNameIndex(name, variable_names_);
100   if (index == kNotFound) {
101     return kNotFound;
102   }
103   return num_args_ - num_variables_ + index;
104 }
105 
LookupResultIndex(const string & name) const106 int XlaCompiledCpuFunction::LookupResultIndex(const string& name) const {
107   return LookupNameIndex(name, result_names_);
108 }
109 
110 }  // namespace tensorflow
111