1 #include <caffe2/serialize/versions.h> 2 #include <gtest/gtest.h> 3 #include <torch/csrc/jit/api/module.h> 4 #include <torch/csrc/jit/operator_upgraders/upgraders.h> 5 #include <torch/csrc/jit/operator_upgraders/version_map.h> 6 #include <torch/csrc/jit/serialization/import.h> 7 8 #include <test/cpp/jit/test_utils.h> 9 10 namespace torch { 11 namespace jit { 12 13 // Basic tests to check if C++ torch::jit::load 14 // can load the upgraders fine 15 // TODO (tugsuu) add more tests TEST(UpgraderLoad,CanPopulateUpgradersGraph)16TEST(UpgraderLoad, CanPopulateUpgradersGraph) { 17 Module m("m"); 18 m.define(R"( 19 def forward(self, x: Tensor): 20 b = 5 21 return torch.div(x, b) 22 )"); 23 std::stringstream ms; 24 m.save(ms); 25 auto loaded_m = torch::jit::load(ms); 26 auto version_map = get_operator_version_map(); 27 auto upgraders = dump_upgraders_map(); 28 29 for (const auto& entry : version_map) { 30 auto list_of_upgraders_for_op = entry.second; 31 for (const auto& upgrader_entry : list_of_upgraders_for_op) { 32 EXPECT_TRUE( 33 upgraders.find(upgrader_entry.upgrader_name) != upgraders.end()); 34 } 35 } 36 37 auto test_graph = loaded_m.get_method("forward").graph(); 38 // should have saved with version 4, so it is still up to date 39 testing::FileCheck().check_count("aten::div", 1, true)->run(*test_graph); 40 } 41 42 } // namespace jit 43 } // namespace torch 44