xref: /aosp_15_r20/external/pytorch/test/cpp/lite_interpreter_runtime/main.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/ivalue.h>
2 #include <gtest/gtest.h>
3 #include <torch/csrc/autograd/generated/variable_factories.h>
4 #include <torch/csrc/jit/mobile/import.h>
5 #include <iostream>
6 #include <string>
7 
add_negative_flag(const std::string & flag)8 std::string add_negative_flag(const std::string& flag) {
9   std::string filter = ::testing::GTEST_FLAG(filter);
10   if (filter.find('-') == std::string::npos) {
11     filter.push_back('-');
12   } else {
13     filter.push_back(':');
14   }
15   filter += flag;
16   return filter;
17 }
main(int argc,char * argv[])18 int main(int argc, char* argv[]) {
19   ::testing::InitGoogleTest(&argc, argv);
20   ::testing::GTEST_FLAG(filter) = add_negative_flag("*_CUDA:*_MultiCUDA");
21 
22   return RUN_ALL_TESTS();
23 }
24