xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Context.h>
4 
5 #include <torch/custom_class.h>
6 
7 #include <ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h>
8 #include <ATen/native/ao_sparse/quantized/cpu/packed_params.h>
9 #include <ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h>
10 
11 namespace ao {
12 namespace sparse {
register_linear_params()13 int register_linear_params() {
14   static auto register_linear_params =
15       torch::selective_class_<LinearPackedParamsBase>(
16           "sparse", TORCH_SELECTIVE_CLASS("LinearPackedParamsBase"))
17           .def_pickle(
18               [](const c10::intrusive_ptr<LinearPackedParamsBase>& params)
19                   -> BCSRSerializationType { // __getstate__
20                 return params->serialize();
21               },
22               [](BCSRSerializationType state)
23                   -> c10::intrusive_ptr<
24                       LinearPackedParamsBase> { // __setstate__
25 #ifdef USE_FBGEMM
26                 if (at::globalContext().qEngine() == at::QEngine::FBGEMM) {
27                   return PackedLinearWeight::deserialize(state);
28                 }
29 #endif // USE_FBGEMM
30 #ifdef USE_PYTORCH_QNNPACK
31                 if (at::globalContext().qEngine() == at::QEngine::QNNPACK) {
32                   return PackedLinearWeightQnnp::deserialize(state);
33                 }
34 #endif // USE_FBGEMM
35                 TORCH_CHECK(false, "Unknown qengine");
36               });
37   // (1) we can't (easily) return the static initializer itself because it can have a different type because of selective build
38   // (2) we can't return void and be able to call the function in the global scope
39   return 0;
40 }
41 
42 namespace {
43 static C10_UNUSED auto linear_params = register_linear_params();
44 }  // namespace
45 
46 }}  // namespace ao::sparse
47