#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else #include #include #endif namespace at::meta { TORCH_META_FUNC(adaptive_max_pool2d) (const Tensor& input, IntArrayRef output_size) { int ndim = input.ndimension(); TORCH_CHECK(ndim == 3 || ndim == 4, "adaptive_max_pool2d(): Expected 3D or 4D tensor, but got: ", input.sizes()); for (const auto i : c10::irange(1, ndim)) { TORCH_CHECK(input.size(i) > 0, "adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, " "but input has sizes ", input.sizes(), " with dimension ", i, " being empty"); } TORCH_CHECK(output_size.size() == 2, "adaptive_max_pool2d(): internal error: output_size.size() must be 2"); int dimH = 1; int64_t sizeB = 1; int64_t sizeD = 0; if (input.ndimension() == 4) { sizeB = input.size(0); dimH++; } sizeD = input.size(dimH - 1); int64_t osizeH = output_size[0]; int64_t osizeW = output_size[1]; /* resize output */ if (input.ndimension() == 3) { set_output_raw_strided(0, {sizeD, osizeH, osizeW}, {}, input.options()); /* indices will contain i,j locations for each output point */ set_output_raw_strided(1, {sizeD, osizeH, osizeW}, {}, input.options().dtype(kLong)); } else { set_output_raw_strided(0, {sizeB, sizeD, osizeH, osizeW}, {}, input.options().memory_format(input.suggest_memory_format())); /* indices will contain i,j locations for each output point */ set_output_raw_strided(1, {sizeB, sizeD, osizeH, osizeW}, {}, input.options().memory_format(input.suggest_memory_format()).dtype(kLong)); } } TORCH_META_FUNC(adaptive_max_pool2d_backward) (const Tensor& grad_output, const Tensor& input, const Tensor& indices) { int64_t ndim = grad_output.ndimension(); TORCH_CHECK(ndim == 3 || ndim == 4, "adaptive_max_pooling2d_backward(): Expected 3D or 4D grad_output, but got: ", grad_output.sizes()); at::native::adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool2d_backward"); TORCH_CHECK(input.dtype() == grad_output.dtype(), "expected dtype ", input.dtype(), " for `grad_output` but got dtype ", grad_output.dtype()); set_output_raw_strided(0, input.sizes(), {}, input.options().memory_format(input.suggest_memory_format())); } } // namespace at::meta namespace at::native { TORCH_IMPL_FUNC(adaptive_max_pool2d_out_cpu) (const Tensor& input, IntArrayRef output_size, const Tensor& output, const Tensor& indices) { adaptive_max_pool2d_kernel(kCPU, output, indices, input, output_size); } TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cpu) (const Tensor& grad_output, const Tensor& input, const Tensor& indices, const Tensor& grad_input) { grad_input.zero_(); adaptive_max_pool2d_backward_kernel(kCPU, grad_input, grad_output, indices); } DEFINE_DISPATCH(adaptive_max_pool2d_kernel); DEFINE_DISPATCH(adaptive_max_pool2d_backward_kernel); } // namespace at::native