1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // This file extends/implements core stream executor base classes in terms of
16 // the C API defined in stream_executor.h. A class "CSomething" represents a
17 // "Something" that can be manipulated via calls in the C interface and a C
18 // struct called "SP_Something".
19 //
20 // This file also contains stream_executor::Platform registration for pluggable
21 // device.
22 #include "tensorflow/c/experimental/stream_executor/stream_executor.h"
23
24 #include <string>
25
26 #include "tensorflow/c/c_api_macros.h"
27 #include "tensorflow/c/c_api_macros_internal.h"
28 #include "tensorflow/c/experimental/stream_executor/stream_executor_internal.h"
29 #include "tensorflow/c/tf_status_helper.h"
30 #include "tensorflow/core/common_runtime/device/device_utils.h"
31 #include "tensorflow/core/platform/env.h"
32 #include "tensorflow/core/platform/errors.h"
33 #include "tensorflow/core/platform/logging.h"
34 #include "tensorflow/core/platform/status.h"
35 #include "tensorflow/core/platform/strcat.h"
36 #include "tensorflow/core/platform/stringpiece.h"
37 #include "tensorflow/stream_executor/executor_cache.h"
38 #include "tensorflow/stream_executor/multi_platform_manager.h"
39 #include "tensorflow/stream_executor/platform.h"
40 #include "tensorflow/stream_executor/stream.h"
41 #include "tensorflow/stream_executor/stream_executor_internal.h"
42 #include "tensorflow/stream_executor/stream_executor_pimpl.h"
43 #include "tensorflow/stream_executor/timer.h"
44
45 using tensorflow::StatusFromTF_Status;
46
47 namespace stream_executor {
48 using tensorflow::StringPiece;
49
50 // TODO(penporn): Remove OwnedTFStatus.
51 using OwnedTFStatus = tensorflow::TF_StatusPtr;
52
53 namespace {
ValidateSPPlatform(const SP_Platform & platform)54 port::Status ValidateSPPlatform(const SP_Platform& platform) {
55 TF_VALIDATE_STRUCT_SIZE(SP_Platform, platform, SP_PLATFORM_STRUCT_SIZE);
56 TF_VALIDATE_NOT_NULL(SP_Platform, platform, name);
57 TF_VALIDATE_NOT_NULL(SP_Platform, platform, type);
58 TF_RETURN_IF_ERROR(
59 tensorflow::device_utils::ValidateDeviceType(platform.name));
60 TF_RETURN_IF_ERROR(
61 tensorflow::device_utils::ValidateDeviceType(platform.type));
62 // `visible_device_count` could be 0 at initialization time.
63 return ::tensorflow::OkStatus();
64 }
65
ValidateSPPlatformFns(const SP_PlatformFns & platform_fns)66 port::Status ValidateSPPlatformFns(const SP_PlatformFns& platform_fns) {
67 TF_VALIDATE_STRUCT_SIZE(SP_PlatformFns, platform_fns,
68 SP_PLATFORM_FNS_STRUCT_SIZE);
69 TF_VALIDATE_NOT_NULL(SP_PlatformFns, platform_fns, create_device);
70 TF_VALIDATE_NOT_NULL(SP_PlatformFns, platform_fns, destroy_device);
71 TF_VALIDATE_NOT_NULL(SP_PlatformFns, platform_fns, create_stream_executor);
72 TF_VALIDATE_NOT_NULL(SP_PlatformFns, platform_fns, destroy_stream_executor);
73 TF_VALIDATE_NOT_NULL(SP_PlatformFns, platform_fns, create_timer_fns);
74 TF_VALIDATE_NOT_NULL(SP_PlatformFns, platform_fns, destroy_timer_fns);
75 TF_VALIDATE_NOT_NULL(SP_PlatformFns, platform_fns, create_device_fns);
76 TF_VALIDATE_NOT_NULL(SP_PlatformFns, platform_fns, destroy_device_fns);
77 return ::tensorflow::OkStatus();
78 }
79
ValidateSPTimerFns(const SP_TimerFns & timer_fns)80 port::Status ValidateSPTimerFns(const SP_TimerFns& timer_fns) {
81 TF_VALIDATE_STRUCT_SIZE(SP_TimerFns, timer_fns, SP_TIMER_FNS_STRUCT_SIZE);
82 TF_VALIDATE_NOT_NULL(SP_TimerFns, timer_fns, nanoseconds);
83 return ::tensorflow::OkStatus();
84 }
85
ValidateSPAllocatorStats(const SP_AllocatorStats & stats)86 port::Status ValidateSPAllocatorStats(const SP_AllocatorStats& stats) {
87 TF_VALIDATE_STRUCT_SIZE(SP_AllocatorStats, stats,
88 SP_ALLOCATORSTATS_STRUCT_SIZE);
89 // All other fields could theoretically be zero/null.
90 return ::tensorflow::OkStatus();
91 }
92
ValidateSPDeviceMemoryBase(const SP_DeviceMemoryBase & mem)93 port::Status ValidateSPDeviceMemoryBase(const SP_DeviceMemoryBase& mem) {
94 TF_VALIDATE_STRUCT_SIZE(SP_DeviceMemoryBase, mem,
95 SP_DEVICE_MEMORY_BASE_STRUCT_SIZE);
96 // All other fields could theoretically be zero/null.
97 return ::tensorflow::OkStatus();
98 }
99
ValidateSPDevice(const SP_Device & device)100 port::Status ValidateSPDevice(const SP_Device& device) {
101 TF_VALIDATE_STRUCT_SIZE(SP_Device, device, SP_DEVICE_STRUCT_SIZE);
102 // All other fields could theoretically be zero/null.
103 return ::tensorflow::OkStatus();
104 }
105
ValidateSPDeviceFns(const SP_DeviceFns & device_fns)106 port::Status ValidateSPDeviceFns(const SP_DeviceFns& device_fns) {
107 TF_VALIDATE_STRUCT_SIZE(SP_DeviceFns, device_fns, SP_DEVICE_FNS_STRUCT_SIZE);
108 // All other fields could theoretically be zero/null.
109 return ::tensorflow::OkStatus();
110 }
111
ValidateSPStreamExecutor(const SP_StreamExecutor & se,const SP_Platform & platform)112 port::Status ValidateSPStreamExecutor(const SP_StreamExecutor& se,
113 const SP_Platform& platform) {
114 TF_VALIDATE_STRUCT_SIZE(SP_StreamExecutor, se,
115 SP_STREAM_EXECUTOR_STRUCT_SIZE);
116 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, allocate);
117 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, deallocate);
118 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, get_allocator_stats);
119 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, host_memory_allocate);
120 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, host_memory_deallocate);
121 if (platform.supports_unified_memory) {
122 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, unified_memory_allocate);
123 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, unified_memory_deallocate);
124 }
125 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, device_memory_usage);
126 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, create_stream);
127 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, destroy_stream);
128 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, create_stream_dependency);
129 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, get_stream_status);
130 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, create_event);
131 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, destroy_event);
132 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, get_event_status);
133 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, record_event);
134 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, wait_for_event);
135 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, create_timer);
136 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, destroy_timer);
137 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, start_timer);
138 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, stop_timer);
139 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, memcpy_dtoh);
140 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, memcpy_htod);
141 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, sync_memcpy_dtoh);
142 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, sync_memcpy_htod);
143 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, block_host_for_event);
144 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, synchronize_all_activity);
145 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, host_callback);
146 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, mem_zero);
147 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, memset);
148 TF_VALIDATE_NOT_NULL(SP_StreamExecutor, se, memset32);
149 return ::tensorflow::OkStatus();
150 }
151
ValidateSEPlatformRegistrationParams(const SE_PlatformRegistrationParams & params)152 port::Status ValidateSEPlatformRegistrationParams(
153 const SE_PlatformRegistrationParams& params) {
154 TF_VALIDATE_STRUCT_SIZE(SE_PlatformRegistrationParams, params,
155 SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE);
156 TF_VALIDATE_NOT_NULL(SE_PlatformRegistrationParams, params, destroy_platform);
157 TF_VALIDATE_NOT_NULL(SE_PlatformRegistrationParams, params,
158 destroy_platform_fns);
159 return ::tensorflow::OkStatus();
160 }
161 #undef TF_VALIDATE_NOT_NULL
162
163 // Converts SE_EventStatus to Event::Status.
SEEventStatusToEventStatus(SE_EventStatus s)164 Event::Status SEEventStatusToEventStatus(SE_EventStatus s) {
165 switch (s) {
166 case SE_EVENT_ERROR:
167 return Event::Status::kError;
168 case SE_EVENT_PENDING:
169 return Event::Status::kPending;
170 case SE_EVENT_COMPLETE:
171 return Event::Status::kComplete;
172 default:
173 return Event::Status::kUnknown;
174 }
175 }
176
177 // Converts DeviceMemoryBase to a C struct.
DeviceMemoryBaseToC(const DeviceMemoryBase * mem)178 SP_DeviceMemoryBase DeviceMemoryBaseToC(const DeviceMemoryBase* mem) {
179 SP_DeviceMemoryBase device_memory_base{SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};
180 // `opaque` field inside SP_DeviceMemoryBase is not const.
181 // Therefore, we need to cast away the constness before setting it.
182 device_memory_base.opaque = const_cast<void*>(mem->opaque());
183 device_memory_base.size = mem->size();
184 device_memory_base.payload = mem->payload();
185 return device_memory_base;
186 }
187
DeviceMemoryBaseFromC(const SP_DeviceMemoryBase & mem)188 DeviceMemoryBase DeviceMemoryBaseFromC(const SP_DeviceMemoryBase& mem) {
189 DeviceMemoryBase base(mem.opaque, mem.size);
190 base.SetPayload(mem.payload);
191 return base;
192 }
193
194 // Wrapper that allows passing std::function across C API.
195 struct HostCallbackContext {
196 std::function<port::Status()> callback;
197 };
198
199 // This wrapper allows calling `HostCallbackContext::callback` across C API.
200 // This function matches `SE_StatusCallbackFn` signature and will be passed as
201 // `callback_fn` to `host_callback` in `SP_StreamExecutor`.
HostCallbackTrampoline(void * ctx,TF_Status * status)202 void HostCallbackTrampoline(void* ctx, TF_Status* status) {
203 HostCallbackContext* host_ctx = static_cast<HostCallbackContext*>(ctx);
204 port::Status s = host_ctx->callback();
205 Set_TF_Status_from_Status(status, s);
206 delete host_ctx;
207 }
208
209 class CStreamExecutor : public internal::StreamExecutorInterface {
210 public:
CStreamExecutor(SP_Device device,SP_DeviceFns * device_fns,SP_StreamExecutor * stream_executor,SP_Platform * platform,SP_PlatformFns * platform_fns,SP_TimerFns * timer_fns,const std::string & name,int visible_device_count)211 explicit CStreamExecutor(SP_Device device, SP_DeviceFns* device_fns,
212 SP_StreamExecutor* stream_executor,
213 SP_Platform* platform, SP_PlatformFns* platform_fns,
214 SP_TimerFns* timer_fns, const std::string& name,
215 int visible_device_count)
216 : device_(std::move(device)),
217 device_fns_(device_fns),
218 stream_executor_(stream_executor),
219 platform_(platform),
220 platform_fns_(platform_fns),
221 timer_fns_(timer_fns),
222 platform_name_(name),
223 visible_device_count_(visible_device_count) {}
224
~CStreamExecutor()225 ~CStreamExecutor() override {
226 platform_fns_->destroy_device(platform_, &device_);
227 }
228
Init(int device_ordinal,DeviceOptions device_options)229 port::Status Init(int device_ordinal, DeviceOptions device_options) override {
230 return ::tensorflow::OkStatus();
231 }
232
Allocate(uint64 size,int64_t memory_space)233 DeviceMemoryBase Allocate(uint64 size, int64_t memory_space) override {
234 SP_DeviceMemoryBase mem = {SP_DEVICE_MEMORY_BASE_STRUCT_SIZE};
235 stream_executor_->allocate(&device_, size, memory_space, &mem);
236 port::Status status = ValidateSPDeviceMemoryBase(mem);
237 if (!status.ok()) {
238 LOG(ERROR) << status.error_message();
239 }
240 return DeviceMemoryBaseFromC(mem);
241 }
Allocate(uint64 size)242 DeviceMemoryBase Allocate(uint64 size) {
243 return Allocate(size, /*memory_space=*/0);
244 }
GetSubBuffer(DeviceMemoryBase * parent,uint64 offset,uint64 size)245 void* GetSubBuffer(DeviceMemoryBase* parent, uint64 offset,
246 uint64 size) override {
247 LOG(FATAL) << "GetSubBuffer is not supported by pluggable device.";
248 }
249
Deallocate(DeviceMemoryBase * mem)250 void Deallocate(DeviceMemoryBase* mem) override {
251 SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(mem);
252 stream_executor_->deallocate(&device_, &device_memory_base);
253 }
254
HostMemoryAllocate(uint64 size)255 void* HostMemoryAllocate(uint64 size) override {
256 return stream_executor_->host_memory_allocate(&device_, size);
257 }
258
HostMemoryDeallocate(void * mem)259 void HostMemoryDeallocate(void* mem) override {
260 stream_executor_->host_memory_deallocate(&device_, mem);
261 }
262
HostMemoryRegister(void * mem,uint64 size)263 bool HostMemoryRegister(void* mem, uint64 size) override { return false; }
HostMemoryUnregister(void * mem)264 bool HostMemoryUnregister(void* mem) override { return false; }
265
UnifiedMemoryAllocate(uint64 size)266 void* UnifiedMemoryAllocate(uint64 size) override {
267 CHECK(stream_executor_->unified_memory_allocate);
268 return stream_executor_->unified_memory_allocate(&device_, size);
269 }
270
UnifiedMemoryDeallocate(void * mem)271 void UnifiedMemoryDeallocate(void* mem) override {
272 CHECK(stream_executor_->unified_memory_deallocate);
273 stream_executor_->unified_memory_deallocate(&device_, mem);
274 }
275
GetAllocatorStats()276 absl::optional<AllocatorStats> GetAllocatorStats() override {
277 SP_AllocatorStats c_stats{SP_ALLOCATORSTATS_STRUCT_SIZE};
278 TF_Bool has_stats =
279 stream_executor_->get_allocator_stats(&device_, &c_stats);
280 if (!has_stats) {
281 return absl::nullopt;
282 }
283 port::Status status = ValidateSPAllocatorStats(c_stats);
284 if (!status.ok()) {
285 LOG(ERROR) << status.error_message();
286 return absl::nullopt;
287 }
288 ::stream_executor::AllocatorStats stats;
289 stats.num_allocs = c_stats.num_allocs;
290 stats.bytes_in_use = c_stats.bytes_in_use;
291 stats.peak_bytes_in_use = c_stats.peak_bytes_in_use;
292 stats.largest_alloc_size = c_stats.largest_alloc_size;
293 if (c_stats.has_bytes_limit) {
294 stats.bytes_limit = c_stats.bytes_limit;
295 }
296 stats.bytes_reserved = c_stats.bytes_reserved;
297 stats.peak_bytes_reserved = c_stats.peak_bytes_reserved;
298 if (c_stats.has_bytes_reservable_limit) {
299 stats.bytes_reservable_limit = c_stats.bytes_reservable_limit;
300 }
301 stats.largest_free_block_bytes = c_stats.largest_free_block_bytes;
302 return stats;
303 }
SynchronizeAllActivity()304 bool SynchronizeAllActivity() override {
305 OwnedTFStatus c_status(TF_NewStatus());
306 stream_executor_->synchronize_all_activity(&device_, c_status.get());
307 if (TF_GetCode(c_status.get()) != TF_OK) {
308 LOG(ERROR) << TF_Message(c_status.get());
309 return false;
310 }
311 return true;
312 }
SynchronousMemZero(DeviceMemoryBase * location,uint64 size)313 port::Status SynchronousMemZero(DeviceMemoryBase* location,
314 uint64 size) override {
315 // TODO(annarev): figure out if we should support memzero/memset
316 // functionality by allocating on host and then copying to device.
317 return port::UnimplementedError(
318 "SynchronousMemZero is not supported by pluggable device.");
319 }
SynchronousMemSet(DeviceMemoryBase * location,int value,uint64 size)320 port::Status SynchronousMemSet(DeviceMemoryBase* location, int value,
321 uint64 size) override {
322 return port::UnimplementedError(
323 "SynchronousMemSet is not supported by pluggable device.");
324 }
SynchronousMemcpy(DeviceMemoryBase * gpu_dst,const void * host_src,uint64 size)325 port::Status SynchronousMemcpy(DeviceMemoryBase* gpu_dst,
326 const void* host_src, uint64 size) override {
327 OwnedTFStatus c_status(TF_NewStatus());
328 SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(gpu_dst);
329 stream_executor_->sync_memcpy_htod(&device_, &device_memory_base, host_src,
330 size, c_status.get());
331 return StatusFromTF_Status(c_status.get());
332 }
SynchronousMemcpy(void * host_dst,const DeviceMemoryBase & gpu_src,uint64 size)333 port::Status SynchronousMemcpy(void* host_dst,
334 const DeviceMemoryBase& gpu_src,
335 uint64 size) override {
336 OwnedTFStatus c_status(TF_NewStatus());
337 SP_DeviceMemoryBase device_memory_base = DeviceMemoryBaseToC(&gpu_src);
338 stream_executor_->sync_memcpy_dtoh(&device_, host_dst, &device_memory_base,
339 size, c_status.get());
340 return StatusFromTF_Status(c_status.get());
341 }
SynchronousMemcpyDeviceToDevice(DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)342 port::Status SynchronousMemcpyDeviceToDevice(DeviceMemoryBase* gpu_dst,
343 const DeviceMemoryBase& gpu_src,
344 uint64 size) override {
345 OwnedTFStatus c_status(TF_NewStatus());
346 SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst);
347 SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src);
348 stream_executor_->sync_memcpy_dtod(&device_, &device_mem_dst,
349 &device_mem_src, size, c_status.get());
350 return StatusFromTF_Status(c_status.get());
351 }
MemZero(Stream * stream,DeviceMemoryBase * location,uint64 size)352 port::Status MemZero(Stream* stream, DeviceMemoryBase* location,
353 uint64 size) override {
354 OwnedTFStatus c_status(TF_NewStatus());
355 SP_Stream stream_handle =
356 static_cast<CStream*>(stream->implementation())->Handle();
357 SP_DeviceMemoryBase device_mem = DeviceMemoryBaseToC(location);
358 stream_executor_->mem_zero(&device_, stream_handle, &device_mem, size,
359 c_status.get());
360 return StatusFromTF_Status(c_status.get());
361 }
Memset(Stream * stream,DeviceMemoryBase * location,uint8 pattern,uint64 size)362 port::Status Memset(Stream* stream, DeviceMemoryBase* location, uint8 pattern,
363 uint64 size) override {
364 OwnedTFStatus c_status(TF_NewStatus());
365 SP_Stream stream_handle =
366 static_cast<CStream*>(stream->implementation())->Handle();
367 SP_DeviceMemoryBase device_mem = DeviceMemoryBaseToC(location);
368 stream_executor_->memset(&device_, stream_handle, &device_mem, pattern,
369 size, c_status.get());
370 return StatusFromTF_Status(c_status.get());
371 }
Memset32(Stream * stream,DeviceMemoryBase * location,uint32 pattern,uint64 size)372 port::Status Memset32(Stream* stream, DeviceMemoryBase* location,
373 uint32 pattern, uint64 size) override {
374 OwnedTFStatus c_status(TF_NewStatus());
375 SP_Stream stream_handle =
376 static_cast<CStream*>(stream->implementation())->Handle();
377 SP_DeviceMemoryBase device_mem = DeviceMemoryBaseToC(location);
378 stream_executor_->memset32(&device_, stream_handle, &device_mem, pattern,
379 size, c_status.get());
380 return StatusFromTF_Status(c_status.get());
381 }
Memcpy(Stream * stream,void * host_dst,const DeviceMemoryBase & gpu_src,uint64 size)382 bool Memcpy(Stream* stream, void* host_dst, const DeviceMemoryBase& gpu_src,
383 uint64 size) override {
384 OwnedTFStatus c_status(TF_NewStatus());
385 SP_Stream stream_handle =
386 static_cast<CStream*>(stream->implementation())->Handle();
387 SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src);
388 stream_executor_->memcpy_dtoh(&device_, stream_handle, host_dst,
389 &device_mem_src, size, c_status.get());
390 if (TF_GetCode(c_status.get()) != TF_OK) {
391 LOG(ERROR) << TF_Message(c_status.get());
392 return false;
393 }
394 return true;
395 }
Memcpy(Stream * stream,DeviceMemoryBase * gpu_dst,const void * host_src,uint64 size)396 bool Memcpy(Stream* stream, DeviceMemoryBase* gpu_dst, const void* host_src,
397 uint64 size) override {
398 OwnedTFStatus c_status(TF_NewStatus());
399 SP_Stream stream_handle =
400 static_cast<CStream*>(stream->implementation())->Handle();
401 SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst);
402 stream_executor_->memcpy_htod(&device_, stream_handle, &device_mem_dst,
403 host_src, size, c_status.get());
404 if (TF_GetCode(c_status.get()) != TF_OK) {
405 LOG(ERROR) << TF_Message(c_status.get());
406 return false;
407 }
408 return true;
409 }
MemcpyDeviceToDevice(Stream * stream,DeviceMemoryBase * gpu_dst,const DeviceMemoryBase & gpu_src,uint64 size)410 bool MemcpyDeviceToDevice(Stream* stream, DeviceMemoryBase* gpu_dst,
411 const DeviceMemoryBase& gpu_src,
412 uint64 size) override {
413 OwnedTFStatus c_status(TF_NewStatus());
414 SP_Stream stream_handle =
415 static_cast<CStream*>(stream->implementation())->Handle();
416 SP_DeviceMemoryBase device_mem_dst = DeviceMemoryBaseToC(gpu_dst);
417 SP_DeviceMemoryBase device_mem_src = DeviceMemoryBaseToC(&gpu_src);
418 stream_executor_->memcpy_dtod(&device_, stream_handle, &device_mem_dst,
419 &device_mem_src, size, c_status.get());
420 if (TF_GetCode(c_status.get()) != TF_OK) {
421 LOG(ERROR) << TF_Message(c_status.get());
422 return false;
423 }
424 return true;
425 }
HostCallback(Stream * stream,std::function<port::Status ()> callback)426 bool HostCallback(Stream* stream,
427 std::function<port::Status()> callback) override {
428 SP_Stream stream_handle =
429 static_cast<CStream*>(stream->implementation())->Handle();
430 HostCallbackContext* ctx = new HostCallbackContext{callback};
431 return stream_executor_->host_callback(&device_, stream_handle,
432 &HostCallbackTrampoline, ctx);
433 }
AllocateEvent(Event * event)434 port::Status AllocateEvent(Event* event) override {
435 DCHECK(event != nullptr);
436 return static_cast<CEvent*>(event->implementation())->Create();
437 }
DeallocateEvent(Event * event)438 port::Status DeallocateEvent(Event* event) override {
439 static_cast<CEvent*>(event->implementation())->Destroy();
440 return ::tensorflow::OkStatus();
441 }
RecordEvent(Stream * stream,Event * event)442 port::Status RecordEvent(Stream* stream, Event* event) override {
443 SP_Stream stream_handle =
444 static_cast<CStream*>(stream->implementation())->Handle();
445 return static_cast<CEvent*>(event->implementation())->Record(stream_handle);
446 }
WaitForEvent(Stream * stream,Event * event)447 port::Status WaitForEvent(Stream* stream, Event* event) override {
448 SP_Stream stream_handle =
449 static_cast<CStream*>(stream->implementation())->Handle();
450 SP_Event event_handle =
451 static_cast<CEvent*>(event->implementation())->Handle();
452 OwnedTFStatus c_status(TF_NewStatus());
453 stream_executor_->wait_for_event(&device_, stream_handle, event_handle,
454 c_status.get());
455 port::Status s = StatusFromTF_Status(c_status.get());
456 return s;
457 }
PollForEventStatus(Event * event)458 Event::Status PollForEventStatus(Event* event) override {
459 SP_Event event_handle =
460 static_cast<CEvent*>(event->implementation())->Handle();
461 SE_EventStatus event_status =
462 stream_executor_->get_event_status(&device_, event_handle);
463 return SEEventStatusToEventStatus(event_status);
464 }
AllocateStream(Stream * stream)465 bool AllocateStream(Stream* stream) override {
466 DCHECK(stream != nullptr);
467 port::Status status =
468 static_cast<CStream*>(stream->implementation())->Create();
469 // TODO(annarev): update AllocateStream to return status instead
470 // (similar to AllocateEvent).
471 return status.ok();
472 }
DeallocateStream(Stream * stream)473 void DeallocateStream(Stream* stream) override {
474 static_cast<CStream*>(stream->implementation())->Destroy();
475 }
CreateStreamDependency(Stream * dependent,Stream * other)476 bool CreateStreamDependency(Stream* dependent, Stream* other) override {
477 OwnedTFStatus c_status(TF_NewStatus());
478 SP_Stream dependent_handle =
479 static_cast<CStream*>(dependent->implementation())->Handle();
480 SP_Stream other_handle =
481 static_cast<CStream*>(other->implementation())->Handle();
482 stream_executor_->create_stream_dependency(&device_, dependent_handle,
483 other_handle, c_status.get());
484 if (TF_GetCode(c_status.get()) != TF_OK) {
485 LOG(ERROR) << TF_Message(c_status.get());
486 return false;
487 }
488 return true;
489 }
AllocateTimer(Timer * timer)490 bool AllocateTimer(Timer* timer) override {
491 port::Status status =
492 static_cast<CTimer*>(timer->implementation())->Create();
493 // TODO(annarev): change return value of AllocateTimer
494 // to status (similar to AllocateEvent).
495 return status.ok();
496 }
DeallocateTimer(Timer * timer)497 void DeallocateTimer(Timer* timer) override {
498 static_cast<CTimer*>(timer->implementation())->Destroy();
499 }
StartTimer(Stream * stream,Timer * timer)500 bool StartTimer(Stream* stream, Timer* timer) override {
501 OwnedTFStatus c_status(TF_NewStatus());
502 SP_Stream stream_handle =
503 static_cast<CStream*>(stream->implementation())->Handle();
504 SP_Timer timer_handle =
505 static_cast<CTimer*>(timer->implementation())->Handle();
506 stream_executor_->start_timer(&device_, stream_handle, timer_handle,
507 c_status.get());
508 if (TF_GetCode(c_status.get()) != TF_OK) {
509 LOG(ERROR) << TF_Message(c_status.get());
510 return false;
511 }
512 return true;
513 }
StopTimer(Stream * stream,Timer * timer)514 bool StopTimer(Stream* stream, Timer* timer) override {
515 OwnedTFStatus c_status(TF_NewStatus());
516 SP_Stream stream_handle =
517 static_cast<CStream*>(stream->implementation())->Handle();
518 SP_Timer timer_handle =
519 static_cast<CTimer*>(timer->implementation())->Handle();
520 stream_executor_->stop_timer(&device_, stream_handle, timer_handle,
521 c_status.get());
522 if (TF_GetCode(c_status.get()) != TF_OK) {
523 LOG(ERROR) << TF_Message(c_status.get());
524 return false;
525 }
526 return true;
527 }
BlockHostForEvent(Stream * stream,Event * event)528 port::Status BlockHostForEvent(Stream* stream, Event* event) {
529 OwnedTFStatus c_status(TF_NewStatus());
530 SP_Event event_handle =
531 static_cast<CEvent*>(event->implementation())->Handle();
532 stream_executor_->block_host_for_event(&device_, event_handle,
533 c_status.get());
534 return StatusFromTF_Status(c_status.get());
535 }
536
BlockHostUntilDone(Stream * stream)537 port::Status BlockHostUntilDone(Stream* stream) override {
538 OwnedTFStatus c_status(TF_NewStatus());
539 SP_Stream stream_handle =
540 static_cast<CStream*>(stream->implementation())->Handle();
541
542 // If `block_host_until_done` is set, use it.
543 if (stream_executor_->block_host_until_done != nullptr) {
544 stream_executor_->block_host_until_done(&device_, stream_handle,
545 c_status.get());
546 return StatusFromTF_Status(c_status.get());
547 }
548 // Create and record an event and then wait for it.
549 SP_Event event_handle;
550 stream_executor_->create_event(&device_, &event_handle, c_status.get());
551 TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get()));
552 stream_executor_->record_event(&device_, stream_handle, event_handle,
553 c_status.get());
554 port::Status s = StatusFromTF_Status(c_status.get());
555 if (!s.ok()) {
556 stream_executor_->destroy_event(&device_, event_handle);
557 return s;
558 }
559 stream_executor_->block_host_for_event(&device_, event_handle,
560 c_status.get());
561 stream_executor_->destroy_event(&device_, event_handle);
562 return StatusFromTF_Status(c_status.get());
563 }
564
GetStatus(Stream * stream)565 port::Status GetStatus(Stream* stream) override {
566 OwnedTFStatus c_status(TF_NewStatus());
567 SP_Stream stream_handle =
568 static_cast<CStream*>(stream->implementation())->Handle();
569 stream_executor_->get_stream_status(&device_, stream_handle,
570 c_status.get());
571 return StatusFromTF_Status(c_status.get());
572 }
PlatformDeviceCount()573 int PlatformDeviceCount() override { return visible_device_count_; }
EnablePeerAccessTo(StreamExecutorInterface * other)574 port::Status EnablePeerAccessTo(StreamExecutorInterface* other) override {
575 return port::UnimplementedError(
576 "EnablePeerAccessTo is not supported by pluggable device.");
577 }
CanEnablePeerAccessTo(StreamExecutorInterface * other)578 bool CanEnablePeerAccessTo(StreamExecutorInterface* other) override {
579 return false;
580 }
581
DeviceMemoryUsage(int64_t * free,int64_t * total) const582 bool DeviceMemoryUsage(int64_t* free, int64_t* total) const override {
583 return stream_executor_->device_memory_usage(
584 &device_, reinterpret_cast<int64_t*>(free),
585 reinterpret_cast<int64_t*>(total));
586 }
587
588 // Creates a new DeviceDescription object.
589 // Ownership is transferred to the caller.
CreateDeviceDescription() const590 port::StatusOr<std::unique_ptr<DeviceDescription>> CreateDeviceDescription()
591 const override {
592 OwnedTFStatus c_status(TF_NewStatus());
593
594 internal::DeviceDescriptionBuilder builder;
595 if (device_.hardware_name != nullptr) {
596 builder.set_name(device_.hardware_name);
597 }
598 if (device_.device_vendor != nullptr) {
599 builder.set_device_vendor(device_.device_vendor);
600 }
601 if (device_.pci_bus_id != nullptr) {
602 builder.set_pci_bus_id(device_.pci_bus_id);
603 }
604
605 if (device_fns_->get_numa_node != nullptr) {
606 int32_t numa_node = device_fns_->get_numa_node(&device_);
607 if (numa_node >= 0) {
608 builder.set_numa_node(numa_node);
609 }
610 }
611
612 if (device_fns_->get_memory_bandwidth != nullptr) {
613 int64_t memory_bandwidth = device_fns_->get_memory_bandwidth(&device_);
614 if (memory_bandwidth >= 0) {
615 builder.set_memory_bandwidth(memory_bandwidth);
616 }
617 }
618 // TODO(annarev): Add gflops field in DeviceDescription and set it here.
619 // TODO(annarev): Perhaps add `supports_unified_memory` in
620 // DeviceDescription.
621 return builder.Build();
622 }
623
624 // Each call creates a new instance of the platform-specific implementation of
625 // the corresponding interface type.
CreateEventImplementation()626 std::unique_ptr<internal::EventInterface> CreateEventImplementation()
627 override {
628 return std::unique_ptr<internal::EventInterface>(
629 new CEvent(&device_, stream_executor_));
630 }
CreateKernelImplementation()631 std::unique_ptr<internal::KernelInterface> CreateKernelImplementation()
632 override {
633 LOG(FATAL)
634 << "CreateKernelImplementation is not supported by pluggable device.";
635 }
GetStreamImplementation()636 std::unique_ptr<internal::StreamInterface> GetStreamImplementation()
637 override {
638 return std::unique_ptr<internal::StreamInterface>(
639 new CStream(&device_, stream_executor_));
640 }
GetTimerImplementation()641 std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override {
642 return std::unique_ptr<internal::TimerInterface>(
643 new CTimer(&device_, stream_executor_, timer_fns_));
644 }
645
646 private:
647 SP_Device device_;
648 SP_DeviceFns* device_fns_;
649 SP_StreamExecutor* stream_executor_;
650 SP_Platform* platform_;
651 SP_PlatformFns* platform_fns_;
652 SP_TimerFns* timer_fns_;
653 std::string platform_name_;
654 int visible_device_count_;
655 };
656 } // namespace
657
CPlatform(SP_Platform platform,void (* destroy_platform)(SP_Platform *),SP_PlatformFns platform_fns,void (* destroy_platform_fns)(SP_PlatformFns *),SP_DeviceFns device_fns,SP_StreamExecutor stream_executor,SP_TimerFns timer_fns)658 CPlatform::CPlatform(SP_Platform platform,
659 void (*destroy_platform)(SP_Platform*),
660 SP_PlatformFns platform_fns,
661 void (*destroy_platform_fns)(SP_PlatformFns*),
662 SP_DeviceFns device_fns, SP_StreamExecutor stream_executor,
663 SP_TimerFns timer_fns)
664 : platform_(std::move(platform)),
665 destroy_platform_(destroy_platform),
666 platform_fns_(std::move(platform_fns)),
667 destroy_platform_fns_(destroy_platform_fns),
668 device_fns_(std::move(device_fns)),
669 stream_executor_(std::move(stream_executor)),
670 timer_fns_(std::move(timer_fns)),
671 name_(platform.name) {}
672
~CPlatform()673 CPlatform::~CPlatform() {
674 executor_cache_.DestroyAllExecutors();
675 platform_fns_.destroy_device_fns(&platform_, &device_fns_);
676 platform_fns_.destroy_stream_executor(&platform_, &stream_executor_);
677 platform_fns_.destroy_timer_fns(&platform_, &timer_fns_);
678 destroy_platform_(&platform_);
679 destroy_platform_fns_(&platform_fns_);
680 }
681
682 port::StatusOr<std::unique_ptr<DeviceDescription>>
DescriptionForDevice(int ordinal) const683 CPlatform::DescriptionForDevice(int ordinal) const {
684 // TODO(annarev): see if we can get StreamExecutor instance
685 // and call GetDeviceDescription. executor_cache_.Get would need
686 // to be made const for it to work.
687 internal::DeviceDescriptionBuilder builder;
688 builder.set_name(name_);
689 return builder.Build();
690 }
ExecutorForDevice(int ordinal)691 port::StatusOr<StreamExecutor*> CPlatform::ExecutorForDevice(int ordinal) {
692 stream_executor::StreamExecutorConfig config;
693 config.ordinal = ordinal;
694 return GetExecutor(config);
695 }
ExecutorForDeviceWithPluginConfig(int ordinal,const PluginConfig & plugin_config)696 port::StatusOr<StreamExecutor*> CPlatform::ExecutorForDeviceWithPluginConfig(
697 int ordinal, const PluginConfig& plugin_config) {
698 StreamExecutorConfig config;
699 config.ordinal = ordinal;
700 config.plugin_config = plugin_config;
701 return GetExecutor(config);
702 }
GetExecutor(const StreamExecutorConfig & config)703 port::StatusOr<StreamExecutor*> CPlatform::GetExecutor(
704 const StreamExecutorConfig& config) {
705 return executor_cache_.GetOrCreate(
706 config, [&]() { return GetUncachedExecutor(config); });
707 }
GetUncachedExecutor(const StreamExecutorConfig & config)708 port::StatusOr<std::unique_ptr<StreamExecutor>> CPlatform::GetUncachedExecutor(
709 const StreamExecutorConfig& config) {
710 // Fill device creation params
711 SE_CreateDeviceParams device_params{SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE};
712 SP_Device device{SP_DEVICE_STRUCT_SIZE};
713 device_params.device = &device;
714 device_params.ext = nullptr;
715 device_params.ordinal = config.ordinal;
716 OwnedTFStatus c_status(TF_NewStatus());
717
718 // Create Device
719 platform_fns_.create_device(&platform_, &device_params, c_status.get());
720 TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get()));
721 TF_RETURN_IF_ERROR(ValidateSPDevice(device));
722
723 // Get Device Count
724 int visible_device_count = 0;
725 platform_fns_.get_device_count(&platform_, &visible_device_count,
726 c_status.get());
727 TF_RETURN_IF_ERROR(StatusFromTF_Status(c_status.get()));
728
729 auto executor = absl::make_unique<CStreamExecutor>(
730 std::move(device), &device_fns_, &stream_executor_, &platform_,
731 &platform_fns_, &timer_fns_, name_, visible_device_count);
732 auto result = absl::make_unique<StreamExecutor>(this, std::move(executor),
733 config.ordinal);
734 return result;
735 }
736
InitStreamExecutorPlugin(void * dso_handle,std::string * device_type,std::string * platform_name)737 port::Status InitStreamExecutorPlugin(void* dso_handle,
738 std::string* device_type,
739 std::string* platform_name) {
740 tensorflow::Env* env = tensorflow::Env::Default();
741
742 // Step 1: Load symbol for `TF_InitPlugin`
743 void* dso_symbol;
744 TF_RETURN_IF_ERROR(
745 env->GetSymbolFromLibrary(dso_handle, "SE_InitPlugin", &dso_symbol));
746
747 // Step 2: Call `TF_InitPlugin`
748 auto init_fn = reinterpret_cast<SEInitPluginFn>(dso_symbol);
749 return InitStreamExecutorPlugin(init_fn, device_type, platform_name);
750 }
751
InitStreamExecutorPlugin(SEInitPluginFn init_fn,std::string * device_type,std::string * platform_name)752 port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn,
753 std::string* device_type,
754 std::string* platform_name) {
755 SE_PlatformRegistrationParams params{
756 SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE};
757 SP_Platform platform{SP_PLATFORM_STRUCT_SIZE};
758 SP_PlatformFns platform_fns{SP_PLATFORM_FNS_STRUCT_SIZE};
759 params.major_version = SE_MAJOR;
760 params.minor_version = SE_MINOR;
761 params.patch_version = SE_PATCH;
762 params.platform = &platform;
763 params.platform_fns = &platform_fns;
764
765 OwnedTFStatus c_status(TF_NewStatus());
766 init_fn(¶ms, c_status.get());
767 TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
768 TF_RETURN_IF_ERROR(ValidateSEPlatformRegistrationParams(params));
769 TF_RETURN_IF_ERROR(ValidateSPPlatform(platform));
770 TF_RETURN_IF_ERROR(ValidateSPPlatformFns(platform_fns));
771
772 // Fill SP_DeviceFns creation params
773 SE_CreateDeviceFnsParams device_fns_params{
774 SE_CREATE_DEVICE_FNS_PARAMS_STRUCT_SIZE};
775 SP_DeviceFns device_fns{SP_DEVICE_FNS_STRUCT_SIZE};
776 device_fns_params.device_fns = &device_fns;
777
778 // Create StreamExecutor
779 platform_fns.create_device_fns(&platform, &device_fns_params, c_status.get());
780 TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
781 TF_RETURN_IF_ERROR(ValidateSPDeviceFns(device_fns));
782
783 // Fill stream executor creation params
784 SE_CreateStreamExecutorParams se_params{
785 SE_CREATE_STREAM_EXECUTOR_PARAMS_STRUCT_SIZE};
786 SP_StreamExecutor se{SP_STREAMEXECUTOR_STRUCT_SIZE};
787 se_params.stream_executor = &se;
788
789 // Create StreamExecutor
790 platform_fns.create_stream_executor(&platform, &se_params, c_status.get());
791 TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
792 TF_RETURN_IF_ERROR(ValidateSPStreamExecutor(se, platform));
793
794 SP_TimerFns timer_fns{SP_TIMER_FNS_STRUCT_SIZE};
795 platform_fns.create_timer_fns(&platform, &timer_fns, c_status.get());
796 TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(c_status.get()));
797 TF_RETURN_IF_ERROR(ValidateSPTimerFns(timer_fns));
798
799 // Register new platform
800 *device_type = std::string(platform.type);
801 *platform_name = std::string(platform.name);
802 std::unique_ptr<stream_executor::CPlatform> cplatform(
803 new stream_executor::CPlatform(
804 std::move(platform), params.destroy_platform, std::move(platform_fns),
805 params.destroy_platform_fns, std::move(device_fns), std::move(se),
806 std::move(timer_fns)));
807 SE_CHECK_OK(stream_executor::MultiPlatformManager::RegisterPlatform(
808 std::move(cplatform)));
809 // TODO(annarev): Return `use_bfc_allocator` value in some way so that it is
810 // available in `PluggableDeviceProcessState` once the latter is checked in.
811 return ::tensorflow::OkStatus();
812 }
813 } // namespace stream_executor
814