xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/operator_upgraders/utils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/operator_upgraders/utils.h>
2 
3 #include <caffe2/serialize/versions.h>
4 #include <torch/csrc/jit/operator_upgraders/version_map.h>
5 #include <iostream>
6 #include <optional>
7 #include <regex>
8 #include <string>
9 #include <vector>
10 
11 namespace torch::jit {
12 
findUpgrader(const std::vector<UpgraderEntry> & upgraders_for_schema,size_t current_version)13 std::optional<UpgraderEntry> findUpgrader(
14     const std::vector<UpgraderEntry>& upgraders_for_schema,
15     size_t current_version) {
16   // we want to find the entry which satisfies following two conditions:
17   //    1. the version entry must be greater than current_version
18   //    2. Among the version entries, we need to see if the current version
19   //       is in the upgrader name range
20   auto pos = std::find_if(
21       upgraders_for_schema.begin(),
22       upgraders_for_schema.end(),
23       [current_version](const UpgraderEntry& entry) {
24         return entry.bumped_at_version > static_cast<int>(current_version);
25       });
26 
27   if (pos != upgraders_for_schema.end()) {
28     return *pos;
29   }
30   return std::nullopt;
31 }
32 
isOpCurrentBasedOnUpgraderEntries(const std::vector<UpgraderEntry> & upgraders_for_schema,size_t current_version)33 bool isOpCurrentBasedOnUpgraderEntries(
34     const std::vector<UpgraderEntry>& upgraders_for_schema,
35     size_t current_version) {
36   auto latest_update =
37       upgraders_for_schema[upgraders_for_schema.size() - 1].bumped_at_version;
38   if (latest_update > static_cast<int>(current_version)) {
39     return false;
40   }
41   return true;
42 }
43 
isOpSymbolCurrent(const std::string & name,size_t current_version)44 bool isOpSymbolCurrent(const std::string& name, size_t current_version) {
45   auto it = get_operator_version_map().find(name);
46   if (it != get_operator_version_map().end()) {
47     return isOpCurrentBasedOnUpgraderEntries(it->second, current_version);
48   }
49   return true;
50 }
51 
loadPossibleHistoricOps(const std::string & name,std::optional<size_t> version)52 std::vector<std::string> loadPossibleHistoricOps(
53     const std::string& name,
54     std::optional<size_t> version) {
55   std::vector<std::string> possibleSchemas;
56 
57   if (!version.has_value()) {
58     return possibleSchemas;
59   }
60 
61   for (const auto& entry : get_operator_version_map()) {
62     auto old_symbol_name = entry.first;
63     // strip off the overload name, if exist
64     auto base_name = old_symbol_name.substr(0, old_symbol_name.find('.'));
65     if (base_name == name) {
66       auto possibleUpgrader = findUpgrader(entry.second, version.value());
67       if (possibleUpgrader.has_value()) {
68         possibleSchemas.push_back(possibleUpgrader.value().old_schema);
69       }
70     }
71   }
72 
73   return possibleSchemas;
74 }
75 
getMaxOperatorVersion()76 uint64_t getMaxOperatorVersion() {
77   return caffe2::serialize::kProducedFileFormatVersion;
78 }
79 
getUpgradersRangeForOp(const std::string & name)80 std::vector<UpgraderRange> getUpgradersRangeForOp(const std::string& name) {
81   std::vector<UpgraderRange> output;
82   auto it = get_operator_version_map().find(name);
83   if (it == get_operator_version_map().end()) {
84     return output;
85   }
86 
87   output.reserve(it->second.size());
88   int cur_min = 0;
89   for (const auto& entry : it->second) {
90     int cur_max = entry.bumped_at_version - 1;
91     output.emplace_back(UpgraderRange{cur_min, cur_max});
92     cur_min = entry.bumped_at_version;
93   }
94   return output;
95 }
96 
97 } // namespace torch::jit
98