#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #if !defined(__s390x__) && !defined(__powerpc__) #include #endif #include #include #ifdef HAVE_ZVECTOR_CPU_DEFINITION #include #endif namespace at::native { #ifdef HAVE_ZVECTOR_CPU_DEFINITION static inline bool cpu_has_vxe() { return (getauxval(AT_HWCAP) & HWCAP_S390_VXE); } #endif static CPUCapability compute_cpu_capability() { auto envar = std::getenv("ATEN_CPU_CAPABILITY"); if (envar) { #if defined(HAVE_VSX_CPU_DEFINITION) if (strcmp(envar, "vsx") == 0) { return CPUCapability::VSX; } #elif defined(HAVE_ZVECTOR_CPU_DEFINITION) if (strcmp(envar, "zvector") == 0) { return CPUCapability::ZVECTOR; } #else #ifdef HAVE_AVX512_CPU_DEFINITION if (strcmp(envar, "avx512") == 0) { return CPUCapability::AVX512; } #endif #ifdef HAVE_AVX2_CPU_DEFINITION if (strcmp(envar, "avx2") == 0) { return CPUCapability::AVX2; } #endif #endif if (strcmp(envar, "default") == 0) { return CPUCapability::DEFAULT; } TORCH_WARN("ignoring invalid value for ATEN_CPU_CAPABILITY: ", envar); } #if !defined(__powerpc__) && !defined(__s390x__) if (cpuinfo_initialize()) { #if defined(HAVE_AVX512_CPU_DEFINITION) // GCC supports some AVX512 intrinsics such as _mm512_set_epi16 only in // versions 9 & beyond. So, we want to ensure that only releases built with // supported compilers on supported hardware return CPU Capability AVX512, // if it's supported on the hardware PyTorch is running on. if (cpuinfo_has_x86_avx512vl() && cpuinfo_has_x86_avx512bw() && \ cpuinfo_has_x86_avx512dq() && cpuinfo_has_x86_fma3()) { return CPUCapability::AVX512; } #endif #ifdef HAVE_AVX2_CPU_DEFINITION if (cpuinfo_has_x86_avx2() && cpuinfo_has_x86_fma3()) { return CPUCapability::AVX2; } #endif } #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION // vxe is needed for fp32 vector instructions if (cpu_has_vxe()) { return CPUCapability::ZVECTOR; } #endif #ifdef HAVE_VSX_CPU_DEFINITION return CPUCapability::VSX; #else return CPUCapability::DEFAULT; #endif } CPUCapability get_cpu_capability() { static CPUCapability capability = compute_cpu_capability(); return capability; } DispatchResult DispatchStubImpl::try_get_call_ptr( const DeviceType device_type , void *DEFAULT #ifdef HAVE_AVX512_CPU_DEFINITION , void *AVX512 #endif #ifdef HAVE_AVX2_CPU_DEFINITION , void *AVX2 #endif #ifdef HAVE_VSX_CPU_DEFINITION , void *VSX #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR #endif ) { constexpr auto supported_devices = c10::array_of( c10::DeviceType::CPU, c10::DeviceType::CUDA, c10::DeviceType::HIP, c10::DeviceType::MPS, c10::DeviceType::MTIA, c10::DeviceType::XPU, c10::DeviceType::PrivateUse1 ); // Check if the device type is supported. if (std::find(supported_devices.begin(), supported_devices.end(), device_type) == supported_devices.end()) { return ErrorType::DeviceNotSupported; } switch (device_type) { case DeviceType::CPU: { // Use memory_order_relaxed here since even if two threads race, // they will still compute the same value for cpu_dispatch_ptr. auto fptr = cpu_dispatch_ptr.load(std::memory_order_relaxed); if (!fptr) { auto result = try_choose_cpu_impl( DEFAULT #ifdef HAVE_AVX512_CPU_DEFINITION , AVX512 #endif #ifdef HAVE_AVX2_CPU_DEFINITION , AVX2 #endif #ifdef HAVE_VSX_CPU_DEFINITION , VSX #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION , ZVECTOR #endif ); if (!std::holds_alternative(result)) { cpu_dispatch_ptr.store(fptr, std::memory_order_relaxed); } return result; } return DispatchResult(fptr); } case DeviceType::CUDA: return cuda_dispatch_ptr != nullptr ? DispatchResult(cuda_dispatch_ptr) : ErrorType::MissingDeviceKernel; case DeviceType::HIP: return hip_dispatch_ptr != nullptr ? DispatchResult(hip_dispatch_ptr) : ErrorType::MissingDeviceKernel; #if defined(USE_MPS) case DeviceType::MPS: return mps_dispatch_ptr != nullptr ? DispatchResult(mps_dispatch_ptr) : ErrorType::MissingDeviceKernel; #endif case DeviceType::MTIA: return mtia_dispatch_ptr != nullptr ? DispatchResult(mtia_dispatch_ptr) : ErrorType::MissingDeviceKernel; #if defined(USE_XPU) case DeviceType::XPU: return xpu_dispatch_ptr != nullptr ? DispatchResult(xpu_dispatch_ptr) : ErrorType::MissingDeviceKernel; #endif case DeviceType::PrivateUse1: return privateuse1_dispatch_ptr != nullptr ? DispatchResult(privateuse1_dispatch_ptr) : ErrorType::MissingDeviceKernel; default: TORCH_INTERNAL_ASSERT(false, "An unexpected device type was provided ", device_type); return ErrorType::DeviceNotSupported; } } void* DispatchStubImpl::get_call_ptr( const DeviceType device_type , void *DEFAULT #ifdef HAVE_AVX512_CPU_DEFINITION , void *AVX512 #endif #ifdef HAVE_AVX2_CPU_DEFINITION , void *AVX2 #endif #ifdef HAVE_VSX_CPU_DEFINITION , void *VSX #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR #endif ) { auto result = try_get_call_ptr( device_type, DEFAULT #ifdef HAVE_AVX512_CPU_DEFINITION , AVX512 #endif #ifdef HAVE_AVX2_CPU_DEFINITION , AVX2 #endif #ifdef HAVE_VSX_CPU_DEFINITION , VSX #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION , ZVECTOR #endif ); if (std::holds_alternative(result)) { auto error = std::get(result); switch (error) { case ErrorType::MissingDeviceKernel: TORCH_INTERNAL_ASSERT( false, "DispatchStub: missing kernel for ", device_type); return nullptr; case ErrorType::DeviceNotSupported: AT_ERROR("DispatchStub: unsupported device type", device_type); } } void* fptr = std::get(result); return fptr; } DispatchResult DispatchStubImpl::try_choose_cpu_impl( void *DEFAULT #ifdef HAVE_AVX512_CPU_DEFINITION , void *AVX512 #endif #ifdef HAVE_AVX2_CPU_DEFINITION , void *AVX2 #endif #ifdef HAVE_VSX_CPU_DEFINITION , void *VSX #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR #endif ){ auto capability = static_cast(get_cpu_capability()); (void)capability; #ifdef HAVE_AVX512_CPU_DEFINITION if (capability >= static_cast(CPUCapability::AVX512)) { // Quantization kernels have also been disabled on Windows // for AVX512 because some of their tests are flaky on Windows. // Ideally, we should have AVX512 kernels for all kernels. if (C10_UNLIKELY(!AVX512)) { // dispatch to AVX2, since the AVX512 kernel is missing return AVX2 != nullptr ? DispatchResult(AVX2) : ErrorType::MissingDeviceKernel; } else { return DispatchResult(AVX512); } } #endif #ifdef HAVE_AVX2_CPU_DEFINITION if (capability >= static_cast(CPUCapability::AVX2)) { return AVX2 != nullptr ? DispatchResult(AVX2) : ErrorType::MissingDeviceKernel; } #endif #ifdef HAVE_VSX_CPU_DEFINITION if (capability >= static_cast(CPUCapability::VSX)) { return VSX != nullptr ? DispatchResult(VSX) : ErrorType::MissingDeviceKernel; } #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION if (capability >= static_cast(CPUCapability::ZVECTOR)) { return ZVECTOR != nullptr ? DispatchResult(ZVECTOR) : ErrorType::MissingDeviceKernel; } #endif return DEFAULT != nullptr ? DispatchResult(DEFAULT) : ErrorType::MissingDeviceKernel; } void* DispatchStubImpl::choose_cpu_impl( void *DEFAULT #ifdef HAVE_AVX512_CPU_DEFINITION , void *AVX512 #endif #ifdef HAVE_AVX2_CPU_DEFINITION , void *AVX2 #endif #ifdef HAVE_VSX_CPU_DEFINITION , void *VSX #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION , void *ZVECTOR #endif ) { auto capability = static_cast(get_cpu_capability()); (void)capability; #ifdef HAVE_AVX512_CPU_DEFINITION if (capability >= static_cast(CPUCapability::AVX512)) { // Quantization kernels have also been disabled on Windows // for AVX512 because some of their tests are flaky on Windows. // Ideally, we should have AVX512 kernels for all kernels. if (C10_UNLIKELY(!AVX512)) { // dispatch to AVX2, since the AVX512 kernel is missing TORCH_INTERNAL_ASSERT(AVX2, "DispatchStub: missing AVX2 kernel"); return AVX2; } else { return AVX512; } } #endif #ifdef HAVE_AVX2_CPU_DEFINITION if (capability >= static_cast(CPUCapability::AVX2)) { TORCH_INTERNAL_ASSERT(AVX2, "DispatchStub: missing AVX2 kernel"); return AVX2; } #endif #ifdef HAVE_VSX_CPU_DEFINITION if (capability >= static_cast(CPUCapability::VSX)) { TORCH_INTERNAL_ASSERT(VSX, "DispatchStub: missing VSX kernel"); return VSX; } #endif #ifdef HAVE_ZVECTOR_CPU_DEFINITION if (capability >= static_cast(CPUCapability::ZVECTOR)) { TORCH_INTERNAL_ASSERT(ZVECTOR, "DispatchStub: missing ZVECTOR kernel"); return ZVECTOR; } #endif TORCH_INTERNAL_ASSERT(DEFAULT, "DispatchStub: missing default kernel"); return DEFAULT; } } // namespace at::native