xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_
18 
19 #include <cassert>
20 #include <string>
21 
22 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
23 #include "tensorflow/compiler/xla/executable_run_options.h"
24 #include "tensorflow/compiler/xla/service/custom_call_status_internal.h"
25 #include "tensorflow/core/platform/types.h"
26 
27 // Forward-declare, rather than include, to reduce code size for users that
28 // never use this functionality.
29 namespace xla {
30 class ProgramShapeProto;
31 class HloProfilePrinterData;
32 }  // namespace xla
33 
34 namespace tensorflow {
35 
36 // Represents a function compiled by XLA, produced via either JIT or AOT.
37 //
38 // The Run method invokes the actual computation, with inputs read from arg
39 // buffers, and outputs written to result buffers. Each Run call may also use a
40 // set of temporary buffers for the computation.
41 //
42 // By default each instance of this class manages its own arg, result and temp
43 // buffers. The AllocMode constructor parameter may be used to modify the buffer
44 // allocation strategy.
45 //
46 // Under the default allocation strategy, this class is thread-compatible:
47 // o Calls to non-const methods require exclusive access to the object.
48 // o Concurrent calls to const methods are OK, if those calls are made while it
49 //   is guaranteed that no thread may call a non-const method.
50 class XlaCompiledCpuFunction {
51  public:
52   // Type of the raw function, produced by either JIT or AOT.
53   using RawFunction = void (*)(void* result,
54                                const xla::ExecutableRunOptions* run_options,
55                                const void** args, void** temps,
56                                XlaCustomCallStatus*, int64_t* profile_counters);
57 
58   // StaticData represents the state necessary to run an XLA-compiled
59   // function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for
60   // AOT this is backed by data compiled into the object file.
61   //
62   // The contents of StaticData are XLA-internal implementation details and
63   // should not be relied on by clients (and therefore are private).
64   class StaticData {
65    private:
66     // The raw function to call.
67     RawFunction raw_function_;
68 
69     // Contains information about the buffers used by the XLA computation.
70     const xla::cpu_function_runtime::BufferInfo* buffer_infos_ = nullptr;
71     size_t num_buffers_ = 0;
72 
73     // Entry parameter i is described by
74     // buffer_infos[arg_index_table[i]].
75     const int32* arg_index_table_ = nullptr;
76 
77     // There are num_args entry parameters.
78     int64_t num_args_ = 0;
79 
80     // There are num_variables variables.
81     int64_t num_variables_ = 0;
82 
83     // The 0-based index of the result tuple, in the temp buffers.
84     size_t result_index_ = 0;
85 
86     // [Optional] Arrays of arg and result names. These are arrays of C-style
87     // strings, where the array is terminated by nullptr.
88     const char** arg_names_ = nullptr;
89     const char** variable_names_ = nullptr;
90     const char** result_names_ = nullptr;
91 
92     // [Optional] Arg and result shapes.
93     const xla::ProgramShapeProto* program_shape_ = nullptr;
94 
95     // [Optional] Profile printer data.  Null if profiling is disabled.
96     const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;
97 
98     // [Optional] The number of profile counters expected in the profile counter
99     // buffer by the generated code and hlo_profile_printer.  0 if profiling is
100     // disabled.  This information is already present in
101     // hlo_profile_printer_data but xla::HloProfilePrinterData is forward
102     // declared so we don't have access to that information here.
103     int64_t profile_counters_size_ = 0;
104 
105     // Only XlaCompiledCpuFunction is allowed to read and write the above
106     // fields.
107     friend class XlaCompiledCpuFunction;
108   };
109 
110   // AllocMode controls the buffer allocation mode.
111   enum class AllocMode {
112     // Allocate all buffers - args, results, profile and temps.
113     ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS,
114 
115     // Only allocate result, profile and temp buffers.
116     // Use set_arg_data to set argument buffers before Run is called.
117     RESULTS_PROFILES_AND_TEMPS_ONLY,
118   };
119 
120   explicit XlaCompiledCpuFunction(
121       const StaticData& static_data,
122       AllocMode alloc_mode =
123           AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS);
124   virtual ~XlaCompiledCpuFunction();
125 
126   XlaCompiledCpuFunction(const XlaCompiledCpuFunction&) = delete;
127   XlaCompiledCpuFunction& operator=(const XlaCompiledCpuFunction&) = delete;
128 
129   // Sets the intra-op thread pool used to run individual ops concurrently.
set_thread_pool(const Eigen::ThreadPoolDevice * pool)130   void set_thread_pool(const Eigen::ThreadPoolDevice* pool) {
131     run_options_.set_intra_op_thread_pool(pool);
132   }
133 
134   // Runs the computation, with inputs read from arg buffers, and outputs
135   // written to result buffers. Returns true on success and false on failure.
136   bool Run();
137 
138   // Returns the error message from the previous failed Run call.
139   //
140   // TODO(fschneider): For now this always returns an empty string because there
141   // is no support for error reporting in XLA. Remove this once all callers are
142   // updated.
error_msg()143   string error_msg() const { return {}; }
144 
145   // ------------------------------
146   // Arg methods for managing input buffers. Buffers are in row-major order.
147 
148   // Returns the buffer for the positional argument at the given `index`.
arg_data(size_t index)149   void* arg_data(size_t index) {
150     return buffer_table_[arg_index_table_[index]];
151   }
arg_data(size_t index)152   const void* arg_data(size_t index) const {
153     return buffer_table_[arg_index_table_[index]];
154   }
155 
num_args()156   int num_args() const { return num_args_; }
157 
num_variables()158   int num_variables() const { return num_variables_; }
159 
160   // Returns the size of entry parameter `idx`.
161   //
162   // There is a static version of this method on tfcompile generated subclasses
163   // of XlaCompiledCpuFunction, but try to prefer this when possible since it
164   // works both for XlaJitCompiledCpuFunction and AOT compiled subclasses.
arg_size(int idx)165   int arg_size(int idx) const {
166     assert(idx < num_args());
167     return buffer_infos_[arg_index_table_[idx]].size();
168   }
169 
170   // Sets the buffer for the positional argument at the given `index` to `data`.
171   // Must be called before Run to have an effect. May be called under any
172   // AllocMode; if the AllocMode is RESULTS_AND_TEMPS_ONLY, this method must be
173   // called for each positional argument, in order to set the argument buffers.
174   //
175   // Allocated memory must be aligned to the size specified by
176   // xla::cpu_function_runtime::MinAlign(). If possible, use the functions in
177   // tensorflow/compiler/tf2xla/cpu_function_runtime.h to ensure correct
178   // alignment.
179   //
180   // Aliasing of argument and result buffers is not allowed, and results in
181   // undefined behavior.
set_arg_data(size_t index,const void * data)182   void set_arg_data(size_t index, const void* data) {
183     assert((arg_size(index) < xla::cpu_function_runtime::MinAlign() ||
184             (uintptr_t)data % xla::cpu_function_runtime::MinAlign() == 0) &&
185            "Underaligned pointer!");
186     // The const_cast is safe because the generated code does not write to arg
187     // buffers.
188     //
189     // buffer_table_ contains pointers to buffers that _will_ be written to by
190     // generated code so it would be misleading to make buffer_table_ a `const
191     // void**`.
192     buffer_table_[arg_index_table_[index]] = const_cast<void*>(data);
193   }
194 
195   // ------------------------------
196   // Result methods for managing output buffers. Buffers are in row-major order.
197   // Must only be called after a successful Run call. Unlike the arg methods,
198   // there is no set_resultN_data method. The result buffers are managed
199   // internally, and may change after each call to Run.
200 
201   // Returns the underlying array of result buffers, where results()[I] is the
202   // buffer for the positional result at index I.
results()203   void** results() { return static_cast<void**>(buffer_table_[result_index_]); }
results()204   const void* const* results() const {
205     return static_cast<const void* const*>(buffer_table_[result_index_]);
206   }
207 
208   // Profile counters for this XLA computation.
209   //
210   // When Hlo profiling is enabled (`hlo_profiling_enabled()` return true in
211   // this case) these counters are non-null and are automatically populated by
212   // `Run`.  The counters can then be pretty-printed using
213   // `hlo_profile_printer()`.
214   //
215   // When Hlo profiling is disabled, this accessor returns null.
profile_counters()216   const int64_t* profile_counters() const { return profile_counters_; }
217 
218   // Returns the buffer for the positional result at the given `index`.
result_data(size_t index)219   void* result_data(size_t index) { return results()[index]; }
result_data(size_t index)220   const void* result_data(size_t index) const { return results()[index]; }
221 
222   // ------------------------------
223   // Methods for extracting optional metadata.
224 
225   // Returns true iff data is available for the Lookup{Arg,Variable,Result}Index
226   // methods. E.g. the data might not be compiled into the binary for AOT.
HasNameIndices()227   bool HasNameIndices() const {
228     return arg_names_ != nullptr && variable_names_ != nullptr &&
229            result_names_ != nullptr;
230   }
231 
232   // Returns the 0-based index for the argument with the given `name`.
233   // Returns -1 if the name wasn't found, or data isn't available.
234   //
235   // The index remains constant for every instance of XlaCompiledCpuFunction
236   // generated from the same static data, and might not be cheap to determine.
237   // Recommended usage is to capture this in a variable for re-use.
238   int LookupArgIndex(const string& name) const;
239 
240   // Returns the 0-based index for the variable with the given `name`.
241   // Returns -1 if the name wasn't found, or data isn't available.
242   //
243   // The index remains constant for every instance of XlaCompiledCpuFunction
244   // generated from the same static data, and might not be cheap to determine.
245   // Recommended usage is to capture this in a variable for re-use.
246   int LookupVariableIndex(const string& name) const;
247 
248   // Returns the 0-based index for the result with the given `name`.
249   // Returns -1 if the name wasn't found, or data isn't available.
250   //
251   // The index remains constant for every instance of XlaCompiledCpuFunction
252   // generated from the same static data, and might not be cheap to determine.
253   // Recommended usage is to capture this in a variable for re-use.
254   int LookupResultIndex(const string& name) const;
255 
256   // Returns the shape of the args and results. May return nullptr if the
257   // program shape isn't available.
ProgramShape()258   const xla::ProgramShapeProto* ProgramShape() const { return program_shape_; }
259 
hlo_profiling_enabled()260   bool hlo_profiling_enabled() const {
261     return hlo_profile_printer_data_ != nullptr;
262   }
hlo_profile_printer_data()263   const xla::HloProfilePrinterData& hlo_profile_printer_data() const {
264     assert(hlo_profiling_enabled());
265     return *hlo_profile_printer_data_;
266   }
267 
268  protected:
269   // ---------------------------------------------------------------------------
270   // Accessors for reading from and writing to instances of `StaticData`.
271   //
272   // Classes generated by tfcompile can call these because the generated classes
273   // inherit from `XlaCompiledCpuFunction`.  `XlaJitCompiledCpuFunction` can
274   // call these because it is explicitly added as a friend.
275 
set_static_data_raw_function(StaticData * static_data,RawFunction raw_function)276   static void set_static_data_raw_function(StaticData* static_data,
277                                            RawFunction raw_function) {
278     static_data->raw_function_ = raw_function;
279   }
280 
set_static_data_buffer_infos(StaticData * static_data,const xla::cpu_function_runtime::BufferInfo * buffer_infos)281   static void set_static_data_buffer_infos(
282       StaticData* static_data,
283       const xla::cpu_function_runtime::BufferInfo* buffer_infos) {
284     static_data->buffer_infos_ = buffer_infos;
285   }
286 
set_static_data_num_buffers(StaticData * static_data,size_t num_buffers)287   static void set_static_data_num_buffers(StaticData* static_data,
288                                           size_t num_buffers) {
289     static_data->num_buffers_ = num_buffers;
290   }
291 
set_static_data_arg_index_table(StaticData * static_data,const int32 * arg_index_table)292   static void set_static_data_arg_index_table(StaticData* static_data,
293                                               const int32* arg_index_table) {
294     static_data->arg_index_table_ = arg_index_table;
295   }
296 
set_static_data_num_args(StaticData * static_data,int64_t num_args)297   static void set_static_data_num_args(StaticData* static_data,
298                                        int64_t num_args) {
299     static_data->num_args_ = num_args;
300   }
301 
set_static_data_num_variables(StaticData * static_data,int64_t num_variables)302   static void set_static_data_num_variables(StaticData* static_data,
303                                             int64_t num_variables) {
304     static_data->num_variables_ = num_variables;
305   }
306 
set_static_data_result_index(StaticData * static_data,size_t result_index)307   static void set_static_data_result_index(StaticData* static_data,
308                                            size_t result_index) {
309     static_data->result_index_ = result_index;
310   }
311 
set_static_data_arg_names(StaticData * static_data,const char ** arg_names)312   static void set_static_data_arg_names(StaticData* static_data,
313                                         const char** arg_names) {
314     static_data->arg_names_ = arg_names;
315   }
316 
set_static_data_variable_names(StaticData * static_data,const char ** variable_names)317   static void set_static_data_variable_names(StaticData* static_data,
318                                              const char** variable_names) {
319     static_data->variable_names_ = variable_names;
320   }
321 
set_static_data_result_names(StaticData * static_data,const char ** result_names)322   static void set_static_data_result_names(StaticData* static_data,
323                                            const char** result_names) {
324     static_data->result_names_ = result_names;
325   }
326 
set_static_data_program_shape(StaticData * static_data,const xla::ProgramShapeProto * program_shape)327   static void set_static_data_program_shape(
328       StaticData* static_data, const xla::ProgramShapeProto* program_shape) {
329     static_data->program_shape_ = program_shape;
330   }
331 
set_static_data_hlo_profile_printer_data(StaticData * static_data,const xla::HloProfilePrinterData * hlo_profile_printer_data)332   static void set_static_data_hlo_profile_printer_data(
333       StaticData* static_data,
334       const xla::HloProfilePrinterData* hlo_profile_printer_data) {
335     static_data->hlo_profile_printer_data_ = hlo_profile_printer_data;
336   }
337 
338   static const xla::HloProfilePrinterData*
get_static_data_hlo_profile_printer_data(StaticData * static_data)339   get_static_data_hlo_profile_printer_data(StaticData* static_data) {
340     return static_data->hlo_profile_printer_data_;
341   }
342 
set_static_data_profile_counters_size(StaticData * static_data,int64_t profile_counters_size)343   static void set_static_data_profile_counters_size(
344       StaticData* static_data, int64_t profile_counters_size) {
345     static_data->profile_counters_size_ = profile_counters_size;
346   }
347 
348  private:
349   const RawFunction raw_function_;
350   const size_t result_index_;
351 
352   // Array containing pointers to argument and temp buffers (slots corresponding
353   // to constant and on-stack buffers are null).
354   void** const buffer_table_;
355 
356   // Describes the buffers used by the XLA computation.
357   const xla::cpu_function_runtime::BufferInfo* const buffer_infos_;
358 
359   // Argument i needs to be placed in buffer_table_[arg_index_to_temp_index_[i]]
360   // for XLA generated code to be able to find it.
361   const int32* const arg_index_table_;
362 
363   // The number of incoming arguments.
364   const int32 num_args_;
365 
366   // The number of incoming variables.
367   const int32 num_variables_;
368 
369   // Backing memory for buffer_table_ and args_, the latter depending on
370   // AllocMode.
371   void* alloc_buffer_table_ = nullptr;
372 
373   // Backing memory for profiling counters.
374   int64_t* profile_counters_ = nullptr;
375 
376   // Options and context passed to the compiled function.
377   xla::ExecutableRunOptions run_options_;
378 
379   // Optional metadata.
380   const char** arg_names_ = nullptr;
381   const char** variable_names_ = nullptr;
382   const char** result_names_ = nullptr;
383   const xla::ProgramShapeProto* program_shape_ = nullptr;
384   const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;
385 
386   // Add `XlaJitCompiledCpuFunction` as a friend so that it can access the
387   // `set_static_data_*` static methods above.
388   friend class XlaJitCompiledCpuFunction;
389 };
390 
391 }  // namespace tensorflow
392 
393 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_
394