xref: /aosp_15_r20/external/pytorch/test/cpp/api/meta_tensor.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/MetaFunctions.h>
4 #include <torch/torch.h>
5 
6 #include <vector>
7 
TEST(MetaTensorTest,MetaDeviceApi)8 TEST(MetaTensorTest, MetaDeviceApi) {
9   auto a = at::ones({4}, at::kFloat);
10   auto b = at::ones({3, 4}, at::kFloat);
11   // at::add() will return a meta tensor if its inputs are also meta tensors.
12   auto out_meta = at::add(a.to(c10::kMeta), b.to(c10::kMeta));
13 
14   ASSERT_EQ(a.device(), c10::kCPU);
15   ASSERT_EQ(b.device(), c10::kCPU);
16   ASSERT_EQ(out_meta.device(), c10::kMeta);
17   c10::IntArrayRef sizes_actual = out_meta.sizes();
18   std::vector<int64_t> sizes_expected = std::vector<int64_t>{3, 4};
19   ASSERT_EQ(sizes_actual, sizes_expected);
20 }
21 
TEST(MetaTensorTest,MetaNamespaceApi)22 TEST(MetaTensorTest, MetaNamespaceApi) {
23   auto a = at::ones({4}, at::kFloat);
24   auto b = at::ones({3, 4}, at::kFloat);
25   // The at::meta:: namespace take in tensors from any backend
26   // and return a meta tensor.
27   auto out_meta = at::meta::add(a, b);
28 
29   ASSERT_EQ(a.device(), c10::kCPU);
30   ASSERT_EQ(b.device(), c10::kCPU);
31   ASSERT_EQ(out_meta.device(), c10::kMeta);
32   c10::IntArrayRef sizes_actual = out_meta.sizes();
33   std::vector<int64_t> sizes_expected = std::vector<int64_t>{3, 4};
34   ASSERT_EQ(sizes_actual, sizes_expected);
35 }
36