1// Copyright © 2022 Apple Inc. 2 3#include <ATen/mps/MPSDevice.h> 4#include <ATen/mps/MPSGuardImpl.h> 5 6namespace at::mps { 7 8void MPSGuardImpl::createEvent(mpsEvent_t* event, const EventFlag flag) const {} 9 10void MPSGuardImpl::destroyEvent(void* event, const DeviceIndex device_index) const noexcept { 11 if (!event) 12 return; 13 auto mps_event = static_cast<mpsEvent_t>(event); 14 mps_event->~MPSEvent(); 15} 16 17void MPSGuardImpl::record(void** event, 18 const Stream& stream, 19 const DeviceIndex device_index, 20 const EventFlag flag) const { 21 TORCH_CHECK(device_index == -1 || device_index == stream.device_index(), 22 "Event device index ", 23 device_index, 24 " does not match recording stream's device index ", 25 stream.device_index(), 26 "."); 27 28 auto mps_event = static_cast<mpsEvent_t>(*event); 29 MPSStream mps_stream{stream}; 30 mps_event->record(true); 31} 32 33void MPSGuardImpl::block(void* event, const Stream& stream) const { 34 auto mps_event = static_cast<mpsEvent_t>(event); 35 MPSStream mps_stream{stream}; 36 37 mps_event->wait(true, false); 38} 39 40bool MPSGuardImpl::queryEvent(void* event) const { 41 auto mps_event = static_cast<mpsEvent_t>(event); 42 return mps_event->query(); 43} 44 45} // namespace at::mps 46