1 #pragma once 2 3 #include <torch/csrc/jit/api/module.h> 4 #include <torch/csrc/jit/passes/quantization/quantization_type.h> 5 6 namespace std { 7 8 template <> 9 struct hash<torch::jit::Module> { 10 inline size_t operator()(const torch::jit::Module& arg) const { 11 return std::hash<c10::intrusive_ptr<c10::ivalue::Object>>()(arg._ivalue()); 12 } 13 }; 14 15 } // namespace std 16 17 namespace torch { 18 namespace jit { 19 20 using QConfig = std::tuple<Module, Module>; 21 using QConfigDict = std::unordered_map<std::string, std::optional<QConfig>>; 22 23 /** \brief Insert observer module and observer function call for 24 * the Tensors that needs to be observed. 25 * 26 * For each Tensor that needs to be observed in the method, insert observer 27 * module to the input module and add forward calls of observer to the specified 28 * method. 29 * 30 * \param module the input module 31 * \param method_name the method we want to insert observers for 32 * \param qconfig_dict the qconfig dictionary that specifies how 33 * each module is going to be quantized 34 * \param inplace whether we want to do inplace modification to the input module 35 * or clone the module 36 * \param is_dynamic whether the dynamic quantization script is being used. 37 */ 38 TORCH_API Module InsertObservers( 39 Module& module, 40 const std::string& method_name, 41 const QConfigDict& qconfig_dict, 42 bool inplace, 43 QuantType quant_type = QuantType::STATIC); 44 45 /** \brief Insert observer module and observer method for 46 * the Tensors that needs to be observed. 47 * 48 * For each Tensor that needs to be observed in the method, insert observer 49 * module to the input module and observe_<method-name> methods to the module. 50 * This method is clone of mehtod_name with forward calls of observer added. 51 * 52 * \param module the input module 53 * \param method_name the method we want to insert observers for 54 * \param qconfig_dict the qconfig dictionary that specifies how 55 * each module is going to be quantized 56 * \param inplace whether we want to do inplace modification to the input module 57 * or clone the module 58 * \param is_dynamic whether the dynamic quantization script is being used. 59 */ 60 TORCH_API Module InsertObserversForOnDevicePTQ( 61 Module& module, 62 const std::string& method_name, 63 const QConfigDict& qconfig_dict, 64 bool inplace, 65 QuantType quant_type = QuantType::STATIC); 66 67 } // namespace jit 68 } // namespace torch 69