xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/compatibility/model_compatibility.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/ivalue.h>
2 #include <caffe2/serialize/file_adapter.h>
3 #include <caffe2/serialize/inline_container.h>
4 #include <torch/csrc/jit/api/compilation_unit.h> // removed after using simple type_resolver/obj_loader
5 #include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
6 #include <torch/csrc/jit/mobile/file_format.h>
7 #include <torch/csrc/jit/mobile/flatbuffer_loader.h>
8 #include <torch/csrc/jit/mobile/import.h> // removed after using simple type_resolver/obj_loader
9 #include <torch/csrc/jit/mobile/type_parser.h>
10 #include <torch/csrc/jit/serialization/import_export_constants.h>
11 #include <torch/csrc/jit/serialization/import_read.h>
12 
13 #include <caffe2/serialize/in_memory_adapter.h>
14 #include <sstream>
15 #include <string>
16 #include <unordered_set>
17 #include <vector>
18 
19 namespace c10 {
20 TypePtr parseType(const std::string& pythonStr);
21 } // namespace c10
22 
23 namespace torch::jit {
24 
25 using caffe2::serialize::FileAdapter;
26 using caffe2::serialize::IStreamAdapter;
27 using caffe2::serialize::PyTorchStreamReader;
28 using caffe2::serialize::ReadAdapterInterface;
29 
readArchive(const std::string & archive_name,PyTorchStreamReader & stream_reader)30 c10::IValue readArchive(
31     const std::string& archive_name,
32     PyTorchStreamReader& stream_reader) {
33   std::optional<at::Device> device;
34   std::shared_ptr<CompilationUnit> compilation_unit =
35       std::make_shared<CompilationUnit>();
36 
37   // TODO (T90180710): Simplify type_resolver and obj_loader when getting
38   // bytecode version from model
39   auto type_resolver = [&](const c10::QualifiedName& qn) {
40     return typeResolverMobile(qn, compilation_unit);
41   };
42 
43   std::shared_ptr<mobile::CompilationUnit> mobile_compilation_unit =
44       std::make_shared<mobile::CompilationUnit>();
45   auto obj_loader = [&](const at::StrongTypePtr& type, const IValue& input) {
46     return objLoaderMobile(type, input, *mobile_compilation_unit);
47   };
48   bool bytecode_tensor_in_constants_archive =
49       (archive_name == "bytecode" && !isTensorInBytecodeArchive(stream_reader));
50   auto ivalues = torch::jit::readArchiveAndTensors(
51       archive_name,
52       /*pickle_prefix=*/"",
53       /*tensor_prefix=*/
54       bytecode_tensor_in_constants_archive ? "constants/" : "",
55       type_resolver,
56       obj_loader,
57       device,
58       stream_reader,
59       nullptr);
60   return ivalues;
61 }
62 
get_bytecode_ivalues(PyTorchStreamReader & reader)63 std::vector<IValue> get_bytecode_ivalues(PyTorchStreamReader& reader) {
64   return std::move(*readArchive("bytecode", reader).toTuple()).elements().vec();
65 }
66 
67 /********************** Bytecode **********************/
68 
69 // Forward declare
70 uint64_t _get_model_bytecode_version(
71     const std::vector<IValue>& bytecode_ivalues);
72 static uint64_t _get_model_bytecode_version_from_bytes(char* data, size_t size);
73 
_get_model_bytecode_version(std::istream & in)74 uint64_t _get_model_bytecode_version(std::istream& in) {
75   auto orig_pos = in.tellg();
76   in.seekg(0, in.beg);
77   auto [data, size] = get_stream_content(in);
78   in.seekg(orig_pos, in.beg);
79   return _get_model_bytecode_version_from_bytes(data.get(), size);
80 }
81 
_get_model_bytecode_version(const std::string & filename)82 uint64_t _get_model_bytecode_version(const std::string& filename) {
83   std::ifstream ifile(filename);
84   return _get_model_bytecode_version(ifile);
85 }
86 
_get_model_bytecode_version(const std::shared_ptr<ReadAdapterInterface> & rai)87 uint64_t _get_model_bytecode_version(
88     const std::shared_ptr<ReadAdapterInterface>& rai) {
89   auto [data, size] = get_rai_content(rai.get());
90   return _get_model_bytecode_version_from_bytes(data.get(), size);
91 }
92 
_get_model_bytecode_version_zip(std::shared_ptr<ReadAdapterInterface> rai)93 static uint64_t _get_model_bytecode_version_zip(
94     std::shared_ptr<ReadAdapterInterface> rai) {
95   if (!check_zip_file(rai)) {
96     TORCH_CHECK(
97         false,
98         "Failed to open .ptl file please ensure the model was exported for mobile");
99   }
100   PyTorchStreamReader reader(std::move(rai));
101   auto bytecode_values = get_bytecode_ivalues(reader);
102   return _get_model_bytecode_version(bytecode_values);
103 }
104 
_get_model_bytecode_version_from_bytes(char * data,size_t size)105 uint64_t _get_model_bytecode_version_from_bytes(char* data, size_t size) {
106   TORCH_CHECK(data != nullptr, "Pointer to bytes is null.");
107   TORCH_CHECK(size >= kFileFormatHeaderSize, "Unrecognized data format");
108   auto format = getFileFormat(data);
109   switch (format) {
110     case FileFormat::FlatbufferFileFormat: {
111       return get_bytecode_version_from_bytes(data);
112     }
113     case FileFormat::ZipFileFormat: {
114       auto rai =
115           std::make_unique<caffe2::serialize::MemoryReadAdapter>(data, size);
116       auto version = _get_model_bytecode_version_zip(std::move(rai));
117       return version;
118     }
119 
120     default:
121       TORCH_CHECK(false, "Unrecognized data format");
122   }
123 }
124 
_get_model_bytecode_version(const std::vector<IValue> & bytecode_ivalues)125 uint64_t _get_model_bytecode_version(
126     const std::vector<IValue>& bytecode_ivalues) {
127   if (!bytecode_ivalues.empty() && bytecode_ivalues[0].isInt()) {
128     int64_t model_version = bytecode_ivalues[0].toInt();
129     TORCH_CHECK(
130         model_version > 0,
131         "Expected model bytecode version > 0 got ",
132         model_version);
133     return static_cast<uint64_t>(model_version);
134   }
135   TORCH_CHECK(false, "Failed to get bytecode version.");
136 }
137 
138 /********************** Operator Version **********************/
139 
140 uint64_t _get_model_operator_version(
141     PyTorchStreamReader& reader); // Forward Declare
142 
_get_model_operator_version(std::istream & in)143 uint64_t _get_model_operator_version(std::istream& in) {
144   std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
145   return _get_model_operator_version(std::move(rai));
146 }
147 
_get_model_operator_version(const std::string & filename)148 uint64_t _get_model_operator_version(const std::string& filename) {
149   std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
150   return _get_model_operator_version(std::move(rai));
151 }
152 
_get_model_operator_version(std::shared_ptr<ReadAdapterInterface> rai)153 uint64_t _get_model_operator_version(
154     std::shared_ptr<ReadAdapterInterface> rai) {
155   if (!check_zip_file(rai)) {
156     TORCH_CHECK(
157         false,
158         "Failed to open .ptl file please ensure the model was exported for mobile");
159   }
160   PyTorchStreamReader reader(std::move(rai));
161   return _get_model_operator_version(reader);
162 }
163 
_get_model_operator_version(PyTorchStreamReader & reader)164 uint64_t _get_model_operator_version(PyTorchStreamReader& reader) {
165   return reader.version();
166 }
167 
168 /********************** Operators and Info **********************/
169 
170 // Forward declare
171 std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
172     std::vector<IValue> bytecode_ivalues);
173 
_get_model_ops_and_info(std::istream & in)174 std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
175     std::istream& in) {
176   std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
177   return _get_model_ops_and_info(std::move(rai));
178 }
179 
_get_model_ops_and_info(const std::string & filename)180 std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
181     const std::string& filename) {
182   std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
183   return _get_model_ops_and_info(std::move(rai));
184 }
185 
_get_model_ops_and_info(std::shared_ptr<ReadAdapterInterface> rai)186 std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
187     std::shared_ptr<ReadAdapterInterface> rai) {
188   if (!check_zip_file(rai)) {
189     TORCH_WARN("Failed to open zip file for model ops.");
190     return std::unordered_map<std::string, OperatorInfo>{};
191   }
192   PyTorchStreamReader reader(std::move(rai));
193   auto bytecode_values = get_bytecode_ivalues(reader);
194   return _get_model_ops_and_info(bytecode_values);
195 }
196 
197 /* A function to retrieve the root (top level) operators of a model and their
198  * corresponding compatibility info. These root operators can call other
199  * operators within them (traced ops), and a root op can call many different
200  * traced ops depending on internal code paths in the root op. These traced ops
201  * are not returned by this function. Those operators are abstracted into the
202  * runtime as an implementation detail (and the traced ops themselves can also
203  * call other operators) making retrieving them difficult and their value from
204  * this api negligible since they will differ between which runtime version the
205  * model is run on. Because of this, there is a false positive this api can't
206  * prevent in a compatibility usecase. All the root ops of a model are present
207  * in a target runtime, but not all the traced ops are which prevents a model
208  * from being able to run.
209  **/
_get_model_ops_and_info(std::vector<IValue> bytecode_ivalues)210 std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
211     std::vector<IValue> bytecode_ivalues) {
212   constexpr uint64_t min_version_with_schema = 6;
213   if (_get_model_bytecode_version(bytecode_ivalues) < min_version_with_schema) {
214     TORCH_WARN(
215         "Only models with bytecode version 6 and above contain operator schema information. Please re-export your model to generate it");
216   }
217   std::unordered_map<std::string, OperatorInfo> result;
218   if (bytecode_ivalues.empty()) {
219     TORCH_WARN("Failed to get model ops and info.");
220     return result;
221   }
222   // loop over all the functions in the bytecode
223   for (const auto i : c10::irange(1, bytecode_ivalues.size())) {
224     // descend to the operators list
225     const auto& method_tuple = bytecode_ivalues.at(i).toTupleRef().elements();
226     auto operators_tuple = method_tuple.at(1).toTupleRef().elements()[1];
227     auto operators = operators_tuple.toTupleRef().elements()[1];
228     for (auto& op_tuple : operators.toTupleRef().elements()) {
229       const auto& op = op_tuple.toTupleRef().elements();
230 
231       // grab name
232       std::string op_name = op.at(0).toStringRef();
233       std::string op_overload_name = op.at(1).toStringRef();
234       if (!op_overload_name.empty()) {
235         op_name.append(".");
236         op_name.append(op_overload_name);
237       }
238 
239       // grab schema size
240       if (op.size() > 2) {
241         result.emplace(op_name, OperatorInfo{(int)op.at(2).toInt()});
242       } else { // no schema information use default
243         result.emplace(op_name, OperatorInfo{});
244       }
245     }
246   }
247   return result;
248 }
249 
250 /********************** Get Type Table **********************/
251 
252 // Forward declare
253 std::unordered_set<std::string> _get_mobile_model_contained_types(
254     const std::vector<IValue>& bytecode_ivalues);
255 
_get_mobile_model_contained_types(std::istream & in)256 std::unordered_set<std::string> _get_mobile_model_contained_types(
257     std::istream& in) {
258   std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
259   return _get_mobile_model_contained_types(std::move(rai));
260 }
261 
_get_mobile_model_contained_types(const std::string & filename)262 std::unordered_set<std::string> _get_mobile_model_contained_types(
263     const std::string& filename) {
264   std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
265   return _get_mobile_model_contained_types(std::move(rai));
266 }
267 
_get_mobile_model_contained_types(std::shared_ptr<ReadAdapterInterface> rai)268 std::unordered_set<std::string> _get_mobile_model_contained_types(
269     std::shared_ptr<ReadAdapterInterface> rai) {
270   if (!check_zip_file(rai)) {
271     TORCH_CHECK(
272         false,
273         "Failed to open .ptl file please ensure the model was exported for mobile");
274   }
275   PyTorchStreamReader reader(std::move(rai));
276   auto bytecode_values = get_bytecode_ivalues(reader);
277   return _get_mobile_model_contained_types(bytecode_values);
278 }
279 
280 // Get deduplicate type table given bytecode, and each string is a atomic type,
281 // like str, Tensor and etc. For example,
282 // input: "Dict[int, Tuple[Tensor, Tensor, Tensor]]"
283 // output: {Dict, int, Tuple, Tensor}
_get_mobile_model_contained_types(const std::vector<IValue> & bytecode_ivalues)284 std::unordered_set<std::string> _get_mobile_model_contained_types(
285     const std::vector<IValue>& bytecode_ivalues) {
286   std::unordered_set<std::string> contained_types;
287   // To avoid parsing same type twice, declare $parsed_type_names_records and
288   // use type name (string, ex: "Dict[int, Tuple[Tensor, Tensor, Tensor]]") as
289   // the hash to record which types are parsed.
290   std::unordered_set<std::string> parsed_type_names_records;
291   for (const auto i : c10::irange(1, bytecode_ivalues.size())) {
292     const auto& method_tuple = bytecode_ivalues.at(i).toTupleRef().elements();
293     auto type_table_tuple =
294         method_tuple.at(1).toTupleRef().elements()[BYTECODE_INDEX_TYPE];
295     const auto& type_table =
296         type_table_tuple.toTupleRef().elements()[1].toTupleRef().elements();
297 
298     // type_table is a list of IValue, and each IValue is a string,
299     // for example: "Dict[int, Tuple[Tensor, Tensor, Tensor]]"
300     std::vector<std::string> type_name_list;
301     for (const auto& type_definition : type_table) {
302       std::unordered_set<std::string> type_tokens;
303       std::string type_name = type_definition.toStringRef();
304       type_name_list.emplace_back(type_name);
305     }
306     at::TypeParser parser(type_name_list);
307     parser.parseList();
308     contained_types = parser.getContainedTypes();
309   }
310 
311   return contained_types;
312 }
313 
314 /********************** Compatibility Checker **********************/
315 
get(std::istream & in)316 ModelCompatibilityInfo ModelCompatibilityInfo::get(std::istream& in) {
317   std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
318   return get(std::move(rai));
319 }
320 
get(const std::string & filename)321 ModelCompatibilityInfo ModelCompatibilityInfo::get(
322     const std::string& filename) {
323   std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
324   return get(std::move(rai));
325 }
326 
get(std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai)327 ModelCompatibilityInfo ModelCompatibilityInfo::get(
328     std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai) {
329   if (!check_zip_file(rai)) {
330     TORCH_CHECK(
331         false, "Failed to open zip file for model compatibility information");
332   }
333   PyTorchStreamReader reader(std::move(rai));
334   std::vector<IValue> bytecode_values = get_bytecode_ivalues(reader);
335   uint64_t model_bytecode_version =
336       _get_model_bytecode_version(bytecode_values);
337   auto model_info = _get_model_ops_and_info(bytecode_values);
338   std::unordered_set<std::string> type_table =
339       _get_mobile_model_contained_types(bytecode_values);
340   uint64_t operator_version = _get_model_operator_version(reader);
341   return ModelCompatibilityInfo{
342       model_bytecode_version, model_info, type_table, operator_version};
343 }
344 
is_compatible(RuntimeCompatibilityInfo runtime_info,const ModelCompatibilityInfo & model_info)345 ModelCompatCheckResult is_compatible(
346     RuntimeCompatibilityInfo runtime_info,
347     const ModelCompatibilityInfo& model_info) {
348   ModelCompatCheckResult result = {ModelCompatibilityStatus::OK, {}};
349   // Check that the models bytecode version is less than or equal to
350   // kMaxSupportedBytecodeVersion from the runtime
351   if (model_info.bytecode_version >
352       runtime_info.min_max_supported_bytecode_version.second) {
353     result.status = ModelCompatibilityStatus::ERROR;
354     std::ostringstream s;
355     s << "model bytecode version " << model_info.bytecode_version
356       << "is greater than the max supported bytecode version in runtimes "
357       << runtime_info.min_max_supported_bytecode_version.second;
358     result.errors.emplace_back(s.str());
359   } else if (
360       model_info.bytecode_version <
361       runtime_info.min_max_supported_bytecode_version.first) {
362     result.status = ModelCompatibilityStatus::ERROR;
363     std::ostringstream s;
364     s << "model bytecode version " << model_info.bytecode_version
365       << "is less than the minimum supported bytecode version in runtime "
366       << runtime_info.min_max_supported_bytecode_version.first;
367     result.errors.emplace_back(s.str());
368   }
369 
370   std::unordered_set<std::string> supported_type = runtime_info.supported_types;
371 
372   // Check type table
373   for (const auto& type_name : model_info.type_table) {
374     if (supported_type.find(type_name) == supported_type.end()) {
375       result.status = ModelCompatibilityStatus::ERROR;
376       std::ostringstream s;
377       s << "Primitive type: '" << type_name
378         << "' is not supported in current runtime";
379       result.errors.push_back(s.str());
380     }
381   }
382 
383   // Check operators
384   std::unordered_map<std::string, OperatorInfo> operator_info =
385       model_info.operator_info;
386   for (auto const& op : operator_info) {
387     std::string op_name = op.first;
388     OperatorInfo model_op_info = op.second;
389 
390     // Check if operator not present in runtime
391     if (runtime_info.operator_info.find(op_name) ==
392         runtime_info.operator_info.end()) {
393       result.status = ModelCompatibilityStatus::ERROR;
394       std::ostringstream s;
395       s << "Operator '" << op_name << "' missing from runtime (not found)";
396       result.errors.push_back(s.str());
397     } else {
398       OperatorInfo runtime_op_info = runtime_info.operator_info.at(op_name);
399 
400       // If the runtime op has no schema information its a false alarm and isn't
401       // actually useable
402       if (!runtime_op_info.num_schema_args.has_value()) {
403         result.status = ModelCompatibilityStatus::ERROR;
404         std::ostringstream s;
405         s << "Operator '" << op_name
406           << "' missing from runtime (missing schema)";
407         result.errors.push_back(s.str());
408       } else {
409         // Check if the model operator has schema information. If it doesn't
410         // then the model is from a bytecode version < 6 and we are done. If the
411         // model has more args than the runtime, then the runtime can't know
412         // what to do so we aren't compatible. If the runtime has more args than
413         // the model then we can just use default values and be fine.
414         if (model_op_info.num_schema_args.has_value() &&
415             (model_op_info.num_schema_args.value() >
416              runtime_op_info.num_schema_args.value())) {
417           result.status = ModelCompatibilityStatus::ERROR;
418           std::ostringstream s;
419           s << "Operator schema for'" << op_name << "' has "
420             << model_op_info.num_schema_args.value()
421             << " args in model but only "
422             << runtime_op_info.num_schema_args.value() << " in the runtime";
423           result.errors.push_back(s.str());
424         }
425       }
426     }
427   }
428 
429   // Check Operator Versions
430   if (model_info.operator_version <
431           runtime_info.min_max_supported_opperator_versions.first ||
432       model_info.operator_version >
433           runtime_info.min_max_supported_opperator_versions.second) {
434     result.status = ModelCompatibilityStatus::ERROR;
435     std::ostringstream s;
436     s << "Model Operator Version " << model_info.operator_version
437       << "is not within supported version range of the runtime "
438       << runtime_info.min_max_supported_opperator_versions.first << " to "
439       << runtime_info.min_max_supported_opperator_versions.second;
440     result.errors.push_back(s.str());
441   }
442 
443   return result;
444 }
445 
446 } // namespace torch::jit
447