xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/wrapdim_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/ATen.h>
4 
5 using namespace at;
TestSimpleCase(DeprecatedTypeProperties & T)6 void TestSimpleCase(DeprecatedTypeProperties& T) {
7   auto a = randn({2, 3, 4, 5}, T);
8   ASSERT_TRUE(a.prod(-4).equal(a.prod(0)));
9   ASSERT_TRUE(a.prod(3).equal(a.prod(-1)));
10 }
11 
TestExpressionSpecification(DeprecatedTypeProperties & T)12 void TestExpressionSpecification(DeprecatedTypeProperties& T) {
13   auto a = randn({2, 3, 4, 5}, T);
14   ASSERT_TRUE(a.unsqueeze(-5).equal(a.unsqueeze(0)));
15   ASSERT_TRUE(a.unsqueeze(4).equal(a.unsqueeze(-1)));
16 
17   // can unsqueeze scalar
18   auto b = randn({}, T);
19   ASSERT_TRUE(b.unsqueeze(0).equal(b.unsqueeze(-1)));
20 }
21 
TestEmptyTensor(DeprecatedTypeProperties & T)22 void TestEmptyTensor(DeprecatedTypeProperties& T) {
23   auto a = randn(0, T);
24   ASSERT_TRUE(a.prod(0).equal(at::ones({}, T)));
25 }
26 
TestScalarVs1Dim1Size(DeprecatedTypeProperties & T)27 void TestScalarVs1Dim1Size(DeprecatedTypeProperties& T) {
28   auto a = randn(1, T);
29   ASSERT_TRUE(a.prod(0).equal(a.prod(-1)));
30   a.resize_({});
31   ASSERT_EQ(a.dim(), 0);
32   ASSERT_TRUE(a.prod(0).equal(a.prod(-1)));
33 }
34 
TEST(TestWrapdim,TestWrapdim)35 TEST(TestWrapdim, TestWrapdim) {
36   manual_seed(123);
37   DeprecatedTypeProperties& T = CPU(kFloat);
38 
39   TestSimpleCase(T);
40   TestEmptyTensor(T);
41   TestScalarVs1Dim1Size(T);
42   TestExpressionSpecification(T);
43 }
44