1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS 2 #include <ATen/core/Tensor.h> 3 #include <ATen/Context.h> 4 5 #include <torch/custom_class.h> 6 7 #include <ATen/native/ao_sparse/quantized/cpu/fbgemm_utils.h> 8 #include <ATen/native/ao_sparse/quantized/cpu/packed_params.h> 9 #include <ATen/native/ao_sparse/quantized/cpu/qnnpack_utils.h> 10 11 namespace ao { 12 namespace sparse { register_linear_params()13int register_linear_params() { 14 static auto register_linear_params = 15 torch::selective_class_<LinearPackedParamsBase>( 16 "sparse", TORCH_SELECTIVE_CLASS("LinearPackedParamsBase")) 17 .def_pickle( 18 [](const c10::intrusive_ptr<LinearPackedParamsBase>& params) 19 -> BCSRSerializationType { // __getstate__ 20 return params->serialize(); 21 }, 22 [](BCSRSerializationType state) 23 -> c10::intrusive_ptr< 24 LinearPackedParamsBase> { // __setstate__ 25 #ifdef USE_FBGEMM 26 if (at::globalContext().qEngine() == at::QEngine::FBGEMM) { 27 return PackedLinearWeight::deserialize(state); 28 } 29 #endif // USE_FBGEMM 30 #ifdef USE_PYTORCH_QNNPACK 31 if (at::globalContext().qEngine() == at::QEngine::QNNPACK) { 32 return PackedLinearWeightQnnp::deserialize(state); 33 } 34 #endif // USE_FBGEMM 35 TORCH_CHECK(false, "Unknown qengine"); 36 }); 37 // (1) we can't (easily) return the static initializer itself because it can have a different type because of selective build 38 // (2) we can't return void and be able to call the function in the global scope 39 return 0; 40 } 41 42 namespace { 43 static C10_UNUSED auto linear_params = register_linear_params(); 44 } // namespace 45 46 }} // namespace ao::sparse 47