1 #include <stack>
2
3 #include <torch/csrc/jit/api/module.h>
4 #include <torch/csrc/jit/jit_log.h>
5 #include <torch/csrc/jit/passes/constant_pooling.h>
6 #include <torch/csrc/jit/passes/constant_propagation.h>
7 #include <torch/csrc/jit/passes/hoist_conv_packed_params.h>
8 #include <torch/csrc/jit/passes/quantization/helper.h>
9
10 namespace torch::jit {
11
12 // Hoists packed params from a conv module to the parent module.
13 // The benefit is that after this hoisting, the conv module
14 // no longer holds anything and can be deleted, reducing model
15 // size.
16 //
17 // Before (easy case):
18 //
19 // %1 = prim::GetAttr[name="conv1"][%self]
20 // %2 = prim::GetAttr[name="_packed_params][%1]
21 //
22 // After (easy case):
23 //
24 // %2 = prim::GetAttr[name="{prefix}.conv1._packed_params"][%self]
25 //
26 // Before (generic case):
27 //
28 // %1 = prim::GetAttr[name="name1"][%self]
29 // %2 = prim::GetAttr[name="name2"][%1]
30 // ...
31 // %n = prim::GetAttr[name="_packed_params][%n-1]
32 //
33 // After (generic case):
34 //
35 // %n =
36 // prim::GetAttr[name="{prefix}.name1{...}.name(n-1)._packed_params"][%self]
37 //
hoistConvPackedParams(Module & rootModule,Node * getConvPackedParamsNode,const std::string & prefix,int & nameUniqueCounter)38 static void hoistConvPackedParams(
39 Module& rootModule,
40 Node* getConvPackedParamsNode,
41 const std::string& prefix,
42 int& nameUniqueCounter) {
43 auto method = rootModule.get_method("forward");
44 auto graph = method.graph();
45 Value* rootModuleAsValue = graph->inputs()[0];
46
47 // get a path from root module to conv module
48 Value* convModuleAsValue = getConvPackedParamsNode->inputs()[0];
49 std::vector<std::string> rootToConvPath =
50 getModuleAccessPath(convModuleAsValue, rootModuleAsValue);
51
52 // get a module object representing the conv
53 Module convModule = findChildModule(rootModule, rootToConvPath);
54
55 // get the packed params value
56 c10::IValue packedParams = convModule.attr("_packed_params");
57
58 // create the new name
59
60 std::string suffix = "";
61 for (const auto& attrName : rootToConvPath) {
62 suffix += attrName + ".";
63 }
64 std::string newNameBase = prefix + "." + suffix + "_packed_params";
65 nameUniqueCounter++;
66 std::string newName = newNameBase + "." + std::to_string(nameUniqueCounter);
67 while (rootModule.hasattr(newName)) {
68 nameUniqueCounter++;
69 newName = newNameBase + "." + std::to_string(nameUniqueCounter);
70 }
71
72 // copy the packed params
73 rootModule.register_attribute(newName, packedParams.type(), packedParams);
74
75 // change target module to rootModule
76 getConvPackedParamsNode->replaceInput(0, rootModuleAsValue);
77
78 // change attribute name to new name
79 getConvPackedParamsNode->s_(Symbol::attr("name"), newName);
80 }
81
HoistConvPackedParams(script::Module & m)82 void HoistConvPackedParams(script::Module& m) {
83 auto method = m.get_method("forward");
84 auto graph = method.graph();
85
86 std::stack<Block*> blocks_to_visit;
87 blocks_to_visit.push(graph->block());
88 std::string attr_name_base = "_jit_pass_hoist_conv_packed_params";
89 // counter to ensure new attribute names are unique
90 int nameUniqueCounter = 0;
91
92 while (!blocks_to_visit.empty()) {
93 Block* b = blocks_to_visit.top();
94 blocks_to_visit.pop();
95
96 for (Node* n : b->nodes()) {
97 // make sure this node is fetching {foo}.{_packed_params}
98 bool isGetPackedParamsNode =
99 n->kind() == prim::GetAttr && n->s(attr::name) == "_packed_params";
100 if (isGetPackedParamsNode) {
101 // make sure the foo in {foo}.{_packed_params} is a quantized conv
102 std::optional<std::string> moduleName = getModuleName(n->inputs()[0]);
103 bool moduleNameIsQuantizedConv = moduleName.has_value() &&
104 (moduleName.value() ==
105 "__torch__.torch.ao.nn.quantized.modules.conv.Conv1d" ||
106 moduleName.value() ==
107 "__torch__.torch.ao.nn.quantized.modules.conv.Conv2d" ||
108 moduleName.value() ==
109 "__torch__.torch.ao.nn.quantized.modules.conv.Conv3d" ||
110 moduleName.value() ==
111 "__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU1d" ||
112 moduleName.value() ==
113 "__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU2d" ||
114 moduleName.value() ==
115 "__torch__.torch.nn.intrinsic.quantized.modules.conv_relu.ConvReLU3d" ||
116 // BC Stuff
117 moduleName.value() ==
118 "__torch__.torch.nn.quantized.modules.conv.Conv1d" ||
119 moduleName.value() ==
120 "__torch__.torch.nn.quantized.modules.conv.Conv2d" ||
121 moduleName.value() ==
122 "__torch__.torch.nn.quantized.modules.conv.Conv3d");
123
124 if (moduleNameIsQuantizedConv) {
125 GRAPH_UPDATE("Hoisting ", *n, " to root module.");
126 hoistConvPackedParams(m, n, attr_name_base, nameUniqueCounter);
127 }
128 }
129
130 for (Block* subblock : n->blocks()) {
131 blocks_to_visit.push(subblock);
132 }
133
134 } // for
135
136 } // while
137 }
138
139 } // namespace torch::jit
140