1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_EXECUTABLE_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_EXECUTABLE_H_ 18 19 #include <cstddef> 20 #include <memory> 21 #include <string> 22 #include <vector> 23 24 #include "absl/types/span.h" 25 #include "tensorflow/compiler/xla/service/buffer_assignment.h" 26 #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" 27 #include "tensorflow/compiler/xla/service/custom_call_status_internal.h" 28 #include "tensorflow/compiler/xla/service/executable.h" 29 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" 30 #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" 31 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 32 #include "tensorflow/compiler/xla/service/hlo_module.h" 33 #include "tensorflow/compiler/xla/service/shaped_buffer.h" 34 #include "tensorflow/compiler/xla/statusor.h" 35 #include "tensorflow/compiler/xla/types.h" 36 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 37 #include "tensorflow/stream_executor/device_memory_allocator.h" 38 39 namespace xla { 40 namespace cpu { 41 42 // CPU-targeting implementation of the XLA Executable interface. 43 // 44 // Wraps a JIT-ed object that can be executed "on device". We JIT for the host 45 // architecture, so JIT-ed code and host code share the same ABI. 46 class CpuExecutable : public Executable { 47 public: 48 CpuExecutable(std::unique_ptr<SimpleOrcJIT> jit, 49 std::unique_ptr<const BufferAssignment> assignment, 50 std::unique_ptr<HloModule> hlo_module, 51 const std::string& entry_function_name, 52 std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data, 53 std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map); 54 ~CpuExecutable() override; 55 56 StatusOr<ExecutionOutput> ExecuteAsyncOnStream( 57 const ServiceExecutableRunOptions* run_options, 58 std::vector<ExecutionInput> arguments, 59 HloExecutionProfile* hlo_execution_profile) override; 60 61 // Calls the generated function performing the computation with the given 62 // arguments using the supplied buffers. 63 Status ExecuteComputeFunction( 64 const ExecutableRunOptions* run_options, 65 absl::Span<MaybeOwningDeviceMemory const> buffers, 66 HloExecutionProfile* hlo_execution_profile); 67 68 // This should be called after set_ir_module_string. ir_module_string()69 const std::string& ir_module_string() const { return ir_module_string_; } 70 set_ir_module_string(const std::string & ir_module_string)71 void set_ir_module_string(const std::string& ir_module_string) { 72 ir_module_string_ = ir_module_string; 73 } 74 75 static int64_t ShapeSizeBytes(const Shape& shape); 76 77 // Type of the computation function we expect in the JIT. 78 using ComputeFunctionType = 79 void (*)(void* /*result*/, const ExecutableRunOptions* /*run_options*/, 80 const void** /*args*/, void** /*buffer_table*/, 81 XlaCustomCallStatus* /*status*/, int64_t* /*profile_counters*/); 82 compute_function()83 const ComputeFunctionType& compute_function() const { 84 return compute_function_; 85 } 86 buffer_assignment()87 const BufferAssignment& buffer_assignment() const { return *assignment_; } 88 89 int64_t SizeOfGeneratedCodeInBytes() const override; 90 91 private: 92 // Creates an array suitable for passing as the "buffer_table" argument to the 93 // JIT compiled function pointer. 94 // 95 // Returns (unowning_buffers, owning_buffers) where: 96 // 97 // - unowning_buffers.data() can be passed as the buffer_table argument as-is 98 // and includes pointers to the scratch storage required by the 99 // computation, the live-out buffer into which the result will be written 100 // and entry computation parameters. 101 // 102 // - owning_buffers contains owning pointers to the buffers that were 103 // allocated by this routine. This routine allocates buffers for temporary 104 // storage and the live-out buffer into which the computation writes it 105 // result. 106 // 107 // - buffers_to_free: buffers whose ownership was donated by the caller that 108 // are to be freed by the caller. 109 StatusOr<std::vector<MaybeOwningDeviceMemory>> CreateBufferTable( 110 se::DeviceMemoryAllocator* memory_allocator, int device_ordinal, 111 absl::Span<ExecutionInput const> arguments); 112 113 // Creates an Execution output holding ScopedShapedBuffer for holding the 114 // result of the computation, moving buffers out of allocated_buffers and into 115 // the result as appropriate. The addresses are set according to buffer 116 // assignment. 117 StatusOr<ExecutionOutput> CreateResultShapedBuffer( 118 const ServiceExecutableRunOptions* run_options, 119 absl::Span<MaybeOwningDeviceMemory> buffers, 120 absl::Span<ExecutionInput> arguments); 121 122 // Returns the instruction value set of the root instruction of the entry 123 // computation. Uses dataflow analysis from buffer assignment. 124 const InstructionValueSet& GetRootValueSet() const; 125 126 // The JIT containing compiled modules. 127 const std::unique_ptr<SimpleOrcJIT> jit_; 128 129 // Buffer assignment for the buffers we need to allocate. 130 const std::unique_ptr<const BufferAssignment> assignment_; 131 132 std::shared_ptr<const BufferAssignmentProto> buffer_assignment_; 133 134 // The LLVM IR, in string format, of the unoptimized module generated for this 135 // CpuExecutable. We save a string instead of an llvm::Module* because leaving 136 // llvm::Module* in a singleton can cause the heap checker to emit false 137 // positives. 138 std::string ir_module_string_; 139 140 // Unique identifier. 141 std::string module_name_; 142 143 ComputeFunctionType compute_function_; 144 145 // Entry function name for the computation. 146 const std::string entry_function_name_; 147 148 CpuExecutable(const CpuExecutable&) = delete; 149 CpuExecutable& operator=(const CpuExecutable&) = delete; 150 }; 151 152 } // namespace cpu 153 } // namespace xla 154 155 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_EXECUTABLE_H_ 156