#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else #include #include #include #endif namespace at::meta { TORCH_PRECOMPUTE_META_FUNC(fractional_max_pool3d)( const at::Tensor& input_, IntArrayRef pool_size, IntArrayRef output_size, const at::Tensor& randomSamples ) { TORCH_CHECK( pool_size.size() == 3, "fractional_max_pool3d: kernel_size must either be a single Int or tuple of three Ints") TORCH_CHECK( output_size.size() == 3, "fractional_max_pool3d: output_size must either be a single Int or tuple of three Ints") int64_t outputT = output_size[0]; int64_t outputH = output_size[1]; int64_t outputW = output_size[2]; int64_t poolSizeT = pool_size[0]; int64_t poolSizeH = pool_size[1]; int64_t poolSizeW = pool_size[2]; int64_t numBatch = 1; int64_t planeDim = 0; int64_t timeDim = 1; int64_t heightDim = 2; int64_t widthDim = 3; int64_t ndims = input_.ndimension(); TORCH_CHECK(ndims == 4 || ndims == 5, "fractional_max_pool3d_out(): Expected 4D or 5D tensor, but got: ", input_.sizes()); for (const auto i : c10::irange(1, ndims)) { TORCH_CHECK(input_.size(i) > 0, "fractional_max_pool3d_out(): Expected input to have non-zero size for non-batch dimensions, but got", input_.sizes(), " with dimension ", i, " being empty."); } if (ndims == 5) { numBatch = input_.size(0); planeDim++; timeDim++; heightDim++; widthDim++; } /* sizes */ int64_t numPlanes = input_.size(planeDim); int64_t inputT = input_.size(timeDim); int64_t inputH = input_.size(heightDim); int64_t inputW = input_.size(widthDim); TORCH_CHECK(outputT + poolSizeT - 1 < inputT, "fractional_max_pool3d_out(): pool time ", poolSizeT, " too large relative to input time ", inputT); TORCH_CHECK(outputW + poolSizeW - 1 < inputW, "fractional_max_pool3d_out(): pool width ", poolSizeW, " too large relative to input width ", inputW); TORCH_CHECK(outputH + poolSizeH - 1 < inputH, "fractional_max_pool3d_out(): pool height ", poolSizeH, " too large relative to input height ", inputH); if (ndims == 4) { /* resize output */ set_output_raw_strided(0, {numPlanes, outputT, outputH, outputW}, {}, input_.options()); /* indices will contain the locations for each output point */ set_output_raw_strided(1, {numPlanes, outputT, outputH, outputW}, {}, input_.options().dtype(kLong)); } else { set_output_raw_strided(0, {numBatch, numPlanes, outputT, outputH, outputW}, {}, input_.options()); /* indices will contain the locations for each output point */ set_output_raw_strided(1, {numBatch, numPlanes, outputT, outputH, outputW}, {}, input_.options().dtype(kLong)); } return TORCH_PRECOMPUTE_STRUCT(fractional_max_pool3d)().set_numBatch(numBatch).set_numPlanes(numPlanes).set_inputT(inputT).set_inputH(inputH).set_inputW(inputW) .set_poolSizeT(poolSizeT).set_poolSizeH(poolSizeH).set_poolSizeW(poolSizeW) .set_outputT(outputT).set_outputH(outputH).set_outputW(outputW); } } // namespace at::meta namespace at::native { namespace { template static void fractional_max_pool3d_out_single_batch_frame( const scalar_t* input, scalar_t* output, int64_t* indices, const scalar_t* randomSamples, int64_t numPlanes, int64_t inputT, int64_t inputH, int64_t inputW, int64_t outputT, int64_t outputH, int64_t outputW, int64_t poolSizeT, int64_t poolSizeH, int64_t poolSizeW) { at::parallel_for(0, numPlanes, 0, [&](int64_t start, int64_t end) { for (const auto plane : c10::irange(start, end)) { /* each plane contains 3 random samples, one for T, one for W, and one for H */ const scalar_t* randomSamplesForPlane = randomSamples + plane * 3; /* Generate interval sequence */ auto sequenceT = generate_intervals( randomSamplesForPlane[0], inputT, outputT, poolSizeT); auto sequenceH = generate_intervals( randomSamplesForPlane[1], inputH, outputH, poolSizeH); auto sequenceW = generate_intervals( randomSamplesForPlane[2], inputW, outputW, poolSizeW); /* loop over output */ // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int64_t t, h, w; const scalar_t* inputForPlane = input + plane * inputT * inputH * inputW; scalar_t* outputForPlane = output + plane * outputT * outputH * outputW; int64_t* indicesForPlane = indices + plane * outputT * outputH * outputW; for (t = 0; t < outputT; ++t) { int64_t inputTStart = sequenceT[t]; for (h = 0; h < outputH; ++h) { int64_t inputHStart = sequenceH[h]; for (w = 0; w < outputW; ++w) { int64_t inputWStart = sequenceW[w]; int64_t t2 = inputTStart, h2 = inputHStart, w2 = inputWStart; scalar_t maxVal = -std::numeric_limits::infinity(); int64_t maxIndex = t2 * inputH * inputW + h2 * inputW + w2; for (t2 = inputTStart; t2 < inputTStart + poolSizeT; ++t2) { for (h2 = inputHStart; h2 < inputHStart + poolSizeH; ++h2) { for (w2 = inputWStart; w2 < inputWStart + poolSizeW; ++w2) { AT_ASSERT(t2 >= 0 && t2 < inputT); AT_ASSERT(h2 >= 0 && h2 < inputH); AT_ASSERT(w2 >= 0 && w2 < inputW); int64_t planeIndex = t2 * inputH * inputW + h2 * inputW + w2; scalar_t val = inputForPlane[planeIndex]; if (val > maxVal || std::isnan(val)) { maxVal = val; maxIndex = planeIndex; } } } } outputForPlane[t * outputH * outputW + h * outputW + w] = maxVal; indicesForPlane[t * outputH * outputW + h * outputW + w] = maxIndex; } } } } }); } template static void fractional_max_pool3d_out_frame( const scalar_t* input, scalar_t* output, int64_t* indices, const scalar_t* randomSamples, int64_t numBatch, int64_t numPlanes, int64_t inputT, int64_t inputH, int64_t inputW, int64_t outputT, int64_t outputH, int64_t outputW, int64_t poolSizeT, int64_t poolSizeH, int64_t poolSizeW) { if(numBatch == 1) { fractional_max_pool3d_out_single_batch_frame( input, output, indices, randomSamples, numPlanes, inputT, inputH, inputW, outputT, outputH, outputW, poolSizeT, poolSizeH, poolSizeW ); return; } at::parallel_for(0, numBatch, 0, [&](int64_t start, int64_t end) { for (const auto batch : c10::irange(start, end)) { fractional_max_pool3d_out_single_batch_frame( input + batch * numPlanes * inputW * inputH * inputT, output + batch * numPlanes * outputW * outputH * outputT, indices + batch * numPlanes * outputW * outputH * outputT, randomSamples + batch * numPlanes * 3, numPlanes, inputT, inputH, inputW, outputT, outputH, outputW, poolSizeT, poolSizeH, poolSizeW ); } }); } } // anonymous namespace TORCH_IMPL_FUNC(fractional_max_pool3d_out_cpu)( const at::Tensor& input_, int64_t poolSizeT, int64_t poolSizeH, int64_t poolSizeW, int64_t outputT, int64_t outputH, int64_t outputW, const at::Tensor& randomSamples_, int64_t numBatch, int64_t numPlanes, int64_t inputT, int64_t inputH, int64_t inputW, const at::Tensor& output, const at::Tensor& indices) { fractional_max_pool_check_shape(input_, randomSamples_); if (output.numel() == 0) { return; } /* get contiguous input and samples */ auto input = input_.contiguous(); auto randomSamples = randomSamples_.contiguous(); AT_DISPATCH_FLOATING_TYPES_AND2( kBFloat16, kHalf, input.scalar_type(), "fractional_max_pool3d_out_frame", [&] { fractional_max_pool3d_out_frame( input.const_data_ptr(), output.data_ptr(), indices.data_ptr(), randomSamples.const_data_ptr(), numBatch, numPlanes, inputT, inputH, inputW, outputT, outputH, outputW, poolSizeT, poolSizeH, poolSizeW ); } ); } namespace { template static void fractional_max_pool3d_backward_out_single_batch_frame( scalar_t* gradInput, const scalar_t* gradOutput, const int64_t* indices, int64_t numPlanes, int64_t inputT, int64_t inputH, int64_t inputW, int64_t outputT, int64_t outputH, int64_t outputW) { at::parallel_for(0, numPlanes, 0, [&](int64_t start, int64_t end) { for (const auto plane : c10::irange(start, end)) { scalar_t* gradInputForPlane = gradInput + plane * inputT * inputH * inputW; const scalar_t* gradOutputForPlane = gradOutput + plane * outputT * outputH * outputW; const int64_t* indicesForPlane = indices + plane * outputT * outputH * outputW; // NOLINTNEXTLINE(cppcoreguidelines-init-variables) int64_t h, w, t; for (t = 0; t < outputT; ++t) { for (h = 0; h < outputH; ++h) { for (w = 0; w < outputW; ++w) { int64_t outputIndex = t * outputH * outputW + h * outputW + w; int64_t index = indicesForPlane[outputIndex]; AT_ASSERT(index >= 0 && index < inputT * inputH * inputW); gradInputForPlane[index] += gradOutputForPlane[outputIndex]; } } } } }); } template static void fractional_max_pool3d_backward_out_frame( scalar_t* gradInput, const scalar_t* gradOutput, const int64_t* indices, int64_t numBatch, int64_t numPlanes, int64_t inputT, int64_t inputH, int64_t inputW, int64_t outputT, int64_t outputH, int64_t outputW) { if(numBatch == 1) { fractional_max_pool3d_backward_out_single_batch_frame( gradInput, gradOutput, indices, numPlanes, inputT, inputH, inputW, outputT, outputH, outputW ); return; } at::parallel_for(0, numBatch, 0, [&](int64_t start, int64_t end) { for (const auto batch : c10::irange(start, end)) { fractional_max_pool3d_backward_out_single_batch_frame( gradInput + batch * numPlanes * inputW * inputH * inputT, gradOutput + batch * numPlanes * outputW * outputH * outputT, indices + batch * numPlanes * outputW * outputH * outputT, numPlanes, inputT, inputH, inputW, outputT, outputH, outputW ); } }); } void fractional_max_pool3d_backward_out_cpu_template( const Tensor& input, const Tensor& gradOutput_, Tensor& gradInput, IntArrayRef output_size, IntArrayRef pool_size /* unused */, const Tensor& indices) { int64_t outputT = output_size[0]; int64_t outputH = output_size[1]; int64_t outputW = output_size[2]; int64_t numBatch = 1; int64_t planeDim = 0; int64_t timeDim = 1; int64_t heightDim = 2; int64_t widthDim = 3; int64_t ndims = input.ndimension(); if (ndims == 5) { numBatch = input.size(0); planeDim = 1; heightDim++; widthDim++; timeDim++; } /* sizes */ int64_t numPlanes = input.size(planeDim); int64_t inputT = input.size(timeDim); int64_t inputH = input.size(heightDim); int64_t inputW = input.size(widthDim); TORCH_CHECK(outputT == gradOutput_.size(timeDim), "fractional_max_pool3d_backward_out(): gradOutput time unexpected"); TORCH_CHECK(outputH == gradOutput_.size(heightDim), "fractional_max_pool3d_backward_out(): ", "gradOutput height unexpected"); TORCH_CHECK(outputW == gradOutput_.size(widthDim), "fractional_max_pool3d_backward_out(): gradOutput width unexpected"); /* get contiguous gradOutput */ auto gradOutput = gradOutput_.contiguous(); /* resize */ gradInput.resize_as_(input); gradInput.zero_(); /* backprop */ AT_DISPATCH_FLOATING_TYPES_AND2( kBFloat16, kHalf, input.scalar_type(), "fractional_max_pool3d_backward_out_frame", [&]{ fractional_max_pool3d_backward_out_frame( gradInput.data_ptr(), gradOutput.const_data_ptr(), indices.const_data_ptr(), numBatch, numPlanes, inputT, inputH, inputW, outputT, outputH, outputW ); } ); } }// anonymous namespace Tensor& fractional_max_pool3d_backward_out_cpu(const at::Tensor& gradOutput_, const at::Tensor& input, IntArrayRef pool_size, IntArrayRef output_size, const at::Tensor& indices, at::Tensor& gradInput) { fractional_max_pool3d_backward_out_cpu_template( input, gradOutput_, gradInput, output_size, pool_size, indices); return gradInput; } Tensor fractional_max_pool3d_backward_cpu( const at::Tensor& gradOutput_, const at::Tensor& input, IntArrayRef pool_size, IntArrayRef output_size, const at::Tensor& indices) { Tensor gradInput = at::empty({0}, input.options()); fractional_max_pool3d_backward_out_cpu_template( input, gradOutput_, gradInput, output_size, pool_size, indices); return gradInput; } } // namespace at::native