xref: /aosp_15_r20/external/pytorch/test/cpp/lazy/test_shape.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <sstream>
4 
5 #include <torch/csrc/lazy/core/shape.h>
6 
7 namespace torch {
8 namespace lazy {
9 
TEST(ShapeTest,Basic1)10 TEST(ShapeTest, Basic1) {
11   auto shape = Shape();
12 
13   EXPECT_STREQ(shape.to_string().c_str(), "UNKNOWN_SCALAR[]");
14   EXPECT_EQ(shape.scalar_type(), c10::ScalarType::Undefined);
15   EXPECT_EQ(shape.dim(), 0);
16   EXPECT_TRUE(shape.sizes().empty());
17   EXPECT_THROW(shape.size(0), std::out_of_range);
18 }
19 
TEST(ShapeTest,Basic2)20 TEST(ShapeTest, Basic2) {
21   auto shape = Shape(c10::ScalarType::Float, {1, 2, 3});
22 
23   EXPECT_EQ(shape.numel(), 6);
24   EXPECT_STREQ(shape.to_string().c_str(), "Float[1,2,3]");
25   EXPECT_EQ(shape.scalar_type(), c10::ScalarType::Float);
26   EXPECT_EQ(shape.dim(), 3);
27   EXPECT_EQ(shape.sizes().size(), 3);
28   for (int64_t i = 0; i < shape.dim(); i++) {
29     EXPECT_EQ(shape.sizes()[i], i + 1);
30     EXPECT_EQ(shape.size(i), i + 1);
31   }
32 }
33 
TEST(ShapeTest,Basic3)34 TEST(ShapeTest, Basic3) {
35   auto shape = Shape(c10::ScalarType::Float, {});
36 
37   EXPECT_STREQ(shape.to_string().c_str(), "Float[]");
38   EXPECT_EQ(shape.scalar_type(), c10::ScalarType::Float);
39   EXPECT_EQ(shape.dim(), 0);
40   // this is surprising, but it's in line with how 0-D tensors behave
41   EXPECT_EQ(shape.numel(), 1);
42   EXPECT_TRUE(shape.sizes().empty());
43   EXPECT_THROW(shape.size(0), std::out_of_range);
44 }
45 
TEST(ShapeTest,SetScalarType)46 TEST(ShapeTest, SetScalarType) {
47   auto shape = Shape();
48 
49   shape.set_scalar_type(c10::ScalarType::Long);
50   EXPECT_EQ(shape.scalar_type(), c10::ScalarType::Long);
51 }
52 
TEST(ShapeTest,SetSize)53 TEST(ShapeTest, SetSize) {
54   auto shape1 = Shape();
55   EXPECT_THROW(shape1.set_size(0, 0), std::out_of_range);
56 
57   auto shape2 = Shape(c10::ScalarType::Float, {1, 2, 3});
58   shape2.set_size(0, 3);
59   EXPECT_EQ(shape2.sizes()[0], 3);
60   EXPECT_EQ(shape2.size(0), 3);
61 }
62 
TEST(ShapeTest,Equal)63 TEST(ShapeTest, Equal) {
64   auto shape1 = Shape(c10::ScalarType::Float, {});
65   auto shape2 = Shape(c10::ScalarType::Float, {1, 2, 3});
66   auto shape3 = Shape(c10::ScalarType::Long, {1, 2, 3});
67   auto shape4 = Shape(c10::ScalarType::Float, {1, 2, 3});
68 
69   EXPECT_FALSE(shape1 == shape2);
70   EXPECT_FALSE(shape2 == shape3);
71   EXPECT_FALSE(shape1 == shape3);
72   EXPECT_TRUE(shape2 == shape2);
73 }
74 
TEST(ShapeTest,Ostream)75 TEST(ShapeTest, Ostream) {
76   auto shape = Shape();
77   std::stringstream ss;
78   ss << shape;
79 
80   EXPECT_EQ(shape.to_string(), ss.str());
81 }
82 
83 } // namespace lazy
84 } // namespace torch
85