xref: /aosp_15_r20/external/pytorch/test/cpp/common/main.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <gtest/gtest.h>
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <torch/cuda.h>
4*da0073e9SAndroid Build Coastguard Worker 
5*da0073e9SAndroid Build Coastguard Worker #include <iostream>
6*da0073e9SAndroid Build Coastguard Worker #include <string>
7*da0073e9SAndroid Build Coastguard Worker 
add_negative_flag(const std::string & flag)8*da0073e9SAndroid Build Coastguard Worker std::string add_negative_flag(const std::string& flag) {
9*da0073e9SAndroid Build Coastguard Worker   std::string filter = ::testing::GTEST_FLAG(filter);
10*da0073e9SAndroid Build Coastguard Worker   if (filter.find('-') == std::string::npos) {
11*da0073e9SAndroid Build Coastguard Worker     filter.push_back('-');
12*da0073e9SAndroid Build Coastguard Worker   } else {
13*da0073e9SAndroid Build Coastguard Worker     filter.push_back(':');
14*da0073e9SAndroid Build Coastguard Worker   }
15*da0073e9SAndroid Build Coastguard Worker   filter += flag;
16*da0073e9SAndroid Build Coastguard Worker   return filter;
17*da0073e9SAndroid Build Coastguard Worker }
18*da0073e9SAndroid Build Coastguard Worker 
main(int argc,char * argv[])19*da0073e9SAndroid Build Coastguard Worker int main(int argc, char* argv[]) {
20*da0073e9SAndroid Build Coastguard Worker   ::testing::InitGoogleTest(&argc, argv);
21*da0073e9SAndroid Build Coastguard Worker 
22*da0073e9SAndroid Build Coastguard Worker   if (!torch::cuda::is_available()) {
23*da0073e9SAndroid Build Coastguard Worker     std::cout << "CUDA not available. Disabling CUDA and MultiCUDA tests"
24*da0073e9SAndroid Build Coastguard Worker               << std::endl;
25*da0073e9SAndroid Build Coastguard Worker     ::testing::GTEST_FLAG(filter) = add_negative_flag("*_CUDA:*_MultiCUDA");
26*da0073e9SAndroid Build Coastguard Worker   } else if (torch::cuda::device_count() < 2) {
27*da0073e9SAndroid Build Coastguard Worker     std::cout << "Only one CUDA device detected. Disabling MultiCUDA tests"
28*da0073e9SAndroid Build Coastguard Worker               << std::endl;
29*da0073e9SAndroid Build Coastguard Worker     ::testing::GTEST_FLAG(filter) = add_negative_flag("*_MultiCUDA");
30*da0073e9SAndroid Build Coastguard Worker   }
31*da0073e9SAndroid Build Coastguard Worker 
32*da0073e9SAndroid Build Coastguard Worker   return RUN_ALL_TESTS();
33*da0073e9SAndroid Build Coastguard Worker }
34