xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/fold_conv_bn.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/passes/fold_conv_bn.h>
2 
3 #include <torch/csrc/jit/ir/subgraph_matcher.h>
4 #include <torch/csrc/jit/jit_log.h>
5 #include <torch/csrc/jit/passes/graph_rewrite_helper.h>
6 #include <torch/csrc/jit/passes/quantization/helper.h>
7 
8 #include <ATen/TensorOperators.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #else
13 #include <ATen/ops/empty_like.h>
14 #include <ATen/ops/ones_like.h>
15 #include <ATen/ops/rsqrt.h>
16 #include <ATen/ops/zeros_like.h>
17 #endif
18 
19 #include <stack>
20 #include <utility>
21 
22 namespace torch::jit {
23 
computeUpdatedConvWeightAndBias(const ConvBNParameters & p)24 std::tuple<at::Tensor, at::Tensor> computeUpdatedConvWeightAndBias(
25     const ConvBNParameters& p) {
26   at::Tensor bn_var_rsqrt = at::rsqrt(p.bn_rv + p.bn_eps);
27   const int64_t ndim = p.conv_w.dim();
28   at::DimVector sizes(ndim, 1);
29   sizes.at(0) = -1;
30 
31   auto conv_w_dtype = p.conv_w.dtype();
32   auto conv_b_dtype = p.conv_b.dtype();
33 
34   at::Tensor new_w = p.conv_w * (p.bn_w * bn_var_rsqrt).reshape(sizes);
35   at::Tensor new_b = (p.conv_b - p.bn_rm) * bn_var_rsqrt * p.bn_w + p.bn_b;
36   return std::make_tuple(new_w.to(conv_w_dtype), new_b.to(conv_b_dtype));
37 }
38 
39 namespace {
40 using graph_rewrite_helper::PatternInfo;
41 
hastensor(Module & m,const char * name)42 static bool hastensor(Module& m, const char* name) {
43   return m.hasattr(name) && m.attr(name).isTensor();
44 }
45 
replaceConvBiasWithGetAttr(Module & module)46 void replaceConvBiasWithGetAttr(Module& module) {
47   for (const auto& method : module.get_methods()) {
48     auto graph = method.graph();
49     // Only looks for _convolution pattern.
50     // Thus assumes that tracing will have always gotten rid of aten::conv2d or
51     // aten::conv3d. If it did not, BN folding will fail.
52     const PatternInfo& pattern_convolution = PatternInfo::parse_from_str(R"(
53         graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
54             %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
55             %deterministic:bool, %cudnn_enabled:bool, %allow_tf32:bool):
56           %conv_out = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation,
57               %transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled, %allow_tf32)
58           return (%conv_out) )");
59     const PatternInfo& pattern_convolution_deprecated =
60         PatternInfo::parse_from_str(R"(
61         graph(%a, %w, %b, %stride:int[], %padding:int[], %dilation:int[],
62             %transposed:bool, %output_padding:int[], %groups:int, %benchmark:bool,
63             %deterministic:bool, %cudnn_enabled:bool):
64           %conv_out = aten::_convolution(%a, %w, %b, %stride, %padding, %dilation,
65               %transposed, %output_padding, %groups, %benchmark, %deterministic, %cudnn_enabled)
66           return (%conv_out) )");
67     auto replace_pattern = [&](const PatternInfo& pattern_convolution) {
68       const Graph& pattern_convolution_graph =
69           *pattern_convolution.pattern_graph;
70       const auto& convolution_vmap = pattern_convolution.vmap;
71 
72       const auto& matches =
73           findPatternMatches(pattern_convolution_graph, *graph);
74       for (const auto& match : matches) {
75         // We come here only if the bias was not present in the module.
76         // In that case, the corresponding graph will not have getAttr("bias")
77         // Insert that in the graph.
78         // And change _convolution to take the new value.
79         auto conv_node =
80             match.values_map.at(convolution_vmap.at("conv_out"))->node();
81         WithInsertPoint ins(conv_node);
82         Value* bias_attr_val = graph->insertGetAttr(graph->inputs()[0], "bias")
83                                    ->setType(TensorType::get());
84         constexpr size_t conv_bias_index = 2;
85         conv_node->replaceInput(conv_bias_index, bias_attr_val);
86       }
87     };
88     replace_pattern(pattern_convolution);
89     replace_pattern(pattern_convolution_deprecated);
90   }
91 }
92 
addBiasForConvIfNone(Module & module,const std::string & pattern_name)93 void addBiasForConvIfNone(Module& module, const std::string& pattern_name) {
94   auto t = module.type()->expect<ClassType>();
95 
96   const std::string real_typename = t->name()->qualifiedName();
97   const std::string demangled_typename = removeTorchMangle(real_typename);
98   bool is_floating_point_conv =
99       ((demangled_typename == "__torch__.torch.nn.modules.conv.Conv1d") ||
100        (demangled_typename == "__torch__.torch.nn.modules.conv.Conv2d") ||
101        (demangled_typename == "__torch__.torch.nn.modules.conv.Conv3d"));
102 
103   if (is_floating_point_conv) {
104     if (!t->hasAttribute("bias")) {
105       auto optional_tensor_type = OptionalType::create(TensorType::get());
106       t->addAttribute("bias", std::move(optional_tensor_type), true);
107       auto optional_tensor = std::optional<at::Tensor>();
108       module.setattr("bias", std::move(optional_tensor));
109       replaceConvBiasWithGetAttr(module);
110     }
111   }
112   for (Module m : module.children()) {
113     addBiasForConvIfNone(m, pattern_name);
114   }
115 }
116 
117 class FoldConvBatchNormHelper {
118  public:
119   /**
120    * In this step we find all Conv - BatchNorm patterns in the graph
121    * and extract the corresponding parameters for these two modules,
122    * and record information for the modifications of the graph without
123    * actually performing these modifications.
124    */
125   void analyze(Module& module, const PatternInfo& pattern);
126   /**
127    * In this step we perform all the modifications including
128    * setting the attributes for conv module, rewriting values
129    * and deleting nodes in the graph
130    */
131   void transform();
132 
133  private:
134   bool tryExtractingConvBNParameters(
135       Module& conv,
136       Module& bn,
137       ConvBNParameters& r);
138 
139   std::unordered_map<ModulePtr, std::tuple<at::Tensor, at::Tensor>>
140       conv_module_and_params_;
141 
142   // A map from graph to a list of tuple of paths of matched conv and bn module
143   // e.g. if we have a graph `g` containing following code
144   // x = self.sub.conv1(..)
145   // x = self.sub.bn1(..)
146   // x = self.sub.conv2(..)
147   // x = self.sub.bn2(..)
148   // then the value for graph `g` in this map will be:
149   // [(['sub', 'conv1'], ['sub', 'bn1']), (['sub', 'conv2'], ['sub', 'bn2'])]
150   // the first entry of the list is the paths to first conv-bn match
151   // the second entry of the list is the path to second match
152   std::unordered_map<
153       Graph*,
154       std::vector<
155           std::tuple<std::vector<std::string>, std::vector<std::string>>>>
156       conv_bn_paths_;
157 
158   std::unordered_map<Value*, Value*> rewrite_map_;
159   std::vector<Value*> values_to_rewrite_;
160   std::unordered_set<Node*> nodes_to_delete_;
161 };
162 
extractOptionalBNParams(const script::Module & bn,ConvBNParameters & r)163 bool extractOptionalBNParams(const script::Module& bn, ConvBNParameters& r) {
164   auto bn_forward = bn.get_method("forward");
165   auto graph = bn_forward.graph();
166   const PatternInfo& pattern_bn = PatternInfo::parse_from_str(R"(
167       graph(%a, %weight, %bias, %running_mean, %running_var,
168           %training, %momentum, %eps, %cudnn_enabled):
169         %bn_out = aten::batch_norm(%a, %weight, %bias, %running_mean,
170             %running_var, %training, %momentum, %eps, %cudnn_enabled)
171         return (%bn_out) )");
172   const Graph& pattern_bn_graph = *pattern_bn.pattern_graph;
173   const auto& bn_vmap = pattern_bn.vmap;
174 
175   const auto& matches = findPatternMatches(pattern_bn_graph, *graph);
176 
177   if (matches.size() > 1) {
178     return false;
179   }
180 
181   if (bn.hasattr("eps")) {
182     r.bn_eps = bn.attr("eps").toDouble();
183   } else {
184     auto optional_eps = toIValue(matches[0].values_map.at(bn_vmap.at("eps")));
185     if (!optional_eps) {
186       return false;
187     }
188     r.bn_eps = optional_eps.value().toDouble();
189   }
190   r.bn_w = at::ones_like(bn.attr("running_mean").toTensor());
191   if (bn.hasattr("weight")) {
192     if (bn.attr("weight").isTensor()) {
193       r.bn_w = bn.attr("weight").toTensor();
194     }
195   } else {
196     auto optional_bn_weight =
197         toIValue(matches[0].values_map.at(bn_vmap.at("weight")));
198     if (!optional_bn_weight) {
199       return false;
200     }
201     if (optional_bn_weight.value().isTensor()) {
202       r.bn_w = optional_bn_weight.value().toTensor();
203     }
204   }
205   r.bn_b = at::zeros_like(bn.attr("running_mean").toTensor());
206   if (bn.hasattr("bias")) {
207     if (bn.attr("bias").isTensor()) {
208       r.bn_b = bn.attr("bias").toTensor();
209     }
210   } else {
211     auto optional_bn_bias =
212         toIValue(matches[0].values_map.at(bn_vmap.at("bias")));
213     if (!optional_bn_bias) {
214       return false;
215     }
216 
217     if (optional_bn_bias.value().isTensor()) {
218       r.bn_b = optional_bn_bias.value().toTensor();
219     }
220   }
221   return true;
222 }
223 
tryExtractingConvBNParameters(Module & conv,Module & bn,ConvBNParameters & r)224 bool FoldConvBatchNormHelper::tryExtractingConvBNParameters(
225     Module& conv,
226     Module& bn,
227     ConvBNParameters& r) {
228   if (!hastensor(conv, "weight") || !conv.hasattr("bias") ||
229       !hastensor(bn, "running_mean") || !hastensor(bn, "running_var")) {
230     return false;
231   }
232 
233   r.bn_rm = bn.attr("running_mean").toTensor();
234   r.bn_rv = bn.attr("running_var").toTensor();
235   if (!extractOptionalBNParams(bn, r)) {
236     return false;
237   }
238 
239   r.conv_w = conv.attr("weight").toTensor();
240   r.conv_b = at::zeros_like(r.bn_rm);
241   auto bias_opt = conv.attr("bias").toOptional<at::Tensor>();
242   if (bias_opt) {
243     r.conv_b = *bias_opt;
244   }
245 
246   return true;
247 }
248 
analyze(Module & module,const PatternInfo & pattern)249 void FoldConvBatchNormHelper::analyze(
250     Module& module,
251     const PatternInfo& pattern) {
252   const Graph& pattern_graph = *pattern.pattern_graph;
253   const auto& vmap = pattern.vmap;
254   Value* pattern_conv_out = vmap.at("conv_out");
255   Value* pattern_bn_out = vmap.at("bn_out");
256   Value* pattern_bn_submodule = vmap.at("batchnorm");
257   Node* pattern_conv = pattern_conv_out->node();
258   Node* pattern_bn = pattern_bn_out->node();
259 
260   // We will put submodules into this worklist and keep processing items from it
261   // one by one. We start by just putting the top module there.
262   std::stack<Module> worklist({module});
263   while (!worklist.empty()) {
264     Module current = worklist.top();
265     worklist.pop();
266 
267     // Queue submodules for processing
268     for (const Module& submodule : current.children()) {
269       worklist.push(submodule);
270     }
271 
272     // Process all method of the current module
273     for (auto& method : current.get_methods()) {
274       GRAPH_DUMP(
275           current.type()->name()->name() + "::" + method.name() +
276               "() before Conv-BatchNorm folding",
277           method.graph());
278       const auto& matches = findPatternMatches(pattern_graph, *method.graph());
279 
280       GRAPH_DEBUG("number of Conv-BatchNorm matches: ", matches.size());
281       Graph* g = method.graph().get();
282       if (!conv_bn_paths_.count(g)) {
283         // This is to make sure we don't visit one graph multiple times
284         conv_bn_paths_[g] = {};
285         for (const Match& match : matches) {
286           if (!std::all_of(
287                   pattern.filters.begin(),
288                   pattern.filters.end(),
289                   [&](const MatchFilter& f) { return f(match, vmap); })) {
290             continue;
291           }
292           GRAPH_DEBUG("Checking next match...");
293           // Get the conv and bn submodule
294           Node* matched_conv = match.nodes_map.at(pattern_conv);
295           Node* matched_bn = match.nodes_map.at(pattern_bn);
296           Node* matched_bn_submodule =
297               match.values_map.at(pattern_bn_submodule)->node();
298           Value* conv_instance = matched_conv->input(0);
299           Value* bn_instance = matched_bn->input(0);
300           Value* self = g->inputs()[0];
301           auto conv_module_path = getModuleAccessPath(conv_instance, self);
302           auto bn_module_path = getModuleAccessPath(bn_instance, self);
303           Module conv_submodule = findChildModule(current, conv_module_path);
304           Module bn_submodule = findChildModule(current, bn_module_path);
305 
306           ConvBNParameters params;
307           if (!tryExtractingConvBNParameters(
308                   conv_submodule, bn_submodule, params)) {
309             GRAPH_DEBUG(
310                 "Conv and BN modules didn't have all required parameters or attributes...");
311             continue;
312           }
313           conv_bn_paths_[g].emplace_back(conv_module_path, bn_module_path);
314           // We are using a separate vector for saving Values we want to rewrite
315           // to make sure that the order in which we perform these
316           // transformations is deterministic. Iterating through keys of
317           // rewrite_map would result in non-determinism that might not manifest
318           // as a bug now, but can bite us later.
319           values_to_rewrite_.push_back(matched_bn->output());
320           rewrite_map_[matched_bn->output()] = matched_conv->output();
321           GRAPH_UPDATE(
322               "Rewriting %",
323               matched_bn->output()->debugName(),
324               " with %",
325               matched_conv->output()->debugName());
326 
327           nodes_to_delete_.insert(matched_bn);
328           nodes_to_delete_.insert(matched_bn_submodule);
329           GRAPH_UPDATE("Deleting ", *matched_bn);
330           GRAPH_UPDATE("Deleting ", *matched_bn_submodule);
331 
332           auto slot = conv_submodule.type()->getAttributeSlot("bias");
333           TORCH_CHECK(
334               conv_submodule.type()->is_parameter(slot),
335               "Expected conv module to have a bias parameter");
336         } // matches
337       }
338 
339       for (const auto& conv_bn : conv_bn_paths_.at(g)) {
340         Module conv_submodule = findChildModule(current, std::get<0>(conv_bn));
341         Module bn_submodule = findChildModule(current, std::get<1>(conv_bn));
342 
343         ConvBNParameters params;
344         TORCH_INTERNAL_ASSERT(tryExtractingConvBNParameters(
345             conv_submodule, bn_submodule, params));
346         auto new_w_b = computeUpdatedConvWeightAndBias(params);
347         conv_module_and_params_[conv_submodule._ivalue()] = new_w_b;
348       } // conv_bn module
349     } // methods
350   } // while
351 }
352 
transform()353 void FoldConvBatchNormHelper::transform() {
354   for (const auto& item : conv_module_and_params_) {
355     Module conv(item.first);
356     auto w_b = item.second;
357     conv.setattr("weight", std::get<0>(w_b));
358     conv.setattr("bias", std::get<1>(w_b));
359   }
360 
361   // Perform planned rewritings
362   for (auto v : values_to_rewrite_) {
363     v->replaceAllUsesWith(rewrite_map_.at(v));
364   }
365 
366   // Perform planned deletions
367   for (auto n : nodes_to_delete_) {
368     n->removeAllInputs();
369   }
370   for (auto n : nodes_to_delete_) {
371     n->destroy();
372   }
373 }
374 
375 } // namespace
376 
FoldConvBatchNorm(const Module & module)377 Module FoldConvBatchNorm(const Module& module) {
378   Module m = module.clone();
379 
380   addBiasForConvIfNone(m, "Conv2d");
381   addBiasForConvIfNone(m, "Conv3d");
382   // Conv2d + BatchNorm2d
383   const PatternInfo pattern2d = PatternInfo::parse_from_str(
384       R"(
385 graph(%self, %input, %conv, %batchnorm):
386     %conv_out = prim::CallMethod[name="forward"](%conv, %input)
387     %bn_out = prim::CallMethod[name="forward"](%batchnorm, %conv_out)
388     return (%bn_out))",
389       {is_conv2d_module, is_batchnorm2d_module});
390   // Conv3d + BatchNorm3d
391   const PatternInfo pattern3d = PatternInfo::parse_from_str(
392       R"(
393 graph(%self, %input, %conv, %batchnorm):
394     %conv_out = prim::CallMethod[name="forward"](%conv, %input)
395     %bn_out = prim::CallMethod[name="forward"](%batchnorm, %conv_out)
396     return (%bn_out))",
397       {is_conv3d_module, is_batchnorm3d_module});
398 
399   const std::vector<std::reference_wrapper<const PatternInfo>> patterns = {
400       pattern2d, pattern3d};
401   for (const auto& pattern : patterns) {
402     FoldConvBatchNormHelper h;
403     h.analyze(m, pattern);
404     h.transform();
405   }
406   return m;
407 }
408 
409 } // namespace torch::jit
410