xref: /aosp_15_r20/external/pytorch/test/cpp/jit/test_upgrader_utils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 #include <torch/csrc/jit/operator_upgraders/utils.h>
3 #include <torch/csrc/jit/operator_upgraders/version_map.h>
4 
5 #include <test/cpp/jit/test_utils.h>
6 
7 #include <vector>
8 
9 namespace torch {
10 namespace jit {
11 
TEST(UpgraderUtils,FindCorrectUpgrader)12 TEST(UpgraderUtils, FindCorrectUpgrader) {
13   std::vector<UpgraderEntry> dummy_entry = {
14       {4, "foo__0_3", "foo.bar()"},
15       {8, "foo__4_7", "foo.bar()"},
16   };
17 
18   auto upgrader_at_6 = findUpgrader(dummy_entry, 6);
19   EXPECT_TRUE(upgrader_at_6.has_value());
20   EXPECT_EQ(upgrader_at_6.value().upgrader_name, "foo__4_7");
21 
22   auto upgrader_at_1 = findUpgrader(dummy_entry, 1);
23   EXPECT_TRUE(upgrader_at_1.has_value());
24   EXPECT_EQ(upgrader_at_1.value().upgrader_name, "foo__0_3");
25 
26   auto upgrader_at_10 = findUpgrader(dummy_entry, 10);
27   EXPECT_TRUE(upgrader_at_1.has_value());
28   EXPECT_EQ(upgrader_at_1.value().upgrader_name, "foo__0_3");
29 }
30 
TEST(UpgraderUtils,IsVersionMapSorted)31 TEST(UpgraderUtils, IsVersionMapSorted) {
32   auto map = get_operator_version_map();
33   // tests if the each list of UpgraderEntry in the map is sorted by
34   // their bumped_at_version field.
35   for (const auto& entry : map) {
36     std::vector<int> versions;
37     for (const auto& el : entry.second) {
38       versions.push_back(el.bumped_at_version);
39     }
40     EXPECT_TRUE(std::is_sorted(versions.begin(), versions.end()));
41   }
42 }
43 
TEST(UpgraderUtils,FindIfOpIsCurrent)44 TEST(UpgraderUtils, FindIfOpIsCurrent) {
45   std::vector<UpgraderEntry> dummy_entry = {
46       {4, "foo__0_3", "foo.bar()"},
47       {8, "foo__4_7", "foo.bar()"},
48   };
49 
50   auto isCurrent = isOpCurrentBasedOnUpgraderEntries(dummy_entry, 6);
51   auto isCurrentV2 = isOpCurrentBasedOnUpgraderEntries(dummy_entry, 8);
52   EXPECT_FALSE(isCurrent);
53   EXPECT_TRUE(isCurrentV2);
54 
55   // symbol based look up
56   test_only_add_entry("foo", dummy_entry[0]);
57   test_only_add_entry("foo", dummy_entry[1]);
58   EXPECT_FALSE(isOpSymbolCurrent("foo", 6));
59   EXPECT_TRUE(isOpSymbolCurrent("foo", 8));
60   test_only_remove_entry("foo");
61 }
62 
TEST(UpgraderUtils,CanLoadHistoricOp)63 TEST(UpgraderUtils, CanLoadHistoricOp) {
64   std::vector<UpgraderEntry> dummy_entry = {
65       {4, "foo__0_3", "foo.bar()"},
66       {8, "foo__4_7", "foo.foo()"},
67   };
68 
69   std::vector<std::string> schemas = {"foo.bar()", "foo.foo()"};
70 
71   // symbol based look up
72   test_only_add_entry("old_op_not_exist.first", dummy_entry[0]);
73   test_only_add_entry("old_op_not_exist.second", dummy_entry[1]);
74 
75   auto oldSchemas = loadPossibleHistoricOps("old_op_not_exist", 2);
76   EXPECT_EQ(oldSchemas.size(), 2);
77   for (const auto& entry : oldSchemas) {
78     EXPECT_TRUE(
79         std::find(schemas.begin(), schemas.end(), entry) != schemas.end());
80   }
81 
82   auto oldSchemasWithCurrentVersion =
83       loadPossibleHistoricOps("old_op_not_exist", 9);
84   EXPECT_EQ(oldSchemasWithCurrentVersion.size(), 0);
85 
86   test_only_remove_entry("old_op_not_exist.first");
87   test_only_remove_entry("old_op_not_exist.first");
88 
89   // it is ok to have old schemas without overload
90   test_only_add_entry("old_op_not_exist_no_overload", dummy_entry[0]);
91   auto oldSchemasNoOverload =
92       loadPossibleHistoricOps("old_op_not_exist_no_overload", 2);
93   EXPECT_EQ(oldSchemasNoOverload.size(), 1);
94   EXPECT_EQ(oldSchemasNoOverload[0], "foo.bar()");
95   test_only_remove_entry("old_op_not_exist_no_overload");
96 }
97 
98 } // namespace jit
99 } // namespace torch
100