1*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h> 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker #include <torch/cuda.h> 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Worker #include <iostream> 6*da0073e9SAndroid Build Coastguard Worker #include <string> 7*da0073e9SAndroid Build Coastguard Worker add_negative_flag(const std::string & flag)8*da0073e9SAndroid Build Coastguard Workerstd::string add_negative_flag(const std::string& flag) { 9*da0073e9SAndroid Build Coastguard Worker std::string filter = ::testing::GTEST_FLAG(filter); 10*da0073e9SAndroid Build Coastguard Worker if (filter.find('-') == std::string::npos) { 11*da0073e9SAndroid Build Coastguard Worker filter.push_back('-'); 12*da0073e9SAndroid Build Coastguard Worker } else { 13*da0073e9SAndroid Build Coastguard Worker filter.push_back(':'); 14*da0073e9SAndroid Build Coastguard Worker } 15*da0073e9SAndroid Build Coastguard Worker filter += flag; 16*da0073e9SAndroid Build Coastguard Worker return filter; 17*da0073e9SAndroid Build Coastguard Worker } 18*da0073e9SAndroid Build Coastguard Worker main(int argc,char * argv[])19*da0073e9SAndroid Build Coastguard Workerint main(int argc, char* argv[]) { 20*da0073e9SAndroid Build Coastguard Worker ::testing::InitGoogleTest(&argc, argv); 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker if (!torch::cuda::is_available()) { 23*da0073e9SAndroid Build Coastguard Worker std::cout << "CUDA not available. Disabling CUDA and MultiCUDA tests" 24*da0073e9SAndroid Build Coastguard Worker << std::endl; 25*da0073e9SAndroid Build Coastguard Worker ::testing::GTEST_FLAG(filter) = add_negative_flag("*_CUDA:*_MultiCUDA"); 26*da0073e9SAndroid Build Coastguard Worker } else if (torch::cuda::device_count() < 2) { 27*da0073e9SAndroid Build Coastguard Worker std::cout << "Only one CUDA device detected. Disabling MultiCUDA tests" 28*da0073e9SAndroid Build Coastguard Worker << std::endl; 29*da0073e9SAndroid Build Coastguard Worker ::testing::GTEST_FLAG(filter) = add_negative_flag("*_MultiCUDA"); 30*da0073e9SAndroid Build Coastguard Worker } 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker return RUN_ALL_TESTS(); 33*da0073e9SAndroid Build Coastguard Worker } 34