xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/mobile/train/export_data.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/mobile/module.h>
4 
5 namespace torch::jit {
6 
7 /**
8  * Serializes the provided tensor map to the provided stream.
9  *
10  * @param[in] map The tensors to serialize.
11  * @param[in] out The stream to write the serialized data to.
12  * @param[in] use_flatbuffer If true, use Flatbuffers to serialize the data.
13  *     If false, use Pickle.
14  */
15 TORCH_API void _save_parameters(
16     const std::map<std::string, at::Tensor>& map,
17     std::ostream& out,
18     bool use_flatbuffer = false);
19 
20 /**
21  * Serializes the provided tensor map to a file.
22  *
23  * @param[in] map The tensors to serialize.
24  * @param[in] filename The stem of the file name to write to. If
25  *     @p use_flatbuffer is false, the extension ".pkl" will be appended. If
26  *     @p use_flatbuffer is true, the extension ".ff" will be appended.
27  * @param[in] use_flatbuffer If true, use Flatbuffers to serialize the data.
28  *     If false, use Pickle.
29  */
30 TORCH_API void _save_parameters(
31     const std::map<std::string, at::Tensor>& map,
32     const std::string& filename,
33     bool use_flatbuffer = false);
34 
35 namespace mobile {
36 
37 // NOTE: Please prefer using _save_parameters directly over using the 2
38 // functions below.
39 TORCH_API mobile::Module tensor_dict_to_mobile(
40     const c10::Dict<std::string, at::Tensor>& dict);
41 
42 c10::Dict<std::string, at::Tensor> tensor_map_to_dict(
43     const std::map<std::string, at::Tensor>& map);
44 
45 } // namespace mobile
46 
47 extern void (*_save_mobile_module_to)(
48     const mobile::Module& module,
49     const std::function<size_t(const void*, size_t)>& writer_func);
50 
51 } // namespace torch::jit
52