xref: /aosp_15_r20/external/pytorch/test/cpp/api/operations.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/irange.h>
4*da0073e9SAndroid Build Coastguard Worker #include <torch/torch.h>
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker #include <test/cpp/api/support.h>
7*da0073e9SAndroid Build Coastguard Worker struct OperationTest : torch::test::SeedingFixture {
8*da0073e9SAndroid Build Coastguard Worker  protected:
SetUpOperationTest9*da0073e9SAndroid Build Coastguard Worker   void SetUp() override {}
10*da0073e9SAndroid Build Coastguard Worker 
11*da0073e9SAndroid Build Coastguard Worker   const int TEST_AMOUNT = 10;
12*da0073e9SAndroid Build Coastguard Worker };
13*da0073e9SAndroid Build Coastguard Worker 
TEST_F(OperationTest,Lerp)14*da0073e9SAndroid Build Coastguard Worker TEST_F(OperationTest, Lerp) {
15*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(TEST_AMOUNT)) {
16*da0073e9SAndroid Build Coastguard Worker     (void)i; // Suppress unused variable warning
17*da0073e9SAndroid Build Coastguard Worker     // test lerp_kernel_scalar
18*da0073e9SAndroid Build Coastguard Worker     auto start = torch::rand({3, 5});
19*da0073e9SAndroid Build Coastguard Worker     auto end = torch::rand({3, 5});
20*da0073e9SAndroid Build Coastguard Worker     auto scalar = 0.5;
21*da0073e9SAndroid Build Coastguard Worker     // expected and actual
22*da0073e9SAndroid Build Coastguard Worker     auto scalar_expected = start + scalar * (end - start);
23*da0073e9SAndroid Build Coastguard Worker     auto out = torch::lerp(start, end, scalar);
24*da0073e9SAndroid Build Coastguard Worker     // compare
25*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(out.dtype(), scalar_expected.dtype());
26*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(out.allclose(scalar_expected));
27*da0073e9SAndroid Build Coastguard Worker 
28*da0073e9SAndroid Build Coastguard Worker     // test lerp_kernel_tensor
29*da0073e9SAndroid Build Coastguard Worker     auto weight = torch::rand({3, 5});
30*da0073e9SAndroid Build Coastguard Worker     // expected and actual
31*da0073e9SAndroid Build Coastguard Worker     auto tensor_expected = start + weight * (end - start);
32*da0073e9SAndroid Build Coastguard Worker     out = torch::lerp(start, end, weight);
33*da0073e9SAndroid Build Coastguard Worker     // compare
34*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(out.dtype(), tensor_expected.dtype());
35*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(out.allclose(tensor_expected));
36*da0073e9SAndroid Build Coastguard Worker   }
37*da0073e9SAndroid Build Coastguard Worker }
38*da0073e9SAndroid Build Coastguard Worker 
TEST_F(OperationTest,Cross)39*da0073e9SAndroid Build Coastguard Worker TEST_F(OperationTest, Cross) {
40*da0073e9SAndroid Build Coastguard Worker   for (const auto i : c10::irange(TEST_AMOUNT)) {
41*da0073e9SAndroid Build Coastguard Worker     (void)i; // Suppress unused variable warning
42*da0073e9SAndroid Build Coastguard Worker     // input
43*da0073e9SAndroid Build Coastguard Worker     auto a = torch::rand({10, 3});
44*da0073e9SAndroid Build Coastguard Worker     auto b = torch::rand({10, 3});
45*da0073e9SAndroid Build Coastguard Worker     // expected
46*da0073e9SAndroid Build Coastguard Worker     auto exp = torch::empty({10, 3});
47*da0073e9SAndroid Build Coastguard Worker     for (const auto j : c10::irange(10)) {
48*da0073e9SAndroid Build Coastguard Worker       auto u1 = a[j][0], u2 = a[j][1], u3 = a[j][2];
49*da0073e9SAndroid Build Coastguard Worker       auto v1 = b[j][0], v2 = b[j][1], v3 = b[j][2];
50*da0073e9SAndroid Build Coastguard Worker       exp[j][0] = u2 * v3 - v2 * u3;
51*da0073e9SAndroid Build Coastguard Worker       exp[j][1] = v1 * u3 - u1 * v3;
52*da0073e9SAndroid Build Coastguard Worker       exp[j][2] = u1 * v2 - v1 * u2;
53*da0073e9SAndroid Build Coastguard Worker     }
54*da0073e9SAndroid Build Coastguard Worker     // actual
55*da0073e9SAndroid Build Coastguard Worker     auto out = torch::cross(a, b);
56*da0073e9SAndroid Build Coastguard Worker     // compare
57*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(out.dtype(), exp.dtype());
58*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(out.allclose(exp));
59*da0073e9SAndroid Build Coastguard Worker   }
60*da0073e9SAndroid Build Coastguard Worker }
61*da0073e9SAndroid Build Coastguard Worker 
TEST_F(OperationTest,Linear_out)62*da0073e9SAndroid Build Coastguard Worker TEST_F(OperationTest, Linear_out) {
63*da0073e9SAndroid Build Coastguard Worker   {
64*da0073e9SAndroid Build Coastguard Worker     const auto x = torch::arange(100., 118).resize_({3, 3, 2});
65*da0073e9SAndroid Build Coastguard Worker     const auto w = torch::arange(200., 206).resize_({3, 2});
66*da0073e9SAndroid Build Coastguard Worker     const auto b = torch::arange(300., 303);
67*da0073e9SAndroid Build Coastguard Worker     auto y = torch::empty({3, 3, 3});
68*da0073e9SAndroid Build Coastguard Worker     at::linear_out(y, x, w, b);
69*da0073e9SAndroid Build Coastguard Worker     const auto y_exp = torch::tensor(
70*da0073e9SAndroid Build Coastguard Worker         {{{40601, 41004, 41407}, {41403, 41814, 42225}, {42205, 42624, 43043}},
71*da0073e9SAndroid Build Coastguard Worker          {{43007, 43434, 43861}, {43809, 44244, 44679}, {44611, 45054, 45497}},
72*da0073e9SAndroid Build Coastguard Worker          {{45413, 45864, 46315}, {46215, 46674, 47133}, {47017, 47484, 47951}}},
73*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
74*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, y_exp));
75*da0073e9SAndroid Build Coastguard Worker   }
76*da0073e9SAndroid Build Coastguard Worker   {
77*da0073e9SAndroid Build Coastguard Worker     const auto x = torch::arange(100., 118).resize_({3, 3, 2});
78*da0073e9SAndroid Build Coastguard Worker     const auto w = torch::arange(200., 206).resize_({3, 2});
79*da0073e9SAndroid Build Coastguard Worker     auto y = torch::empty({3, 3, 3});
80*da0073e9SAndroid Build Coastguard Worker     at::linear_out(y, x, w);
81*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.ndimension(), 3);
82*da0073e9SAndroid Build Coastguard Worker     ASSERT_EQ(y.sizes(), torch::IntArrayRef({3, 3, 3}));
83*da0073e9SAndroid Build Coastguard Worker     const auto y_exp = torch::tensor(
84*da0073e9SAndroid Build Coastguard Worker         {{{40301, 40703, 41105}, {41103, 41513, 41923}, {41905, 42323, 42741}},
85*da0073e9SAndroid Build Coastguard Worker          {{42707, 43133, 43559}, {43509, 43943, 44377}, {44311, 44753, 45195}},
86*da0073e9SAndroid Build Coastguard Worker          {{45113, 45563, 46013}, {45915, 46373, 46831}, {46717, 47183, 47649}}},
87*da0073e9SAndroid Build Coastguard Worker         torch::kFloat);
88*da0073e9SAndroid Build Coastguard Worker     ASSERT_TRUE(torch::allclose(y, y_exp));
89*da0073e9SAndroid Build Coastguard Worker   }
90*da0073e9SAndroid Build Coastguard Worker }
91