xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/passes/frozen_conv_folding.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/Utils.h>
2 #include <c10/core/ScalarType.h>
3 #include <c10/util/Exception.h>
4 #include <c10/util/accumulate.h>
5 #include <c10/util/irange.h>
6 #include <torch/csrc/jit/ir/constants.h>
7 #include <torch/csrc/jit/ir/ir.h>
8 #include <torch/csrc/jit/jit_log.h>
9 #include <torch/csrc/jit/passes/constant_propagation.h>
10 #include <torch/csrc/jit/passes/dead_code_elimination.h>
11 #include <torch/csrc/jit/passes/fold_conv_bn.h>
12 #include <torch/csrc/jit/passes/frozen_conv_folding.h>
13 #include <torch/csrc/jit/passes/utils/optimization_utils.h>
14 #include <torch/csrc/jit/tensorexpr/types.h>
15 
16 #ifndef AT_PER_OPERATOR_HEADERS
17 #include <ATen/Functions.h>
18 #else
19 #include <ATen/ops/ones_like.h>
20 #include <ATen/ops/zeros.h>
21 #include <ATen/ops/zeros_like.h>
22 #endif
23 
24 namespace torch::jit {
25 
26 namespace {
27 
28 using Tensor = at::Tensor;
29 
supportedConvNode(Node * n)30 bool supportedConvNode(Node* n) {
31   switch (n->kind()) {
32     case aten::conv1d:
33     case aten::conv2d:
34     case aten::conv3d:
35       return true;
36     case aten::_convolution: {
37       auto transposed_conv =
38           constant_as<bool>(n->namedInput("transposed")).value_or(true);
39       // dont handle transposed conv yet or not-constant transpose parameter
40       return !transposed_conv;
41     }
42     default:
43       return false;
44   }
45 }
46 
FoldFrozenConvBatchnorm(Block * b)47 bool FoldFrozenConvBatchnorm(Block* b) {
48   bool graph_modified = false;
49   for (Node* n : b->nodes()) {
50     for (Block* block : n->blocks()) {
51       graph_modified |= FoldFrozenConvBatchnorm(block);
52     }
53 
54     if (n->kind() == aten::batch_norm &&
55         supportedConvNode(n->inputs().at(0)->node())) {
56       auto conv = n->inputs().at(0)->node();
57       auto bn = n;
58       if (nonConstantParameters(conv) || nonConstantParameters(bn)) {
59         continue;
60       }
61       if (conv->output()->uses().size() > 1) {
62         continue;
63       }
64 
65       auto bn_rm_ivalue = bn->namedInput("running_mean");
66       auto bn_rv_ivalue = bn->namedInput("running_var");
67       // check running_mean and running_var has value, if they are
68       // None(track_running_stats=False), skipping the folding path.
69       if (bn_rm_ivalue->type() == NoneType::get() &&
70           bn_rv_ivalue->type() == NoneType::get()) {
71         continue;
72       }
73 
74       auto bn_rm = constant_as<Tensor>(bn->namedInput("running_mean")).value();
75       auto bn_rv = constant_as<Tensor>(bn->namedInput("running_var")).value();
76       auto bn_eps = constant_as<double>(bn->namedInput("eps")).value();
77       auto conv_w = constant_as<Tensor>(conv->namedInput("weight")).value();
78 
79       // implementation taken from torch/nn/utils/fusion.py
80       Tensor conv_b;
81       if (conv->namedInput("bias")->type() == NoneType::get()) {
82         // If this is on GPU and bias is none and weight was half/bfloat, but
83         // bn_rm was float, then probably this was a case where autocasting
84         // casted inputs to conv. And since CUDA conv implementation requires
85         // all the inputs to have the same scalar dtype, we need to make this
86         // placeholder have the same type as conv_w.
87         at::ScalarType bias_dtype = bn_rm.scalar_type();
88         at::ScalarType weight_dtype = conv_w.scalar_type();
89         if ((weight_dtype == at::kHalf || weight_dtype == at::kBFloat16) &&
90             bias_dtype == at::kFloat) {
91           bias_dtype = weight_dtype;
92         }
93         conv_b = at::zeros_like(bn_rm, at::TensorOptions().dtype(bias_dtype));
94       } else {
95         conv_b = constant_as<Tensor>(conv->namedInput("bias")).value();
96       }
97       Tensor bn_w;
98       if (bn->namedInput("weight")->type() == NoneType::get()) {
99         bn_w = at::ones_like(bn_rm);
100       } else {
101         bn_w = constant_as<Tensor>(bn->namedInput("weight")).value();
102       }
103       Tensor bn_b;
104       if (n->namedInput("bias")->type() == NoneType::get()) {
105         bn_b = at::zeros_like(bn_rm);
106       } else {
107         bn_b = constant_as<Tensor>(bn->namedInput("bias")).value();
108       }
109 
110       ConvBNParameters params;
111       params.conv_w = conv_w;
112       params.conv_b = conv_b;
113       params.bn_rm = bn_rm;
114       params.bn_rv = bn_rv;
115       params.bn_eps = bn_eps;
116       params.bn_w = bn_w;
117       params.bn_b = bn_b;
118       std::tuple<Tensor, Tensor> out = computeUpdatedConvWeightAndBias(params);
119       WithInsertPoint guard(conv);
120       auto fused_conv_w = b->owningGraph()->insertConstant(std::get<0>(out));
121       auto fused_conv_b = b->owningGraph()->insertConstant(std::get<1>(out));
122       auto conv_w_value = conv->namedInput("weight");
123       auto conv_b_value = conv->namedInput("bias");
124 
125       fused_conv_w->setDebugName(conv_w_value->debugName() + "_fused_bn");
126       fused_conv_b->setDebugName(conv_b_value->debugName() + "_fused_bn");
127 
128       conv->replaceInputWith(conv_w_value, fused_conv_w);
129       conv->replaceInputWith(conv_b_value, fused_conv_b);
130 
131       bn->output()->replaceAllUsesWith(conv->output());
132       graph_modified = true;
133     }
134   }
135   return graph_modified;
136 }
137 
supportedAddOrSub(Node * n)138 bool supportedAddOrSub(Node* n) {
139   static const OperatorSet add_set{
140       "aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor",
141       "aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor",
142       // sub is equivalent to add
143       "aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor",
144       "aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor",
145   };
146   return n->isMemberOf(add_set);
147 }
148 
149 // In order to fuse add/sub/mul/div with conv, the dimensions of its
150 // constant tensor must satisfy the following:
151 // - with resizing, broadcast to w/ weight/bias tensor shape
152 // - broadcast to the conv output shape
153 // It needs to have a shape that can resize to weight/bias
154 // tensor shape because we need to run the op with the conv
155 // weights/bias without changing their sizes.
156 // It needs to broadcast to the conv output shape so that we do
157 // accidentally change the shape of op output by pre-fusing it
158 // compared to eager.
159 // The only dimension value shared by weight/bias/conv output
160 // is they all contain a dim with value = channels-out. In the
161 // conv output tensor, this is in the second dimension,
162 // so the pointwise op tensor may have a second dimension of
163 // value == channels-out, but all the other dimensions have to be 1
opDoesNotBroadCastWithConv(Tensor & op_tensor,Tensor & weight_tensor)164 bool opDoesNotBroadCastWithConv(Tensor& op_tensor, Tensor& weight_tensor) {
165   if (op_tensor.ndimension() > weight_tensor.ndimension()) {
166     return false;
167   }
168   for (int64_t i = op_tensor.ndimension() - 1; i >= 0; i--) {
169     // channels-out dimension == weight_tensor.size(0)
170     if (i == 1 && op_tensor.size(i) == weight_tensor.size(0)) {
171       continue;
172     }
173     if (op_tensor.size(i) != 1) {
174       return false;
175     }
176   }
177   return true;
178 }
179 
checkConvAndBroadcastingOpPreConditions(Node * conv,Node * op)180 bool checkConvAndBroadcastingOpPreConditions(Node* conv, Node* op) {
181   if (nonConstantParameters(conv) || nonConstantParameters(op)) {
182     return false;
183   }
184 
185   if (conv->output()->uses().size() > 1) {
186     return false;
187   }
188 
189   Tensor weight_tensor =
190       constant_as<Tensor>(conv->namedInput("weight")).value();
191 
192   // avoid fusing op that causes type promotion
193   // restricting to float avoids int/float difficulties with scalar overload
194   if (!weight_tensor.is_floating_point()) {
195     return false;
196   }
197 
198   if (op->inputs().at(1)->type()->cast<TensorType>()) {
199     auto op_tensor = constant_as<Tensor>(op->inputs().at(1)).value();
200     if (!opDoesNotBroadCastWithConv(op_tensor, weight_tensor)) {
201       return false;
202     }
203 
204     if (!op_tensor.is_floating_point() &&
205         c10::promoteTypes(
206             op_tensor.scalar_type(), weight_tensor.scalar_type()) !=
207             weight_tensor.scalar_type()) {
208       return false;
209     }
210   }
211   return true;
212 }
213 
resizeConstantScalarOrTensorToShape(Value * v,const std::vector<int64_t> & shape,at::TensorOptions options)214 Tensor resizeConstantScalarOrTensorToShape(
215     Value* v,
216     const std::vector<int64_t>& shape,
217     at::TensorOptions options) {
218   Tensor ret_tensor;
219   if (v->type()->cast<TensorType>()) {
220     ret_tensor = constant_as<Tensor>(v).value();
221   } else {
222     ret_tensor = at::zeros(shape, options);
223     if (v->type()->cast<IntType>()) {
224       ret_tensor.fill_(constant_as<int64_t>(v).value());
225     } else {
226       ret_tensor.fill_(constant_as<double>(v).value());
227     }
228   }
229 
230   if (ret_tensor.numel() == 1) {
231     // expand errors if the shape input has less # dims than the tensor input
232     ret_tensor = ret_tensor.reshape({1});
233     ret_tensor = ret_tensor.expand(shape);
234   } else {
235     TORCH_INTERNAL_ASSERT(ret_tensor.numel() == c10::multiply_integers(shape));
236     ret_tensor = ret_tensor.view(shape);
237   }
238   return ret_tensor;
239 }
240 
FoldFrozenConvAddOrSub(Block * b)241 bool FoldFrozenConvAddOrSub(Block* b) {
242   bool graph_modified = false;
243   for (Node* n : b->nodes()) {
244     for (Block* block : n->blocks()) {
245       graph_modified |= FoldFrozenConvAddOrSub(block);
246     }
247 
248     if (supportedAddOrSub(n) && supportedConvNode(n->inputs().at(0)->node())) {
249       auto conv = n->inputs().at(0)->node();
250       auto add_or_sub = n;
251 
252       if (!checkConvAndBroadcastingOpPreConditions(conv, add_or_sub)) {
253         continue;
254       }
255 
256       Tensor weight_tensor =
257           constant_as<Tensor>(conv->namedInput("weight")).value();
258 
259       Tensor add_or_sub_tensor = resizeConstantScalarOrTensorToShape(
260           add_or_sub->inputs().at(1),
261           {weight_tensor.size(0)},
262           weight_tensor.options());
263       Tensor bias;
264       if (conv->namedInput("bias")->type() == NoneType::get()) {
265         bias = at::zeros_like(add_or_sub_tensor, weight_tensor.dtype());
266       } else {
267         bias = constant_as<Tensor>(conv->namedInput("bias")).value();
268       }
269 
270       WithInsertPoint guard(conv);
271 
272       add_or_sub->replaceInputWith(
273           conv->output(), b->owningGraph()->insertConstant(bias));
274       add_or_sub->replaceInput(
275           1, b->owningGraph()->insertConstant(add_or_sub_tensor));
276 
277       auto stack_out = runNodeIfInputsAreConstant(add_or_sub);
278       TORCH_INTERNAL_ASSERT(stack_out && stack_out->size() == 1);
279       Tensor fuse_bias = (*stack_out)[0].toTensor().to(bias.dtype());
280 
281       auto fused_conv_b = b->owningGraph()->insertConstant(fuse_bias);
282       auto conv_b_value = conv->namedInput("bias");
283 
284       fused_conv_b->setDebugName(
285           conv_b_value->debugName() + "_fused_" +
286           add_or_sub->kind().toUnqualString());
287       conv->replaceInputWith(conv_b_value, fused_conv_b);
288       add_or_sub->output()->replaceAllUsesWith(conv->output());
289       graph_modified = true;
290       // DCE run after cleans up nodes
291     }
292   }
293   return graph_modified;
294 }
295 
supportedMulOrDiv(Node * n)296 bool supportedMulOrDiv(Node* n) {
297   static const OperatorSet add_set{
298       "aten::mul.Tensor(Tensor self, Tensor other) -> Tensor",
299       "aten::mul.Scalar(Tensor self, Scalar other) -> Tensor",
300       // div is equivalent to mul
301       "aten::div.Tensor(Tensor self, Tensor other) -> Tensor",
302       "aten::div.Scalar(Tensor self, Scalar other) -> Tensor",
303   };
304   return n->isMemberOf(add_set);
305 }
306 
FoldFrozenConvMulOrDiv(Block * b)307 bool FoldFrozenConvMulOrDiv(Block* b) {
308   bool graph_modified = false;
309   for (Node* n : b->nodes()) {
310     for (Block* block : n->blocks()) {
311       graph_modified |= FoldFrozenConvMulOrDiv(block);
312     }
313 
314     if (supportedMulOrDiv(n) && supportedConvNode(n->inputs().at(0)->node())) {
315       auto conv = n->inputs().at(0)->node();
316       auto mul_or_div = n;
317 
318       if (!checkConvAndBroadcastingOpPreConditions(conv, mul_or_div)) {
319         continue;
320       }
321 
322       Tensor weight_tensor =
323           constant_as<Tensor>(conv->namedInput("weight")).value();
324       int64_t out_channels = weight_tensor.size(0);
325 
326       // We've already verified that the second input has numel == 1 or
327       // channels-out resize it to the shape that will broadcast to
328       // weight_tensor when the op is run so we dont change weight size
329       std::vector<int64_t> weight_compatible_size = {out_channels};
330       for (const auto i : c10::irange(1, weight_tensor.ndimension())) {
331         (void)i; // Suppress unused variable warning
332         weight_compatible_size.push_back(1);
333       }
334 
335       WithInsertPoint guard(conv);
336 
337       Tensor mul_tensor = resizeConstantScalarOrTensorToShape(
338           mul_or_div->inputs().at(1),
339           weight_compatible_size,
340           weight_tensor.options());
341 
342       // First fold with weight tensor
343       mul_or_div->replaceInputWith(
344           conv->output(), b->owningGraph()->insertConstant(weight_tensor));
345       mul_or_div->replaceInput(1, b->owningGraph()->insertConstant(mul_tensor));
346 
347       auto stack_out = runNodeIfInputsAreConstant(mul_or_div);
348       TORCH_INTERNAL_ASSERT(stack_out && stack_out->size() == 1);
349       Tensor fuse_weight = (*stack_out)[0].toTensor().to(weight_tensor.dtype());
350 
351       auto fused_conv_weight = b->owningGraph()->insertConstant(fuse_weight);
352       auto conv_weight_value = conv->namedInput("weight");
353 
354       fused_conv_weight->setDebugName(
355           conv_weight_value->debugName() + "_fused_" +
356           mul_or_div->kind().toUnqualString());
357       conv->replaceInputWith(conv_weight_value, fused_conv_weight);
358       mul_or_div->output()->replaceAllUsesWith(conv->output());
359 
360       // now fold with bias tensor
361       if (conv->namedInput("bias")->type() != NoneType::get()) {
362         Tensor bias = constant_as<Tensor>(conv->namedInput("bias")).value();
363         // bias is of shape {channels_out}
364         auto mul_tensor = resizeConstantScalarOrTensorToShape(
365             mul_or_div->inputs().at(1), {out_channels}, bias.options());
366 
367         mul_or_div->replaceInput(0, b->owningGraph()->insertConstant(bias));
368         mul_or_div->replaceInput(
369             1, b->owningGraph()->insertConstant(mul_tensor));
370 
371         auto stack_out = runNodeIfInputsAreConstant(mul_or_div);
372         TORCH_INTERNAL_ASSERT(stack_out && stack_out->size() == 1);
373         Tensor fuse_bias = (*stack_out)[0].toTensor().to(bias.dtype());
374 
375         auto fused_conv_bias = b->owningGraph()->insertConstant(fuse_bias);
376         auto conv_b_value = conv->namedInput("bias");
377 
378         fused_conv_weight->setDebugName(
379             conv_b_value->debugName() + "_fused_" +
380             mul_or_div->kind().toUnqualString());
381         conv->replaceInputWith(conv_b_value, fused_conv_bias);
382       }
383       graph_modified = true;
384       // DCE run after cleans up nodes
385     }
386   }
387   return graph_modified;
388 }
389 
390 } // namespace
391 
FoldFrozenConvBatchnorm(std::shared_ptr<Graph> & graph)392 bool FoldFrozenConvBatchnorm(std::shared_ptr<Graph>& graph) {
393   bool graph_modified = FoldFrozenConvBatchnorm(graph->block());
394   EliminateDeadCode(graph);
395   return graph_modified;
396 }
397 
FoldFrozenConvAddOrSub(std::shared_ptr<Graph> & graph)398 bool FoldFrozenConvAddOrSub(std::shared_ptr<Graph>& graph) {
399   bool graph_modified = FoldFrozenConvAddOrSub(graph->block());
400   EliminateDeadCode(graph);
401   return graph_modified;
402 }
403 
FoldFrozenConvMulOrDiv(std::shared_ptr<Graph> & graph)404 bool FoldFrozenConvMulOrDiv(std::shared_ptr<Graph>& graph) {
405   bool graph_modified = FoldFrozenConvMulOrDiv(graph->block());
406   EliminateDeadCode(graph);
407   return graph_modified;
408 }
409 
410 } // namespace torch::jit
411