1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ 18 19 #include <cassert> 20 #include <string> 21 22 #include "tensorflow/compiler/xla/cpu_function_runtime.h" 23 #include "tensorflow/compiler/xla/executable_run_options.h" 24 #include "tensorflow/compiler/xla/service/custom_call_status_internal.h" 25 #include "tensorflow/core/platform/types.h" 26 27 // Forward-declare, rather than include, to reduce code size for users that 28 // never use this functionality. 29 namespace xla { 30 class ProgramShapeProto; 31 class HloProfilePrinterData; 32 } // namespace xla 33 34 namespace tensorflow { 35 36 // Represents a function compiled by XLA, produced via either JIT or AOT. 37 // 38 // The Run method invokes the actual computation, with inputs read from arg 39 // buffers, and outputs written to result buffers. Each Run call may also use a 40 // set of temporary buffers for the computation. 41 // 42 // By default each instance of this class manages its own arg, result and temp 43 // buffers. The AllocMode constructor parameter may be used to modify the buffer 44 // allocation strategy. 45 // 46 // Under the default allocation strategy, this class is thread-compatible: 47 // o Calls to non-const methods require exclusive access to the object. 48 // o Concurrent calls to const methods are OK, if those calls are made while it 49 // is guaranteed that no thread may call a non-const method. 50 class XlaCompiledCpuFunction { 51 public: 52 // Type of the raw function, produced by either JIT or AOT. 53 using RawFunction = void (*)(void* result, 54 const xla::ExecutableRunOptions* run_options, 55 const void** args, void** temps, 56 XlaCustomCallStatus*, int64_t* profile_counters); 57 58 // StaticData represents the state necessary to run an XLA-compiled 59 // function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for 60 // AOT this is backed by data compiled into the object file. 61 // 62 // The contents of StaticData are XLA-internal implementation details and 63 // should not be relied on by clients (and therefore are private). 64 class StaticData { 65 private: 66 // The raw function to call. 67 RawFunction raw_function_; 68 69 // Contains information about the buffers used by the XLA computation. 70 const xla::cpu_function_runtime::BufferInfo* buffer_infos_ = nullptr; 71 size_t num_buffers_ = 0; 72 73 // Entry parameter i is described by 74 // buffer_infos[arg_index_table[i]]. 75 const int32* arg_index_table_ = nullptr; 76 77 // There are num_args entry parameters. 78 int64_t num_args_ = 0; 79 80 // There are num_variables variables. 81 int64_t num_variables_ = 0; 82 83 // The 0-based index of the result tuple, in the temp buffers. 84 size_t result_index_ = 0; 85 86 // [Optional] Arrays of arg and result names. These are arrays of C-style 87 // strings, where the array is terminated by nullptr. 88 const char** arg_names_ = nullptr; 89 const char** variable_names_ = nullptr; 90 const char** result_names_ = nullptr; 91 92 // [Optional] Arg and result shapes. 93 const xla::ProgramShapeProto* program_shape_ = nullptr; 94 95 // [Optional] Profile printer data. Null if profiling is disabled. 96 const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; 97 98 // [Optional] The number of profile counters expected in the profile counter 99 // buffer by the generated code and hlo_profile_printer. 0 if profiling is 100 // disabled. This information is already present in 101 // hlo_profile_printer_data but xla::HloProfilePrinterData is forward 102 // declared so we don't have access to that information here. 103 int64_t profile_counters_size_ = 0; 104 105 // Only XlaCompiledCpuFunction is allowed to read and write the above 106 // fields. 107 friend class XlaCompiledCpuFunction; 108 }; 109 110 // AllocMode controls the buffer allocation mode. 111 enum class AllocMode { 112 // Allocate all buffers - args, results, profile and temps. 113 ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS, 114 115 // Only allocate result, profile and temp buffers. 116 // Use set_arg_data to set argument buffers before Run is called. 117 RESULTS_PROFILES_AND_TEMPS_ONLY, 118 }; 119 120 explicit XlaCompiledCpuFunction( 121 const StaticData& static_data, 122 AllocMode alloc_mode = 123 AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS); 124 virtual ~XlaCompiledCpuFunction(); 125 126 XlaCompiledCpuFunction(const XlaCompiledCpuFunction&) = delete; 127 XlaCompiledCpuFunction& operator=(const XlaCompiledCpuFunction&) = delete; 128 129 // Sets the intra-op thread pool used to run individual ops concurrently. set_thread_pool(const Eigen::ThreadPoolDevice * pool)130 void set_thread_pool(const Eigen::ThreadPoolDevice* pool) { 131 run_options_.set_intra_op_thread_pool(pool); 132 } 133 134 // Runs the computation, with inputs read from arg buffers, and outputs 135 // written to result buffers. Returns true on success and false on failure. 136 bool Run(); 137 138 // Returns the error message from the previous failed Run call. 139 // 140 // TODO(fschneider): For now this always returns an empty string because there 141 // is no support for error reporting in XLA. Remove this once all callers are 142 // updated. error_msg()143 string error_msg() const { return {}; } 144 145 // ------------------------------ 146 // Arg methods for managing input buffers. Buffers are in row-major order. 147 148 // Returns the buffer for the positional argument at the given `index`. arg_data(size_t index)149 void* arg_data(size_t index) { 150 return buffer_table_[arg_index_table_[index]]; 151 } arg_data(size_t index)152 const void* arg_data(size_t index) const { 153 return buffer_table_[arg_index_table_[index]]; 154 } 155 num_args()156 int num_args() const { return num_args_; } 157 num_variables()158 int num_variables() const { return num_variables_; } 159 160 // Returns the size of entry parameter `idx`. 161 // 162 // There is a static version of this method on tfcompile generated subclasses 163 // of XlaCompiledCpuFunction, but try to prefer this when possible since it 164 // works both for XlaJitCompiledCpuFunction and AOT compiled subclasses. arg_size(int idx)165 int arg_size(int idx) const { 166 assert(idx < num_args()); 167 return buffer_infos_[arg_index_table_[idx]].size(); 168 } 169 170 // Sets the buffer for the positional argument at the given `index` to `data`. 171 // Must be called before Run to have an effect. May be called under any 172 // AllocMode; if the AllocMode is RESULTS_AND_TEMPS_ONLY, this method must be 173 // called for each positional argument, in order to set the argument buffers. 174 // 175 // Allocated memory must be aligned to the size specified by 176 // xla::cpu_function_runtime::MinAlign(). If possible, use the functions in 177 // tensorflow/compiler/tf2xla/cpu_function_runtime.h to ensure correct 178 // alignment. 179 // 180 // Aliasing of argument and result buffers is not allowed, and results in 181 // undefined behavior. set_arg_data(size_t index,const void * data)182 void set_arg_data(size_t index, const void* data) { 183 assert((arg_size(index) < xla::cpu_function_runtime::MinAlign() || 184 (uintptr_t)data % xla::cpu_function_runtime::MinAlign() == 0) && 185 "Underaligned pointer!"); 186 // The const_cast is safe because the generated code does not write to arg 187 // buffers. 188 // 189 // buffer_table_ contains pointers to buffers that _will_ be written to by 190 // generated code so it would be misleading to make buffer_table_ a `const 191 // void**`. 192 buffer_table_[arg_index_table_[index]] = const_cast<void*>(data); 193 } 194 195 // ------------------------------ 196 // Result methods for managing output buffers. Buffers are in row-major order. 197 // Must only be called after a successful Run call. Unlike the arg methods, 198 // there is no set_resultN_data method. The result buffers are managed 199 // internally, and may change after each call to Run. 200 201 // Returns the underlying array of result buffers, where results()[I] is the 202 // buffer for the positional result at index I. results()203 void** results() { return static_cast<void**>(buffer_table_[result_index_]); } results()204 const void* const* results() const { 205 return static_cast<const void* const*>(buffer_table_[result_index_]); 206 } 207 208 // Profile counters for this XLA computation. 209 // 210 // When Hlo profiling is enabled (`hlo_profiling_enabled()` return true in 211 // this case) these counters are non-null and are automatically populated by 212 // `Run`. The counters can then be pretty-printed using 213 // `hlo_profile_printer()`. 214 // 215 // When Hlo profiling is disabled, this accessor returns null. profile_counters()216 const int64_t* profile_counters() const { return profile_counters_; } 217 218 // Returns the buffer for the positional result at the given `index`. result_data(size_t index)219 void* result_data(size_t index) { return results()[index]; } result_data(size_t index)220 const void* result_data(size_t index) const { return results()[index]; } 221 222 // ------------------------------ 223 // Methods for extracting optional metadata. 224 225 // Returns true iff data is available for the Lookup{Arg,Variable,Result}Index 226 // methods. E.g. the data might not be compiled into the binary for AOT. HasNameIndices()227 bool HasNameIndices() const { 228 return arg_names_ != nullptr && variable_names_ != nullptr && 229 result_names_ != nullptr; 230 } 231 232 // Returns the 0-based index for the argument with the given `name`. 233 // Returns -1 if the name wasn't found, or data isn't available. 234 // 235 // The index remains constant for every instance of XlaCompiledCpuFunction 236 // generated from the same static data, and might not be cheap to determine. 237 // Recommended usage is to capture this in a variable for re-use. 238 int LookupArgIndex(const string& name) const; 239 240 // Returns the 0-based index for the variable with the given `name`. 241 // Returns -1 if the name wasn't found, or data isn't available. 242 // 243 // The index remains constant for every instance of XlaCompiledCpuFunction 244 // generated from the same static data, and might not be cheap to determine. 245 // Recommended usage is to capture this in a variable for re-use. 246 int LookupVariableIndex(const string& name) const; 247 248 // Returns the 0-based index for the result with the given `name`. 249 // Returns -1 if the name wasn't found, or data isn't available. 250 // 251 // The index remains constant for every instance of XlaCompiledCpuFunction 252 // generated from the same static data, and might not be cheap to determine. 253 // Recommended usage is to capture this in a variable for re-use. 254 int LookupResultIndex(const string& name) const; 255 256 // Returns the shape of the args and results. May return nullptr if the 257 // program shape isn't available. ProgramShape()258 const xla::ProgramShapeProto* ProgramShape() const { return program_shape_; } 259 hlo_profiling_enabled()260 bool hlo_profiling_enabled() const { 261 return hlo_profile_printer_data_ != nullptr; 262 } hlo_profile_printer_data()263 const xla::HloProfilePrinterData& hlo_profile_printer_data() const { 264 assert(hlo_profiling_enabled()); 265 return *hlo_profile_printer_data_; 266 } 267 268 protected: 269 // --------------------------------------------------------------------------- 270 // Accessors for reading from and writing to instances of `StaticData`. 271 // 272 // Classes generated by tfcompile can call these because the generated classes 273 // inherit from `XlaCompiledCpuFunction`. `XlaJitCompiledCpuFunction` can 274 // call these because it is explicitly added as a friend. 275 set_static_data_raw_function(StaticData * static_data,RawFunction raw_function)276 static void set_static_data_raw_function(StaticData* static_data, 277 RawFunction raw_function) { 278 static_data->raw_function_ = raw_function; 279 } 280 set_static_data_buffer_infos(StaticData * static_data,const xla::cpu_function_runtime::BufferInfo * buffer_infos)281 static void set_static_data_buffer_infos( 282 StaticData* static_data, 283 const xla::cpu_function_runtime::BufferInfo* buffer_infos) { 284 static_data->buffer_infos_ = buffer_infos; 285 } 286 set_static_data_num_buffers(StaticData * static_data,size_t num_buffers)287 static void set_static_data_num_buffers(StaticData* static_data, 288 size_t num_buffers) { 289 static_data->num_buffers_ = num_buffers; 290 } 291 set_static_data_arg_index_table(StaticData * static_data,const int32 * arg_index_table)292 static void set_static_data_arg_index_table(StaticData* static_data, 293 const int32* arg_index_table) { 294 static_data->arg_index_table_ = arg_index_table; 295 } 296 set_static_data_num_args(StaticData * static_data,int64_t num_args)297 static void set_static_data_num_args(StaticData* static_data, 298 int64_t num_args) { 299 static_data->num_args_ = num_args; 300 } 301 set_static_data_num_variables(StaticData * static_data,int64_t num_variables)302 static void set_static_data_num_variables(StaticData* static_data, 303 int64_t num_variables) { 304 static_data->num_variables_ = num_variables; 305 } 306 set_static_data_result_index(StaticData * static_data,size_t result_index)307 static void set_static_data_result_index(StaticData* static_data, 308 size_t result_index) { 309 static_data->result_index_ = result_index; 310 } 311 set_static_data_arg_names(StaticData * static_data,const char ** arg_names)312 static void set_static_data_arg_names(StaticData* static_data, 313 const char** arg_names) { 314 static_data->arg_names_ = arg_names; 315 } 316 set_static_data_variable_names(StaticData * static_data,const char ** variable_names)317 static void set_static_data_variable_names(StaticData* static_data, 318 const char** variable_names) { 319 static_data->variable_names_ = variable_names; 320 } 321 set_static_data_result_names(StaticData * static_data,const char ** result_names)322 static void set_static_data_result_names(StaticData* static_data, 323 const char** result_names) { 324 static_data->result_names_ = result_names; 325 } 326 set_static_data_program_shape(StaticData * static_data,const xla::ProgramShapeProto * program_shape)327 static void set_static_data_program_shape( 328 StaticData* static_data, const xla::ProgramShapeProto* program_shape) { 329 static_data->program_shape_ = program_shape; 330 } 331 set_static_data_hlo_profile_printer_data(StaticData * static_data,const xla::HloProfilePrinterData * hlo_profile_printer_data)332 static void set_static_data_hlo_profile_printer_data( 333 StaticData* static_data, 334 const xla::HloProfilePrinterData* hlo_profile_printer_data) { 335 static_data->hlo_profile_printer_data_ = hlo_profile_printer_data; 336 } 337 338 static const xla::HloProfilePrinterData* get_static_data_hlo_profile_printer_data(StaticData * static_data)339 get_static_data_hlo_profile_printer_data(StaticData* static_data) { 340 return static_data->hlo_profile_printer_data_; 341 } 342 set_static_data_profile_counters_size(StaticData * static_data,int64_t profile_counters_size)343 static void set_static_data_profile_counters_size( 344 StaticData* static_data, int64_t profile_counters_size) { 345 static_data->profile_counters_size_ = profile_counters_size; 346 } 347 348 private: 349 const RawFunction raw_function_; 350 const size_t result_index_; 351 352 // Array containing pointers to argument and temp buffers (slots corresponding 353 // to constant and on-stack buffers are null). 354 void** const buffer_table_; 355 356 // Describes the buffers used by the XLA computation. 357 const xla::cpu_function_runtime::BufferInfo* const buffer_infos_; 358 359 // Argument i needs to be placed in buffer_table_[arg_index_to_temp_index_[i]] 360 // for XLA generated code to be able to find it. 361 const int32* const arg_index_table_; 362 363 // The number of incoming arguments. 364 const int32 num_args_; 365 366 // The number of incoming variables. 367 const int32 num_variables_; 368 369 // Backing memory for buffer_table_ and args_, the latter depending on 370 // AllocMode. 371 void* alloc_buffer_table_ = nullptr; 372 373 // Backing memory for profiling counters. 374 int64_t* profile_counters_ = nullptr; 375 376 // Options and context passed to the compiled function. 377 xla::ExecutableRunOptions run_options_; 378 379 // Optional metadata. 380 const char** arg_names_ = nullptr; 381 const char** variable_names_ = nullptr; 382 const char** result_names_ = nullptr; 383 const xla::ProgramShapeProto* program_shape_ = nullptr; 384 const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; 385 386 // Add `XlaJitCompiledCpuFunction` as a friend so that it can access the 387 // `set_static_data_*` static methods above. 388 friend class XlaJitCompiledCpuFunction; 389 }; 390 391 } // namespace tensorflow 392 393 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ 394