xref: /aosp_15_r20/external/tensorflow/tensorflow/c/experimental/stream_executor/stream_executor.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // This file extends/implements core stream executor base classes in terms of
16 // the C API defined in stream_executor.h. A class "CSomething" represents a
17 // "Something" that can be manipulated via calls in the C interface and a C
18 // struct called "SP_Something".
19 //
20 // This file also contains stream_executor::Platform registration for pluggable
21 // device.
22 #include "tensorflow/c/experimental/stream_executor/stream_executor.h"
23 
24 #include <string>
25 
26 #include "tensorflow/c/c_api_macros.h"
27 #include "tensorflow/c/c_api_macros_internal.h"
28 #include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
29 #include "tensorflow/c/tf_status_helper.h"
30 #include "tensorflow/core/common_runtime/device/device_utils.h"
31 #include "tensorflow/core/platform/env.h"
32 #include "tensorflow/core/platform/errors.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/status.h"
35 #include "tensorflow/core/platform/strcat.h"
36 #include "tensorflow/core/platform/stringpiece.h"
37 #include "tensorflow/stream_executor/executor_cache.h"
38 #include "tensorflow/stream_executor/multi_platform_manager.h"
39 #include "tensorflow/stream_executor/platform.h"
40 #include "tensorflow/stream_executor/stream.h"
41 #include "tensorflow/stream_executor/stream_executor_internal.h"
42 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
43 #include "tensorflow/stream_executor/timer.h"
44 
45 using tensorflow::StatusFromTF_Status;
46 
47 namespace stream_executor {
48 using tensorflow::StringPiece;
49 
50 // TODO(penporn): Remove OwnedTFStatus.
51 using OwnedTFStatus = tensorflow::TF_StatusPtr;
52 
53 namespace {
ValidateSPPlatform(const SP_Platform & platform)54 port::Status ValidateSPPlatform(const SP_Platform& platform) {
55   TF_VALIDATE_STRUCT_SIZE(SP_Platform, platform, SP_PLATFORM_STRUCT_SIZE);
56   TF_VALIDATE_NOT_NULL(SP_Platform, platform, name);
57   TF_VALIDATE_NOT_NULL(SP_Platform, platform, type);
58   TF_RETURN_IF_ERROR(
59       tensorflow::device_utils::ValidateDeviceType(platform.name));
60   TF_RETURN_IF_ERROR(
61       tensorflow::device_utils::ValidateDeviceType(platform.type));
62   // `visible_device_count` could be 0 at initialization time.
63   return ::tensorflow::OkStatus();
64 }
65 
ValidateSPPlatformFns(const SP_PlatformFns & platform_fns)66 port::Status ValidateSPPlatformFns(const SP_PlatformFns& platform_fns) {
67   TF_VALIDATE_STRUCT_SIZE(SP_PlatformFns, platform_fns,
68                           SP_PLATFORM_FNS_STRUCT_SIZE);
69   TF_VALIDATE_NOT_NULL(SP_PlatformFns, platform_fns, create_device);
70   TF_VALIDATE_NOT_NULL(SP_PlatformFns, platform_fns, destroy_device);
71   TF_VALIDATE_NOT_NULL(SP_PlatformFns, platform_fns, create_stream_executor);
72   TF_VALIDATE_NOT_NULL(SP_PlatformFns, platform_fns, destroy_stream_executor);
73   TF_VALIDATE_NOT_NULL(SP_PlatformFns, platform_fns, create_timer_fns);
74   TF_VALIDATE_NOT_NULL(SP_PlatformFns, platform_fns, destroy_timer_fns);
75   TF_VALIDATE_NOT_NULL(SP_PlatformFns, platform_fns, create_device_fns);
76   TF_VALIDATE_NOT_NULL(SP_PlatformFns, platform_fns, destroy_device_fns);
77   return ::tensorflow::OkStatus();
78 }
79 
ValidateSPTimerFns(const SP_TimerFns & timer_fns)80 port::Status ValidateSPTimerFns(const SP_TimerFns& timer_fns) {
81   TF_VALIDATE_STRUCT_SIZE(SP_TimerFns, timer_fns, SP_TIMER_FNS_STRUCT_SIZE);
82   TF_VALIDATE_NOT_NULL(SP_TimerFns, timer_fns, nanoseconds);
83   return ::tensorflow::OkStatus();
84 }
85 
ValidateSPAllocatorStats(const SP_AllocatorStats & stats)86 port::Status ValidateSPAllocatorStats(const SP_AllocatorStats& stats) {
87   TF_VALIDATE_STRUCT_SIZE(SP_AllocatorStats, stats,
88                           SP_ALLOCATORSTATS_STRUCT_SIZE);
89   // All other fields could theoretically be zero/null.
90   return ::tensorflow::OkStatus();
91 }
92 
ValidateSPDeviceMemoryBase(const SP_DeviceMemoryBase & mem)93 port::Status ValidateSPDeviceMemoryBase(const SP_DeviceMemoryBase& mem) {
94   TF_VALIDATE_STRUCT_SIZE(SP_DeviceMemoryBase, mem,
95                           SP_DEVICE_MEMORY_BASE_STRUCT_SIZE);
96   // All other fields could theoretically be zero/null.
97   return ::tensorflow::OkStatus();
98 }
99 
ValidateSPDevice(const SP_Device & device)100 port::Status ValidateSPDevice(const SP_Device& device) {
101   TF_VALIDATE_STRUCT_SIZE(SP_Device, device, SP_DEVICE_STRUCT_SIZE);
102   // All other fields could theoretically be zero/null.
103   return ::tensorflow::OkStatus();
104 }
105 
ValidateSPDeviceFns(const SP_DeviceFns & device_fns)106 port::Status ValidateSPDeviceFns(const SP_DeviceFns& device_fns) {
107   TF_VALIDATE_STRUCT_SIZE(SP_DeviceFns, device_fns, SP_DEVICE_FNS_STRUCT_SIZE);
108   // All other fields could theoretically be zero/null.
109   return ::tensorflow::OkStatus();
110 }
111 
ValidateSPStreamExecutor(const SP_StreamExecutor & se,const SP_Platform & platform)112 port::Status ValidateSPStreamExecutor(const SP_StreamExecutor& se,
113                                       const SP_Platform& platform) {
114   TF_VALIDATE_STRUCT_SIZE(SP_StreamExecutor, se,
115                           SP_STREAM_EXECUTOR_STRUCT_SIZE);
116   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, allocate);
117   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, deallocate);
118   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, get_allocator_stats);
119   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, host_memory_allocate);
120   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, host_memory_deallocate);
121   if (platform.supports_unified_memory) {
122     TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, unified_memory_allocate);
123     TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, unified_memory_deallocate);
124   }
125   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, device_memory_usage);
126   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, create_stream);
127   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, destroy_stream);
128   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, create_stream_dependency);
129   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, get_stream_status);
130   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, create_event);
131   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, destroy_event);
132   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, get_event_status);
133   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, record_event);
134   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, wait_for_event);
135   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, create_timer);
136   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, destroy_timer);
137   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, start_timer);
138   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, stop_timer);
139   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, memcpy_dtoh);
140   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, memcpy_htod);
141   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, sync_memcpy_dtoh);
142   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, sync_memcpy_htod);
143   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, block_host_for_event);
144   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, synchronize_all_activity);
145   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, host_callback);
146   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, mem_zero);
147   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, memset);
148   TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, memset32);
149   return ::tensorflow::OkStatus();
150 }
151 
ValidateSEPlatformRegistrationParams(const SE_PlatformRegistrationParams & params)152 port::Status ValidateSEPlatformRegistrationParams(
153     const SE_PlatformRegistrationParams& params) {
154   TF_VALIDATE_STRUCT_SIZE(SE_PlatformRegistrationParams, params,
155                           SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE);
156   TF_VALIDATE_NOT_NULL(SE_PlatformRegistrationParams, params, destroy_platform);
157   TF_VALIDATE_NOT_NULL(SE_PlatformRegistrationParams, params,
158                        destroy_platform_fns);
159   return ::tensorflow::OkStatus();
160 }
161 #undef TF_VALIDATE_NOT_NULL
162 
163 // Converts SE_EventStatus to Event::Status.
SEEventStatusToEventStatus(SE_EventStatus s)164 Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
165   switch (s) {
166     case SE_EVENT_ERROR:
167       return Event::Status::kError;
168     case SE_EVENT_PENDING:
169       return Event::Status::kPending;
170     case SE_EVENT_COMPLETE:
171       return Event::Status::kComplete;
172     default:
173       return Event::Status::kUnknown;
174   }
175 }
176 
177 // Converts DeviceMemoryBase to a C struct.
DeviceMemoryBaseToC(const DeviceMemoryBase * mem)178 SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) {
179   SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};
180   // `opaque` field inside SP_DeviceMemoryBase is not const.
181   // Therefore, we need to cast away the constness before setting it.
182   device_memory_base.opaque = const_cast<void*>(mem->opaque());
183   device_memory_base.size = mem->size();
184   device_memory_base.payload = mem->payload();
185   return device_memory_base;
186 }
187 
DeviceMemoryBaseFromC(const SP_DeviceMemoryBase & mem)188 DeviceMemoryBase DeviceMemoryBaseFromC(const SP_DeviceMemoryBase& mem) {
189   DeviceMemoryBase base(mem.opaque, mem.size);
190   base.SetPayload(mem.payload);
191   return base;
192 }
193 
194 // Wrapper that allows passing std::function across C API.
195 struct HostCallbackContext {
196   std::function<port::Status()> callback;
197 };
198 
199 // This wrapper allows calling `HostCallbackContext::callback` across C API.
200 // This function matches `SE_StatusCallbackFn` signature and will be passed as
201 // `callback_fn` to `host_callback` in `SP_StreamExecutor`.
HostCallbackTrampoline(void * ctx,TF_Status * status)202 void HostCallbackTrampoline(void* ctx, TF_Status* status) {
203   HostCallbackContext* host_ctx = static_cast<HostCallbackContext*>(ctx);
204   port::Status s = host_ctx->callback();
205   Set_TF_Status_from_Status(status, s);
206   delete host_ctx;
207 }
208 
209 class CStreamExecutor : public internal::StreamExecutorInterface {
210  public:
CStreamExecutor(SP_Device device,SP_DeviceFns * device_fns,SP_StreamExecutor * stream_executor,SP_Platform * platform,SP_PlatformFns * platform_fns,SP_TimerFns * timer_fns,const std::string & name,int visible_device_count)211   explicit CStreamExecutor(SP_Device device, SP_DeviceFns* device_fns,
212                            SP_StreamExecutor* stream_executor,
213                            SP_Platform* platform, SP_PlatformFns* platform_fns,
214                            SP_TimerFns* timer_fns, const std::string& name,
215                            int visible_device_count)
216       : device_(std::move(device)),
217         device_fns_(device_fns),
218         stream_executor_(stream_executor),
219         platform_(platform),
220         platform_fns_(platform_fns),
221         timer_fns_(timer_fns),
222         platform_name_(name),
223         visible_device_count_(visible_device_count) {}
224 
~CStreamExecutor()225   ~CStreamExecutor() override {
226     platform_fns_->destroy_device(platform_, &device_);
227   }
228 
Init(int device_ordinal,DeviceOptions device_options)229   port::Status Init(int device_ordinal, DeviceOptions device_options) override {
230     return ::tensorflow::OkStatus();
231   }
232 
Allocate(uint64 size,int64_t memory_space)233   DeviceMemoryBase Allocate(uint64 size, int64_t memory_space) override {
234     SP_DeviceMemoryBase mem = {SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};
235     stream_executor_->allocate(&device_, size, memory_space, &mem);
236     port::Status status = ValidateSPDeviceMemoryBase(mem);
237     if (!status.ok()) {
238       LOG(ERROR) << status.error_message();
239     }
240     return DeviceMemoryBaseFromC(mem);
241   }
Allocate(uint64 size)242   DeviceMemoryBase Allocate(uint64 size) {
243     return Allocate(size, /*memory_space=*/0);
244   }
GetSubBuffer(DeviceMemoryBase * parent,uint64 offset,uint64 size)245   void* GetSubBuffer(DeviceMemoryBase* parent, uint64 offset,
246                      uint64 size) override {
247     LOG(FATAL) << "GetSubBuffer is not supported by pluggable device.";
248   }
249 
Deallocate(DeviceMemoryBase * mem)250   void Deallocate(DeviceMemoryBase* mem) override {
251     SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(mem);
252     stream_executor_->deallocate(&device_, &device_memory_base);
253   }
254 
HostMemoryAllocate(uint64 size)255   void* HostMemoryAllocate(uint64 size) override {
256     return stream_executor_->host_memory_allocate(&device_, size);
257   }
258 
HostMemoryDeallocate(void * mem)259   void HostMemoryDeallocate(void* mem) override {
260     stream_executor_->host_memory_deallocate(&device_, mem);
261   }
262 
HostMemoryRegister(void * mem,uint64 size)263   bool HostMemoryRegister(void* mem, uint64 size) override { return false; }
HostMemoryUnregister(void * mem)264   bool HostMemoryUnregister(void* mem) override { return false; }
265 
UnifiedMemoryAllocate(uint64 size)266   void* UnifiedMemoryAllocate(uint64 size) override {
267     CHECK(stream_executor_->unified_memory_allocate);
268     return stream_executor_->unified_memory_allocate(&device_, size);
269   }
270 
UnifiedMemoryDeallocate(void * mem)271   void UnifiedMemoryDeallocate(void* mem) override {
272     CHECK(stream_executor_->unified_memory_deallocate);
273     stream_executor_->unified_memory_deallocate(&device_, mem);
274   }
275 
GetAllocatorStats()276   absl::optional<AllocatorStats> GetAllocatorStats() override {
277     SP_AllocatorStats c_stats{SP_ALLOCATORSTATS_STRUCT_SIZE};
278     TF_Bool has_stats =
279         stream_executor_->get_allocator_stats(&device_, &c_stats);
280     if (!has_stats) {
281       return absl::nullopt;
282     }
283     port::Status status = ValidateSPAllocatorStats(c_stats);
284     if (!status.ok()) {
285       LOG(ERROR) << status.error_message();
286       return absl::nullopt;
287     }
288     ::stream_executor::AllocatorStats stats;
289     stats.num_allocs = c_stats.num_allocs;
290     stats.bytes_in_use = c_stats.bytes_in_use;
291     stats.peak_bytes_in_use = c_stats.peak_bytes_in_use;
292     stats.largest_alloc_size = c_stats.largest_alloc_size;
293     if (c_stats.has_bytes_limit) {
294       stats.bytes_limit = c_stats.bytes_limit;
295     }
296     stats.bytes_reserved = c_stats.bytes_reserved;
297     stats.peak_bytes_reserved = c_stats.peak_bytes_reserved;
298     if (c_stats.has_bytes_reservable_limit) {
299       stats.bytes_reservable_limit = c_stats.bytes_reservable_limit;
300     }
301     stats.largest_free_block_bytes = c_stats.largest_free_block_bytes;
302     return stats;
303   }
SynchronizeAllActivity()304   bool SynchronizeAllActivity() override {
305     OwnedTFStatus c_status(TF_NewStatus());
306     stream_executor_->synchronize_all_activity(&device_, c_status.get());
307     if (TF_GetCode(c_status.get()) != TF_OK) {
308       LOG(ERROR) << TF_Message(c_status.get());
309       return false;
310     }
311     return true;
312   }
SynchronousMemZero(DeviceMemoryBase * location,uint64 size)313   port::Status SynchronousMemZero(DeviceMemoryBase* location,
314                                   uint64 size) override {
315     // TODO(annarev): figure out if we should support memzero/memset
316     // functionality by allocating on host and then copying to device.
317     return port::UnimplementedError(
318         "SynchronousMemZero is not supported by pluggable device.");
319   }
SynchronousMemSet(DeviceMemoryBase * location,int value,uint64 size)320   port::Status SynchronousMemSet(DeviceMemoryBase* location, int value,
321                                  uint64 size) override {
322     return port::UnimplementedError(
323         "SynchronousMemSet is not supported by pluggable device.");
324   }
SynchronousMemcpy(DeviceMemoryBase * gpu_dst,const void * host_src,uint64 size)325   port::Status SynchronousMemcpy(DeviceMemoryBase* gpu_dst,
326                                  const void* host_src, uint64 size) override {
327     OwnedTFStatus c_status(TF_NewStatus());
328     SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(gpu_dst);
329     stream_executor_->sync_memcpy_htod(&device_, &device_memory_base, host_src,
330                                        size, c_status.get());
331     return StatusFromTF_Status(c_status.get());
332   }
SynchronousMemcpy(void * host_dst,const DeviceMemoryBase & gpu_src,uint64 size)333   port::Status SynchronousMemcpy(void* host_dst,
334                                  const DeviceMemoryBase& gpu_src,
335                                  uint64 size) override {
336     OwnedTFStatus c_status(TF_NewStatus());
337     SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(&gpu_src);
338     stream_executor_->sync_memcpy_dtoh(&device_, host_dst, &device_memory_base,
339                                        size, c_status.get());
340     return StatusFromTF_Status(c_status.get());
341   }
SynchronousMemcpyDeviceToDevice(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)342   port::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase* gpu_dst,
343                                                const DeviceMemoryBase& gpu_src,
344                                                uint64 size) override {
345     OwnedTFStatus c_status(TF_NewStatus());
346     SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst);
347     SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src);
348     stream_executor_->sync_memcpy_dtod(&device_, &device_mem_dst,
349                                        &device_mem_src, size, c_status.get());
350     return StatusFromTF_Status(c_status.get());
351   }
MemZero(Stream * stream,DeviceMemoryBase * location,uint64 size)352   port::Status MemZero(Stream* stream, DeviceMemoryBase* location,
353                        uint64 size) override {
354     OwnedTFStatus c_status(TF_NewStatus());
355     SP_Stream stream_handle =
356         static_cast<CStream*>(stream->implementation())->Handle();
357     SP_DeviceMemoryBase device_mem = DeviceMemoryBaseToC(location);
358     stream_executor_->mem_zero(&device_, stream_handle, &device_mem, size,
359                                c_status.get());
360     return StatusFromTF_Status(c_status.get());
361   }
Memset(Stream * stream,DeviceMemoryBase * location,uint8 pattern,uint64 size)362   port::Status Memset(Stream* stream, DeviceMemoryBase* location, uint8 pattern,
363                       uint64 size) override {
364     OwnedTFStatus c_status(TF_NewStatus());
365     SP_Stream stream_handle =
366         static_cast<CStream*>(stream->implementation())->Handle();
367     SP_DeviceMemoryBase device_mem = DeviceMemoryBaseToC(location);
368     stream_executor_->memset(&device_, stream_handle, &device_mem, pattern,
369                              size, c_status.get());
370     return StatusFromTF_Status(c_status.get());
371   }
Memset32(Stream * stream,DeviceMemoryBase * location,uint32 pattern,uint64 size)372   port::Status Memset32(Stream* stream, DeviceMemoryBase* location,
373                         uint32 pattern, uint64 size) override {
374     OwnedTFStatus c_status(TF_NewStatus());
375     SP_Stream stream_handle =
376         static_cast<CStream*>(stream->implementation())->Handle();
377     SP_DeviceMemoryBase device_mem = DeviceMemoryBaseToC(location);
378     stream_executor_->memset32(&device_, stream_handle, &device_mem, pattern,
379                                size, c_status.get());
380     return StatusFromTF_Status(c_status.get());
381   }
Memcpy(Stream * stream,void * host_dst,const DeviceMemoryBase & gpu_src,uint64 size)382   bool Memcpy(Stream* stream, void* host_dst, const DeviceMemoryBase& gpu_src,
383               uint64 size) override {
384     OwnedTFStatus c_status(TF_NewStatus());
385     SP_Stream stream_handle =
386         static_cast<CStream*>(stream->implementation())->Handle();
387     SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src);
388     stream_executor_->memcpy_dtoh(&device_, stream_handle, host_dst,
389                                   &device_mem_src, size, c_status.get());
390     if (TF_GetCode(c_status.get()) != TF_OK) {
391       LOG(ERROR) << TF_Message(c_status.get());
392       return false;
393     }
394     return true;
395   }
Memcpy(Stream * stream,DeviceMemoryBase * gpu_dst,const void * host_src,uint64 size)396   bool Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, const void* host_src,
397               uint64 size) override {
398     OwnedTFStatus c_status(TF_NewStatus());
399     SP_Stream stream_handle =
400         static_cast<CStream*>(stream->implementation())->Handle();
401     SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst);
402     stream_executor_->memcpy_htod(&device_, stream_handle, &device_mem_dst,
403                                   host_src, size, c_status.get());
404     if (TF_GetCode(c_status.get()) != TF_OK) {
405       LOG(ERROR) << TF_Message(c_status.get());
406       return false;
407     }
408     return true;
409   }
MemcpyDeviceToDevice(Stream * stream,DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)410   bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst,
411                             const DeviceMemoryBase& gpu_src,
412                             uint64 size) override {
413     OwnedTFStatus c_status(TF_NewStatus());
414     SP_Stream stream_handle =
415         static_cast<CStream*>(stream->implementation())->Handle();
416     SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst);
417     SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src);
418     stream_executor_->memcpy_dtod(&device_, stream_handle, &device_mem_dst,
419                                   &device_mem_src, size, c_status.get());
420     if (TF_GetCode(c_status.get()) != TF_OK) {
421       LOG(ERROR) << TF_Message(c_status.get());
422       return false;
423     }
424     return true;
425   }
HostCallback(Stream * stream,std::function<port::Status ()> callback)426   bool HostCallback(Stream* stream,
427                     std::function<port::Status()> callback) override {
428     SP_Stream stream_handle =
429         static_cast<CStream*>(stream->implementation())->Handle();
430     HostCallbackContext* ctx = new HostCallbackContext{callback};
431     return stream_executor_->host_callback(&device_, stream_handle,
432                                            &HostCallbackTrampoline, ctx);
433   }
AllocateEvent(Event * event)434   port::Status AllocateEvent(Event* event) override {
435     DCHECK(event != nullptr);
436     return static_cast<CEvent*>(event->implementation())->Create();
437   }
DeallocateEvent(Event * event)438   port::Status DeallocateEvent(Event* event) override {
439     static_cast<CEvent*>(event->implementation())->Destroy();
440     return ::tensorflow::OkStatus();
441   }
RecordEvent(Stream * stream,Event * event)442   port::Status RecordEvent(Stream* stream, Event* event) override {
443     SP_Stream stream_handle =
444         static_cast<CStream*>(stream->implementation())->Handle();
445     return static_cast<CEvent*>(event->implementation())->Record(stream_handle);
446   }
WaitForEvent(Stream * stream,Event * event)447   port::Status WaitForEvent(Stream* stream, Event* event) override {
448     SP_Stream stream_handle =
449         static_cast<CStream*>(stream->implementation())->Handle();
450     SP_Event event_handle =
451         static_cast<CEvent*>(event->implementation())->Handle();
452     OwnedTFStatus c_status(TF_NewStatus());
453     stream_executor_->wait_for_event(&device_, stream_handle, event_handle,
454                                      c_status.get());
455     port::Status s = StatusFromTF_Status(c_status.get());
456     return s;
457   }
PollForEventStatus(Event * event)458   Event::Status PollForEventStatus(Event* event) override {
459     SP_Event event_handle =
460         static_cast<CEvent*>(event->implementation())->Handle();
461     SE_EventStatus event_status =
462         stream_executor_->get_event_status(&device_, event_handle);
463     return SEEventStatusToEventStatus(event_status);
464   }
AllocateStream(Stream * stream)465   bool AllocateStream(Stream* stream) override {
466     DCHECK(stream != nullptr);
467     port::Status status =
468         static_cast<CStream*>(stream->implementation())->Create();
469     // TODO(annarev): update AllocateStream to return status instead
470     // (similar to AllocateEvent).
471     return status.ok();
472   }
DeallocateStream(Stream * stream)473   void DeallocateStream(Stream* stream) override {
474     static_cast<CStream*>(stream->implementation())->Destroy();
475   }
CreateStreamDependency(Stream * dependent,Stream * other)476   bool CreateStreamDependency(Stream* dependent, Stream* other) override {
477     OwnedTFStatus c_status(TF_NewStatus());
478     SP_Stream dependent_handle =
479         static_cast<CStream*>(dependent->implementation())->Handle();
480     SP_Stream other_handle =
481         static_cast<CStream*>(other->implementation())->Handle();
482     stream_executor_->create_stream_dependency(&device_, dependent_handle,
483                                                other_handle, c_status.get());
484     if (TF_GetCode(c_status.get()) != TF_OK) {
485       LOG(ERROR) << TF_Message(c_status.get());
486       return false;
487     }
488     return true;
489   }
AllocateTimer(Timer * timer)490   bool AllocateTimer(Timer* timer) override {
491     port::Status status =
492         static_cast<CTimer*>(timer->implementation())->Create();
493     // TODO(annarev): change return value of AllocateTimer
494     // to status (similar to AllocateEvent).
495     return status.ok();
496   }
DeallocateTimer(Timer * timer)497   void DeallocateTimer(Timer* timer) override {
498     static_cast<CTimer*>(timer->implementation())->Destroy();
499   }
StartTimer(Stream * stream,Timer * timer)500   bool StartTimer(Stream* stream, Timer* timer) override {
501     OwnedTFStatus c_status(TF_NewStatus());
502     SP_Stream stream_handle =
503         static_cast<CStream*>(stream->implementation())->Handle();
504     SP_Timer timer_handle =
505         static_cast<CTimer*>(timer->implementation())->Handle();
506     stream_executor_->start_timer(&device_, stream_handle, timer_handle,
507                                   c_status.get());
508     if (TF_GetCode(c_status.get()) != TF_OK) {
509       LOG(ERROR) << TF_Message(c_status.get());
510       return false;
511     }
512     return true;
513   }
StopTimer(Stream * stream,Timer * timer)514   bool StopTimer(Stream* stream, Timer* timer) override {
515     OwnedTFStatus c_status(TF_NewStatus());
516     SP_Stream stream_handle =
517         static_cast<CStream*>(stream->implementation())->Handle();
518     SP_Timer timer_handle =
519         static_cast<CTimer*>(timer->implementation())->Handle();
520     stream_executor_->stop_timer(&device_, stream_handle, timer_handle,
521                                  c_status.get());
522     if (TF_GetCode(c_status.get()) != TF_OK) {
523       LOG(ERROR) << TF_Message(c_status.get());
524       return false;
525     }
526     return true;
527   }
BlockHostForEvent(Stream * stream,Event * event)528   port::Status BlockHostForEvent(Stream* stream, Event* event) {
529     OwnedTFStatus c_status(TF_NewStatus());
530     SP_Event event_handle =
531         static_cast<CEvent*>(event->implementation())->Handle();
532     stream_executor_->block_host_for_event(&device_, event_handle,
533                                            c_status.get());
534     return StatusFromTF_Status(c_status.get());
535   }
536 
BlockHostUntilDone(Stream * stream)537   port::Status BlockHostUntilDone(Stream* stream) override {
538     OwnedTFStatus c_status(TF_NewStatus());
539     SP_Stream stream_handle =
540         static_cast<CStream*>(stream->implementation())->Handle();
541 
542     // If `block_host_until_done` is set, use it.
543     if (stream_executor_->block_host_until_done != nullptr) {
544       stream_executor_->block_host_until_done(&device_, stream_handle,
545                                               c_status.get());
546       return StatusFromTF_Status(c_status.get());
547     }
548     // Create and record an event and then wait for it.
549     SP_Event event_handle;
550     stream_executor_->create_event(&device_, &event_handle, c_status.get());
551     TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get()));
552     stream_executor_->record_event(&device_, stream_handle, event_handle,
553                                    c_status.get());
554     port::Status s = StatusFromTF_Status(c_status.get());
555     if (!s.ok()) {
556       stream_executor_->destroy_event(&device_, event_handle);
557       return s;
558     }
559     stream_executor_->block_host_for_event(&device_, event_handle,
560                                            c_status.get());
561     stream_executor_->destroy_event(&device_, event_handle);
562     return StatusFromTF_Status(c_status.get());
563   }
564 
GetStatus(Stream * stream)565   port::Status GetStatus(Stream* stream) override {
566     OwnedTFStatus c_status(TF_NewStatus());
567     SP_Stream stream_handle =
568         static_cast<CStream*>(stream->implementation())->Handle();
569     stream_executor_->get_stream_status(&device_, stream_handle,
570                                         c_status.get());
571     return StatusFromTF_Status(c_status.get());
572   }
PlatformDeviceCount()573   int PlatformDeviceCount() override { return visible_device_count_; }
EnablePeerAccessTo(StreamExecutorInterface * other)574   port::Status EnablePeerAccessTo(StreamExecutorInterface* other) override {
575     return port::UnimplementedError(
576         "EnablePeerAccessTo is not supported by pluggable device.");
577   }
CanEnablePeerAccessTo(StreamExecutorInterface * other)578   bool CanEnablePeerAccessTo(StreamExecutorInterface* other) override {
579     return false;
580   }
581 
DeviceMemoryUsage(int64_t * free,int64_t * total) const582   bool DeviceMemoryUsage(int64_t* free, int64_t* total) const override {
583     return stream_executor_->device_memory_usage(
584         &device_, reinterpret_cast<int64_t*>(free),
585         reinterpret_cast<int64_t*>(total));
586   }
587 
588   // Creates a new DeviceDescription object.
589   // Ownership is transferred to the caller.
CreateDeviceDescription() const590   port::StatusOr<std::unique_ptr<DeviceDescription>> CreateDeviceDescription()
591       const override {
592     OwnedTFStatus c_status(TF_NewStatus());
593 
594     internal::DeviceDescriptionBuilder builder;
595     if (device_.hardware_name != nullptr) {
596       builder.set_name(device_.hardware_name);
597     }
598     if (device_.device_vendor != nullptr) {
599       builder.set_device_vendor(device_.device_vendor);
600     }
601     if (device_.pci_bus_id != nullptr) {
602       builder.set_pci_bus_id(device_.pci_bus_id);
603     }
604 
605     if (device_fns_->get_numa_node != nullptr) {
606       int32_t numa_node = device_fns_->get_numa_node(&device_);
607       if (numa_node >= 0) {
608         builder.set_numa_node(numa_node);
609       }
610     }
611 
612     if (device_fns_->get_memory_bandwidth != nullptr) {
613       int64_t memory_bandwidth = device_fns_->get_memory_bandwidth(&device_);
614       if (memory_bandwidth >= 0) {
615         builder.set_memory_bandwidth(memory_bandwidth);
616       }
617     }
618     // TODO(annarev): Add gflops field in DeviceDescription and set it here.
619     // TODO(annarev): Perhaps add `supports_unified_memory` in
620     // DeviceDescription.
621     return builder.Build();
622   }
623 
624   // Each call creates a new instance of the platform-specific implementation of
625   // the corresponding interface type.
CreateEventImplementation()626   std::unique_ptr<internal::EventInterface> CreateEventImplementation()
627       override {
628     return std::unique_ptr<internal::EventInterface>(
629         new CEvent(&device_, stream_executor_));
630   }
CreateKernelImplementation()631   std::unique_ptr<internal::KernelInterface> CreateKernelImplementation()
632       override {
633     LOG(FATAL)
634         << "CreateKernelImplementation is not supported by pluggable device.";
635   }
GetStreamImplementation()636   std::unique_ptr<internal::StreamInterface> GetStreamImplementation()
637       override {
638     return std::unique_ptr<internal::StreamInterface>(
639         new CStream(&device_, stream_executor_));
640   }
GetTimerImplementation()641   std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override {
642     return std::unique_ptr<internal::TimerInterface>(
643         new CTimer(&device_, stream_executor_, timer_fns_));
644   }
645 
646  private:
647   SP_Device device_;
648   SP_DeviceFns* device_fns_;
649   SP_StreamExecutor* stream_executor_;
650   SP_Platform* platform_;
651   SP_PlatformFns* platform_fns_;
652   SP_TimerFns* timer_fns_;
653   std::string platform_name_;
654   int visible_device_count_;
655 };
656 }  // namespace
657 
CPlatform(SP_Platform platform,void (* destroy_platform)(SP_Platform *),SP_PlatformFns platform_fns,void (* destroy_platform_fns)(SP_PlatformFns *),SP_DeviceFns device_fns,SP_StreamExecutor stream_executor,SP_TimerFns timer_fns)658 CPlatform::CPlatform(SP_Platform platform,
659                      void (*destroy_platform)(SP_Platform*),
660                      SP_PlatformFns platform_fns,
661                      void (*destroy_platform_fns)(SP_PlatformFns*),
662                      SP_DeviceFns device_fns, SP_StreamExecutor stream_executor,
663                      SP_TimerFns timer_fns)
664     : platform_(std::move(platform)),
665       destroy_platform_(destroy_platform),
666       platform_fns_(std::move(platform_fns)),
667       destroy_platform_fns_(destroy_platform_fns),
668       device_fns_(std::move(device_fns)),
669       stream_executor_(std::move(stream_executor)),
670       timer_fns_(std::move(timer_fns)),
671       name_(platform.name) {}
672 
~CPlatform()673 CPlatform::~CPlatform() {
674   executor_cache_.DestroyAllExecutors();
675   platform_fns_.destroy_device_fns(&platform_, &device_fns_);
676   platform_fns_.destroy_stream_executor(&platform_, &stream_executor_);
677   platform_fns_.destroy_timer_fns(&platform_, &timer_fns_);
678   destroy_platform_(&platform_);
679   destroy_platform_fns_(&platform_fns_);
680 }
681 
682 port::StatusOr<std::unique_ptr<DeviceDescription>>
DescriptionForDevice(int ordinal) const683 CPlatform::DescriptionForDevice(int ordinal) const {
684   // TODO(annarev): see if we can get StreamExecutor instance
685   // and call GetDeviceDescription. executor_cache_.Get would need
686   // to be made const for it to work.
687   internal::DeviceDescriptionBuilder builder;
688   builder.set_name(name_);
689   return builder.Build();
690 }
ExecutorForDevice(int ordinal)691 port::StatusOr<StreamExecutor*> CPlatform::ExecutorForDevice(int ordinal) {
692   stream_executor::StreamExecutorConfig config;
693   config.ordinal = ordinal;
694   return GetExecutor(config);
695 }
ExecutorForDeviceWithPluginConfig(int ordinal,const PluginConfig & plugin_config)696 port::StatusOr<StreamExecutor*> CPlatform::ExecutorForDeviceWithPluginConfig(
697     int ordinal, const PluginConfig& plugin_config) {
698   StreamExecutorConfig config;
699   config.ordinal = ordinal;
700   config.plugin_config = plugin_config;
701   return GetExecutor(config);
702 }
GetExecutor(const StreamExecutorConfig & config)703 port::StatusOr<StreamExecutor*> CPlatform::GetExecutor(
704     const StreamExecutorConfig& config) {
705   return executor_cache_.GetOrCreate(
706       config, [&]() { return GetUncachedExecutor(config); });
707 }
GetUncachedExecutor(const StreamExecutorConfig & config)708 port::StatusOr<std::unique_ptr<StreamExecutor>> CPlatform::GetUncachedExecutor(
709     const StreamExecutorConfig& config) {
710   // Fill device creation params
711   SE_CreateDeviceParams device_params{SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE};
712   SP_Device device{SP_DEVICE_STRUCT_SIZE};
713   device_params.device = &device;
714   device_params.ext = nullptr;
715   device_params.ordinal = config.ordinal;
716   OwnedTFStatus c_status(TF_NewStatus());
717 
718   // Create Device
719   platform_fns_.create_device(&platform_, &device_params, c_status.get());
720   TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get()));
721   TF_RETURN_IF_ERROR(ValidateSPDevice(device));
722 
723   // Get Device Count
724   int visible_device_count = 0;
725   platform_fns_.get_device_count(&platform_, &visible_device_count,
726                                  c_status.get());
727   TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get()));
728 
729   auto executor = absl::make_unique<CStreamExecutor>(
730       std::move(device), &device_fns_, &stream_executor_, &platform_,
731       &platform_fns_, &timer_fns_, name_, visible_device_count);
732   auto result = absl::make_unique<StreamExecutor>(this, std::move(executor),
733                                                   config.ordinal);
734   return result;
735 }
736 
InitStreamExecutorPlugin(void * dso_handle,std::string * device_type,std::string * platform_name)737 port::Status InitStreamExecutorPlugin(void* dso_handle,
738                                       std::string* device_type,
739                                       std::string* platform_name) {
740   tensorflow::Env* env = tensorflow::Env::Default();
741 
742   // Step 1: Load symbol for `TF_InitPlugin`
743   void* dso_symbol;
744   TF_RETURN_IF_ERROR(
745       env->GetSymbolFromLibrary(dso_handle, "SE_InitPlugin", &dso_symbol));
746 
747   // Step 2: Call `TF_InitPlugin`
748   auto init_fn = reinterpret_cast<SEInitPluginFn>(dso_symbol);
749   return InitStreamExecutorPlugin(init_fn, device_type, platform_name);
750 }
751 
InitStreamExecutorPlugin(SEInitPluginFn init_fn,std::string * device_type,std::string * platform_name)752 port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn,
753                                       std::string* device_type,
754                                       std::string* platform_name) {
755   SE_PlatformRegistrationParams params{
756       SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE};
757   SP_Platform platform{SP_PLATFORM_STRUCT_SIZE};
758   SP_PlatformFns platform_fns{SP_PLATFORM_FNS_STRUCT_SIZE};
759   params.major_version = SE_MAJOR;
760   params.minor_version = SE_MINOR;
761   params.patch_version = SE_PATCH;
762   params.platform = &platform;
763   params.platform_fns = &platform_fns;
764 
765   OwnedTFStatus c_status(TF_NewStatus());
766   init_fn(&params, c_status.get());
767   TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
768   TF_RETURN_IF_ERROR(ValidateSEPlatformRegistrationParams(params));
769   TF_RETURN_IF_ERROR(ValidateSPPlatform(platform));
770   TF_RETURN_IF_ERROR(ValidateSPPlatformFns(platform_fns));
771 
772   // Fill SP_DeviceFns creation params
773   SE_CreateDeviceFnsParams device_fns_params{
774       SE_CREATE_DEVICE_FNS_PARAMS_STRUCT_SIZE};
775   SP_DeviceFns device_fns{SP_DEVICE_FNS_STRUCT_SIZE};
776   device_fns_params.device_fns = &device_fns;
777 
778   // Create StreamExecutor
779   platform_fns.create_device_fns(&platform, &device_fns_params, c_status.get());
780   TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
781   TF_RETURN_IF_ERROR(ValidateSPDeviceFns(device_fns));
782 
783   // Fill stream executor creation params
784   SE_CreateStreamExecutorParams se_params{
785       SE_CREATE_STREAM_EXECUTOR_PARAMS_STRUCT_SIZE};
786   SP_StreamExecutor se{SP_STREAMEXECUTOR_STRUCT_SIZE};
787   se_params.stream_executor = &se;
788 
789   // Create StreamExecutor
790   platform_fns.create_stream_executor(&platform, &se_params, c_status.get());
791   TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
792   TF_RETURN_IF_ERROR(ValidateSPStreamExecutor(se, platform));
793 
794   SP_TimerFns timer_fns{SP_TIMER_FNS_STRUCT_SIZE};
795   platform_fns.create_timer_fns(&platform, &timer_fns, c_status.get());
796   TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
797   TF_RETURN_IF_ERROR(ValidateSPTimerFns(timer_fns));
798 
799   // Register new platform
800   *device_type = std::string(platform.type);
801   *platform_name = std::string(platform.name);
802   std::unique_ptr<stream_executor::CPlatform> cplatform(
803       new stream_executor::CPlatform(
804           std::move(platform), params.destroy_platform, std::move(platform_fns),
805           params.destroy_platform_fns, std::move(device_fns), std::move(se),
806           std::move(timer_fns)));
807   SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
808       std::move(cplatform)));
809   // TODO(annarev): Return `use_bfc_allocator` value in some way so that it is
810   // available in `PluggableDeviceProcessState` once the latter is checked in.
811   return ::tensorflow::OkStatus();
812 }
813 }  // namespace stream_executor
814