1 #include <gtest/gtest.h> 2 3 #include <torch/cuda.h> 4 5 #include <iostream> 6 #include <string> 7 add_negative_flag(const std::string & flag)8std::string add_negative_flag(const std::string& flag) { 9 std::string filter = ::testing::GTEST_FLAG(filter); 10 if (filter.find('-') == std::string::npos) { 11 filter.push_back('-'); 12 } else { 13 filter.push_back(':'); 14 } 15 filter += flag; 16 return filter; 17 } 18 main(int argc,char * argv[])19int main(int argc, char* argv[]) { 20 ::testing::InitGoogleTest(&argc, argv); 21 22 if (!torch::cuda::is_available()) { 23 std::cout << "CUDA not available. Disabling CUDA and MultiCUDA tests" 24 << std::endl; 25 ::testing::GTEST_FLAG(filter) = add_negative_flag("*_CUDA:*_MultiCUDA"); 26 } else if (torch::cuda::device_count() < 2) { 27 std::cout << "Only one CUDA device detected. Disabling MultiCUDA tests" 28 << std::endl; 29 ::testing::GTEST_FLAG(filter) = add_negative_flag("*_MultiCUDA"); 30 } 31 32 return RUN_ALL_TESTS(); 33 } 34