1 #include <gtest/gtest.h>
2
3 #include <ATen/ATen.h>
4
5 using namespace at;
TestSimpleCase(DeprecatedTypeProperties & T)6 void TestSimpleCase(DeprecatedTypeProperties& T) {
7 auto a = randn({2, 3, 4, 5}, T);
8 ASSERT_TRUE(a.prod(-4).equal(a.prod(0)));
9 ASSERT_TRUE(a.prod(3).equal(a.prod(-1)));
10 }
11
TestExpressionSpecification(DeprecatedTypeProperties & T)12 void TestExpressionSpecification(DeprecatedTypeProperties& T) {
13 auto a = randn({2, 3, 4, 5}, T);
14 ASSERT_TRUE(a.unsqueeze(-5).equal(a.unsqueeze(0)));
15 ASSERT_TRUE(a.unsqueeze(4).equal(a.unsqueeze(-1)));
16
17 // can unsqueeze scalar
18 auto b = randn({}, T);
19 ASSERT_TRUE(b.unsqueeze(0).equal(b.unsqueeze(-1)));
20 }
21
TestEmptyTensor(DeprecatedTypeProperties & T)22 void TestEmptyTensor(DeprecatedTypeProperties& T) {
23 auto a = randn(0, T);
24 ASSERT_TRUE(a.prod(0).equal(at::ones({}, T)));
25 }
26
TestScalarVs1Dim1Size(DeprecatedTypeProperties & T)27 void TestScalarVs1Dim1Size(DeprecatedTypeProperties& T) {
28 auto a = randn(1, T);
29 ASSERT_TRUE(a.prod(0).equal(a.prod(-1)));
30 a.resize_({});
31 ASSERT_EQ(a.dim(), 0);
32 ASSERT_TRUE(a.prod(0).equal(a.prod(-1)));
33 }
34
TEST(TestWrapdim,TestWrapdim)35 TEST(TestWrapdim, TestWrapdim) {
36 manual_seed(123);
37 DeprecatedTypeProperties& T = CPU(kFloat);
38
39 TestSimpleCase(T);
40 TestEmptyTensor(T);
41 TestScalarVs1Dim1Size(T);
42 TestExpressionSpecification(T);
43 }
44