1 #include <ATen/Context.h> 2 #include <torch/mps.h> 3 4 namespace torch { 5 namespace mps { 6 is_available()7bool is_available() { 8 return at::detail::getMPSHooks().hasMPS(); 9 } 10 11 /// Sets the seed for the MPS's default generator. manual_seed(uint64_t seed)12void manual_seed(uint64_t seed) { 13 if (is_available()) { 14 auto gen = at::detail::getMPSHooks().getDefaultMPSGenerator(); 15 { 16 // See Note [Acquire lock when using random generators] 17 std::lock_guard<std::mutex> lock(gen.mutex()); 18 gen.set_current_seed(seed); 19 } 20 } 21 } 22 synchronize()23void synchronize() { 24 at::detail::getMPSHooks().deviceSynchronize(); 25 } 26 commit()27void commit() { 28 at::detail::getMPSHooks().commitStream(); 29 } 30 get_command_buffer()31MTLCommandBuffer_t get_command_buffer() { 32 return at::detail::getMPSHooks().getCommandBuffer(); 33 } 34 get_dispatch_queue()35DispatchQueue_t get_dispatch_queue() { 36 return at::detail::getMPSHooks().getDispatchQueue(); 37 } 38 39 } // namespace mps 40 } // namespace torch 41