1// Copyright © 2022 Apple Inc. 2#define TORCH_ASSERT_ONLY_METHOD_OPERATORS 3#include <ATen/native/ConvUtils.h> 4#include <ATen/native/mps/MPSGraphVenturaOps.h> 5#include <ATen/native/mps/OperationUtils.h> 6#include <ATen/ops/_mps_convolution_native.h> 7#include <ATen/ops/_mps_convolution_transpose_native.h> 8#include <ATen/ops/mps_convolution_backward_native.h> 9#include <ATen/ops/mps_convolution_transpose_backward_native.h> 10 11#if !defined(__MAC_13_2) && (!defined(MAC_OS_X_VERSION_13_2) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_2)) 12 13@implementation FakeMPSGraphConvolution3DOpDescriptor 14- (nonnull id)copyWithZone:(nullable NSZone*)zone { 15 return self; 16} 17 18@end 19 20#endif 21 22namespace at::native { 23 24// Create 3D convolution descriptor 25static void fill_conv3d_desc(MPSGraphConvolution3DOpDescriptor* descriptor_, 26 NSUInteger strideInX, 27 NSUInteger strideInY, 28 NSUInteger strideInZ, 29 NSUInteger dilationRateInX, 30 NSUInteger dilationRateInY, 31 NSUInteger dilationRateInZ, 32 NSUInteger paddingHorizontal, 33 NSUInteger paddingVertical, 34 NSUInteger paddingDepth, 35 NSUInteger groups) { 36 descriptor_.strideInX = strideInX; 37 descriptor_.strideInY = strideInY; 38 descriptor_.strideInZ = strideInZ; 39 descriptor_.dilationRateInX = dilationRateInX; 40 descriptor_.dilationRateInY = dilationRateInY; 41 descriptor_.dilationRateInZ = dilationRateInZ; 42 43 // TODO: Program the padding style 44 descriptor_.paddingStyle = MPSGraphPaddingStyleExplicit; 45 46 descriptor_.paddingLeft = paddingHorizontal; 47 descriptor_.paddingRight = paddingHorizontal; 48 descriptor_.paddingTop = paddingVertical; 49 descriptor_.paddingBottom = paddingVertical; 50 descriptor_.paddingFront = paddingDepth; 51 descriptor_.paddingBack = paddingDepth; 52 53 // PyTorch always uses NCDHW memory layout for 3D tensors 54 descriptor_.dataLayout = (MPSGraphTensorNamedDataLayout)7L; // MPSGraphTensorNamedDataLayoutNCDHW; 55 56 // PyTorch always uses OIDHW memory layout for 3D weights 57 descriptor_.weightsLayout = (MPSGraphTensorNamedDataLayout)9L; // MPSGraphTensorNamedDataLayoutOIDHW; 58 59 descriptor_.groups = groups; // not yet tested in Xcode/C++ 60} 61 62static void fill_depthwise_conv_desc(MPSGraphDepthwiseConvolution3DOpDescriptor* descriptor_, 63 NSUInteger strideInX, 64 NSUInteger strideInY, 65 NSUInteger dilationRateInX, 66 NSUInteger dilationRateInY, 67 NSUInteger paddingHorizontal, 68 NSUInteger paddingVertical, 69 c10::MemoryFormat memory_format, 70 NSUInteger groups) { 71 descriptor_.strides = 72 @[ @1, [[NSNumber alloc] initWithInteger:strideInY], [[NSNumber alloc] initWithInteger:strideInX] ]; 73 descriptor_.dilationRates = 74 @[ @1, [[NSNumber alloc] initWithInteger:dilationRateInY], [[NSNumber alloc] initWithInteger:dilationRateInX] ]; 75 76 descriptor_.paddingStyle = MPSGraphPaddingStyleExplicit; 77 descriptor_.paddingValues = @[ 78 @0, 79 @0, 80 [[NSNumber alloc] initWithInteger:paddingVertical], 81 [[NSNumber alloc] initWithInteger:paddingVertical], 82 [[NSNumber alloc] initWithInteger:paddingHorizontal], 83 [[NSNumber alloc] initWithInteger:paddingHorizontal] 84 ]; 85 descriptor_.channelDimensionIndex = -3LL; 86} 87 88// Create convolution descriptor 89static void fill_conv_desc(MPSGraphConvolution2DOpDescriptor* descriptor_, 90 NSUInteger strideInX, 91 NSUInteger strideInY, 92 NSUInteger dilationRateInX, 93 NSUInteger dilationRateInY, 94 NSUInteger paddingHorizontal, 95 NSUInteger paddingVertical, 96 c10::MemoryFormat memory_format, 97 NSUInteger groups) { 98 descriptor_.strideInX = strideInX; 99 descriptor_.strideInY = strideInY; 100 descriptor_.dilationRateInX = dilationRateInX; 101 descriptor_.dilationRateInY = dilationRateInY; 102 103 // TODO: Program the padding style 104 descriptor_.paddingStyle = MPSGraphPaddingStyleExplicit; 105 106 descriptor_.paddingLeft = paddingHorizontal; 107 descriptor_.paddingRight = paddingHorizontal; 108 descriptor_.paddingTop = paddingVertical; 109 descriptor_.paddingBottom = paddingVertical; 110 111 descriptor_.dataLayout = (memory_format == at::MemoryFormat::Contiguous) ? MPSGraphTensorNamedDataLayoutNCHW 112 : MPSGraphTensorNamedDataLayoutNHWC; 113 114 // PyTorch always uses OIHW memory layout for weights 115 descriptor_.weightsLayout = MPSGraphTensorNamedDataLayoutOIHW; 116 descriptor_.groups = groups; 117} 118 119static Tensor _mps_convolution_impl(const Tensor& input_t_, 120 const Tensor& weight_t, 121 const std::optional<Tensor>& bias_opt, 122 IntArrayRef padding, 123 IntArrayRef stride, 124 IntArrayRef dilation, 125 int64_t groups, 126 std::optional<IntArrayRef> input_shape) { 127 const bool is_macOS_13_2_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_2_PLUS); 128 const bool is_macOS_15_0_or_newer = is_macos_13_or_newer(MacOSVersion::MACOS_VER_15_0_PLUS); 129 Tensor input_t = input_t_; 130 if (!is_macOS_15_0_or_newer) { 131 input_t = input_t.contiguous(); 132 } 133 134 TORCH_CHECK(((input_t.dim() < 5) || is_macOS_13_2_or_newer), 135 "Conv3D is only supported on MPS for MacOS_13_2 or newer"); 136 bool is3DConv = input_t.dim() == 5; 137 138 TORCH_CHECK(isFloatingType(input_t.scalar_type()), "Convolution is supported only for Floating types"); 139 140 using namespace at::native::mps; 141 CheckedFrom c = "mps_convolution"; 142 TensorArg input{input_t, "input", 1}, weight{weight_t, "weight", 2}; 143 checkAllSameType(c, {input, weight}); 144 checkAllSameGPU(c, {input, weight}); 145 146 bool bias_defined; 147 148 if (bias_opt == std::nullopt) 149 bias_defined = false; 150 else 151 bias_defined = bias_opt->defined(); 152 153 auto memory_format = input_t.suggest_memory_format(); 154 bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast) && !is3DConv; 155 auto output_t = 156 at::empty(input_shape.has_value() ? input_shape.value() 157 : conv_output_size(input->sizes(), weight->sizes(), padding, stride, dilation), 158 input->scalar_type(), 159 std::nullopt, 160 kMPS, 161 std::nullopt, 162 is_macOS_15_0_or_newer ? memory_format : MemoryFormat::Contiguous); 163 if (output_t.numel() == 0) { 164 return output_t; 165 } 166 TensorArg output{output_t, "result", 0}; 167 168 // TODO: MPS convolution kernel currently does not support output channels > 2^16 169 for (auto elem : output_t.sizes()) { 170 TORCH_CHECK_NOT_IMPLEMENTED( 171 elem <= (1 << 16), 172 "Output channels > 65536 not supported at the MPS device. ", 173 "As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` ", 174 "to use the CPU as a fallback for this op. WARNING: this will be slower than running natively ", 175 "on MPS."); 176 } 177 178 convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups); 179 180 // Derive from MPSCachedGraph 181 struct CachedGraph : public MPSCachedGraph { 182 CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} 183 MPSGraphTensor* inputTensor_ = nil; 184 MPSGraphTensor* biasTensor_ = nil; 185 MPSGraphTensor* weightTensor_ = nil; 186 MPSGraphTensor* outputTensor_ = nil; 187 }; 188 189 auto stream = at::mps::getCurrentMPSStream(); 190 191 @autoreleasepool { 192 IntArrayRef bias_shape; 193 if (bias_defined) 194 bias_shape = bias_opt.value().sizes(); 195 196 string mem_format_key; 197 switch (memory_format) { 198 case at::MemoryFormat::Contiguous: 199 mem_format_key = "Contiguous"; 200 break; 201 case at::MemoryFormat::ChannelsLast: 202 mem_format_key = "ChannelsLast"; 203 break; 204 default: 205 assert(0 && "Check should have been done earlier\n"); 206 } 207 208 string bias_shape_key; 209 if (bias_defined) { 210 bias_shape_key = std::to_string(bias_shape[0]); 211 } else { 212 bias_shape_key = "nobias"; 213 } 214 215 string key; 216 if (is3DConv) { 217 key = "mps_3d_convolution:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" + 218 std::to_string(stride[2]) + ":" + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + 219 std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" + 220 std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key + 221 mps::getTensorsStringKey({input_t, weight_t}) + ":" + std::to_string(bias_defined) + ":" + bias_shape_key; 222 223 } else { 224 key = "mps_convolution:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" + 225 std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" + 226 std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key + 227 mps::getTensorsStringKey({input_t, weight_t}) + ":" + std::to_string(bias_defined) + ":" + bias_shape_key; 228 } 229 230 MPSShape* inputShape = mps::getMPSShape(input_t, memory_format); 231 MPSShape* outputShape = mps::getMPSShape(output_t, memory_format); 232 MPSNDArray* inputNDArray = nil; 233 MPSNDArray* outputNDArray = nil; 234 235 if (input_t.is_contiguous(memory_format) && output_t.is_contiguous(memory_format) && is_macOS_15_0_or_newer) { 236 inputNDArray = getMPSNDArray(input_t, inputShape); 237 outputNDArray = getMPSNDArray(*output, outputShape); 238 } 239 240 auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { 241 MPSShape* weightShape = mps::getMPSShape(weight_t); 242 bool isDepthwiseConv = ((groups > 1 && (weightShape[1].intValue == 1)) && inputShape.count >= 4 && 243 weightShape.count >= 4 && !is_channels_last); 244 245 MPSGraphTensor* inputTensor = 246 mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(input_t.scalar_type()), inputShape); 247 MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_t); 248 MPSGraphTensor* outputTensor; 249 if (is3DConv) { 250 MPSGraphConvolution3DOpDescriptor* conv3dDescriptor_ = [[MPSGraphConvolution3DOpDescriptor new] autorelease]; 251 fill_conv3d_desc(conv3dDescriptor_, 252 stride[2], 253 stride[1], 254 stride[0], 255 dilation[2], 256 dilation[1], 257 dilation[0], 258 padding[2], 259 padding[1], 260 padding[0], 261 groups); 262 263 outputTensor = [mpsGraph convolution3DWithSourceTensor:inputTensor 264 weightsTensor:weightTensor 265 descriptor:conv3dDescriptor_ 266 name:nil]; 267 } else if (isDepthwiseConv) { 268 MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor_ = 269 [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease]; 270 fill_depthwise_conv_desc(depthWiseConv3dDescriptor_, 271 stride[1], 272 stride[0], 273 dilation[1], 274 dilation[0], 275 padding[1], 276 padding[0], 277 memory_format, 278 groups); 279 280 MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor 281 dimension:-3 282 withDimension:-4 283 name:nil]; 284 outputTensor = [mpsGraph depthwiseConvolution3DWithSourceTensor:inputTensor 285 weightsTensor:weightTransposeTensor 286 descriptor:depthWiseConv3dDescriptor_ 287 name:nil]; 288 } else { 289 MPSGraphConvolution2DOpDescriptor* conv2dDescriptor_ = [[MPSGraphConvolution2DOpDescriptor new] autorelease]; 290 fill_conv_desc(conv2dDescriptor_, 291 stride[1], 292 stride[0], 293 dilation[1], 294 dilation[0], 295 padding[1], 296 padding[0], 297 memory_format, 298 groups); 299 300 outputTensor = [mpsGraph convolution2DWithSourceTensor:inputTensor 301 weightsTensor:weightTensor 302 descriptor:conv2dDescriptor_ 303 name:nil]; 304 } 305 306 MPSGraphTensor* biasTensor = nil; 307 if (bias_defined) { 308 biasTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(bias_opt.value())); 309 } 310 311 if (is_channels_last && !is_macOS_15_0_or_newer) { 312 outputTensor = mps::convertNHWCtoNCHW(mpsGraph, outputTensor); 313 } 314 315 if (bias_defined) { 316 outputTensor = [mpsGraph additionWithPrimaryTensor:outputTensor secondaryTensor:biasTensor name:nil]; 317 } 318 newCachedGraph->inputTensor_ = inputTensor; 319 newCachedGraph->weightTensor_ = weightTensor; 320 newCachedGraph->biasTensor_ = biasTensor; 321 newCachedGraph->outputTensor_ = outputTensor; 322 }); 323 324 auto inputPlaceholder = inputNDArray ? Placeholder(cachedGraph->inputTensor_, inputNDArray) 325 : Placeholder(cachedGraph->inputTensor_, input_t, inputShape); 326 auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_t); 327 auto biasPlaceholder = Placeholder(); 328 // Reshape the bias to be broadcastable with output of conv2d or conv3d 329 if (bias_defined) { 330 if (is3DConv) { 331 biasPlaceholder = Placeholder(cachedGraph->biasTensor_, (bias_opt.value()).view({1, bias_shape[0], 1, 1, 1})); 332 } else { 333 if (is_channels_last && is_macOS_15_0_or_newer) { 334 biasPlaceholder = Placeholder(cachedGraph->biasTensor_, (bias_opt.value()).view({1, 1, 1, bias_shape[0]})); 335 } else { 336 biasPlaceholder = Placeholder(cachedGraph->biasTensor_, (bias_opt.value()).view({1, bias_shape[0], 1, 1})); 337 } 338 } 339 } 340 auto outputPlaceholder = outputNDArray ? Placeholder(cachedGraph->outputTensor_, outputNDArray) 341 : Placeholder(cachedGraph->outputTensor_, *output); 342 343 NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = 344 [[[NSMutableDictionary alloc] initWithCapacity:3] autorelease]; 345 feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData(); 346 feeds[weightsPlaceholder.getMPSGraphTensor()] = weightsPlaceholder.getMPSGraphTensorData(); 347 if (bias_defined) { 348 feeds[biasPlaceholder.getMPSGraphTensor()] = biasPlaceholder.getMPSGraphTensorData(); 349 } 350 351 runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); 352 } 353 354 return *output; 355} 356 357Tensor _mps_convolution(const Tensor& input_t, 358 const Tensor& weight_t, 359 const std::optional<Tensor>& bias_opt, 360 IntArrayRef padding, 361 IntArrayRef stride, 362 IntArrayRef dilation, 363 int64_t groups) { 364 return _mps_convolution_impl(input_t, weight_t, bias_opt, padding, stride, dilation, groups, std::nullopt); 365} 366 367static Tensor mps_convolution_backward_input(IntArrayRef input_size, 368 const Tensor& grad_output_t, 369 const Tensor& weight_t, 370 IntArrayRef padding, 371 IntArrayRef stride, 372 IntArrayRef dilation, 373 int64_t groups, 374 bool bias_defined) { 375 using namespace at::native::mps; 376 using namespace mps; 377 bool is3DConv = grad_output_t.dim() == 5; 378 379 // TODO: MPS convolution kernel currently does not support output channels > 2^16 380 for (auto elem : grad_output_t.sizes()) { 381 TORCH_CHECK_NOT_IMPLEMENTED( 382 elem <= (1 << 16), 383 "Output channels > 65536 not supported at the MPS device. ", 384 "As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` ", 385 "to use the CPU as a fallback for this op. WARNING: this will be slower than running natively ", 386 "on MPS."); 387 } 388 389 TORCH_CHECK(isFloatingType(grad_output_t.scalar_type()), "Convolution is supported only for Floating types"); 390 CheckedFrom c = "mps_convolution_backward_input"; 391 TensorArg grad_output{grad_output_t, "grad_output", 1}, weight{weight_t, "weight", 2}; 392 checkAllSameType(c, {grad_output, weight}); 393 checkAllSameGPU(c, {grad_output, weight}); 394 auto memory_format = grad_output_t.suggest_memory_format(); 395 bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast) && !is3DConv; 396 auto grad_input_t = at::empty(input_size, grad_output_t.options(), std::nullopt); 397 398 // Avoid "grad_input" when this is being used as transposed convolution 399 TensorArg grad_input{grad_input_t, "result", 0}; 400 convolution_shape_check(c, grad_input, weight, grad_output, padding, stride, dilation, groups); 401 402 // Derive from MPSCachedGraph 403 struct CachedGraph : public MPSCachedGraph { 404 CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} 405 MPSGraphTensor* gradOutputTensor_ = nil; 406 MPSGraphTensor* weightTensor_ = nil; 407 MPSGraphTensor* gradInputTensor_ = nil; 408 }; 409 410 // Add backward with input 411 @autoreleasepool { 412 MPSStream* stream = getCurrentMPSStream(); 413 414 string mem_format_key; 415 switch (memory_format) { 416 case at::MemoryFormat::Contiguous: 417 mem_format_key = "Contiguous"; 418 break; 419 case at::MemoryFormat::ChannelsLast: 420 mem_format_key = "ChannelsLast"; 421 break; 422 default: 423 assert(0 && "Check should have been done earlier\n"); 424 } 425 426 MPSShape* gradOutputShape = getMPSShape(grad_output_t, memory_format); 427 MPSShape* mps_input_shape = getMPSShape(input_size); 428 NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","]; 429 string key; 430 if (is3DConv) { 431 key = "mps_3d_convolution_backward_input:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" + 432 ":" + std::to_string(stride[2]) + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + 433 std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" + 434 std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key + 435 getTensorsStringKey({grad_output_t, weight_t}) + ":" + string([ns_shape_key UTF8String]); 436 437 } else { 438 key = "mps_convolution_backward_input:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" + 439 std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" + 440 std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key + 441 getTensorsStringKey({grad_output_t, weight_t}) + ":" + string([ns_shape_key UTF8String]); 442 } 443 auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { 444 MPSGraphTensor* gradOutputTensor = 445 mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(grad_output_t), gradOutputShape); 446 MPSGraphTensor* weightTensor = mpsGraphRankedPlaceHolder(mpsGraph, weight_t); 447 448 MPSGraphTensor* gradOutputTensorTranspose = gradOutputTensor; 449 if (is_channels_last) { 450 gradOutputTensorTranspose = mps::convertNHWCtoNCHW(mpsGraph, gradOutputTensorTranspose); 451 } 452 MPSGraphTensor* gradInputTensor; 453 MPSShape* weightOutputShape = mps::getMPSShape(weight_t); 454 // Depthwise conv is input feature channels = groups. So I in OIHW has to be 1. 455 bool isDepthwiseConv = ((groups > 1 && (weightOutputShape[1].intValue == 1)) && gradOutputShape.count >= 4 && 456 weightOutputShape.count >= 4 && !is_channels_last); 457 458 if (is3DConv) { 459 MPSGraphConvolution3DOpDescriptor* conv3dDescriptor_ = [[MPSGraphConvolution3DOpDescriptor new] autorelease]; 460 fill_conv3d_desc(conv3dDescriptor_, 461 stride[2], 462 stride[1], 463 stride[0], 464 dilation[2], 465 dilation[1], 466 dilation[0], 467 padding[2], 468 padding[1], 469 padding[0], 470 groups); 471 gradInputTensor = [mpsGraph convolution3DDataGradientWithIncomingGradientTensor:gradOutputTensorTranspose 472 weightsTensor:weightTensor 473 outputShape:mps_input_shape 474 forwardConvolutionDescriptor:conv3dDescriptor_ 475 name:nil]; 476 } else if (isDepthwiseConv) { 477 MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor_ = 478 [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease]; 479 fill_depthwise_conv_desc(depthWiseConv3dDescriptor_, 480 stride[1], 481 stride[0], 482 dilation[1], 483 dilation[0], 484 padding[1], 485 padding[0], 486 at::MemoryFormat::Contiguous, 487 groups); 488 MPSGraphTensor* weightTransposeTensor = [mpsGraph transposeTensor:weightTensor 489 dimension:-3 490 withDimension:-4 491 name:nil]; 492 gradInputTensor = 493 [mpsGraph depthwiseConvolution3DDataGradientWithIncomingGradientTensor:gradOutputTensorTranspose 494 weightsTensor:weightTransposeTensor 495 outputShape:mps_input_shape 496 descriptor:depthWiseConv3dDescriptor_ 497 name:nil]; 498 } else { 499 MPSGraphConvolution2DOpDescriptor* conv2dDescriptor_ = [[MPSGraphConvolution2DOpDescriptor new] autorelease]; 500 fill_conv_desc(conv2dDescriptor_, 501 stride[1], 502 stride[0], 503 dilation[1], 504 dilation[0], 505 padding[1], 506 padding[0], 507 at::MemoryFormat::Contiguous, 508 groups); 509 510 gradInputTensor = [mpsGraph convolution2DDataGradientWithIncomingGradientTensor:gradOutputTensorTranspose 511 weightsTensor:weightTensor 512 outputShape:mps_input_shape 513 forwardConvolutionDescriptor:conv2dDescriptor_ 514 name:nil]; 515 } 516 517 newCachedGraph->gradOutputTensor_ = gradOutputTensor; 518 newCachedGraph->weightTensor_ = weightTensor; 519 newCachedGraph->gradInputTensor_ = gradInputTensor; 520 }); 521 522 auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t, gradOutputShape); 523 auto weightsPlaceholder = Placeholder(cachedGraph->weightTensor_, weight_t); 524 auto outputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, *grad_input); 525 526 auto feeds = dictionaryFromPlaceholders(gradOutputPlaceholder, weightsPlaceholder); 527 runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); 528 } 529 return *grad_input; 530} 531 532static Tensor mps_convolution_backward_weights(IntArrayRef weight_size, 533 const Tensor& grad_output_t, 534 const Tensor& input_t, 535 IntArrayRef padding, 536 IntArrayRef stride, 537 IntArrayRef dilation, 538 int64_t groups, 539 bool bias_defined) { 540 using namespace at::native::mps; 541 using namespace mps; 542 bool is3DConv = input_t.dim() == 5; 543 TORCH_CHECK(isFloatingType(grad_output_t.scalar_type()), "Convolution is supported only for Floating types"); 544 CheckedFrom c = "mps_convolution_backward_weights"; 545 auto memory_format = grad_output_t.suggest_memory_format(); 546 bool is_channels_last = (memory_format == at::MemoryFormat::ChannelsLast) && !is3DConv; 547 548 MPSShape* gradOutputShape = mps::getMPSShape(grad_output_t, memory_format); 549 550 // For uniformity with everything else, although it seems grad_weight 551 // would be unambiguous too. 552 TensorArg grad_output{grad_output_t, "grad_output", 1}; 553 TensorArg input{input_t, "input", 2}; 554 555 checkAllSameType(c, {grad_output, input}); 556 checkAllSameGPU(c, {grad_output, input}); 557 558 auto grad_weight_t = 559 at::empty(weight_size, grad_output_t.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt); 560 TensorArg grad_weight{grad_weight_t, "result", 0}; 561 562 convolution_shape_check(c, input, grad_weight, grad_output, padding, stride, dilation, groups); 563 564 // Derive from MPSCachedGraph 565 struct CachedGraph : public MPSCachedGraph { 566 CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {} 567 MPSGraphTensor* gradOutputTensor_ = nil; 568 MPSGraphTensor* inputTensor_ = nil; 569 MPSGraphTensor* gradWeightTensor_ = nil; 570 }; 571 572 @autoreleasepool { 573 MPSStream* stream = getCurrentMPSStream(); 574 575 string mem_format_key; 576 switch (memory_format) { 577 case at::MemoryFormat::Contiguous: 578 mem_format_key = "Contiguous"; 579 break; 580 case at::MemoryFormat::ChannelsLast: 581 mem_format_key = "ChannelsLast"; 582 break; 583 default: 584 assert(0 && "Check should have been done earlier\n"); 585 } 586 MPSShape* mps_weight_shape = getMPSShape(weight_size); 587 NSString* ns_shape_key = [[gradOutputShape valueForKey:@"description"] componentsJoinedByString:@","]; 588 string key; 589 if (is3DConv) { 590 key = "mps_3d_convolution_backward_weights:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" + 591 std::to_string(stride[2]) + ":" + std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + 592 std::to_string(dilation[2]) + ":" + std::to_string(padding[0]) + ":" + std::to_string(padding[1]) + ":" + 593 std::to_string(padding[2]) + ":" + std::to_string(groups) + ":" + mem_format_key + 594 getTensorsStringKey({grad_output_t, input_t, grad_weight_t}) + ":" + string([ns_shape_key UTF8String]); 595 } else { 596 key = "mps_convolution_backward_weights:" + std::to_string(stride[0]) + ":" + std::to_string(stride[1]) + ":" + 597 std::to_string(dilation[0]) + ":" + std::to_string(dilation[1]) + ":" + std::to_string(padding[0]) + ":" + 598 std::to_string(padding[1]) + ":" + std::to_string(groups) + ":" + mem_format_key + 599 getTensorsStringKey({grad_output_t, input_t, grad_weight_t}) + ":" + string([ns_shape_key UTF8String]); 600 } 601 auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) { 602 MPSShape* inputShape = mps::getMPSShape(input_t); 603 bool isDepthwiseConv = ((groups > 1 && (mps_weight_shape[1].intValue == 1)) && inputShape.count >= 4 && 604 mps_weight_shape.count >= 4 && !is_channels_last); 605 606 MPSGraphTensor* gradOutputTensor = 607 mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(grad_output_t), gradOutputShape); 608 MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t); 609 610 MPSGraphTensor* gradOutputTensorTranspose = gradOutputTensor; 611 if (is_channels_last) { 612 gradOutputTensorTranspose = mps::convertNHWCtoNCHW(mpsGraph, gradOutputTensorTranspose); 613 } 614 615 MPSGraphTensor* gradWeightTensor; 616 if (is3DConv) { 617 MPSGraphConvolution3DOpDescriptor* conv3dDescriptor_ = [[MPSGraphConvolution3DOpDescriptor new] autorelease]; 618 fill_conv3d_desc(conv3dDescriptor_, 619 stride[2], 620 stride[1], 621 stride[0], 622 dilation[2], 623 dilation[1], 624 dilation[0], 625 padding[2], 626 padding[1], 627 padding[0], 628 groups); 629 gradWeightTensor = [mpsGraph convolution3DWeightsGradientWithIncomingGradientTensor:gradOutputTensorTranspose 630 sourceTensor:inputTensor 631 outputShape:mps_weight_shape 632 forwardConvolutionDescriptor:conv3dDescriptor_ 633 name:nil]; 634 } else if (isDepthwiseConv) { 635 MPSGraphDepthwiseConvolution3DOpDescriptor* depthWiseConv3dDescriptor_ = 636 [[MPSGraphDepthwiseConvolution3DOpDescriptor new] autorelease]; 637 fill_depthwise_conv_desc(depthWiseConv3dDescriptor_, 638 stride[1], 639 stride[0], 640 dilation[1], 641 dilation[0], 642 padding[1], 643 padding[0], 644 at::MemoryFormat::Contiguous, 645 groups); 646 NSNumber* outputFeatChannelDim = mps_weight_shape[0]; 647 MPSShape* weightShapeTranspose = @[ @1, outputFeatChannelDim, mps_weight_shape[2], mps_weight_shape[3] ]; 648 MPSGraphTensor* gradWeightTensorTranspose = 649 [mpsGraph depthwiseConvolution3DWeightsGradientWithIncomingGradientTensor:gradOutputTensorTranspose 650 sourceTensor:inputTensor 651 outputShape:weightShapeTranspose 652 descriptor:depthWiseConv3dDescriptor_ 653 name:nil]; 654 gradWeightTensor = [mpsGraph transposeTensor:gradWeightTensorTranspose dimension:-3 withDimension:-4 name:nil]; 655 } else { 656 MPSGraphConvolution2DOpDescriptor* conv2dDescriptor_ = [[MPSGraphConvolution2DOpDescriptor new] autorelease]; 657 fill_conv_desc(conv2dDescriptor_, 658 stride[1], 659 stride[0], 660 dilation[1], 661 dilation[0], 662 padding[1], 663 padding[0], 664 at::MemoryFormat::Contiguous, 665 groups); 666 667 gradWeightTensor = [mpsGraph convolution2DWeightsGradientWithIncomingGradientTensor:gradOutputTensorTranspose 668 sourceTensor:inputTensor 669 outputShape:mps_weight_shape 670 forwardConvolutionDescriptor:conv2dDescriptor_ 671 name:nil]; 672 } 673 674 newCachedGraph->gradOutputTensor_ = gradOutputTensor; 675 newCachedGraph->inputTensor_ = inputTensor; 676 newCachedGraph->gradWeightTensor_ = gradWeightTensor; 677 }); 678 679 auto gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output_t, gradOutputShape); 680 auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t); 681 auto outputPlaceholder = Placeholder(cachedGraph->gradWeightTensor_, grad_weight_t); 682 683 auto feeds = dictionaryFromPlaceholders(gradOutputPlaceholder, inputPlaceholder); 684 runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder); 685 } 686 687 return grad_weight_t; 688} 689 690std::tuple<at::Tensor, at::Tensor, at::Tensor> mps_convolution_backward(const at::Tensor& input, 691 const at::Tensor& grad_output, 692 const at::Tensor& weight, 693 IntArrayRef padding, 694 IntArrayRef stride, 695 IntArrayRef dilation, 696 int64_t groups, 697 std::array<bool, 3> output_mask) { 698 Tensor grad_input, grad_weight, grad_bias; 699 if (input.numel() == 0) { 700 if (output_mask[0]) { 701 grad_input = at::empty_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT); 702 } 703 if (output_mask[1]) { 704 grad_weight = at::zeros_like(weight, LEGACY_CONTIGUOUS_MEMORY_FORMAT); 705 } 706 } else { 707 if (output_mask[0]) { 708 grad_input = mps_convolution_backward_input( 709 input.sizes(), grad_output, weight, padding, stride, dilation, groups, output_mask[2]); 710 } 711 if (output_mask[1]) { 712 grad_weight = mps_convolution_backward_weights( 713 weight.sizes(), grad_output, input, padding, stride, dilation, groups, output_mask[2]); 714 } 715 } 716 717 return std::tuple<Tensor, Tensor, Tensor>{grad_input, grad_weight, grad_bias}; 718} 719 720static Tensor mps_convolution_transpose_forward(const Tensor& grad_output, 721 const Tensor& weight, 722 IntArrayRef padding, 723 IntArrayRef output_padding, 724 IntArrayRef stride, 725 IntArrayRef dilation, 726 int64_t groups) { 727 auto input_size = 728 conv_input_size(grad_output.sizes(), weight.sizes(), padding, output_padding, stride, dilation, groups); 729 return mps_convolution_backward_input(input_size, grad_output, weight, padding, stride, dilation, groups, false); 730} 731 732Tensor _mps_convolution_transpose(const Tensor& input_t, 733 const Tensor& weight_t, 734 IntArrayRef padding, 735 IntArrayRef output_padding, 736 IntArrayRef stride, 737 IntArrayRef dilation, 738 int64_t groups) { 739 TORCH_CHECK(input_t.dim() < 5, "ConvTranspose 3D is not supported on MPS"); 740 741 auto output_t = 742 mps_convolution_transpose_forward(input_t, weight_t, padding, output_padding, stride, dilation, groups); 743 return output_t; 744} 745 746static Tensor mps_convolution_transpose_backward_input(const Tensor& grad_output_t, 747 const Tensor& weight_t, 748 IntArrayRef padding, 749 IntArrayRef stride, 750 IntArrayRef dilation, 751 int64_t groups, 752 IntArrayRef input_shape) { 753 return _mps_convolution_impl(grad_output_t, weight_t, std::nullopt, padding, stride, dilation, groups, input_shape); 754} 755 756static Tensor mps_convolution_transpose_backward_weight(IntArrayRef weight_size, 757 const Tensor& grad_output_t, 758 const Tensor& input_t, 759 IntArrayRef padding, 760 IntArrayRef stride, 761 IntArrayRef dilation, 762 int64_t groups) { 763 return mps_convolution_backward_weights( 764 weight_size, input_t, grad_output_t, padding, stride, dilation, groups, false); 765} 766 767std::tuple<Tensor, Tensor> mps_convolution_transpose_backward(const Tensor& input, 768 const Tensor& grad_output, 769 const Tensor& weight, 770 IntArrayRef padding, 771 IntArrayRef output_padding, 772 IntArrayRef stride, 773 IntArrayRef dilation, 774 int64_t groups, 775 std::array<bool, 2> output_mask) { 776 Tensor grad_input, grad_weight; 777 if (output_mask[0]) { 778 grad_input = 779 mps_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, input.sizes()); 780 } 781 if (output_mask[1]) { 782 grad_weight = mps_convolution_transpose_backward_weight( 783 weight.sizes(), grad_output, input, padding, stride, dilation, groups); 784 } 785 786 return std::tuple<Tensor, Tensor>{grad_input, grad_weight}; 787} 788 789} // namespace at::native 790