xref: /aosp_15_r20/external/pytorch/test/cpp/api/enum.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <torch/torch.h>
4 #include <variant>
5 
6 #include <test/cpp/api/support.h>
7 
8 #define TORCH_ENUM_PRETTY_PRINT_TEST(name)                           \
9   {                                                                  \
10     v = torch::k##name;                                              \
11     std::string pretty_print_name("k");                              \
12     pretty_print_name.append(#name);                                 \
13     ASSERT_EQ(torch::enumtype::get_enum_name(v), pretty_print_name); \
14   }
15 
TEST(EnumTest,AllEnums)16 TEST(EnumTest, AllEnums) {
17   std::variant<
18       torch::enumtype::kLinear,
19       torch::enumtype::kConv1D,
20       torch::enumtype::kConv2D,
21       torch::enumtype::kConv3D,
22       torch::enumtype::kConvTranspose1D,
23       torch::enumtype::kConvTranspose2D,
24       torch::enumtype::kConvTranspose3D,
25       torch::enumtype::kSigmoid,
26       torch::enumtype::kTanh,
27       torch::enumtype::kReLU,
28       torch::enumtype::kLeakyReLU,
29       torch::enumtype::kFanIn,
30       torch::enumtype::kFanOut,
31       torch::enumtype::kConstant,
32       torch::enumtype::kReflect,
33       torch::enumtype::kReplicate,
34       torch::enumtype::kCircular,
35       torch::enumtype::kNearest,
36       torch::enumtype::kBilinear,
37       torch::enumtype::kBicubic,
38       torch::enumtype::kTrilinear,
39       torch::enumtype::kArea,
40       torch::enumtype::kSum,
41       torch::enumtype::kMean,
42       torch::enumtype::kMax,
43       torch::enumtype::kNone,
44       torch::enumtype::kBatchMean,
45       torch::enumtype::kZeros,
46       torch::enumtype::kBorder,
47       torch::enumtype::kReflection,
48       torch::enumtype::kRNN_TANH,
49       torch::enumtype::kRNN_RELU,
50       torch::enumtype::kLSTM,
51       torch::enumtype::kGRU>
52       v;
53 
54   TORCH_ENUM_PRETTY_PRINT_TEST(Linear)
55   TORCH_ENUM_PRETTY_PRINT_TEST(Conv1D)
56   TORCH_ENUM_PRETTY_PRINT_TEST(Conv2D)
57   TORCH_ENUM_PRETTY_PRINT_TEST(Conv3D)
58   TORCH_ENUM_PRETTY_PRINT_TEST(ConvTranspose1D)
59   TORCH_ENUM_PRETTY_PRINT_TEST(ConvTranspose2D)
60   TORCH_ENUM_PRETTY_PRINT_TEST(ConvTranspose3D)
61   TORCH_ENUM_PRETTY_PRINT_TEST(Sigmoid)
62   TORCH_ENUM_PRETTY_PRINT_TEST(Tanh)
63   TORCH_ENUM_PRETTY_PRINT_TEST(ReLU)
64   TORCH_ENUM_PRETTY_PRINT_TEST(LeakyReLU)
65   TORCH_ENUM_PRETTY_PRINT_TEST(FanIn)
66   TORCH_ENUM_PRETTY_PRINT_TEST(FanOut)
67   TORCH_ENUM_PRETTY_PRINT_TEST(Constant)
68   TORCH_ENUM_PRETTY_PRINT_TEST(Reflect)
69   TORCH_ENUM_PRETTY_PRINT_TEST(Replicate)
70   TORCH_ENUM_PRETTY_PRINT_TEST(Circular)
71   TORCH_ENUM_PRETTY_PRINT_TEST(Nearest)
72   TORCH_ENUM_PRETTY_PRINT_TEST(Bilinear)
73   TORCH_ENUM_PRETTY_PRINT_TEST(Bicubic)
74   TORCH_ENUM_PRETTY_PRINT_TEST(Trilinear)
75   TORCH_ENUM_PRETTY_PRINT_TEST(Area)
76   TORCH_ENUM_PRETTY_PRINT_TEST(Sum)
77   TORCH_ENUM_PRETTY_PRINT_TEST(Mean)
78   TORCH_ENUM_PRETTY_PRINT_TEST(Max)
79   TORCH_ENUM_PRETTY_PRINT_TEST(None)
80   TORCH_ENUM_PRETTY_PRINT_TEST(BatchMean)
81   TORCH_ENUM_PRETTY_PRINT_TEST(Zeros)
82   TORCH_ENUM_PRETTY_PRINT_TEST(Border)
83   TORCH_ENUM_PRETTY_PRINT_TEST(Reflection)
84   TORCH_ENUM_PRETTY_PRINT_TEST(RNN_TANH)
85   TORCH_ENUM_PRETTY_PRINT_TEST(RNN_RELU)
86   TORCH_ENUM_PRETTY_PRINT_TEST(LSTM)
87   TORCH_ENUM_PRETTY_PRINT_TEST(GRU)
88 }
89