1 #include <gtest/gtest.h> 2 #include <torch/torch.h> 3 #include <ATen/mps/MPSAllocatorInterface.h> 4 5 namespace replay { 6 std::function<void()> callback_action; 7 8 class ReplayBufferCleaner : virtual public at::mps::IMpsAllocatorCallback { 9 public: executeMPSAllocatorCallback(void * ptr,EventType event)10 void executeMPSAllocatorCallback(void* ptr, EventType event) override { 11 if (event == EventType::ALLOCATION_FAILED) { 12 callback_action(); 13 } 14 } 15 }; 16 } 17 18 namespace at::mps { 19 REGISTER_MPS_ALLOCATOR_CALLBACK("ReplayBufferCleaner", replay::ReplayBufferCleaner); 20 } 21 TEST(MPSAllocator,MPSAllocatorCallbacks)22TEST(MPSAllocator, MPSAllocatorCallbacks) { 23 // fail if mps isn't available 24 ASSERT_TRUE(torch::mps::is_available()); 25 26 std::vector<torch::Tensor> replay_buffer; 27 replay::callback_action = [&]() { 28 if (!replay_buffer.empty()) { 29 replay_buffer.erase(replay_buffer.begin(), replay_buffer.begin() + (replay_buffer.size()/10)); 30 } 31 }; 32 size_t max_iter = 100000; 33 for (size_t i = 0; i < max_iter; i++) { 34 torch::Tensor new_value = torch::randn({10000, 10000}, at::device(at::kMPS)); 35 // early stop the first time the callback is called 36 if (replay_buffer.size() != i) { 37 break; 38 } 39 replay_buffer.push_back(new_value); 40 } 41 // call synchronize() explicitly to wait for all MPS streams to 42 // finish the Metal completionHandlers in MPSAllocator. Note that MPSAllocator 43 // does this implicitly, but we call this for testing purposes. 44 torch::mps::synchronize(); 45 ASSERT_TRUE(replay_buffer.size() < max_iter); 46 } 47