1 #include <torch/csrc/jit/operator_upgraders/utils.h>
2
3 #include <caffe2/serialize/versions.h>
4 #include <torch/csrc/jit/operator_upgraders/version_map.h>
5 #include <iostream>
6 #include <optional>
7 #include <regex>
8 #include <string>
9 #include <vector>
10
11 namespace torch::jit {
12
findUpgrader(const std::vector<UpgraderEntry> & upgraders_for_schema,size_t current_version)13 std::optional<UpgraderEntry> findUpgrader(
14 const std::vector<UpgraderEntry>& upgraders_for_schema,
15 size_t current_version) {
16 // we want to find the entry which satisfies following two conditions:
17 // 1. the version entry must be greater than current_version
18 // 2. Among the version entries, we need to see if the current version
19 // is in the upgrader name range
20 auto pos = std::find_if(
21 upgraders_for_schema.begin(),
22 upgraders_for_schema.end(),
23 [current_version](const UpgraderEntry& entry) {
24 return entry.bumped_at_version > static_cast<int>(current_version);
25 });
26
27 if (pos != upgraders_for_schema.end()) {
28 return *pos;
29 }
30 return std::nullopt;
31 }
32
isOpCurrentBasedOnUpgraderEntries(const std::vector<UpgraderEntry> & upgraders_for_schema,size_t current_version)33 bool isOpCurrentBasedOnUpgraderEntries(
34 const std::vector<UpgraderEntry>& upgraders_for_schema,
35 size_t current_version) {
36 auto latest_update =
37 upgraders_for_schema[upgraders_for_schema.size() - 1].bumped_at_version;
38 if (latest_update > static_cast<int>(current_version)) {
39 return false;
40 }
41 return true;
42 }
43
isOpSymbolCurrent(const std::string & name,size_t current_version)44 bool isOpSymbolCurrent(const std::string& name, size_t current_version) {
45 auto it = get_operator_version_map().find(name);
46 if (it != get_operator_version_map().end()) {
47 return isOpCurrentBasedOnUpgraderEntries(it->second, current_version);
48 }
49 return true;
50 }
51
loadPossibleHistoricOps(const std::string & name,std::optional<size_t> version)52 std::vector<std::string> loadPossibleHistoricOps(
53 const std::string& name,
54 std::optional<size_t> version) {
55 std::vector<std::string> possibleSchemas;
56
57 if (!version.has_value()) {
58 return possibleSchemas;
59 }
60
61 for (const auto& entry : get_operator_version_map()) {
62 auto old_symbol_name = entry.first;
63 // strip off the overload name, if exist
64 auto base_name = old_symbol_name.substr(0, old_symbol_name.find('.'));
65 if (base_name == name) {
66 auto possibleUpgrader = findUpgrader(entry.second, version.value());
67 if (possibleUpgrader.has_value()) {
68 possibleSchemas.push_back(possibleUpgrader.value().old_schema);
69 }
70 }
71 }
72
73 return possibleSchemas;
74 }
75
getMaxOperatorVersion()76 uint64_t getMaxOperatorVersion() {
77 return caffe2::serialize::kProducedFileFormatVersion;
78 }
79
getUpgradersRangeForOp(const std::string & name)80 std::vector<UpgraderRange> getUpgradersRangeForOp(const std::string& name) {
81 std::vector<UpgraderRange> output;
82 auto it = get_operator_version_map().find(name);
83 if (it == get_operator_version_map().end()) {
84 return output;
85 }
86
87 output.reserve(it->second.size());
88 int cur_min = 0;
89 for (const auto& entry : it->second) {
90 int cur_max = entry.bumped_at_version - 1;
91 output.emplace_back(UpgraderRange{cur_min, cur_max});
92 cur_min = entry.bumped_at_version;
93 }
94 return output;
95 }
96
97 } // namespace torch::jit
98