xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/Dimname_test.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <ATen/Dimname.h>
4 #include <c10/util/Exception.h>
5 #include <optional>
6 
7 using at::NameType;
8 using at::Symbol;
9 using at::Dimname;
10 
TEST(DimnameTest,isValidIdentifier)11 TEST(DimnameTest, isValidIdentifier) {
12   ASSERT_TRUE(Dimname::isValidName("a"));
13   ASSERT_TRUE(Dimname::isValidName("batch"));
14   ASSERT_TRUE(Dimname::isValidName("N"));
15   ASSERT_TRUE(Dimname::isValidName("CHANNELS"));
16   ASSERT_TRUE(Dimname::isValidName("foo_bar_baz"));
17   ASSERT_TRUE(Dimname::isValidName("batch1"));
18   ASSERT_TRUE(Dimname::isValidName("batch_9"));
19   ASSERT_TRUE(Dimname::isValidName("_"));
20   ASSERT_TRUE(Dimname::isValidName("_1"));
21 
22   ASSERT_FALSE(Dimname::isValidName(""));
23   ASSERT_FALSE(Dimname::isValidName(" "));
24   ASSERT_FALSE(Dimname::isValidName(" a "));
25   ASSERT_FALSE(Dimname::isValidName("1batch"));
26   ASSERT_FALSE(Dimname::isValidName("?"));
27   ASSERT_FALSE(Dimname::isValidName("-"));
28   ASSERT_FALSE(Dimname::isValidName("1"));
29   ASSERT_FALSE(Dimname::isValidName("01"));
30 }
31 
TEST(DimnameTest,wildcardName)32 TEST(DimnameTest, wildcardName) {
33   Dimname wildcard = Dimname::wildcard();
34   ASSERT_EQ(wildcard.type(), NameType::WILDCARD);
35   ASSERT_EQ(wildcard.symbol(), Symbol::dimname("*"));
36 }
37 
TEST(DimnameTest,createNormalName)38 TEST(DimnameTest, createNormalName) {
39   auto foo = Symbol::dimname("foo");
40   auto dimname = Dimname::fromSymbol(foo);
41   ASSERT_EQ(dimname.type(), NameType::BASIC);
42   ASSERT_EQ(dimname.symbol(), foo);
43   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
44   ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("inva.lid")), c10::Error);
45   // NOLINTNEXTLINE(cppcoreguidelines-avoid-goto,hicpp-avoid-goto)
46   ASSERT_THROW(Dimname::fromSymbol(Symbol::dimname("1invalid")), c10::Error);
47 }
48 
check_unify_and_match(const std::string & dimname,const std::string & other,std::optional<const std::string> expected)49 static void check_unify_and_match(
50     const std::string& dimname,
51     const std::string& other,
52     std::optional<const std::string> expected) {
53   auto dimname1 = Dimname::fromSymbol(Symbol::dimname(dimname));
54   auto dimname2 = Dimname::fromSymbol(Symbol::dimname(other));
55   auto result = dimname1.unify(dimname2);
56   if (expected) {
57     auto expected_result = Dimname::fromSymbol(Symbol::dimname(*expected));
58     ASSERT_EQ(result->symbol(), expected_result.symbol());
59     ASSERT_EQ(result->type(), expected_result.type());
60     ASSERT_TRUE(dimname1.matches(dimname2));
61   } else {
62     ASSERT_FALSE(result);
63     ASSERT_FALSE(dimname1.matches(dimname2));
64   }
65 }
66 
TEST(DimnameTest,unifyAndMatch)67 TEST(DimnameTest, unifyAndMatch) {
68   check_unify_and_match("a", "a", "a");
69   check_unify_and_match("a", "*", "a");
70   check_unify_and_match("*", "a", "a");
71   check_unify_and_match("*", "*", "*");
72   check_unify_and_match("a", "b", std::nullopt);
73 }
74