xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/quantization/dedup_module_uses.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/jit/api/module.h>
4 
5 namespace torch {
6 namespace jit {
7 
8 /** Recursively deduplicate multiple uses of the same module by
9  *  creating an instance clone for each use of the module, which means
10  *  the type will be the same as before and all the attributes will be
11  *  copied, then we'll change the use of the original module to the use
12  *  of cloned module in the Graph.
13  *
14  *  This is done to ensure that modules can survive destructive passes
15  *  without changing model behavior. For example, here:
16  *
17  *    x = self.conv1(x)
18  *    x = self.relu(x)
19  *    x = self.conv2(x)
20  *    x = self.relu(x)
21  *
22  *  self.relu needs to be deduplicated for potential future destructive passes
23  *  to work properly.
24  */
25 TORCH_API void DedupModuleUses(Module& module);
26 
27 } // namespace jit
28 } // namespace torch
29