xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/mps/operations/Convolution.mm (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1//  Copyright © 2022 Apple Inc.
2#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3#include <ATen/native/ConvUtils.h>
4#include <ATen/native/mps/MPSGraphVenturaOps.h>
5#include <ATen/native/mps/OperationUtils.h>
6#include <ATen/ops/_mps_convolution_native.h>
7#include <ATen/ops/_mps_convolution_transpose_native.h>
8#include <ATen/ops/mps_convolution_backward_native.h>
9#include <ATen/ops/mps_convolution_transpose_backward_native.h>
10
11#if !defined(__MAC_13_2) && (!defined(MAC_OS_X_VERSION_13_2) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_2))
12
13@implementation FakeMPSGraphConvolution3DOpDescriptor
14- (nonnull id)copyWithZone:(nullable NSZone*)zone {
15  return self;
16}
17
18@end
19
20#endif
21
22namespace at::native {
23
24// Create 3D convolution descriptor
25static void fill_conv3d_desc(MPSGraphConvolution3DOpDescriptor* descriptor_,
26                             NSUInteger strideInX,
27                             NSUInteger strideInY,
28                             NSUInteger strideInZ,
29                             NSUInteger dilationRateInX,
30                             NSUInteger dilationRateInY,
31                             NSUInteger dilationRateInZ,
32                             NSUInteger paddingHorizontal,
33                             NSUInteger paddingVertical,
34                             NSUInteger paddingDepth,
35                             NSUInteger groups) {
36  descriptor_.strideInX = strideInX;
37  descriptor_.strideInY = strideInY;
38  descriptor_.strideInZ = strideInZ;
39  descriptor_.dilationRateInX = dilationRateInX;
40  descriptor_.dilationRateInY = dilationRateInY;
41  descriptor_.dilationRateInZ = dilationRateInZ;
42
43  // TODO: Program the padding style
44  descriptor_.paddingStyle = MPSGraphPaddingStyleExplicit;
45
46  descriptor_.paddingLeft = paddingHorizontal;
47  descriptor_.paddingRight = paddingHorizontal;
48  descriptor_.paddingTop = paddingVertical;
49  descriptor_.paddingBottom = paddingVertical;
50  descriptor_.paddingFront = paddingDepth;
51  descriptor_.paddingBack = paddingDepth;
52
53  // PyTorch always uses NCDHW memory layout for 3D tensors
54  descriptor_.dataLayout = (MPSGraphTensorNamedDataLayout)7L; // MPSGraphTensorNamedDataLayoutNCDHW;
55
56  // PyTorch always uses OIDHW memory layout for 3D weights
57  descriptor_.weightsLayout = (MPSGraphTensorNamedDataLayout)9L; // MPSGraphTensorNamedDataLayoutOIDHW;
58
59  descriptor_.groups = groups; // not yet tested in Xcode/C++
60}
61
62static void fill_depthwise_conv_desc(MPSGraphDepthwiseConvolution3DOpDescriptor* descriptor_,
63                                     NSUInteger strideInX,
64                                     NSUInteger strideInY,
65                                     NSUInteger dilationRateInX,
66                                     NSUInteger dilationRateInY,
67                                     NSUInteger paddingHorizontal,
68                                     NSUInteger paddingVertical,
69                                     c10::MemoryFormat memory_format,
70                                     NSUInteger groups) {
71  descriptor_.strides =
72      @[ @1, [[NSNumber alloc] initWithInteger:strideInY], [[NSNumber alloc] initWithInteger:strideInX] ];
73  descriptor_.dilationRates =
74      @[ @1, [[NSNumber alloc] initWithInteger:dilationRateInY], [[NSNumber alloc] initWithInteger:dilationRateInX] ];
75
76  descriptor_.paddingStyle = MPSGraphPaddingStyleExplicit;
77  descriptor_.paddingValues = @[
78    @0,
79    @0,
80    [[NSNumber alloc] initWithInteger:paddingVertical],
81    [[NSNumber alloc] initWithInteger:paddingVertical],
82    [[NSNumber alloc] initWithInteger:paddingHorizontal],
83    [[NSNumber alloc] initWithInteger:paddingHorizontal]
84  ];
85  descriptor_.channelDimensionIndex = -3LL;
86}
87
88// Create convolution descriptor
89static void fill_conv_desc(MPSGraphConvolution2DOpDescriptor* descriptor_,
90                           NSUInteger strideInX,
91                           NSUInteger strideInY,
92                           NSUInteger dilationRateInX,
93                           NSUInteger dilationRateInY,
94                           NSUInteger paddingHorizontal,
95                           NSUInteger paddingVertical,
96                           c10::MemoryFormat memory_format,
97                           NSUInteger groups) {
98  descriptor_.strideInX = strideInX;
99  descriptor_.strideInY = strideInY;
100  descriptor_.dilationRateInX = dilationRateInX;
101  descriptor_.dilationRateInY = dilationRateInY;
102
103  // TODO: Program the padding style
104  descriptor_.paddingStyle = MPSGraphPaddingStyleExplicit;
105
106  descriptor_.paddingLeft = paddingHorizontal;
107  descriptor_.paddingRight = paddingHorizontal;
108  descriptor_.paddingTop = paddingVertical;
109  descriptor_.paddingBottom = paddingVertical;
110
111  descriptor_.dataLayout = (memory_format == at::MemoryFormat::Contiguous) ? MPSGraphTensorNamedDataLayoutNCHW
112                                                                           : MPSGraphTensorNamedDataLayoutNHWC;
113
114  // PyTorch always uses OIHW memory layout for weights
115  descriptor_.weightsLayout = MPSGraphTensorNamedDataLayoutOIHW;
116  descriptor_.groups = groups;
117}
118
119static Tensor _mps_convolution_impl(const Tensor& input_t_,
120                                    const Tensor& weight_t,
121                                    const std::optional<Tensor>& bias_opt,
122                                    IntArrayRef padding,
123                                    IntArrayRef stride,
124                                    IntArrayRef dilation,
125                                    int64_t groups,
126                                    std::optional<IntArrayRef> input_shape) {
127  const bool is_macOS_13_2_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS);
128  const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS);
129  Tensor input_t = input_t_;
130  if (!is_macOS_15_0_or_newer) {
131    input_t = input_t.contiguous();
132  }
133
134  TORCH_CHECK(((input_t.dim() < 5) || is_macOS_13_2_or_newer),
135              "Conv3D is only supported on MPS for MacOS_13_2 or newer");
136  bool is3DConv = input_t.dim() == 5;
137
138  TORCH_CHECK(isFloatingType(input_t.scalar_type()), "Convolution is supported only for Floating types");
139
140  using namespace at::native::mps;
141  CheckedFrom c = "mps_convolution";
142  TensorArg input{input_t, "input", 1}, weight{weight_t, "weight", 2};
143  checkAllSameType(c, {input, weight});
144  checkAllSameGPU(c, {input, weight});
145
146  bool bias_defined;
147
148  if (bias_opt == std::nullopt)
149    bias_defined = false;
150  else
151    bias_defined = bias_opt->defined();
152
153  auto memory_format = input_t.suggest_memory_format();
154  bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast) && !is3DConv;
155  auto output_t =
156      at::empty(input_shape.has_value() ? input_shape.value()
157                                        : conv_output_size(input->sizes(), weight->sizes(), padding, stride, dilation),
158                input->scalar_type(),
159                std::nullopt,
160                kMPS,
161                std::nullopt,
162                is_macOS_15_0_or_newer ? memory_format : MemoryFormat::Contiguous);
163  if (output_t.numel() == 0) {
164    return output_t;
165  }
166  TensorArg output{output_t, "result", 0};
167
168  // TODO: MPS convolution kernel currently does not support output channels > 2^16
169  for (auto elem : output_t.sizes()) {
170    TORCH_CHECK_NOT_IMPLEMENTED(
171        elem <= (1 << 16),
172        "Output channels > 65536 not supported at the MPS device. ",
173        "As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` ",
174        "to use the CPU as a fallback for this op. WARNING: this will be slower than running natively ",
175        "on MPS.");
176  }
177
178  convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups);
179
180  // Derive from MPSCachedGraph
181  struct CachedGraph : public MPSCachedGraph {
182    CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
183    MPSGraphTensor* inputTensor_ = nil;
184    MPSGraphTensor* biasTensor_ = nil;
185    MPSGraphTensor* weightTensor_ = nil;
186    MPSGraphTensor* outputTensor_ = nil;
187  };
188
189  auto stream = at::mps::getCurrentMPSStream();
190
191  @autoreleasepool {
192    IntArrayRef bias_shape;
193    if (bias_defined)
194      bias_shape = bias_opt.value().sizes();
195
196    string mem_format_key;
197    switch (memory_format) {
198      case at::MemoryFormat::Contiguous:
199        mem_format_key = "Contiguous";
200        break;
201      case at::MemoryFormat::ChannelsLast:
202        mem_format_key = "ChannelsLast";
203        break;
204      default:
205        assert(0 && "Check should have been done earlier\n");
206    }
207
208    string bias_shape_key;
209    if (bias_defined) {
210      bias_shape_key = std::to_string(bias_shape[0]);
211    } else {
212      bias_shape_key = "nobias";
213    }
214
215    string key;
216    if (is3DConv) {
217      key = "mps_3d_convolution:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
218          std::to_string(stride[2]) + ":" + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" +
219          std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" +
220          std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key +
221          mps::getTensorsStringKey({input_t, weight_t}) + ":" + std::to_string(bias_defined) + ":" + bias_shape_key;
222
223    } else {
224      key = "mps_convolution:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
225          std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" +
226          std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key +
227          mps::getTensorsStringKey({input_t, weight_t}) + ":" + std::to_string(bias_defined) + ":" + bias_shape_key;
228    }
229
230    MPSShape* inputShape = mps::getMPSShape(input_t, memory_format);
231    MPSShape* outputShape = mps::getMPSShape(output_t, memory_format);
232    MPSNDArray* inputNDArray = nil;
233    MPSNDArray* outputNDArray = nil;
234
235    if (input_t.is_contiguous(memory_format) && output_t.is_contiguous(memory_format) && is_macOS_15_0_or_newer) {
236      inputNDArray = getMPSNDArray(input_t, inputShape);
237      outputNDArray = getMPSNDArray(*output, outputShape);
238    }
239
240    auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
241      MPSShape* weightShape = mps::getMPSShape(weight_t);
242      bool isDepthwiseConv = ((groups > 1 && (weightShape[1].intValue == 1)) && inputShape.count >= 4 &&
243                              weightShape.count >= 4 && !is_channels_last);
244
245      MPSGraphTensor* inputTensor =
246          mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(input_t.scalar_type()), inputShape);
247      MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_t);
248      MPSGraphTensor* outputTensor;
249      if (is3DConv) {
250        MPSGraphConvolution3DOpDescriptor* conv3dDescriptor_ = [[MPSGraphConvolution3DOpDescriptor new] autorelease];
251        fill_conv3d_desc(conv3dDescriptor_,
252                         stride[2],
253                         stride[1],
254                         stride[0],
255                         dilation[2],
256                         dilation[1],
257                         dilation[0],
258                         padding[2],
259                         padding[1],
260                         padding[0],
261                         groups);
262
263        outputTensor = [mpsGraph convolution3DWithSourceTensor:inputTensor
264                                                 weightsTensor:weightTensor
265                                                    descriptor:conv3dDescriptor_
266                                                          name:nil];
267      } else if (isDepthwiseConv) {
268        MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor_ =
269            [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
270        fill_depthwise_conv_desc(depthWiseConv3dDescriptor_,
271                                 stride[1],
272                                 stride[0],
273                                 dilation[1],
274                                 dilation[0],
275                                 padding[1],
276                                 padding[0],
277                                 memory_format,
278                                 groups);
279
280        MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor
281                                                                dimension:-3
282                                                            withDimension:-4
283                                                                     name:nil];
284        outputTensor = [mpsGraph depthwiseConvolution3DWithSourceTensor:inputTensor
285                                                          weightsTensor:weightTransposeTensor
286                                                             descriptor:depthWiseConv3dDescriptor_
287                                                                   name:nil];
288      } else {
289        MPSGraphConvolution2DOpDescriptor* conv2dDescriptor_ = [[MPSGraphConvolution2DOpDescriptor new] autorelease];
290        fill_conv_desc(conv2dDescriptor_,
291                       stride[1],
292                       stride[0],
293                       dilation[1],
294                       dilation[0],
295                       padding[1],
296                       padding[0],
297                       memory_format,
298                       groups);
299
300        outputTensor = [mpsGraph convolution2DWithSourceTensor:inputTensor
301                                                 weightsTensor:weightTensor
302                                                    descriptor:conv2dDescriptor_
303                                                          name:nil];
304      }
305
306      MPSGraphTensor* biasTensor = nil;
307      if (bias_defined) {
308        biasTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(bias_opt.value()));
309      }
310
311      if (is_channels_last && !is_macOS_15_0_or_newer) {
312        outputTensor = mps::convertNHWCtoNCHW(mpsGraph, outputTensor);
313      }
314
315      if (bias_defined) {
316        outputTensor = [mpsGraph additionWithPrimaryTensor:outputTensor secondaryTensor:biasTensor name:nil];
317      }
318      newCachedGraph->inputTensor_ = inputTensor;
319      newCachedGraph->weightTensor_ = weightTensor;
320      newCachedGraph->biasTensor_ = biasTensor;
321      newCachedGraph->outputTensor_ = outputTensor;
322    });
323
324    auto inputPlaceholder = inputNDArray ? Placeholder(cachedGraph->inputTensor_, inputNDArray)
325                                         : Placeholder(cachedGraph->inputTensor_, input_t, inputShape);
326    auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_t);
327    auto biasPlaceholder = Placeholder();
328    // Reshape the bias to be broadcastable with output of conv2d or conv3d
329    if (bias_defined) {
330      if (is3DConv) {
331        biasPlaceholder = Placeholder(cachedGraph->biasTensor_, (bias_opt.value()).view({1, bias_shape[0], 1, 1, 1}));
332      } else {
333        if (is_channels_last && is_macOS_15_0_or_newer) {
334          biasPlaceholder = Placeholder(cachedGraph->biasTensor_, (bias_opt.value()).view({1, 1, 1, bias_shape[0]}));
335        } else {
336          biasPlaceholder = Placeholder(cachedGraph->biasTensor_, (bias_opt.value()).view({1, bias_shape[0], 1, 1}));
337        }
338      }
339    }
340    auto outputPlaceholder = outputNDArray ? Placeholder(cachedGraph->outputTensor_, outputNDArray)
341                                           : Placeholder(cachedGraph->outputTensor_, *output);
342
343    NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =
344        [[[NSMutableDictionary alloc] initWithCapacity:3] autorelease];
345    feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
346    feeds[weightsPlaceholder.getMPSGraphTensor()] = weightsPlaceholder.getMPSGraphTensorData();
347    if (bias_defined) {
348      feeds[biasPlaceholder.getMPSGraphTensor()] = biasPlaceholder.getMPSGraphTensorData();
349    }
350
351    runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
352  }
353
354  return *output;
355}
356
357Tensor _mps_convolution(const Tensor& input_t,
358                        const Tensor& weight_t,
359                        const std::optional<Tensor>& bias_opt,
360                        IntArrayRef padding,
361                        IntArrayRef stride,
362                        IntArrayRef dilation,
363                        int64_t groups) {
364  return _mps_convolution_impl(input_t, weight_t, bias_opt, padding, stride, dilation, groups, std::nullopt);
365}
366
367static Tensor mps_convolution_backward_input(IntArrayRef input_size,
368                                             const Tensor& grad_output_t,
369                                             const Tensor& weight_t,
370                                             IntArrayRef padding,
371                                             IntArrayRef stride,
372                                             IntArrayRef dilation,
373                                             int64_t groups,
374                                             bool bias_defined) {
375  using namespace at::native::mps;
376  using namespace mps;
377  bool is3DConv = grad_output_t.dim() == 5;
378
379  // TODO: MPS convolution kernel currently does not support output channels > 2^16
380  for (auto elem : grad_output_t.sizes()) {
381    TORCH_CHECK_NOT_IMPLEMENTED(
382        elem <= (1 << 16),
383        "Output channels > 65536 not supported at the MPS device. ",
384        "As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` ",
385        "to use the CPU as a fallback for this op. WARNING: this will be slower than running natively ",
386        "on MPS.");
387  }
388
389  TORCH_CHECK(isFloatingType(grad_output_t.scalar_type()), "Convolution is supported only for Floating types");
390  CheckedFrom c = "mps_convolution_backward_input";
391  TensorArg grad_output{grad_output_t, "grad_output", 1}, weight{weight_t, "weight", 2};
392  checkAllSameType(c, {grad_output, weight});
393  checkAllSameGPU(c, {grad_output, weight});
394  auto memory_format = grad_output_t.suggest_memory_format();
395  bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast) && !is3DConv;
396  auto grad_input_t = at::empty(input_size, grad_output_t.options(), std::nullopt);
397
398  // Avoid "grad_input" when this is being used as transposed convolution
399  TensorArg grad_input{grad_input_t, "result", 0};
400  convolution_shape_check(c, grad_input, weight, grad_output, padding, stride, dilation, groups);
401
402  // Derive from MPSCachedGraph
403  struct CachedGraph : public MPSCachedGraph {
404    CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
405    MPSGraphTensor* gradOutputTensor_ = nil;
406    MPSGraphTensor* weightTensor_ = nil;
407    MPSGraphTensor* gradInputTensor_ = nil;
408  };
409
410  // Add backward with input
411  @autoreleasepool {
412    MPSStream* stream = getCurrentMPSStream();
413
414    string mem_format_key;
415    switch (memory_format) {
416      case at::MemoryFormat::Contiguous:
417        mem_format_key = "Contiguous";
418        break;
419      case at::MemoryFormat::ChannelsLast:
420        mem_format_key = "ChannelsLast";
421        break;
422      default:
423        assert(0 && "Check should have been done earlier\n");
424    }
425
426    MPSShape* gradOutputShape = getMPSShape(grad_output_t, memory_format);
427    MPSShape* mps_input_shape = getMPSShape(input_size);
428    NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","];
429    string key;
430    if (is3DConv) {
431      key = "mps_3d_convolution_backward_input:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
432          ":" + std::to_string(stride[2]) + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" +
433          std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" +
434          std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key +
435          getTensorsStringKey({grad_output_t, weight_t}) + ":" + string([ns_shape_key UTF8String]);
436
437    } else {
438      key = "mps_convolution_backward_input:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
439          std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" +
440          std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key +
441          getTensorsStringKey({grad_output_t, weight_t}) + ":" + string([ns_shape_key UTF8String]);
442    }
443    auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
444      MPSGraphTensor* gradOutputTensor =
445          mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(grad_output_t), gradOutputShape);
446      MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_t);
447
448      MPSGraphTensor* gradOutputTensorTranspose = gradOutputTensor;
449      if (is_channels_last) {
450        gradOutputTensorTranspose = mps::convertNHWCtoNCHW(mpsGraph, gradOutputTensorTranspose);
451      }
452      MPSGraphTensor* gradInputTensor;
453      MPSShape* weightOutputShape = mps::getMPSShape(weight_t);
454      // Depthwise conv is input feature channels = groups. So I in OIHW has to be 1.
455      bool isDepthwiseConv = ((groups > 1 && (weightOutputShape[1].intValue == 1)) && gradOutputShape.count >= 4 &&
456                              weightOutputShape.count >= 4 && !is_channels_last);
457
458      if (is3DConv) {
459        MPSGraphConvolution3DOpDescriptor* conv3dDescriptor_ = [[MPSGraphConvolution3DOpDescriptor new] autorelease];
460        fill_conv3d_desc(conv3dDescriptor_,
461                         stride[2],
462                         stride[1],
463                         stride[0],
464                         dilation[2],
465                         dilation[1],
466                         dilation[0],
467                         padding[2],
468                         padding[1],
469                         padding[0],
470                         groups);
471        gradInputTensor = [mpsGraph convolution3DDataGradientWithIncomingGradientTensor:gradOutputTensorTranspose
472                                                                          weightsTensor:weightTensor
473                                                                            outputShape:mps_input_shape
474                                                           forwardConvolutionDescriptor:conv3dDescriptor_
475                                                                                   name:nil];
476      } else if (isDepthwiseConv) {
477        MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor_ =
478            [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
479        fill_depthwise_conv_desc(depthWiseConv3dDescriptor_,
480                                 stride[1],
481                                 stride[0],
482                                 dilation[1],
483                                 dilation[0],
484                                 padding[1],
485                                 padding[0],
486                                 at::MemoryFormat::Contiguous,
487                                 groups);
488        MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor
489                                                                dimension:-3
490                                                            withDimension:-4
491                                                                     name:nil];
492        gradInputTensor =
493            [mpsGraph depthwiseConvolution3DDataGradientWithIncomingGradientTensor:gradOutputTensorTranspose
494                                                                     weightsTensor:weightTransposeTensor
495                                                                       outputShape:mps_input_shape
496                                                                        descriptor:depthWiseConv3dDescriptor_
497                                                                              name:nil];
498      } else {
499        MPSGraphConvolution2DOpDescriptor* conv2dDescriptor_ = [[MPSGraphConvolution2DOpDescriptor new] autorelease];
500        fill_conv_desc(conv2dDescriptor_,
501                       stride[1],
502                       stride[0],
503                       dilation[1],
504                       dilation[0],
505                       padding[1],
506                       padding[0],
507                       at::MemoryFormat::Contiguous,
508                       groups);
509
510        gradInputTensor = [mpsGraph convolution2DDataGradientWithIncomingGradientTensor:gradOutputTensorTranspose
511                                                                          weightsTensor:weightTensor
512                                                                            outputShape:mps_input_shape
513                                                           forwardConvolutionDescriptor:conv2dDescriptor_
514                                                                                   name:nil];
515      }
516
517      newCachedGraph->gradOutputTensor_ = gradOutputTensor;
518      newCachedGraph->weightTensor_ = weightTensor;
519      newCachedGraph->gradInputTensor_ = gradInputTensor;
520    });
521
522    auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t, gradOutputShape);
523    auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_t);
524    auto outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, *grad_input);
525
526    auto feeds = dictionaryFromPlaceholders(gradOutputPlaceholder, weightsPlaceholder);
527    runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
528  }
529  return *grad_input;
530}
531
532static Tensor mps_convolution_backward_weights(IntArrayRef weight_size,
533                                               const Tensor& grad_output_t,
534                                               const Tensor& input_t,
535                                               IntArrayRef padding,
536                                               IntArrayRef stride,
537                                               IntArrayRef dilation,
538                                               int64_t groups,
539                                               bool bias_defined) {
540  using namespace at::native::mps;
541  using namespace mps;
542  bool is3DConv = input_t.dim() == 5;
543  TORCH_CHECK(isFloatingType(grad_output_t.scalar_type()), "Convolution is supported only for Floating types");
544  CheckedFrom c = "mps_convolution_backward_weights";
545  auto memory_format = grad_output_t.suggest_memory_format();
546  bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast) && !is3DConv;
547
548  MPSShape* gradOutputShape = mps::getMPSShape(grad_output_t, memory_format);
549
550  // For uniformity with everything else, although it seems grad_weight
551  // would be unambiguous too.
552  TensorArg grad_output{grad_output_t, "grad_output", 1};
553  TensorArg input{input_t, "input", 2};
554
555  checkAllSameType(c, {grad_output, input});
556  checkAllSameGPU(c, {grad_output, input});
557
558  auto grad_weight_t =
559      at::empty(weight_size, grad_output_t.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
560  TensorArg grad_weight{grad_weight_t, "result", 0};
561
562  convolution_shape_check(c, input, grad_weight, grad_output, padding, stride, dilation, groups);
563
564  // Derive from MPSCachedGraph
565  struct CachedGraph : public MPSCachedGraph {
566    CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
567    MPSGraphTensor* gradOutputTensor_ = nil;
568    MPSGraphTensor* inputTensor_ = nil;
569    MPSGraphTensor* gradWeightTensor_ = nil;
570  };
571
572  @autoreleasepool {
573    MPSStream* stream = getCurrentMPSStream();
574
575    string mem_format_key;
576    switch (memory_format) {
577      case at::MemoryFormat::Contiguous:
578        mem_format_key = "Contiguous";
579        break;
580      case at::MemoryFormat::ChannelsLast:
581        mem_format_key = "ChannelsLast";
582        break;
583      default:
584        assert(0 && "Check should have been done earlier\n");
585    }
586    MPSShape* mps_weight_shape = getMPSShape(weight_size);
587    NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","];
588    string key;
589    if (is3DConv) {
590      key = "mps_3d_convolution_backward_weights:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
591          std::to_string(stride[2]) + ":" + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" +
592          std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" +
593          std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key +
594          getTensorsStringKey({grad_output_t, input_t, grad_weight_t}) + ":" + string([ns_shape_key UTF8String]);
595    } else {
596      key = "mps_convolution_backward_weights:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" +
597          std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" +
598          std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key +
599          getTensorsStringKey({grad_output_t, input_t, grad_weight_t}) + ":" + string([ns_shape_key UTF8String]);
600    }
601    auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
602      MPSShape* inputShape = mps::getMPSShape(input_t);
603      bool isDepthwiseConv = ((groups > 1 && (mps_weight_shape[1].intValue == 1)) && inputShape.count >= 4 &&
604                              mps_weight_shape.count >= 4 && !is_channels_last);
605
606      MPSGraphTensor* gradOutputTensor =
607          mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(grad_output_t), gradOutputShape);
608      MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
609
610      MPSGraphTensor* gradOutputTensorTranspose = gradOutputTensor;
611      if (is_channels_last) {
612        gradOutputTensorTranspose = mps::convertNHWCtoNCHW(mpsGraph, gradOutputTensorTranspose);
613      }
614
615      MPSGraphTensor* gradWeightTensor;
616      if (is3DConv) {
617        MPSGraphConvolution3DOpDescriptor* conv3dDescriptor_ = [[MPSGraphConvolution3DOpDescriptor new] autorelease];
618        fill_conv3d_desc(conv3dDescriptor_,
619                         stride[2],
620                         stride[1],
621                         stride[0],
622                         dilation[2],
623                         dilation[1],
624                         dilation[0],
625                         padding[2],
626                         padding[1],
627                         padding[0],
628                         groups);
629        gradWeightTensor = [mpsGraph convolution3DWeightsGradientWithIncomingGradientTensor:gradOutputTensorTranspose
630                                                                               sourceTensor:inputTensor
631                                                                                outputShape:mps_weight_shape
632                                                               forwardConvolutionDescriptor:conv3dDescriptor_
633                                                                                       name:nil];
634      } else if (isDepthwiseConv) {
635        MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor_ =
636            [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease];
637        fill_depthwise_conv_desc(depthWiseConv3dDescriptor_,
638                                 stride[1],
639                                 stride[0],
640                                 dilation[1],
641                                 dilation[0],
642                                 padding[1],
643                                 padding[0],
644                                 at::MemoryFormat::Contiguous,
645                                 groups);
646        NSNumber* outputFeatChannelDim = mps_weight_shape[0];
647        MPSShape* weightShapeTranspose = @[ @1, outputFeatChannelDim, mps_weight_shape[2], mps_weight_shape[3] ];
648        MPSGraphTensor* gradWeightTensorTranspose =
649            [mpsGraph depthwiseConvolution3DWeightsGradientWithIncomingGradientTensor:gradOutputTensorTranspose
650                                                                         sourceTensor:inputTensor
651                                                                          outputShape:weightShapeTranspose
652                                                                           descriptor:depthWiseConv3dDescriptor_
653                                                                                 name:nil];
654        gradWeightTensor = [mpsGraph transposeTensor:gradWeightTensorTranspose dimension:-3 withDimension:-4 name:nil];
655      } else {
656        MPSGraphConvolution2DOpDescriptor* conv2dDescriptor_ = [[MPSGraphConvolution2DOpDescriptor new] autorelease];
657        fill_conv_desc(conv2dDescriptor_,
658                       stride[1],
659                       stride[0],
660                       dilation[1],
661                       dilation[0],
662                       padding[1],
663                       padding[0],
664                       at::MemoryFormat::Contiguous,
665                       groups);
666
667        gradWeightTensor = [mpsGraph convolution2DWeightsGradientWithIncomingGradientTensor:gradOutputTensorTranspose
668                                                                               sourceTensor:inputTensor
669                                                                                outputShape:mps_weight_shape
670                                                               forwardConvolutionDescriptor:conv2dDescriptor_
671                                                                                       name:nil];
672      }
673
674      newCachedGraph->gradOutputTensor_ = gradOutputTensor;
675      newCachedGraph->inputTensor_ = inputTensor;
676      newCachedGraph->gradWeightTensor_ = gradWeightTensor;
677    });
678
679    auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t, gradOutputShape);
680    auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t);
681    auto outputPlaceholder = Placeholder(cachedGraph->gradWeightTensor_, grad_weight_t);
682
683    auto feeds = dictionaryFromPlaceholders(gradOutputPlaceholder, inputPlaceholder);
684    runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
685  }
686
687  return grad_weight_t;
688}
689
690std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_convolution_backward(const at::Tensor& input,
691                                                                        const at::Tensor& grad_output,
692                                                                        const at::Tensor& weight,
693                                                                        IntArrayRef padding,
694                                                                        IntArrayRef stride,
695                                                                        IntArrayRef dilation,
696                                                                        int64_t groups,
697                                                                        std::array<bool, 3> output_mask) {
698  Tensor grad_input, grad_weight, grad_bias;
699  if (input.numel() == 0) {
700    if (output_mask[0]) {
701      grad_input = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
702    }
703    if (output_mask[1]) {
704      grad_weight = at::zeros_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
705    }
706  } else {
707    if (output_mask[0]) {
708      grad_input = mps_convolution_backward_input(
709          input.sizes(), grad_output, weight, padding, stride, dilation, groups, output_mask[2]);
710    }
711    if (output_mask[1]) {
712      grad_weight = mps_convolution_backward_weights(
713          weight.sizes(), grad_output, input, padding, stride, dilation, groups, output_mask[2]);
714    }
715  }
716
717  return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias};
718}
719
720static Tensor mps_convolution_transpose_forward(const Tensor& grad_output,
721                                                const Tensor& weight,
722                                                IntArrayRef padding,
723                                                IntArrayRef output_padding,
724                                                IntArrayRef stride,
725                                                IntArrayRef dilation,
726                                                int64_t groups) {
727  auto input_size =
728      conv_input_size(grad_output.sizes(), weight.sizes(), padding, output_padding, stride, dilation, groups);
729  return mps_convolution_backward_input(input_size, grad_output, weight, padding, stride, dilation, groups, false);
730}
731
732Tensor _mps_convolution_transpose(const Tensor& input_t,
733                                  const Tensor& weight_t,
734                                  IntArrayRef padding,
735                                  IntArrayRef output_padding,
736                                  IntArrayRef stride,
737                                  IntArrayRef dilation,
738                                  int64_t groups) {
739  TORCH_CHECK(input_t.dim() < 5, "ConvTranspose 3D is not supported on MPS");
740
741  auto output_t =
742      mps_convolution_transpose_forward(input_t, weight_t, padding, output_padding, stride, dilation, groups);
743  return output_t;
744}
745
746static Tensor mps_convolution_transpose_backward_input(const Tensor& grad_output_t,
747                                                       const Tensor& weight_t,
748                                                       IntArrayRef padding,
749                                                       IntArrayRef stride,
750                                                       IntArrayRef dilation,
751                                                       int64_t groups,
752                                                       IntArrayRef input_shape) {
753  return _mps_convolution_impl(grad_output_t, weight_t, std::nullopt, padding, stride, dilation, groups, input_shape);
754}
755
756static Tensor mps_convolution_transpose_backward_weight(IntArrayRef weight_size,
757                                                        const Tensor& grad_output_t,
758                                                        const Tensor& input_t,
759                                                        IntArrayRef padding,
760                                                        IntArrayRef stride,
761                                                        IntArrayRef dilation,
762                                                        int64_t groups) {
763  return mps_convolution_backward_weights(
764      weight_size, input_t, grad_output_t, padding, stride, dilation, groups, false);
765}
766
767std::tuple<Tensor, Tensor> mps_convolution_transpose_backward(const Tensor& input,
768                                                              const Tensor& grad_output,
769                                                              const Tensor& weight,
770                                                              IntArrayRef padding,
771                                                              IntArrayRef output_padding,
772                                                              IntArrayRef stride,
773                                                              IntArrayRef dilation,
774                                                              int64_t groups,
775                                                              std::array<bool, 2> output_mask) {
776  Tensor grad_input, grad_weight;
777  if (output_mask[0]) {
778    grad_input =
779        mps_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, input.sizes());
780  }
781  if (output_mask[1]) {
782    grad_weight = mps_convolution_transpose_backward_weight(
783        weight.sizes(), grad_output, input, padding, stride, dilation, groups);
784  }
785
786  return std::tuple<Tensor, Tensor>{grad_input, grad_weight};
787}
788
789} // namespace at::native
790