xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/mps_test_allocator.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 #include <torch/torch.h>
3 #include <ATen/mps/MPSAllocatorInterface.h>
4 
5 namespace replay {
6 std::function<void()> callback_action;
7 
8 class ReplayBufferCleaner : virtual public at::mps::IMpsAllocatorCallback {
9     public:
executeMPSAllocatorCallback(void * ptr,EventType event)10     void executeMPSAllocatorCallback(void* ptr, EventType event) override {
11      if (event == EventType::ALLOCATION_FAILED) {
12         callback_action();
13      }
14     }
15 };
16 }
17 
18 namespace at::mps {
19 REGISTER_MPS_ALLOCATOR_CALLBACK("ReplayBufferCleaner", replay::ReplayBufferCleaner);
20 }
21 
TEST(MPSAllocator,MPSAllocatorCallbacks)22 TEST(MPSAllocator, MPSAllocatorCallbacks) {
23     // fail if mps isn't available
24     ASSERT_TRUE(torch::mps::is_available());
25 
26     std::vector<torch::Tensor> replay_buffer;
27     replay::callback_action = [&]() {
28         if (!replay_buffer.empty()) {
29             replay_buffer.erase(replay_buffer.begin(), replay_buffer.begin() + (replay_buffer.size()/10));
30         }
31     };
32     size_t max_iter = 100000;
33     for (size_t i = 0; i < max_iter; i++) {
34         torch::Tensor new_value = torch::randn({10000, 10000}, at::device(at::kMPS));
35         // early stop the first time the callback is called
36         if (replay_buffer.size() != i) {
37             break;
38         }
39         replay_buffer.push_back(new_value);
40     }
41     // call synchronize() explicitly to wait for all MPS streams to
42     // finish the Metal completionHandlers in MPSAllocator. Note that MPSAllocator
43     // does this implicitly, but we call this for testing purposes.
44     torch::mps::synchronize();
45     ASSERT_TRUE(replay_buffer.size() < max_iter);
46 }
47