xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/cpu/cpu_executable.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_EXECUTABLE_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_EXECUTABLE_H_
18 
19 #include <cstddef>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
26 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
27 #include "tensorflow/compiler/xla/service/custom_call_status_internal.h"
28 #include "tensorflow/compiler/xla/service/executable.h"
29 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
30 #include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_module.h"
33 #include "tensorflow/compiler/xla/service/shaped_buffer.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
37 #include "tensorflow/stream_executor/device_memory_allocator.h"
38 
39 namespace xla {
40 namespace cpu {
41 
42 // CPU-targeting implementation of the XLA Executable interface.
43 //
44 // Wraps a JIT-ed object that can be executed "on device". We JIT for the host
45 // architecture, so JIT-ed code and host code share the same ABI.
46 class CpuExecutable : public Executable {
47  public:
48   CpuExecutable(std::unique_ptr<SimpleOrcJIT> jit,
49                 std::unique_ptr<const BufferAssignment> assignment,
50                 std::unique_ptr<HloModule> hlo_module,
51                 const std::string& entry_function_name,
52                 std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
53                 std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
54   ~CpuExecutable() override;
55 
56   StatusOr<ExecutionOutput> ExecuteAsyncOnStream(
57       const ServiceExecutableRunOptions* run_options,
58       std::vector<ExecutionInput> arguments,
59       HloExecutionProfile* hlo_execution_profile) override;
60 
61   // Calls the generated function performing the computation with the given
62   // arguments using the supplied buffers.
63   Status ExecuteComputeFunction(
64       const ExecutableRunOptions* run_options,
65       absl::Span<MaybeOwningDeviceMemory const> buffers,
66       HloExecutionProfile* hlo_execution_profile);
67 
68   // This should be called after set_ir_module_string.
ir_module_string()69   const std::string& ir_module_string() const { return ir_module_string_; }
70 
set_ir_module_string(const std::string & ir_module_string)71   void set_ir_module_string(const std::string& ir_module_string) {
72     ir_module_string_ = ir_module_string;
73   }
74 
75   static int64_t ShapeSizeBytes(const Shape& shape);
76 
77   // Type of the computation function we expect in the JIT.
78   using ComputeFunctionType =
79       void (*)(void* /*result*/, const ExecutableRunOptions* /*run_options*/,
80                const void** /*args*/, void** /*buffer_table*/,
81                XlaCustomCallStatus* /*status*/, int64_t* /*profile_counters*/);
82 
compute_function()83   const ComputeFunctionType& compute_function() const {
84     return compute_function_;
85   }
86 
buffer_assignment()87   const BufferAssignment& buffer_assignment() const { return *assignment_; }
88 
89   int64_t SizeOfGeneratedCodeInBytes() const override;
90 
91  private:
92   // Creates an array suitable for passing as the "buffer_table" argument to the
93   // JIT compiled function pointer.
94   //
95   // Returns (unowning_buffers, owning_buffers) where:
96   //
97   //  - unowning_buffers.data() can be passed as the buffer_table argument as-is
98   //    and includes pointers to the scratch storage required by the
99   //    computation, the live-out buffer into which the result will be written
100   //    and entry computation parameters.
101   //
102   //  - owning_buffers contains owning pointers to the buffers that were
103   //    allocated by this routine.  This routine allocates buffers for temporary
104   //    storage and the live-out buffer into which the computation writes it
105   //    result.
106   //
107   //  - buffers_to_free: buffers whose ownership was donated by the caller that
108   //    are to be freed by the caller.
109   StatusOr<std::vector<MaybeOwningDeviceMemory>> CreateBufferTable(
110       se::DeviceMemoryAllocator* memory_allocator, int device_ordinal,
111       absl::Span<ExecutionInput const> arguments);
112 
113   // Creates an Execution output holding ScopedShapedBuffer for holding the
114   // result of the computation, moving buffers out of allocated_buffers and into
115   // the result as appropriate.  The addresses are set according to buffer
116   // assignment.
117   StatusOr<ExecutionOutput> CreateResultShapedBuffer(
118       const ServiceExecutableRunOptions* run_options,
119       absl::Span<MaybeOwningDeviceMemory> buffers,
120       absl::Span<ExecutionInput> arguments);
121 
122   // Returns the instruction value set of the root instruction of the entry
123   // computation. Uses dataflow analysis from buffer assignment.
124   const InstructionValueSet& GetRootValueSet() const;
125 
126   // The JIT containing compiled modules.
127   const std::unique_ptr<SimpleOrcJIT> jit_;
128 
129   // Buffer assignment for the buffers we need to allocate.
130   const std::unique_ptr<const BufferAssignment> assignment_;
131 
132   std::shared_ptr<const BufferAssignmentProto> buffer_assignment_;
133 
134   // The LLVM IR, in string format, of the unoptimized module generated for this
135   // CpuExecutable. We save a string instead of an llvm::Module* because leaving
136   // llvm::Module* in a singleton can cause the heap checker to emit false
137   // positives.
138   std::string ir_module_string_;
139 
140   // Unique identifier.
141   std::string module_name_;
142 
143   ComputeFunctionType compute_function_;
144 
145   // Entry function name for the computation.
146   const std::string entry_function_name_;
147 
148   CpuExecutable(const CpuExecutable&) = delete;
149   CpuExecutable& operator=(const CpuExecutable&) = delete;
150 };
151 
152 }  // namespace cpu
153 }  // namespace xla
154 
155 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_EXECUTABLE_H_
156