#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #include #include #include #if AT_MKL_ENABLED() #include #endif // AT_MKL_ENABLED() namespace at::native { namespace { bool IsAGeZeroAndALtB(int64_t a, int64_t b) { return static_cast(a) < static_cast(b); } template void MatCopy(int64_t M, int64_t N, int64_t lda, int64_t ldb, const T* A, T* B) { for (const auto i : c10::irange(M)) { std::memcpy(B + i * ldb, A + i * lda, N * sizeof(T)); } } template void MatCopy( int64_t M, int64_t N, int64_t lda, int64_t stridea, int64_t ldb, int64_t strideb, const T* A, T* B) { for (const auto i : c10::irange(M)) { const T* A_ptr = A + i * lda; T* B_ptr = B + i * ldb; for (const auto j : c10::irange(N)) { B_ptr[j * strideb] = A_ptr[j * stridea]; } } } // Y += X template void MatAdd(int64_t M, int64_t N, int64_t ldx, int64_t ldy, const T* X, T* Y) { for (const auto i : c10::irange(M)) { for (const auto j : c10::irange(N)) { Y[i * ldy + j] += X[i * ldx + j]; } } } // Y += X template void MatAdd( int64_t M, int64_t N, int64_t ldx, int64_t stridex, int64_t ldy, int64_t stridey, const T* X, T* Y) { for (const auto i : c10::irange(M)) { for (const auto j : c10::irange(N)) { Y[i * ldy + j * stridey] += X[i * ldx + j * stridex]; } } } #if AT_MKL_ENABLED() template <> void MatCopy( int64_t M, int64_t N, int64_t lda, int64_t ldb, const float* A, float* B) { mkl_somatcopy('R', 'N', M, N, 1.0f, A, lda, B, ldb); } template <> void MatCopy( int64_t M, int64_t N, int64_t lda, int64_t ldb, const double* A, double* B) { mkl_domatcopy('R', 'N', M, N, 1.0, A, lda, B, ldb); } template <> void MatCopy( int64_t M, int64_t N, int64_t lda, int64_t stridea, int64_t ldb, int64_t strideb, const float* A, float* B) { mkl_somatcopy2('R', 'N', M, N, 1.0f, A, lda, stridea, B, ldb, strideb); } template <> void MatCopy( int64_t M, int64_t N, int64_t lda, int64_t stridea, int64_t ldb, int64_t strideb, const double* A, double* B) { mkl_domatcopy2('R', 'N', M, N, 1.0, A, lda, stridea, B, ldb, strideb); } template <> void MatAdd( int64_t M, int64_t N, int64_t ldx, int64_t ldy, const float* X, float* Y) { mkl_somatadd('R', 'N', 'N', M, N, 1.0f, X, ldx, 1.0f, Y, ldy, Y, ldy); } template <> void MatAdd( int64_t M, int64_t N, int64_t ldx, int64_t ldy, const double* X, double* Y) { mkl_domatadd('R', 'N', 'N', M, N, 1.0, X, ldx, 1.0, Y, ldy, Y, ldy); } template <> void MatAdd( int64_t M, int64_t N, int64_t ldx, int64_t stridex, int64_t ldy, int64_t stridey, const float* X, float* Y) { for (const auto i : c10::irange(M)) { cblas_saxpy(N, 1.0f, X + i * ldx, stridex, Y + i * ldy, stridey); } } template <> void MatAdd( int64_t M, int64_t N, int64_t ldx, int64_t stridex, int64_t ldy, int64_t stridey, const double* X, double* Y) { for (const auto i : c10::irange(M)) { cblas_daxpy(N, 1.0, X + i * ldx, stridex, Y + i * ldy, stridey); } } #endif // AT_MKL_ENABLED() template void Unfold3dZeroPaddingCopyKernelImpl( int64_t C, int64_t X_D, int64_t X_H, int64_t X_W, int64_t Y_D, int64_t Y_H, int64_t Y_W, int64_t kernel_d, int64_t kernel_h, int64_t kernel_w, int64_t stride_d, int64_t stride_h, int64_t stride_w, const T* src, T* dst) { const int64_t n = C * kernel_d * kernel_h * kernel_w; const int64_t X_size = X_D * X_H * X_W; const int64_t Y_size = Y_D * Y_H * Y_W; at::parallel_for(0, n, 0, [=](int64_t begin, int64_t end) { for (const auto p : c10::irange(begin, end)) { int64_t c = p; const int64_t kw = c % kernel_w; c /= kernel_w; const int64_t kh = c % kernel_h; c /= kernel_h; const int64_t kd = c % kernel_d; c /= kernel_d; for (const auto yd : c10::irange(Y_D)) { const int64_t xd = yd * stride_d + kd; const T* src_ptr = src + c * X_size + xd * X_H * X_W + kh * X_W + kw; T* dst_ptr = dst + p * Y_size + yd * Y_H * Y_W; if (stride_w == 1) { MatCopy(Y_H, Y_W, stride_h * X_W, Y_W, src_ptr, dst_ptr); } else { MatCopy( Y_H, Y_W, stride_h * X_W, stride_w, Y_W, 1, src_ptr, dst_ptr); } } } }); } template void Unfold3dCopyKernelImpl( int64_t C, int64_t X_D, int64_t X_H, int64_t X_W, int64_t Y_D, int64_t Y_H, int64_t Y_W, int64_t kernel_d, int64_t kernel_h, int64_t kernel_w, int64_t stride_d, int64_t stride_h, int64_t stride_w, int64_t pad_d, int64_t pad_h, int64_t pad_w, const T* src, T* dst) { if (pad_d == 0 && pad_h == 0 && pad_w == 0) { Unfold3dZeroPaddingCopyKernelImpl( C, X_D, X_H, X_W, Y_D, Y_H, Y_W, kernel_d, kernel_h, kernel_w, stride_d, stride_h, stride_w, src, dst); return; } const int64_t n = C * kernel_d * kernel_h * kernel_w; const int64_t X_size = X_D * X_H * X_W; const int64_t Y_size = Y_D * Y_H * Y_W; at::parallel_for(0, n, 0, [=](int64_t begin, int64_t end) { for (const auto p : c10::irange(begin, end)) { int64_t c = p; const int64_t kw = c % kernel_w; c /= kernel_w; const int64_t kh = c % kernel_h; c /= kernel_h; const int64_t kd = c % kernel_d; c /= kernel_d; const T* src_ptr = src + c * X_size; T* dst_ptr = dst + p * Y_size; for (const auto yd : c10::irange(Y_D)) { const int64_t xd = yd * stride_d - pad_d + kd; if (!IsAGeZeroAndALtB(xd, X_D)) { std::memset(dst_ptr + yd * Y_H * Y_W, 0, Y_H * Y_W * sizeof(T)); continue; } for (const auto yh : c10::irange(Y_H)) { const int64_t xh = yh * stride_h - pad_h + kh; if (!IsAGeZeroAndALtB(xh, X_H)) { std::memset( dst_ptr + yd * Y_H * Y_W + yh * Y_W, 0, Y_W * sizeof(T)); continue; } for (const auto yw : c10::irange(Y_W)) { const int64_t xw = yw * stride_w - pad_w + kw; dst_ptr[yd * Y_H * Y_W + yh * Y_W + yw] = IsAGeZeroAndALtB(xw, X_W) ? src_ptr[xd * X_H * X_W + xh * X_W + xw] : T(0); } } } } }); } template void Unfold3dZeroPaddingAccKernelImpl( int64_t C, int64_t X_D, int64_t X_H, int64_t X_W, int64_t Y_D, int64_t Y_H, int64_t Y_W, int64_t kernel_d, int64_t kernel_h, int64_t kernel_w, int64_t stride_d, int64_t stride_h, int64_t stride_w, const T* src, T* dst) { const int64_t X_size = X_D * X_H * X_W; const int64_t Y_size = Y_D * Y_H * Y_W; const int64_t kernel_size = kernel_d * kernel_h * kernel_w; at::parallel_for(0, C, 0, [=](int64_t begin, int64_t end) { std::memset(dst + begin * X_size, 0, (end - begin) * X_size * sizeof(T)); for (const auto c : c10::irange(begin, end)) { for (const auto kd : c10::irange(kernel_d)) { for (const auto kh : c10::irange(kernel_h)) { for (const auto kw : c10::irange(kernel_w)) { const int64_t p = c * kernel_size + kd * kernel_h * kernel_w + kh * kernel_w + kw; for (const auto yd : c10::irange(Y_D)) { const int64_t xd = yd * stride_d + kd; const T* src_ptr = src + p * Y_size + yd * Y_H * Y_W; T* dst_ptr = dst + c * X_size + xd * X_H * X_W + kh * X_W + kw; if (stride_w == 1) { MatAdd(Y_H, Y_W, Y_W, stride_h * X_W, src_ptr, dst_ptr); } else { MatAdd( Y_H, Y_W, Y_W, 1, stride_h * X_W, stride_w, src_ptr, dst_ptr); } } } } } } }); } template void Unfold3dAccKernelImpl( int64_t C, int64_t X_D, int64_t X_H, int64_t X_W, int64_t Y_D, int64_t Y_H, int64_t Y_W, int64_t kernel_d, int64_t kernel_h, int64_t kernel_w, int64_t stride_d, int64_t stride_h, int64_t stride_w, int64_t pad_d, int64_t pad_h, int64_t pad_w, const T* src, T* dst) { if (pad_d == 0 && pad_h == 0 && pad_w == 0) { Unfold3dZeroPaddingAccKernelImpl( C, X_D, X_H, X_W, Y_D, Y_H, Y_W, kernel_d, kernel_h, kernel_w, stride_d, stride_h, stride_w, src, dst); return; } const int64_t X_size = X_D * X_H * X_W; const int64_t Y_size = Y_D * Y_H * Y_W; const int64_t kernel_size = kernel_d * kernel_h * kernel_w; at::parallel_for(0, C, 0, [=](int64_t begin, int64_t end) { std::memset(dst + begin * X_size, 0, (end - begin) * X_size * sizeof(T)); for (const auto c : c10::irange(begin, end)) { T* dst_ptr = dst + c * X_size; for (const auto kd : c10::irange(kernel_d)) { for (const auto kh : c10::irange(kernel_h)) { for (const auto kw : c10::irange(kernel_w)) { const int64_t p = c * kernel_size + kd * kernel_h * kernel_w + kh * kernel_w + kw; const T* src_ptr = src + p * Y_size; for (const auto yd : c10::irange(Y_D)) { const int64_t xd = yd * stride_d - pad_d + kd; if (!IsAGeZeroAndALtB(xd, X_D)) { continue; } for (const auto yh : c10::irange(Y_H)) { const int64_t xh = yh * stride_h - pad_h + kh; if (!IsAGeZeroAndALtB(xh, X_H)) { continue; } for (const auto yw : c10::irange(Y_W)) { const int64_t xw = yw * stride_w - pad_w + kw; if (IsAGeZeroAndALtB(xw, X_W)) { dst_ptr[xd * X_H * X_W + xh * X_W + xw] += src_ptr[yd * Y_H * Y_W + yh * Y_W + yw]; } } } } } } } } }); } } // namespace void Unfold3dCopyCPU( ScalarType dtype, const void *src, int64_t C, int64_t X_D, int64_t X_H, int64_t X_W, int64_t Y_D, int64_t Y_H, int64_t Y_W, int64_t kernel_d, int64_t kernel_h, int64_t kernel_w, int64_t stride_d, int64_t stride_h, int64_t stride_w, int64_t pad_d, int64_t pad_h, int64_t pad_w, void* dst) { AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::BFloat16, at::ScalarType::Half, dtype, "Unfold3dCopyCPU", [=, &src]() { Unfold3dCopyKernelImpl( C, X_D, X_H, X_W, Y_D, Y_H, Y_W, kernel_d, kernel_h, kernel_w, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, static_cast(src), static_cast(dst)); }); } void Unfold3dAccCPU( ScalarType dtype, const void *src, int64_t C, int64_t X_D, int64_t X_H, int64_t X_W, int64_t Y_D, int64_t Y_H, int64_t Y_W, int64_t kernel_d, int64_t kernel_h, int64_t kernel_w, int64_t stride_d, int64_t stride_h, int64_t stride_w, int64_t pad_d, int64_t pad_h, int64_t pad_w, void* dst) { AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::BFloat16, at::ScalarType::Half, dtype, "Unfold3dAccCPU", [=, &src]() { Unfold3dAccKernelImpl( C, X_D, X_H, X_W, Y_D, Y_H, Y_W, kernel_d, kernel_h, kernel_w, stride_d, stride_h, stride_w, pad_d, pad_h, pad_w, static_cast(src), static_cast(dst)); }); } } // namespace at::native