xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/quantization/insert_observers.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/api/module.h>
4 #include <torch/csrc/jit/passes/quantization/quantization_type.h>
5 
6 namespace std {
7 
8 template <>
9 struct hash<torch::jit::Module> {
10   inline size_t operator()(const torch::jit::Module& arg) const {
11     return std::hash<c10::intrusive_ptr<c10::ivalue::Object>>()(arg._ivalue());
12   }
13 };
14 
15 } // namespace std
16 
17 namespace torch {
18 namespace jit {
19 
20 using QConfig = std::tuple<Module, Module>;
21 using QConfigDict = std::unordered_map<std::string, std::optional<QConfig>>;
22 
23 /** \brief Insert observer module and observer function call for
24  *  the Tensors that needs to be observed.
25  *
26  * For each Tensor that needs to be observed in the method, insert observer
27  * module to the input module and add forward calls of observer to the specified
28  * method.
29  *
30  * \param module the input module
31  * \param method_name the method we want to insert observers for
32  * \param qconfig_dict the qconfig dictionary that specifies how
33  * each module is going to be quantized
34  * \param inplace whether we want to do inplace modification to the input module
35  * or clone the module
36  * \param is_dynamic whether the dynamic quantization script is being used.
37  */
38 TORCH_API Module InsertObservers(
39     Module& module,
40     const std::string& method_name,
41     const QConfigDict& qconfig_dict,
42     bool inplace,
43     QuantType quant_type = QuantType::STATIC);
44 
45 /** \brief Insert observer module and observer method for
46  *  the Tensors that needs to be observed.
47  *
48  * For each Tensor that needs to be observed in the method, insert observer
49  * module to the input module and observe_<method-name> methods to the module.
50  * This method is clone of mehtod_name with forward calls of observer added.
51  *
52  * \param module the input module
53  * \param method_name the method we want to insert observers for
54  * \param qconfig_dict the qconfig dictionary that specifies how
55  * each module is going to be quantized
56  * \param inplace whether we want to do inplace modification to the input module
57  * or clone the module
58  * \param is_dynamic whether the dynamic quantization script is being used.
59  */
60 TORCH_API Module InsertObserversForOnDevicePTQ(
61     Module& module,
62     const std::string& method_name,
63     const QConfigDict& qconfig_dict,
64     bool inplace,
65     QuantType quant_type = QuantType::STATIC);
66 
67 } // namespace jit
68 } // namespace torch
69