xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/hoist_conv_packed_params.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <stack>
2 
3 #include <torch/csrc/jit/api/module.h>
4 #include <torch/csrc/jit/jit_log.h>
5 #include <torch/csrc/jit/passes/constant_pooling.h>
6 #include <torch/csrc/jit/passes/constant_propagation.h>
7 #include <torch/csrc/jit/passes/hoist_conv_packed_params.h>
8 #include <torch/csrc/jit/passes/quantization/helper.h>
9 
10 namespace torch::jit {
11 
12 // Hoists packed params from a conv module to the parent module.
13 // The benefit is that after this hoisting, the conv module
14 // no longer holds anything and can be deleted, reducing model
15 // size.
16 //
17 // Before (easy case):
18 //
19 // %1 = prim::GetAttr[name="conv1"][%self]
20 // %2 = prim::GetAttr[name="_packed_params][%1]
21 //
22 // After (easy case):
23 //
24 // %2 = prim::GetAttr[name="{prefix}.conv1._packed_params"][%self]
25 //
26 // Before (generic case):
27 //
28 // %1 = prim::GetAttr[name="name1"][%self]
29 // %2 = prim::GetAttr[name="name2"][%1]
30 // ...
31 // %n = prim::GetAttr[name="_packed_params][%n-1]
32 //
33 // After (generic case):
34 //
35 // %n =
36 // prim::GetAttr[name="{prefix}.name1{...}.name(n-1)._packed_params"][%self]
37 //
hoistConvPackedParams(Module & rootModule,Node * getConvPackedParamsNode,const std::string & prefix,int & nameUniqueCounter)38 static void hoistConvPackedParams(
39     Module& rootModule,
40     Node* getConvPackedParamsNode,
41     const std::string& prefix,
42     int& nameUniqueCounter) {
43   auto method = rootModule.get_method("forward");
44   auto graph = method.graph();
45   Value* rootModuleAsValue = graph->inputs()[0];
46 
47   // get a path from root module to conv module
48   Value* convModuleAsValue = getConvPackedParamsNode->inputs()[0];
49   std::vector<std::string> rootToConvPath =
50       getModuleAccessPath(convModuleAsValue, rootModuleAsValue);
51 
52   // get a module object representing the conv
53   Module convModule = findChildModule(rootModule, rootToConvPath);
54 
55   // get the packed params value
56   c10::IValue packedParams = convModule.attr("_packed_params");
57 
58   // create the new name
59 
60   std::string suffix = "";
61   for (const auto& attrName : rootToConvPath) {
62     suffix += attrName + ".";
63   }
64   std::string newNameBase = prefix + "." + suffix + "_packed_params";
65   nameUniqueCounter++;
66   std::string newName = newNameBase + "." + std::to_string(nameUniqueCounter);
67   while (rootModule.hasattr(newName)) {
68     nameUniqueCounter++;
69     newName = newNameBase + "." + std::to_string(nameUniqueCounter);
70   }
71 
72   // copy the packed params
73   rootModule.register_attribute(newName, packedParams.type(), packedParams);
74 
75   // change target module to rootModule
76   getConvPackedParamsNode->replaceInput(0, rootModuleAsValue);
77 
78   // change attribute name to new name
79   getConvPackedParamsNode->s_(Symbol::attr("name"), newName);
80 }
81 
HoistConvPackedParams(script::Module & m)82 void HoistConvPackedParams(script::Module& m) {
83   auto method = m.get_method("forward");
84   auto graph = method.graph();
85 
86   std::stack<Block*> blocks_to_visit;
87   blocks_to_visit.push(graph->block());
88   std::string attr_name_base = "_jit_pass_hoist_conv_packed_params";
89   // counter to ensure new attribute names are unique
90   int nameUniqueCounter = 0;
91 
92   while (!blocks_to_visit.empty()) {
93     Block* b = blocks_to_visit.top();
94     blocks_to_visit.pop();
95 
96     for (Node* n : b->nodes()) {
97       // make sure this node is fetching {foo}.{_packed_params}
98       bool isGetPackedParamsNode =
99           n->kind() == prim::GetAttr && n->s(attr::name) == "_packed_params";
100       if (isGetPackedParamsNode) {
101         // make sure the foo in {foo}.{_packed_params} is a quantized conv
102         std::optional<std::string> moduleName = getModuleName(n->inputs()[0]);
103         bool moduleNameIsQuantizedConv = moduleName.has_value() &&
104             (moduleName.value() ==
105                  "__torch__.torch.ao.nn.quantized.modules.conv.Conv1d" ||
106              moduleName.value() ==
107                  "__torch__.torch.ao.nn.quantized.modules.conv.Conv2d" ||
108              moduleName.value() ==
109                  "__torch__.torch.ao.nn.quantized.modules.conv.Conv3d" ||
110              moduleName.value() ==
111                  "__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU1d" ||
112              moduleName.value() ==
113                  "__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU2d" ||
114              moduleName.value() ==
115                  "__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU3d" ||
116              // BC Stuff
117              moduleName.value() ==
118                  "__torch__.torch.nn.quantized.modules.conv.Conv1d" ||
119              moduleName.value() ==
120                  "__torch__.torch.nn.quantized.modules.conv.Conv2d" ||
121              moduleName.value() ==
122                  "__torch__.torch.nn.quantized.modules.conv.Conv3d");
123 
124         if (moduleNameIsQuantizedConv) {
125           GRAPH_UPDATE("Hoisting ", *n, " to root module.");
126           hoistConvPackedParams(m, n, attr_name_base, nameUniqueCounter);
127         }
128       }
129 
130       for (Block* subblock : n->blocks()) {
131         blocks_to_visit.push(subblock);
132       }
133 
134     } // for
135 
136   } // while
137 }
138 
139 } // namespace torch::jit
140