xref: /aosp_15_r20/external/pytorch/cmake/Modules/FindAVX.cmake (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1INCLUDE(CheckCSourceRuns)
2INCLUDE(CheckCSourceCompiles)
3INCLUDE(CheckCXXSourceRuns)
4
5SET(AVX_CODE "
6  #include <immintrin.h>
7
8  int main()
9  {
10    __m256 a;
11    a = _mm256_set1_ps(0);
12    return 0;
13  }
14")
15
16SET(AVX512_CODE "
17  #include <immintrin.h>
18
19  int main()
20  {
21    __m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0,
22                                0, 0, 0, 0, 0, 0, 0, 0,
23                                0, 0, 0, 0, 0, 0, 0, 0,
24                                0, 0, 0, 0, 0, 0, 0, 0,
25                                0, 0, 0, 0, 0, 0, 0, 0,
26                                0, 0, 0, 0, 0, 0, 0, 0,
27                                0, 0, 0, 0, 0, 0, 0, 0,
28                                0, 0, 0, 0, 0, 0, 0, 0);
29    __m512i b = a;
30    __mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ);
31    return 0;
32  }
33")
34
35SET(AVX2_CODE "
36  #include <immintrin.h>
37
38  int main()
39  {
40    __m256i a = {0};
41    a = _mm256_abs_epi16(a);
42    __m256i x;
43    _mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code
44    return 0;
45  }
46")
47
48MACRO(CHECK_SSE lang type flags)
49  SET(__FLAG_I 1)
50  SET(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
51  FOREACH(__FLAG ${flags})
52    IF(NOT ${lang}_${type}_FOUND)
53      SET(CMAKE_REQUIRED_FLAGS ${__FLAG})
54      IF(lang STREQUAL "CXX")
55        CHECK_CXX_SOURCE_COMPILES("${${type}_CODE}" ${lang}_HAS_${type}_${__FLAG_I})
56      ELSE()
57        CHECK_C_SOURCE_COMPILES("${${type}_CODE}" ${lang}_HAS_${type}_${__FLAG_I})
58      ENDIF()
59      IF(${lang}_HAS_${type}_${__FLAG_I})
60        SET(${lang}_${type}_FOUND TRUE CACHE BOOL "${lang} ${type} support")
61        SET(${lang}_${type}_FLAGS "${__FLAG}" CACHE STRING "${lang} ${type} flags")
62      ENDIF()
63      MATH(EXPR __FLAG_I "${__FLAG_I}+1")
64    ENDIF()
65  ENDFOREACH()
66  SET(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
67
68  IF(NOT ${lang}_${type}_FOUND)
69    SET(${lang}_${type}_FOUND FALSE CACHE BOOL "${lang} ${type} support")
70    SET(${lang}_${type}_FLAGS "" CACHE STRING "${lang} ${type} flags")
71  ENDIF()
72
73  MARK_AS_ADVANCED(${lang}_${type}_FOUND ${lang}_${type}_FLAGS)
74
75ENDMACRO()
76
77CHECK_SSE(C "AVX" " ;-mavx;/arch:AVX")
78CHECK_SSE(C "AVX2" " ;-mavx2 -mfma -mf16c;/arch:AVX2")
79CHECK_SSE(C "AVX512" " ;-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma;/arch:AVX512")
80
81CHECK_SSE(CXX "AVX" " ;-mavx;/arch:AVX")
82CHECK_SSE(CXX "AVX2" " ;-mavx2 -mfma -mf16c;/arch:AVX2")
83CHECK_SSE(CXX "AVX512" " ;-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma;/arch:AVX512")
84