xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/serialize/output-archive.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/serialize/output-archive.h>
2 
3 #include <torch/types.h>
4 #include <torch/utils.h>
5 
6 #include <torch/csrc/jit/api/module.h>
7 #include <torch/csrc/jit/serialization/export.h>
8 
9 #include <c10/util/Exception.h>
10 
11 #include <memory>
12 #include <ostream>
13 #include <string>
14 
15 namespace torch {
16 namespace serialize {
OutputArchive(std::shared_ptr<jit::CompilationUnit> cu)17 OutputArchive::OutputArchive(std::shared_ptr<jit::CompilationUnit> cu)
18     : cu_(std::move(cu)),
19       module_("__torch__.Module", cu_, /*shouldMangle=*/true) {}
20 
write(const std::string & key,const c10::IValue & ivalue)21 void OutputArchive::write(const std::string& key, const c10::IValue& ivalue) {
22   module_.register_attribute(key, ivalue.type(), ivalue);
23 }
24 
write(const std::string & key,const Tensor & tensor,bool is_buffer)25 void OutputArchive::write(
26     const std::string& key,
27     const Tensor& tensor,
28     bool is_buffer) {
29   module_.register_parameter(key, tensor, is_buffer);
30 }
31 
write(const std::string & key,OutputArchive & nested_archive)32 void OutputArchive::write(
33     const std::string& key,
34     OutputArchive& nested_archive) {
35   module_.register_module(key, nested_archive.module_);
36 }
37 
save_to(const std::string & filename)38 void OutputArchive::save_to(const std::string& filename) {
39   jit::ExportModule(module_, filename);
40 }
41 
save_to(std::ostream & stream)42 void OutputArchive::save_to(std::ostream& stream) {
43   jit::ExportModule(module_, stream);
44 }
45 
save_to(const std::function<size_t (const void *,size_t)> & func)46 void OutputArchive::save_to(
47     const std::function<size_t(const void*, size_t)>& func) {
48   jit::ExportModule(module_, func);
49 }
50 } // namespace serialize
51 } // namespace torch
52