1 #include <torch/serialize/output-archive.h>
2
3 #include <torch/types.h>
4 #include <torch/utils.h>
5
6 #include <torch/csrc/jit/api/module.h>
7 #include <torch/csrc/jit/serialization/export.h>
8
9 #include <c10/util/Exception.h>
10
11 #include <memory>
12 #include <ostream>
13 #include <string>
14
15 namespace torch {
16 namespace serialize {
OutputArchive(std::shared_ptr<jit::CompilationUnit> cu)17 OutputArchive::OutputArchive(std::shared_ptr<jit::CompilationUnit> cu)
18 : cu_(std::move(cu)),
19 module_("__torch__.Module", cu_, /*shouldMangle=*/true) {}
20
write(const std::string & key,const c10::IValue & ivalue)21 void OutputArchive::write(const std::string& key, const c10::IValue& ivalue) {
22 module_.register_attribute(key, ivalue.type(), ivalue);
23 }
24
write(const std::string & key,const Tensor & tensor,bool is_buffer)25 void OutputArchive::write(
26 const std::string& key,
27 const Tensor& tensor,
28 bool is_buffer) {
29 module_.register_parameter(key, tensor, is_buffer);
30 }
31
write(const std::string & key,OutputArchive & nested_archive)32 void OutputArchive::write(
33 const std::string& key,
34 OutputArchive& nested_archive) {
35 module_.register_module(key, nested_archive.module_);
36 }
37
save_to(const std::string & filename)38 void OutputArchive::save_to(const std::string& filename) {
39 jit::ExportModule(module_, filename);
40 }
41
save_to(std::ostream & stream)42 void OutputArchive::save_to(std::ostream& stream) {
43 jit::ExportModule(module_, stream);
44 }
45
save_to(const std::function<size_t (const void *,size_t)> & func)46 void OutputArchive::save_to(
47 const std::function<size_t(const void*, size_t)>& func) {
48 jit::ExportModule(module_, func);
49 }
50 } // namespace serialize
51 } // namespace torch
52