1 #include <gtest/gtest.h>
2
3 #include <ATen/MetaFunctions.h>
4 #include <torch/torch.h>
5
6 #include <vector>
7
TEST(MetaTensorTest,MetaDeviceApi)8 TEST(MetaTensorTest, MetaDeviceApi) {
9 auto a = at::ones({4}, at::kFloat);
10 auto b = at::ones({3, 4}, at::kFloat);
11 // at::add() will return a meta tensor if its inputs are also meta tensors.
12 auto out_meta = at::add(a.to(c10::kMeta), b.to(c10::kMeta));
13
14 ASSERT_EQ(a.device(), c10::kCPU);
15 ASSERT_EQ(b.device(), c10::kCPU);
16 ASSERT_EQ(out_meta.device(), c10::kMeta);
17 c10::IntArrayRef sizes_actual = out_meta.sizes();
18 std::vector<int64_t> sizes_expected = std::vector<int64_t>{3, 4};
19 ASSERT_EQ(sizes_actual, sizes_expected);
20 }
21
TEST(MetaTensorTest,MetaNamespaceApi)22 TEST(MetaTensorTest, MetaNamespaceApi) {
23 auto a = at::ones({4}, at::kFloat);
24 auto b = at::ones({3, 4}, at::kFloat);
25 // The at::meta:: namespace take in tensors from any backend
26 // and return a meta tensor.
27 auto out_meta = at::meta::add(a, b);
28
29 ASSERT_EQ(a.device(), c10::kCPU);
30 ASSERT_EQ(b.device(), c10::kCPU);
31 ASSERT_EQ(out_meta.device(), c10::kMeta);
32 c10::IntArrayRef sizes_actual = out_meta.sizes();
33 std::vector<int64_t> sizes_expected = std::vector<int64_t>{3, 4};
34 ASSERT_EQ(sizes_actual, sizes_expected);
35 }
36