xref: /aosp_15_r20/external/pytorch/test/cpp/common/main.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 
3 #include <torch/cuda.h>
4 
5 #include <iostream>
6 #include <string>
7 
add_negative_flag(const std::string & flag)8 std::string add_negative_flag(const std::string& flag) {
9   std::string filter = ::testing::GTEST_FLAG(filter);
10   if (filter.find('-') == std::string::npos) {
11     filter.push_back('-');
12   } else {
13     filter.push_back(':');
14   }
15   filter += flag;
16   return filter;
17 }
18 
main(int argc,char * argv[])19 int main(int argc, char* argv[]) {
20   ::testing::InitGoogleTest(&argc, argv);
21 
22   if (!torch::cuda::is_available()) {
23     std::cout << "CUDA not available. Disabling CUDA and MultiCUDA tests"
24               << std::endl;
25     ::testing::GTEST_FLAG(filter) = add_negative_flag("*_CUDA:*_MultiCUDA");
26   } else if (torch::cuda::device_count() < 2) {
27     std::cout << "Only one CUDA device detected. Disabling MultiCUDA tests"
28               << std::endl;
29     ::testing::GTEST_FLAG(filter) = add_negative_flag("*_MultiCUDA");
30   }
31 
32   return RUN_ALL_TESTS();
33 }
34