1 #include <gtest/gtest.h>
2
3 #include <sstream>
4
5 #include <torch/csrc/lazy/core/shape.h>
6
7 namespace torch {
8 namespace lazy {
9
TEST(ShapeTest,Basic1)10 TEST(ShapeTest, Basic1) {
11 auto shape = Shape();
12
13 EXPECT_STREQ(shape.to_string().c_str(), "UNKNOWN_SCALAR[]");
14 EXPECT_EQ(shape.scalar_type(), c10::ScalarType::Undefined);
15 EXPECT_EQ(shape.dim(), 0);
16 EXPECT_TRUE(shape.sizes().empty());
17 EXPECT_THROW(shape.size(0), std::out_of_range);
18 }
19
TEST(ShapeTest,Basic2)20 TEST(ShapeTest, Basic2) {
21 auto shape = Shape(c10::ScalarType::Float, {1, 2, 3});
22
23 EXPECT_EQ(shape.numel(), 6);
24 EXPECT_STREQ(shape.to_string().c_str(), "Float[1,2,3]");
25 EXPECT_EQ(shape.scalar_type(), c10::ScalarType::Float);
26 EXPECT_EQ(shape.dim(), 3);
27 EXPECT_EQ(shape.sizes().size(), 3);
28 for (int64_t i = 0; i < shape.dim(); i++) {
29 EXPECT_EQ(shape.sizes()[i], i + 1);
30 EXPECT_EQ(shape.size(i), i + 1);
31 }
32 }
33
TEST(ShapeTest,Basic3)34 TEST(ShapeTest, Basic3) {
35 auto shape = Shape(c10::ScalarType::Float, {});
36
37 EXPECT_STREQ(shape.to_string().c_str(), "Float[]");
38 EXPECT_EQ(shape.scalar_type(), c10::ScalarType::Float);
39 EXPECT_EQ(shape.dim(), 0);
40 // this is surprising, but it's in line with how 0-D tensors behave
41 EXPECT_EQ(shape.numel(), 1);
42 EXPECT_TRUE(shape.sizes().empty());
43 EXPECT_THROW(shape.size(0), std::out_of_range);
44 }
45
TEST(ShapeTest,SetScalarType)46 TEST(ShapeTest, SetScalarType) {
47 auto shape = Shape();
48
49 shape.set_scalar_type(c10::ScalarType::Long);
50 EXPECT_EQ(shape.scalar_type(), c10::ScalarType::Long);
51 }
52
TEST(ShapeTest,SetSize)53 TEST(ShapeTest, SetSize) {
54 auto shape1 = Shape();
55 EXPECT_THROW(shape1.set_size(0, 0), std::out_of_range);
56
57 auto shape2 = Shape(c10::ScalarType::Float, {1, 2, 3});
58 shape2.set_size(0, 3);
59 EXPECT_EQ(shape2.sizes()[0], 3);
60 EXPECT_EQ(shape2.size(0), 3);
61 }
62
TEST(ShapeTest,Equal)63 TEST(ShapeTest, Equal) {
64 auto shape1 = Shape(c10::ScalarType::Float, {});
65 auto shape2 = Shape(c10::ScalarType::Float, {1, 2, 3});
66 auto shape3 = Shape(c10::ScalarType::Long, {1, 2, 3});
67 auto shape4 = Shape(c10::ScalarType::Float, {1, 2, 3});
68
69 EXPECT_FALSE(shape1 == shape2);
70 EXPECT_FALSE(shape2 == shape3);
71 EXPECT_FALSE(shape1 == shape3);
72 EXPECT_TRUE(shape2 == shape2);
73 }
74
TEST(ShapeTest,Ostream)75 TEST(ShapeTest, Ostream) {
76 auto shape = Shape();
77 std::stringstream ss;
78 ss << shape;
79
80 EXPECT_EQ(shape.to_string(), ss.str());
81 }
82
83 } // namespace lazy
84 } // namespace torch
85