xref: /aosp_15_r20/external/pytorch/test/cpp/api/operations.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <c10/util/irange.h>
4 #include <torch/torch.h>
5 
6 #include <test/cpp/api/support.h>
7 struct OperationTest : torch::test::SeedingFixture {
8  protected:
SetUpOperationTest9   void SetUp() override {}
10 
11   const int TEST_AMOUNT = 10;
12 };
13 
TEST_F(OperationTest,Lerp)14 TEST_F(OperationTest, Lerp) {
15   for (const auto i : c10::irange(TEST_AMOUNT)) {
16     (void)i; // Suppress unused variable warning
17     // test lerp_kernel_scalar
18     auto start = torch::rand({3, 5});
19     auto end = torch::rand({3, 5});
20     auto scalar = 0.5;
21     // expected and actual
22     auto scalar_expected = start + scalar * (end - start);
23     auto out = torch::lerp(start, end, scalar);
24     // compare
25     ASSERT_EQ(out.dtype(), scalar_expected.dtype());
26     ASSERT_TRUE(out.allclose(scalar_expected));
27 
28     // test lerp_kernel_tensor
29     auto weight = torch::rand({3, 5});
30     // expected and actual
31     auto tensor_expected = start + weight * (end - start);
32     out = torch::lerp(start, end, weight);
33     // compare
34     ASSERT_EQ(out.dtype(), tensor_expected.dtype());
35     ASSERT_TRUE(out.allclose(tensor_expected));
36   }
37 }
38 
TEST_F(OperationTest,Cross)39 TEST_F(OperationTest, Cross) {
40   for (const auto i : c10::irange(TEST_AMOUNT)) {
41     (void)i; // Suppress unused variable warning
42     // input
43     auto a = torch::rand({10, 3});
44     auto b = torch::rand({10, 3});
45     // expected
46     auto exp = torch::empty({10, 3});
47     for (const auto j : c10::irange(10)) {
48       auto u1 = a[j][0], u2 = a[j][1], u3 = a[j][2];
49       auto v1 = b[j][0], v2 = b[j][1], v3 = b[j][2];
50       exp[j][0] = u2 * v3 - v2 * u3;
51       exp[j][1] = v1 * u3 - u1 * v3;
52       exp[j][2] = u1 * v2 - v1 * u2;
53     }
54     // actual
55     auto out = torch::cross(a, b);
56     // compare
57     ASSERT_EQ(out.dtype(), exp.dtype());
58     ASSERT_TRUE(out.allclose(exp));
59   }
60 }
61 
TEST_F(OperationTest,Linear_out)62 TEST_F(OperationTest, Linear_out) {
63   {
64     const auto x = torch::arange(100., 118).resize_({3, 3, 2});
65     const auto w = torch::arange(200., 206).resize_({3, 2});
66     const auto b = torch::arange(300., 303);
67     auto y = torch::empty({3, 3, 3});
68     at::linear_out(y, x, w, b);
69     const auto y_exp = torch::tensor(
70         {{{40601, 41004, 41407}, {41403, 41814, 42225}, {42205, 42624, 43043}},
71          {{43007, 43434, 43861}, {43809, 44244, 44679}, {44611, 45054, 45497}},
72          {{45413, 45864, 46315}, {46215, 46674, 47133}, {47017, 47484, 47951}}},
73         torch::kFloat);
74     ASSERT_TRUE(torch::allclose(y, y_exp));
75   }
76   {
77     const auto x = torch::arange(100., 118).resize_({3, 3, 2});
78     const auto w = torch::arange(200., 206).resize_({3, 2});
79     auto y = torch::empty({3, 3, 3});
80     at::linear_out(y, x, w);
81     ASSERT_EQ(y.ndimension(), 3);
82     ASSERT_EQ(y.sizes(), torch::IntArrayRef({3, 3, 3}));
83     const auto y_exp = torch::tensor(
84         {{{40301, 40703, 41105}, {41103, 41513, 41923}, {41905, 42323, 42741}},
85          {{42707, 43133, 43559}, {43509, 43943, 44377}, {44311, 44753, 45195}},
86          {{45113, 45563, 46013}, {45915, 46373, 46831}, {46717, 47183, 47649}}},
87         torch::kFloat);
88     ASSERT_TRUE(torch::allclose(y, y_exp));
89   }
90 }
91