1 #include <gtest/gtest.h>
2
3 #include <torch/torch.h>
4 #include <variant>
5
6 #include <test/cpp/api/support.h>
7
8 #define TORCH_ENUM_PRETTY_PRINT_TEST(name) \
9 { \
10 v = torch::k##name; \
11 std::string pretty_print_name("k"); \
12 pretty_print_name.append(#name); \
13 ASSERT_EQ(torch::enumtype::get_enum_name(v), pretty_print_name); \
14 }
15
TEST(EnumTest,AllEnums)16 TEST(EnumTest, AllEnums) {
17 std::variant<
18 torch::enumtype::kLinear,
19 torch::enumtype::kConv1D,
20 torch::enumtype::kConv2D,
21 torch::enumtype::kConv3D,
22 torch::enumtype::kConvTranspose1D,
23 torch::enumtype::kConvTranspose2D,
24 torch::enumtype::kConvTranspose3D,
25 torch::enumtype::kSigmoid,
26 torch::enumtype::kTanh,
27 torch::enumtype::kReLU,
28 torch::enumtype::kLeakyReLU,
29 torch::enumtype::kFanIn,
30 torch::enumtype::kFanOut,
31 torch::enumtype::kConstant,
32 torch::enumtype::kReflect,
33 torch::enumtype::kReplicate,
34 torch::enumtype::kCircular,
35 torch::enumtype::kNearest,
36 torch::enumtype::kBilinear,
37 torch::enumtype::kBicubic,
38 torch::enumtype::kTrilinear,
39 torch::enumtype::kArea,
40 torch::enumtype::kSum,
41 torch::enumtype::kMean,
42 torch::enumtype::kMax,
43 torch::enumtype::kNone,
44 torch::enumtype::kBatchMean,
45 torch::enumtype::kZeros,
46 torch::enumtype::kBorder,
47 torch::enumtype::kReflection,
48 torch::enumtype::kRNN_TANH,
49 torch::enumtype::kRNN_RELU,
50 torch::enumtype::kLSTM,
51 torch::enumtype::kGRU>
52 v;
53
54 TORCH_ENUM_PRETTY_PRINT_TEST(Linear)
55 TORCH_ENUM_PRETTY_PRINT_TEST(Conv1D)
56 TORCH_ENUM_PRETTY_PRINT_TEST(Conv2D)
57 TORCH_ENUM_PRETTY_PRINT_TEST(Conv3D)
58 TORCH_ENUM_PRETTY_PRINT_TEST(ConvTranspose1D)
59 TORCH_ENUM_PRETTY_PRINT_TEST(ConvTranspose2D)
60 TORCH_ENUM_PRETTY_PRINT_TEST(ConvTranspose3D)
61 TORCH_ENUM_PRETTY_PRINT_TEST(Sigmoid)
62 TORCH_ENUM_PRETTY_PRINT_TEST(Tanh)
63 TORCH_ENUM_PRETTY_PRINT_TEST(ReLU)
64 TORCH_ENUM_PRETTY_PRINT_TEST(LeakyReLU)
65 TORCH_ENUM_PRETTY_PRINT_TEST(FanIn)
66 TORCH_ENUM_PRETTY_PRINT_TEST(FanOut)
67 TORCH_ENUM_PRETTY_PRINT_TEST(Constant)
68 TORCH_ENUM_PRETTY_PRINT_TEST(Reflect)
69 TORCH_ENUM_PRETTY_PRINT_TEST(Replicate)
70 TORCH_ENUM_PRETTY_PRINT_TEST(Circular)
71 TORCH_ENUM_PRETTY_PRINT_TEST(Nearest)
72 TORCH_ENUM_PRETTY_PRINT_TEST(Bilinear)
73 TORCH_ENUM_PRETTY_PRINT_TEST(Bicubic)
74 TORCH_ENUM_PRETTY_PRINT_TEST(Trilinear)
75 TORCH_ENUM_PRETTY_PRINT_TEST(Area)
76 TORCH_ENUM_PRETTY_PRINT_TEST(Sum)
77 TORCH_ENUM_PRETTY_PRINT_TEST(Mean)
78 TORCH_ENUM_PRETTY_PRINT_TEST(Max)
79 TORCH_ENUM_PRETTY_PRINT_TEST(None)
80 TORCH_ENUM_PRETTY_PRINT_TEST(BatchMean)
81 TORCH_ENUM_PRETTY_PRINT_TEST(Zeros)
82 TORCH_ENUM_PRETTY_PRINT_TEST(Border)
83 TORCH_ENUM_PRETTY_PRINT_TEST(Reflection)
84 TORCH_ENUM_PRETTY_PRINT_TEST(RNN_TANH)
85 TORCH_ENUM_PRETTY_PRINT_TEST(RNN_RELU)
86 TORCH_ENUM_PRETTY_PRINT_TEST(LSTM)
87 TORCH_ENUM_PRETTY_PRINT_TEST(GRU)
88 }
89