1 #include <ATen/core/ivalue.h>
2 #include <caffe2/serialize/file_adapter.h>
3 #include <caffe2/serialize/inline_container.h>
4 #include <torch/csrc/jit/api/compilation_unit.h> // removed after using simple type_resolver/obj_loader
5 #include <torch/csrc/jit/mobile/compatibility/model_compatibility.h>
6 #include <torch/csrc/jit/mobile/file_format.h>
7 #include <torch/csrc/jit/mobile/flatbuffer_loader.h>
8 #include <torch/csrc/jit/mobile/import.h> // removed after using simple type_resolver/obj_loader
9 #include <torch/csrc/jit/mobile/type_parser.h>
10 #include <torch/csrc/jit/serialization/import_export_constants.h>
11 #include <torch/csrc/jit/serialization/import_read.h>
12
13 #include <caffe2/serialize/in_memory_adapter.h>
14 #include <sstream>
15 #include <string>
16 #include <unordered_set>
17 #include <vector>
18
19 namespace c10 {
20 TypePtr parseType(const std::string& pythonStr);
21 } // namespace c10
22
23 namespace torch::jit {
24
25 using caffe2::serialize::FileAdapter;
26 using caffe2::serialize::IStreamAdapter;
27 using caffe2::serialize::PyTorchStreamReader;
28 using caffe2::serialize::ReadAdapterInterface;
29
readArchive(const std::string & archive_name,PyTorchStreamReader & stream_reader)30 c10::IValue readArchive(
31 const std::string& archive_name,
32 PyTorchStreamReader& stream_reader) {
33 std::optional<at::Device> device;
34 std::shared_ptr<CompilationUnit> compilation_unit =
35 std::make_shared<CompilationUnit>();
36
37 // TODO (T90180710): Simplify type_resolver and obj_loader when getting
38 // bytecode version from model
39 auto type_resolver = [&](const c10::QualifiedName& qn) {
40 return typeResolverMobile(qn, compilation_unit);
41 };
42
43 std::shared_ptr<mobile::CompilationUnit> mobile_compilation_unit =
44 std::make_shared<mobile::CompilationUnit>();
45 auto obj_loader = [&](const at::StrongTypePtr& type, const IValue& input) {
46 return objLoaderMobile(type, input, *mobile_compilation_unit);
47 };
48 bool bytecode_tensor_in_constants_archive =
49 (archive_name == "bytecode" && !isTensorInBytecodeArchive(stream_reader));
50 auto ivalues = torch::jit::readArchiveAndTensors(
51 archive_name,
52 /*pickle_prefix=*/"",
53 /*tensor_prefix=*/
54 bytecode_tensor_in_constants_archive ? "constants/" : "",
55 type_resolver,
56 obj_loader,
57 device,
58 stream_reader,
59 nullptr);
60 return ivalues;
61 }
62
get_bytecode_ivalues(PyTorchStreamReader & reader)63 std::vector<IValue> get_bytecode_ivalues(PyTorchStreamReader& reader) {
64 return std::move(*readArchive("bytecode", reader).toTuple()).elements().vec();
65 }
66
67 /********************** Bytecode **********************/
68
69 // Forward declare
70 uint64_t _get_model_bytecode_version(
71 const std::vector<IValue>& bytecode_ivalues);
72 static uint64_t _get_model_bytecode_version_from_bytes(char* data, size_t size);
73
_get_model_bytecode_version(std::istream & in)74 uint64_t _get_model_bytecode_version(std::istream& in) {
75 auto orig_pos = in.tellg();
76 in.seekg(0, in.beg);
77 auto [data, size] = get_stream_content(in);
78 in.seekg(orig_pos, in.beg);
79 return _get_model_bytecode_version_from_bytes(data.get(), size);
80 }
81
_get_model_bytecode_version(const std::string & filename)82 uint64_t _get_model_bytecode_version(const std::string& filename) {
83 std::ifstream ifile(filename);
84 return _get_model_bytecode_version(ifile);
85 }
86
_get_model_bytecode_version(const std::shared_ptr<ReadAdapterInterface> & rai)87 uint64_t _get_model_bytecode_version(
88 const std::shared_ptr<ReadAdapterInterface>& rai) {
89 auto [data, size] = get_rai_content(rai.get());
90 return _get_model_bytecode_version_from_bytes(data.get(), size);
91 }
92
_get_model_bytecode_version_zip(std::shared_ptr<ReadAdapterInterface> rai)93 static uint64_t _get_model_bytecode_version_zip(
94 std::shared_ptr<ReadAdapterInterface> rai) {
95 if (!check_zip_file(rai)) {
96 TORCH_CHECK(
97 false,
98 "Failed to open .ptl file please ensure the model was exported for mobile");
99 }
100 PyTorchStreamReader reader(std::move(rai));
101 auto bytecode_values = get_bytecode_ivalues(reader);
102 return _get_model_bytecode_version(bytecode_values);
103 }
104
_get_model_bytecode_version_from_bytes(char * data,size_t size)105 uint64_t _get_model_bytecode_version_from_bytes(char* data, size_t size) {
106 TORCH_CHECK(data != nullptr, "Pointer to bytes is null.");
107 TORCH_CHECK(size >= kFileFormatHeaderSize, "Unrecognized data format");
108 auto format = getFileFormat(data);
109 switch (format) {
110 case FileFormat::FlatbufferFileFormat: {
111 return get_bytecode_version_from_bytes(data);
112 }
113 case FileFormat::ZipFileFormat: {
114 auto rai =
115 std::make_unique<caffe2::serialize::MemoryReadAdapter>(data, size);
116 auto version = _get_model_bytecode_version_zip(std::move(rai));
117 return version;
118 }
119
120 default:
121 TORCH_CHECK(false, "Unrecognized data format");
122 }
123 }
124
_get_model_bytecode_version(const std::vector<IValue> & bytecode_ivalues)125 uint64_t _get_model_bytecode_version(
126 const std::vector<IValue>& bytecode_ivalues) {
127 if (!bytecode_ivalues.empty() && bytecode_ivalues[0].isInt()) {
128 int64_t model_version = bytecode_ivalues[0].toInt();
129 TORCH_CHECK(
130 model_version > 0,
131 "Expected model bytecode version > 0 got ",
132 model_version);
133 return static_cast<uint64_t>(model_version);
134 }
135 TORCH_CHECK(false, "Failed to get bytecode version.");
136 }
137
138 /********************** Operator Version **********************/
139
140 uint64_t _get_model_operator_version(
141 PyTorchStreamReader& reader); // Forward Declare
142
_get_model_operator_version(std::istream & in)143 uint64_t _get_model_operator_version(std::istream& in) {
144 std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
145 return _get_model_operator_version(std::move(rai));
146 }
147
_get_model_operator_version(const std::string & filename)148 uint64_t _get_model_operator_version(const std::string& filename) {
149 std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
150 return _get_model_operator_version(std::move(rai));
151 }
152
_get_model_operator_version(std::shared_ptr<ReadAdapterInterface> rai)153 uint64_t _get_model_operator_version(
154 std::shared_ptr<ReadAdapterInterface> rai) {
155 if (!check_zip_file(rai)) {
156 TORCH_CHECK(
157 false,
158 "Failed to open .ptl file please ensure the model was exported for mobile");
159 }
160 PyTorchStreamReader reader(std::move(rai));
161 return _get_model_operator_version(reader);
162 }
163
_get_model_operator_version(PyTorchStreamReader & reader)164 uint64_t _get_model_operator_version(PyTorchStreamReader& reader) {
165 return reader.version();
166 }
167
168 /********************** Operators and Info **********************/
169
170 // Forward declare
171 std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
172 std::vector<IValue> bytecode_ivalues);
173
_get_model_ops_and_info(std::istream & in)174 std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
175 std::istream& in) {
176 std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
177 return _get_model_ops_and_info(std::move(rai));
178 }
179
_get_model_ops_and_info(const std::string & filename)180 std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
181 const std::string& filename) {
182 std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
183 return _get_model_ops_and_info(std::move(rai));
184 }
185
_get_model_ops_and_info(std::shared_ptr<ReadAdapterInterface> rai)186 std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
187 std::shared_ptr<ReadAdapterInterface> rai) {
188 if (!check_zip_file(rai)) {
189 TORCH_WARN("Failed to open zip file for model ops.");
190 return std::unordered_map<std::string, OperatorInfo>{};
191 }
192 PyTorchStreamReader reader(std::move(rai));
193 auto bytecode_values = get_bytecode_ivalues(reader);
194 return _get_model_ops_and_info(bytecode_values);
195 }
196
197 /* A function to retrieve the root (top level) operators of a model and their
198 * corresponding compatibility info. These root operators can call other
199 * operators within them (traced ops), and a root op can call many different
200 * traced ops depending on internal code paths in the root op. These traced ops
201 * are not returned by this function. Those operators are abstracted into the
202 * runtime as an implementation detail (and the traced ops themselves can also
203 * call other operators) making retrieving them difficult and their value from
204 * this api negligible since they will differ between which runtime version the
205 * model is run on. Because of this, there is a false positive this api can't
206 * prevent in a compatibility usecase. All the root ops of a model are present
207 * in a target runtime, but not all the traced ops are which prevents a model
208 * from being able to run.
209 **/
_get_model_ops_and_info(std::vector<IValue> bytecode_ivalues)210 std::unordered_map<std::string, OperatorInfo> _get_model_ops_and_info(
211 std::vector<IValue> bytecode_ivalues) {
212 constexpr uint64_t min_version_with_schema = 6;
213 if (_get_model_bytecode_version(bytecode_ivalues) < min_version_with_schema) {
214 TORCH_WARN(
215 "Only models with bytecode version 6 and above contain operator schema information. Please re-export your model to generate it");
216 }
217 std::unordered_map<std::string, OperatorInfo> result;
218 if (bytecode_ivalues.empty()) {
219 TORCH_WARN("Failed to get model ops and info.");
220 return result;
221 }
222 // loop over all the functions in the bytecode
223 for (const auto i : c10::irange(1, bytecode_ivalues.size())) {
224 // descend to the operators list
225 const auto& method_tuple = bytecode_ivalues.at(i).toTupleRef().elements();
226 auto operators_tuple = method_tuple.at(1).toTupleRef().elements()[1];
227 auto operators = operators_tuple.toTupleRef().elements()[1];
228 for (auto& op_tuple : operators.toTupleRef().elements()) {
229 const auto& op = op_tuple.toTupleRef().elements();
230
231 // grab name
232 std::string op_name = op.at(0).toStringRef();
233 std::string op_overload_name = op.at(1).toStringRef();
234 if (!op_overload_name.empty()) {
235 op_name.append(".");
236 op_name.append(op_overload_name);
237 }
238
239 // grab schema size
240 if (op.size() > 2) {
241 result.emplace(op_name, OperatorInfo{(int)op.at(2).toInt()});
242 } else { // no schema information use default
243 result.emplace(op_name, OperatorInfo{});
244 }
245 }
246 }
247 return result;
248 }
249
250 /********************** Get Type Table **********************/
251
252 // Forward declare
253 std::unordered_set<std::string> _get_mobile_model_contained_types(
254 const std::vector<IValue>& bytecode_ivalues);
255
_get_mobile_model_contained_types(std::istream & in)256 std::unordered_set<std::string> _get_mobile_model_contained_types(
257 std::istream& in) {
258 std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
259 return _get_mobile_model_contained_types(std::move(rai));
260 }
261
_get_mobile_model_contained_types(const std::string & filename)262 std::unordered_set<std::string> _get_mobile_model_contained_types(
263 const std::string& filename) {
264 std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
265 return _get_mobile_model_contained_types(std::move(rai));
266 }
267
_get_mobile_model_contained_types(std::shared_ptr<ReadAdapterInterface> rai)268 std::unordered_set<std::string> _get_mobile_model_contained_types(
269 std::shared_ptr<ReadAdapterInterface> rai) {
270 if (!check_zip_file(rai)) {
271 TORCH_CHECK(
272 false,
273 "Failed to open .ptl file please ensure the model was exported for mobile");
274 }
275 PyTorchStreamReader reader(std::move(rai));
276 auto bytecode_values = get_bytecode_ivalues(reader);
277 return _get_mobile_model_contained_types(bytecode_values);
278 }
279
280 // Get deduplicate type table given bytecode, and each string is a atomic type,
281 // like str, Tensor and etc. For example,
282 // input: "Dict[int, Tuple[Tensor, Tensor, Tensor]]"
283 // output: {Dict, int, Tuple, Tensor}
_get_mobile_model_contained_types(const std::vector<IValue> & bytecode_ivalues)284 std::unordered_set<std::string> _get_mobile_model_contained_types(
285 const std::vector<IValue>& bytecode_ivalues) {
286 std::unordered_set<std::string> contained_types;
287 // To avoid parsing same type twice, declare $parsed_type_names_records and
288 // use type name (string, ex: "Dict[int, Tuple[Tensor, Tensor, Tensor]]") as
289 // the hash to record which types are parsed.
290 std::unordered_set<std::string> parsed_type_names_records;
291 for (const auto i : c10::irange(1, bytecode_ivalues.size())) {
292 const auto& method_tuple = bytecode_ivalues.at(i).toTupleRef().elements();
293 auto type_table_tuple =
294 method_tuple.at(1).toTupleRef().elements()[BYTECODE_INDEX_TYPE];
295 const auto& type_table =
296 type_table_tuple.toTupleRef().elements()[1].toTupleRef().elements();
297
298 // type_table is a list of IValue, and each IValue is a string,
299 // for example: "Dict[int, Tuple[Tensor, Tensor, Tensor]]"
300 std::vector<std::string> type_name_list;
301 for (const auto& type_definition : type_table) {
302 std::unordered_set<std::string> type_tokens;
303 std::string type_name = type_definition.toStringRef();
304 type_name_list.emplace_back(type_name);
305 }
306 at::TypeParser parser(type_name_list);
307 parser.parseList();
308 contained_types = parser.getContainedTypes();
309 }
310
311 return contained_types;
312 }
313
314 /********************** Compatibility Checker **********************/
315
get(std::istream & in)316 ModelCompatibilityInfo ModelCompatibilityInfo::get(std::istream& in) {
317 std::unique_ptr<IStreamAdapter> rai = std::make_unique<IStreamAdapter>(&in);
318 return get(std::move(rai));
319 }
320
get(const std::string & filename)321 ModelCompatibilityInfo ModelCompatibilityInfo::get(
322 const std::string& filename) {
323 std::unique_ptr<FileAdapter> rai = std::make_unique<FileAdapter>(filename);
324 return get(std::move(rai));
325 }
326
get(std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai)327 ModelCompatibilityInfo ModelCompatibilityInfo::get(
328 std::shared_ptr<caffe2::serialize::ReadAdapterInterface> rai) {
329 if (!check_zip_file(rai)) {
330 TORCH_CHECK(
331 false, "Failed to open zip file for model compatibility information");
332 }
333 PyTorchStreamReader reader(std::move(rai));
334 std::vector<IValue> bytecode_values = get_bytecode_ivalues(reader);
335 uint64_t model_bytecode_version =
336 _get_model_bytecode_version(bytecode_values);
337 auto model_info = _get_model_ops_and_info(bytecode_values);
338 std::unordered_set<std::string> type_table =
339 _get_mobile_model_contained_types(bytecode_values);
340 uint64_t operator_version = _get_model_operator_version(reader);
341 return ModelCompatibilityInfo{
342 model_bytecode_version, model_info, type_table, operator_version};
343 }
344
is_compatible(RuntimeCompatibilityInfo runtime_info,const ModelCompatibilityInfo & model_info)345 ModelCompatCheckResult is_compatible(
346 RuntimeCompatibilityInfo runtime_info,
347 const ModelCompatibilityInfo& model_info) {
348 ModelCompatCheckResult result = {ModelCompatibilityStatus::OK, {}};
349 // Check that the models bytecode version is less than or equal to
350 // kMaxSupportedBytecodeVersion from the runtime
351 if (model_info.bytecode_version >
352 runtime_info.min_max_supported_bytecode_version.second) {
353 result.status = ModelCompatibilityStatus::ERROR;
354 std::ostringstream s;
355 s << "model bytecode version " << model_info.bytecode_version
356 << "is greater than the max supported bytecode version in runtimes "
357 << runtime_info.min_max_supported_bytecode_version.second;
358 result.errors.emplace_back(s.str());
359 } else if (
360 model_info.bytecode_version <
361 runtime_info.min_max_supported_bytecode_version.first) {
362 result.status = ModelCompatibilityStatus::ERROR;
363 std::ostringstream s;
364 s << "model bytecode version " << model_info.bytecode_version
365 << "is less than the minimum supported bytecode version in runtime "
366 << runtime_info.min_max_supported_bytecode_version.first;
367 result.errors.emplace_back(s.str());
368 }
369
370 std::unordered_set<std::string> supported_type = runtime_info.supported_types;
371
372 // Check type table
373 for (const auto& type_name : model_info.type_table) {
374 if (supported_type.find(type_name) == supported_type.end()) {
375 result.status = ModelCompatibilityStatus::ERROR;
376 std::ostringstream s;
377 s << "Primitive type: '" << type_name
378 << "' is not supported in current runtime";
379 result.errors.push_back(s.str());
380 }
381 }
382
383 // Check operators
384 std::unordered_map<std::string, OperatorInfo> operator_info =
385 model_info.operator_info;
386 for (auto const& op : operator_info) {
387 std::string op_name = op.first;
388 OperatorInfo model_op_info = op.second;
389
390 // Check if operator not present in runtime
391 if (runtime_info.operator_info.find(op_name) ==
392 runtime_info.operator_info.end()) {
393 result.status = ModelCompatibilityStatus::ERROR;
394 std::ostringstream s;
395 s << "Operator '" << op_name << "' missing from runtime (not found)";
396 result.errors.push_back(s.str());
397 } else {
398 OperatorInfo runtime_op_info = runtime_info.operator_info.at(op_name);
399
400 // If the runtime op has no schema information its a false alarm and isn't
401 // actually useable
402 if (!runtime_op_info.num_schema_args.has_value()) {
403 result.status = ModelCompatibilityStatus::ERROR;
404 std::ostringstream s;
405 s << "Operator '" << op_name
406 << "' missing from runtime (missing schema)";
407 result.errors.push_back(s.str());
408 } else {
409 // Check if the model operator has schema information. If it doesn't
410 // then the model is from a bytecode version < 6 and we are done. If the
411 // model has more args than the runtime, then the runtime can't know
412 // what to do so we aren't compatible. If the runtime has more args than
413 // the model then we can just use default values and be fine.
414 if (model_op_info.num_schema_args.has_value() &&
415 (model_op_info.num_schema_args.value() >
416 runtime_op_info.num_schema_args.value())) {
417 result.status = ModelCompatibilityStatus::ERROR;
418 std::ostringstream s;
419 s << "Operator schema for'" << op_name << "' has "
420 << model_op_info.num_schema_args.value()
421 << " args in model but only "
422 << runtime_op_info.num_schema_args.value() << " in the runtime";
423 result.errors.push_back(s.str());
424 }
425 }
426 }
427 }
428
429 // Check Operator Versions
430 if (model_info.operator_version <
431 runtime_info.min_max_supported_opperator_versions.first ||
432 model_info.operator_version >
433 runtime_info.min_max_supported_opperator_versions.second) {
434 result.status = ModelCompatibilityStatus::ERROR;
435 std::ostringstream s;
436 s << "Model Operator Version " << model_info.operator_version
437 << "is not within supported version range of the runtime "
438 << runtime_info.min_max_supported_opperator_versions.first << " to "
439 << runtime_info.min_max_supported_opperator_versions.second;
440 result.errors.push_back(s.str());
441 }
442
443 return result;
444 }
445
446 } // namespace torch::jit
447