xref: /aosp_15_r20/external/pytorch/aten/src/ATen/mps/MPSGuardImpl.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1//  Copyright © 2022 Apple Inc.
2
3#include <ATen/mps/MPSDevice.h>
4#include <ATen/mps/MPSGuardImpl.h>
5
6namespace at::mps {
7
8void MPSGuardImpl::createEvent(mpsEvent_t* event, const EventFlag flag) const {}
9
10void MPSGuardImpl::destroyEvent(void* event, const DeviceIndex device_index) const noexcept {
11  if (!event)
12    return;
13  auto mps_event = static_cast<mpsEvent_t>(event);
14  mps_event->~MPSEvent();
15}
16
17void MPSGuardImpl::record(void** event,
18                          const Stream& stream,
19                          const DeviceIndex device_index,
20                          const EventFlag flag) const {
21  TORCH_CHECK(device_index == -1 || device_index == stream.device_index(),
22              "Event device index ",
23              device_index,
24              " does not match recording stream's device index ",
25              stream.device_index(),
26              ".");
27
28  auto mps_event = static_cast<mpsEvent_t>(*event);
29  MPSStream mps_stream{stream};
30  mps_event->record(true);
31}
32
33void MPSGuardImpl::block(void* event, const Stream& stream) const {
34  auto mps_event = static_cast<mpsEvent_t>(event);
35  MPSStream mps_stream{stream};
36
37  mps_event->wait(true, false);
38}
39
40bool MPSGuardImpl::queryEvent(void* event) const {
41  auto mps_event = static_cast<mpsEvent_t>(event);
42  return mps_event->query();
43}
44
45} // namespace at::mps
46