xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/allocator_clone_test.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <gtest/gtest.h>
3 #include <ATen/ATen.h>
4 
test_allocator_clone(c10::Allocator * allocator)5 void test_allocator_clone(c10::Allocator* allocator) {
6   ASSERT_TRUE(allocator != nullptr);
7 
8   c10::Storage a_storage(c10::make_intrusive<c10::StorageImpl>(
9     c10::StorageImpl::use_byte_size_t(),
10     0,
11     allocator,
12     /*resizable=*/true));
13 
14   c10::Storage b_storage(c10::make_intrusive<c10::StorageImpl>(
15     c10::StorageImpl::use_byte_size_t(),
16     0,
17     allocator,
18     /*resizable=*/true));
19 
20   at::Tensor a = at::empty({0}, at::TensorOptions().device(a_storage.device())).set_(a_storage);
21   at::Tensor b = at::empty({0}, at::TensorOptions().device(b_storage.device())).set_(b_storage);
22 
23   std::vector<int64_t> sizes({13, 4, 5});
24 
25   at::rand_out(a, sizes);
26   at::rand_out(b, sizes);
27 
28   ASSERT_TRUE(a_storage.nbytes() == static_cast<size_t>(a.numel() * a.element_size()));
29   ASSERT_TRUE(a_storage.nbytes() == b_storage.nbytes());
30 
31   void* a_data_ptr = a_storage.mutable_data();
32   b_storage.set_data_ptr(allocator->clone(a_data_ptr, a_storage.nbytes()));
33 
34   ASSERT_TRUE((a == b).all().item<bool>());
35 }
36