1 #include <ATen/Utils.h>
2 #include <c10/core/ScalarType.h>
3 #include <c10/util/Exception.h>
4 #include <c10/util/accumulate.h>
5 #include <c10/util/irange.h>
6 #include <torch/csrc/jit/ir/constants.h>
7 #include <torch/csrc/jit/ir/ir.h>
8 #include <torch/csrc/jit/jit_log.h>
9 #include <torch/csrc/jit/passes/constant_propagation.h>
10 #include <torch/csrc/jit/passes/dead_code_elimination.h>
11 #include <torch/csrc/jit/passes/fold_conv_bn.h>
12 #include <torch/csrc/jit/passes/frozen_conv_folding.h>
13 #include <torch/csrc/jit/passes/utils/optimization_utils.h>
14 #include <torch/csrc/jit/tensorexpr/types.h>
15
16 #ifndef AT_PER_OPERATOR_HEADERS
17 #include <ATen/Functions.h>
18 #else
19 #include <ATen/ops/ones_like.h>
20 #include <ATen/ops/zeros.h>
21 #include <ATen/ops/zeros_like.h>
22 #endif
23
24 namespace torch::jit {
25
26 namespace {
27
28 using Tensor = at::Tensor;
29
supportedConvNode(Node * n)30 bool supportedConvNode(Node* n) {
31 switch (n->kind()) {
32 case aten::conv1d:
33 case aten::conv2d:
34 case aten::conv3d:
35 return true;
36 case aten::_convolution: {
37 auto transposed_conv =
38 constant_as<bool>(n->namedInput("transposed")).value_or(true);
39 // dont handle transposed conv yet or not-constant transpose parameter
40 return !transposed_conv;
41 }
42 default:
43 return false;
44 }
45 }
46
FoldFrozenConvBatchnorm(Block * b)47 bool FoldFrozenConvBatchnorm(Block* b) {
48 bool graph_modified = false;
49 for (Node* n : b->nodes()) {
50 for (Block* block : n->blocks()) {
51 graph_modified |= FoldFrozenConvBatchnorm(block);
52 }
53
54 if (n->kind() == aten::batch_norm &&
55 supportedConvNode(n->inputs().at(0)->node())) {
56 auto conv = n->inputs().at(0)->node();
57 auto bn = n;
58 if (nonConstantParameters(conv) || nonConstantParameters(bn)) {
59 continue;
60 }
61 if (conv->output()->uses().size() > 1) {
62 continue;
63 }
64
65 auto bn_rm_ivalue = bn->namedInput("running_mean");
66 auto bn_rv_ivalue = bn->namedInput("running_var");
67 // check running_mean and running_var has value, if they are
68 // None(track_running_stats=False), skipping the folding path.
69 if (bn_rm_ivalue->type() == NoneType::get() &&
70 bn_rv_ivalue->type() == NoneType::get()) {
71 continue;
72 }
73
74 auto bn_rm = constant_as<Tensor>(bn->namedInput("running_mean")).value();
75 auto bn_rv = constant_as<Tensor>(bn->namedInput("running_var")).value();
76 auto bn_eps = constant_as<double>(bn->namedInput("eps")).value();
77 auto conv_w = constant_as<Tensor>(conv->namedInput("weight")).value();
78
79 // implementation taken from torch/nn/utils/fusion.py
80 Tensor conv_b;
81 if (conv->namedInput("bias")->type() == NoneType::get()) {
82 // If this is on GPU and bias is none and weight was half/bfloat, but
83 // bn_rm was float, then probably this was a case where autocasting
84 // casted inputs to conv. And since CUDA conv implementation requires
85 // all the inputs to have the same scalar dtype, we need to make this
86 // placeholder have the same type as conv_w.
87 at::ScalarType bias_dtype = bn_rm.scalar_type();
88 at::ScalarType weight_dtype = conv_w.scalar_type();
89 if ((weight_dtype == at::kHalf || weight_dtype == at::kBFloat16) &&
90 bias_dtype == at::kFloat) {
91 bias_dtype = weight_dtype;
92 }
93 conv_b = at::zeros_like(bn_rm, at::TensorOptions().dtype(bias_dtype));
94 } else {
95 conv_b = constant_as<Tensor>(conv->namedInput("bias")).value();
96 }
97 Tensor bn_w;
98 if (bn->namedInput("weight")->type() == NoneType::get()) {
99 bn_w = at::ones_like(bn_rm);
100 } else {
101 bn_w = constant_as<Tensor>(bn->namedInput("weight")).value();
102 }
103 Tensor bn_b;
104 if (n->namedInput("bias")->type() == NoneType::get()) {
105 bn_b = at::zeros_like(bn_rm);
106 } else {
107 bn_b = constant_as<Tensor>(bn->namedInput("bias")).value();
108 }
109
110 ConvBNParameters params;
111 params.conv_w = conv_w;
112 params.conv_b = conv_b;
113 params.bn_rm = bn_rm;
114 params.bn_rv = bn_rv;
115 params.bn_eps = bn_eps;
116 params.bn_w = bn_w;
117 params.bn_b = bn_b;
118 std::tuple<Tensor, Tensor> out = computeUpdatedConvWeightAndBias(params);
119 WithInsertPoint guard(conv);
120 auto fused_conv_w = b->owningGraph()->insertConstant(std::get<0>(out));
121 auto fused_conv_b = b->owningGraph()->insertConstant(std::get<1>(out));
122 auto conv_w_value = conv->namedInput("weight");
123 auto conv_b_value = conv->namedInput("bias");
124
125 fused_conv_w->setDebugName(conv_w_value->debugName() + "_fused_bn");
126 fused_conv_b->setDebugName(conv_b_value->debugName() + "_fused_bn");
127
128 conv->replaceInputWith(conv_w_value, fused_conv_w);
129 conv->replaceInputWith(conv_b_value, fused_conv_b);
130
131 bn->output()->replaceAllUsesWith(conv->output());
132 graph_modified = true;
133 }
134 }
135 return graph_modified;
136 }
137
supportedAddOrSub(Node * n)138 bool supportedAddOrSub(Node* n) {
139 static const OperatorSet add_set{
140 "aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor",
141 "aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor",
142 // sub is equivalent to add
143 "aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor",
144 "aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor",
145 };
146 return n->isMemberOf(add_set);
147 }
148
149 // In order to fuse add/sub/mul/div with conv, the dimensions of its
150 // constant tensor must satisfy the following:
151 // - with resizing, broadcast to w/ weight/bias tensor shape
152 // - broadcast to the conv output shape
153 // It needs to have a shape that can resize to weight/bias
154 // tensor shape because we need to run the op with the conv
155 // weights/bias without changing their sizes.
156 // It needs to broadcast to the conv output shape so that we do
157 // accidentally change the shape of op output by pre-fusing it
158 // compared to eager.
159 // The only dimension value shared by weight/bias/conv output
160 // is they all contain a dim with value = channels-out. In the
161 // conv output tensor, this is in the second dimension,
162 // so the pointwise op tensor may have a second dimension of
163 // value == channels-out, but all the other dimensions have to be 1
opDoesNotBroadCastWithConv(Tensor & op_tensor,Tensor & weight_tensor)164 bool opDoesNotBroadCastWithConv(Tensor& op_tensor, Tensor& weight_tensor) {
165 if (op_tensor.ndimension() > weight_tensor.ndimension()) {
166 return false;
167 }
168 for (int64_t i = op_tensor.ndimension() - 1; i >= 0; i--) {
169 // channels-out dimension == weight_tensor.size(0)
170 if (i == 1 && op_tensor.size(i) == weight_tensor.size(0)) {
171 continue;
172 }
173 if (op_tensor.size(i) != 1) {
174 return false;
175 }
176 }
177 return true;
178 }
179
checkConvAndBroadcastingOpPreConditions(Node * conv,Node * op)180 bool checkConvAndBroadcastingOpPreConditions(Node* conv, Node* op) {
181 if (nonConstantParameters(conv) || nonConstantParameters(op)) {
182 return false;
183 }
184
185 if (conv->output()->uses().size() > 1) {
186 return false;
187 }
188
189 Tensor weight_tensor =
190 constant_as<Tensor>(conv->namedInput("weight")).value();
191
192 // avoid fusing op that causes type promotion
193 // restricting to float avoids int/float difficulties with scalar overload
194 if (!weight_tensor.is_floating_point()) {
195 return false;
196 }
197
198 if (op->inputs().at(1)->type()->cast<TensorType>()) {
199 auto op_tensor = constant_as<Tensor>(op->inputs().at(1)).value();
200 if (!opDoesNotBroadCastWithConv(op_tensor, weight_tensor)) {
201 return false;
202 }
203
204 if (!op_tensor.is_floating_point() &&
205 c10::promoteTypes(
206 op_tensor.scalar_type(), weight_tensor.scalar_type()) !=
207 weight_tensor.scalar_type()) {
208 return false;
209 }
210 }
211 return true;
212 }
213
resizeConstantScalarOrTensorToShape(Value * v,const std::vector<int64_t> & shape,at::TensorOptions options)214 Tensor resizeConstantScalarOrTensorToShape(
215 Value* v,
216 const std::vector<int64_t>& shape,
217 at::TensorOptions options) {
218 Tensor ret_tensor;
219 if (v->type()->cast<TensorType>()) {
220 ret_tensor = constant_as<Tensor>(v).value();
221 } else {
222 ret_tensor = at::zeros(shape, options);
223 if (v->type()->cast<IntType>()) {
224 ret_tensor.fill_(constant_as<int64_t>(v).value());
225 } else {
226 ret_tensor.fill_(constant_as<double>(v).value());
227 }
228 }
229
230 if (ret_tensor.numel() == 1) {
231 // expand errors if the shape input has less # dims than the tensor input
232 ret_tensor = ret_tensor.reshape({1});
233 ret_tensor = ret_tensor.expand(shape);
234 } else {
235 TORCH_INTERNAL_ASSERT(ret_tensor.numel() == c10::multiply_integers(shape));
236 ret_tensor = ret_tensor.view(shape);
237 }
238 return ret_tensor;
239 }
240
FoldFrozenConvAddOrSub(Block * b)241 bool FoldFrozenConvAddOrSub(Block* b) {
242 bool graph_modified = false;
243 for (Node* n : b->nodes()) {
244 for (Block* block : n->blocks()) {
245 graph_modified |= FoldFrozenConvAddOrSub(block);
246 }
247
248 if (supportedAddOrSub(n) && supportedConvNode(n->inputs().at(0)->node())) {
249 auto conv = n->inputs().at(0)->node();
250 auto add_or_sub = n;
251
252 if (!checkConvAndBroadcastingOpPreConditions(conv, add_or_sub)) {
253 continue;
254 }
255
256 Tensor weight_tensor =
257 constant_as<Tensor>(conv->namedInput("weight")).value();
258
259 Tensor add_or_sub_tensor = resizeConstantScalarOrTensorToShape(
260 add_or_sub->inputs().at(1),
261 {weight_tensor.size(0)},
262 weight_tensor.options());
263 Tensor bias;
264 if (conv->namedInput("bias")->type() == NoneType::get()) {
265 bias = at::zeros_like(add_or_sub_tensor, weight_tensor.dtype());
266 } else {
267 bias = constant_as<Tensor>(conv->namedInput("bias")).value();
268 }
269
270 WithInsertPoint guard(conv);
271
272 add_or_sub->replaceInputWith(
273 conv->output(), b->owningGraph()->insertConstant(bias));
274 add_or_sub->replaceInput(
275 1, b->owningGraph()->insertConstant(add_or_sub_tensor));
276
277 auto stack_out = runNodeIfInputsAreConstant(add_or_sub);
278 TORCH_INTERNAL_ASSERT(stack_out && stack_out->size() == 1);
279 Tensor fuse_bias = (*stack_out)[0].toTensor().to(bias.dtype());
280
281 auto fused_conv_b = b->owningGraph()->insertConstant(fuse_bias);
282 auto conv_b_value = conv->namedInput("bias");
283
284 fused_conv_b->setDebugName(
285 conv_b_value->debugName() + "_fused_" +
286 add_or_sub->kind().toUnqualString());
287 conv->replaceInputWith(conv_b_value, fused_conv_b);
288 add_or_sub->output()->replaceAllUsesWith(conv->output());
289 graph_modified = true;
290 // DCE run after cleans up nodes
291 }
292 }
293 return graph_modified;
294 }
295
supportedMulOrDiv(Node * n)296 bool supportedMulOrDiv(Node* n) {
297 static const OperatorSet add_set{
298 "aten::mul.Tensor(Tensor self, Tensor other) -> Tensor",
299 "aten::mul.Scalar(Tensor self, Scalar other) -> Tensor",
300 // div is equivalent to mul
301 "aten::div.Tensor(Tensor self, Tensor other) -> Tensor",
302 "aten::div.Scalar(Tensor self, Scalar other) -> Tensor",
303 };
304 return n->isMemberOf(add_set);
305 }
306
FoldFrozenConvMulOrDiv(Block * b)307 bool FoldFrozenConvMulOrDiv(Block* b) {
308 bool graph_modified = false;
309 for (Node* n : b->nodes()) {
310 for (Block* block : n->blocks()) {
311 graph_modified |= FoldFrozenConvMulOrDiv(block);
312 }
313
314 if (supportedMulOrDiv(n) && supportedConvNode(n->inputs().at(0)->node())) {
315 auto conv = n->inputs().at(0)->node();
316 auto mul_or_div = n;
317
318 if (!checkConvAndBroadcastingOpPreConditions(conv, mul_or_div)) {
319 continue;
320 }
321
322 Tensor weight_tensor =
323 constant_as<Tensor>(conv->namedInput("weight")).value();
324 int64_t out_channels = weight_tensor.size(0);
325
326 // We've already verified that the second input has numel == 1 or
327 // channels-out resize it to the shape that will broadcast to
328 // weight_tensor when the op is run so we dont change weight size
329 std::vector<int64_t> weight_compatible_size = {out_channels};
330 for (const auto i : c10::irange(1, weight_tensor.ndimension())) {
331 (void)i; // Suppress unused variable warning
332 weight_compatible_size.push_back(1);
333 }
334
335 WithInsertPoint guard(conv);
336
337 Tensor mul_tensor = resizeConstantScalarOrTensorToShape(
338 mul_or_div->inputs().at(1),
339 weight_compatible_size,
340 weight_tensor.options());
341
342 // First fold with weight tensor
343 mul_or_div->replaceInputWith(
344 conv->output(), b->owningGraph()->insertConstant(weight_tensor));
345 mul_or_div->replaceInput(1, b->owningGraph()->insertConstant(mul_tensor));
346
347 auto stack_out = runNodeIfInputsAreConstant(mul_or_div);
348 TORCH_INTERNAL_ASSERT(stack_out && stack_out->size() == 1);
349 Tensor fuse_weight = (*stack_out)[0].toTensor().to(weight_tensor.dtype());
350
351 auto fused_conv_weight = b->owningGraph()->insertConstant(fuse_weight);
352 auto conv_weight_value = conv->namedInput("weight");
353
354 fused_conv_weight->setDebugName(
355 conv_weight_value->debugName() + "_fused_" +
356 mul_or_div->kind().toUnqualString());
357 conv->replaceInputWith(conv_weight_value, fused_conv_weight);
358 mul_or_div->output()->replaceAllUsesWith(conv->output());
359
360 // now fold with bias tensor
361 if (conv->namedInput("bias")->type() != NoneType::get()) {
362 Tensor bias = constant_as<Tensor>(conv->namedInput("bias")).value();
363 // bias is of shape {channels_out}
364 auto mul_tensor = resizeConstantScalarOrTensorToShape(
365 mul_or_div->inputs().at(1), {out_channels}, bias.options());
366
367 mul_or_div->replaceInput(0, b->owningGraph()->insertConstant(bias));
368 mul_or_div->replaceInput(
369 1, b->owningGraph()->insertConstant(mul_tensor));
370
371 auto stack_out = runNodeIfInputsAreConstant(mul_or_div);
372 TORCH_INTERNAL_ASSERT(stack_out && stack_out->size() == 1);
373 Tensor fuse_bias = (*stack_out)[0].toTensor().to(bias.dtype());
374
375 auto fused_conv_bias = b->owningGraph()->insertConstant(fuse_bias);
376 auto conv_b_value = conv->namedInput("bias");
377
378 fused_conv_weight->setDebugName(
379 conv_b_value->debugName() + "_fused_" +
380 mul_or_div->kind().toUnqualString());
381 conv->replaceInputWith(conv_b_value, fused_conv_bias);
382 }
383 graph_modified = true;
384 // DCE run after cleans up nodes
385 }
386 }
387 return graph_modified;
388 }
389
390 } // namespace
391
FoldFrozenConvBatchnorm(std::shared_ptr<Graph> & graph)392 bool FoldFrozenConvBatchnorm(std::shared_ptr<Graph>& graph) {
393 bool graph_modified = FoldFrozenConvBatchnorm(graph->block());
394 EliminateDeadCode(graph);
395 return graph_modified;
396 }
397
FoldFrozenConvAddOrSub(std::shared_ptr<Graph> & graph)398 bool FoldFrozenConvAddOrSub(std::shared_ptr<Graph>& graph) {
399 bool graph_modified = FoldFrozenConvAddOrSub(graph->block());
400 EliminateDeadCode(graph);
401 return graph_modified;
402 }
403
FoldFrozenConvMulOrDiv(std::shared_ptr<Graph> & graph)404 bool FoldFrozenConvMulOrDiv(std::shared_ptr<Graph>& graph) {
405 bool graph_modified = FoldFrozenConvMulOrDiv(graph->block());
406 EliminateDeadCode(graph);
407 return graph_modified;
408 }
409
410 } // namespace torch::jit
411