xref: /aosp_15_r20/external/pytorch/third_party/generate-xnnpack-wrappers.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker#!/usr/bin/env python3
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom __future__ import print_function
4*da0073e9SAndroid Build Coastguard Workerimport collections
5*da0073e9SAndroid Build Coastguard Workerimport os
6*da0073e9SAndroid Build Coastguard Workerimport sys
7*da0073e9SAndroid Build Coastguard Workerimport logging
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard WorkerBANNER = "Auto-generated by generate-wrappers.py script. Do not modify"
10*da0073e9SAndroid Build Coastguard WorkerWRAPPER_SRC_NAMES = {
11*da0073e9SAndroid Build Coastguard Worker    "PROD_SCALAR_MICROKERNEL_SRCS": None,
12*da0073e9SAndroid Build Coastguard Worker    "PROD_FMA_MICROKERNEL_SRCS": "defined(__riscv) || defined(__riscv__)",
13*da0073e9SAndroid Build Coastguard Worker    "PROD_ARMSIMD32_MICROKERNEL_SRCS": "defined(__arm__)",
14*da0073e9SAndroid Build Coastguard Worker    "PROD_FP16ARITH_MICROKERNEL_SRCS": "defined(__arm__)",
15*da0073e9SAndroid Build Coastguard Worker    "PROD_NEON_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
16*da0073e9SAndroid Build Coastguard Worker    "PROD_NEONFP16_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
17*da0073e9SAndroid Build Coastguard Worker    "PROD_NEONFMA_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
18*da0073e9SAndroid Build Coastguard Worker    "PROD_NEON_AARCH64_MICROKERNEL_SRCS": "defined(__aarch64__)",
19*da0073e9SAndroid Build Coastguard Worker    "PROD_NEONV8_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
20*da0073e9SAndroid Build Coastguard Worker    "PROD_NEONFP16ARITH_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
21*da0073e9SAndroid Build Coastguard Worker    "PROD_NEONFP16ARITH_AARCH64_MICROKERNEL_SRCS": "defined(__aarch64__)",
22*da0073e9SAndroid Build Coastguard Worker    "PROD_NEONDOT_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
23*da0073e9SAndroid Build Coastguard Worker    "PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS": "defined(__aarch64__)",
24*da0073e9SAndroid Build Coastguard Worker    "PROD_NEONDOTFP16ARITH_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
25*da0073e9SAndroid Build Coastguard Worker    "PROD_NEONDOTFP16ARITH_AARCH64_MICROKERNEL_SRCS": "defined(__aarch64__)",
26*da0073e9SAndroid Build Coastguard Worker    "PROD_NEONI8MM_MICROKERNEL_SRCS": "defined(__arm__) || defined(__aarch64__)",
27*da0073e9SAndroid Build Coastguard Worker    "PROD_SSE_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
28*da0073e9SAndroid Build Coastguard Worker    "PROD_SSE2_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
29*da0073e9SAndroid Build Coastguard Worker    "PROD_SSSE3_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
30*da0073e9SAndroid Build Coastguard Worker    "PROD_SSE41_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
31*da0073e9SAndroid Build Coastguard Worker    "PROD_AVX_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
32*da0073e9SAndroid Build Coastguard Worker    "PROD_F16C_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
33*da0073e9SAndroid Build Coastguard Worker    "PROD_XOP_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
34*da0073e9SAndroid Build Coastguard Worker    "PROD_FMA3_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
35*da0073e9SAndroid Build Coastguard Worker    "PROD_AVX2_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
36*da0073e9SAndroid Build Coastguard Worker    "PROD_AVX512F_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
37*da0073e9SAndroid Build Coastguard Worker    "PROD_AVX512SKX_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
38*da0073e9SAndroid Build Coastguard Worker    "PROD_AVX512VBMI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
39*da0073e9SAndroid Build Coastguard Worker    "PROD_AVX512VNNI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
40*da0073e9SAndroid Build Coastguard Worker    "PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
41*da0073e9SAndroid Build Coastguard Worker    "PROD_RVV_MICROKERNEL_SRCS": "defined(__riscv) || defined(__riscv__)",
42*da0073e9SAndroid Build Coastguard Worker    "PROD_AVXVNNI_MICROKERNEL_SRCS": "defined(__i386__) || defined(__i686__) || defined(__x86_64__)",
43*da0073e9SAndroid Build Coastguard Worker    "AARCH32_ASM_MICROKERNEL_SRCS": "defined(__arm__)",
44*da0073e9SAndroid Build Coastguard Worker    "AARCH64_ASM_MICROKERNEL_SRCS": "defined(__aarch64__)",
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker    # add non-prod microkernel sources here:
47*da0073e9SAndroid Build Coastguard Worker}
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard WorkerSRC_NAMES = {
50*da0073e9SAndroid Build Coastguard Worker    "OPERATOR_SRCS",
51*da0073e9SAndroid Build Coastguard Worker    "SUBGRAPH_SRCS",
52*da0073e9SAndroid Build Coastguard Worker    "LOGGING_SRCS",
53*da0073e9SAndroid Build Coastguard Worker    "XNNPACK_SRCS",
54*da0073e9SAndroid Build Coastguard Worker    "TABLE_SRCS",
55*da0073e9SAndroid Build Coastguard Worker    "JIT_SRCS",
56*da0073e9SAndroid Build Coastguard Worker    "PROD_SCALAR_MICROKERNEL_SRCS",
57*da0073e9SAndroid Build Coastguard Worker    "PROD_FMA_MICROKERNEL_SRCS",
58*da0073e9SAndroid Build Coastguard Worker    "PROD_ARMSIMD32_MICROKERNEL_SRCS",
59*da0073e9SAndroid Build Coastguard Worker    "PROD_FP16ARITH_MICROKERNEL_SRCS",
60*da0073e9SAndroid Build Coastguard Worker    "PROD_NEON_MICROKERNEL_SRCS",
61*da0073e9SAndroid Build Coastguard Worker    "PROD_NEONFP16_MICROKERNEL_SRCS",
62*da0073e9SAndroid Build Coastguard Worker    "PROD_NEONFMA_MICROKERNEL_SRCS",
63*da0073e9SAndroid Build Coastguard Worker    "PROD_NEON_AARCH64_MICROKERNEL_SRCS",
64*da0073e9SAndroid Build Coastguard Worker    "PROD_NEONV8_MICROKERNEL_SRCS",
65*da0073e9SAndroid Build Coastguard Worker    "PROD_NEONFP16ARITH_MICROKERNEL_SRCS",
66*da0073e9SAndroid Build Coastguard Worker    "PROD_NEONFP16ARITH_AARCH64_MICROKERNEL_SRCS",
67*da0073e9SAndroid Build Coastguard Worker    "PROD_NEONDOT_MICROKERNEL_SRCS",
68*da0073e9SAndroid Build Coastguard Worker    "PROD_NEONDOT_AARCH64_MICROKERNEL_SRCS",
69*da0073e9SAndroid Build Coastguard Worker    "PROD_NEONDOTFP16ARITH_MICROKERNEL_SRCS",
70*da0073e9SAndroid Build Coastguard Worker    "PROD_NEONDOTFP16ARITH_AARCH64_MICROKERNEL_SRCS",
71*da0073e9SAndroid Build Coastguard Worker    "PROD_NEONI8MM_MICROKERNEL_SRCS",
72*da0073e9SAndroid Build Coastguard Worker    "PROD_SSE_MICROKERNEL_SRCS",
73*da0073e9SAndroid Build Coastguard Worker    "PROD_SSE2_MICROKERNEL_SRCS",
74*da0073e9SAndroid Build Coastguard Worker    "PROD_SSSE3_MICROKERNEL_SRCS",
75*da0073e9SAndroid Build Coastguard Worker    "PROD_SSE41_MICROKERNEL_SRCS",
76*da0073e9SAndroid Build Coastguard Worker    "PROD_AVX_MICROKERNEL_SRCS",
77*da0073e9SAndroid Build Coastguard Worker    "PROD_F16C_MICROKERNEL_SRCS",
78*da0073e9SAndroid Build Coastguard Worker    "PROD_XOP_MICROKERNEL_SRCS",
79*da0073e9SAndroid Build Coastguard Worker    "PROD_FMA3_MICROKERNEL_SRCS",
80*da0073e9SAndroid Build Coastguard Worker    "PROD_AVX2_MICROKERNEL_SRCS",
81*da0073e9SAndroid Build Coastguard Worker    "PROD_AVX512F_MICROKERNEL_SRCS",
82*da0073e9SAndroid Build Coastguard Worker    "PROD_AVX512SKX_MICROKERNEL_SRCS",
83*da0073e9SAndroid Build Coastguard Worker    "PROD_AVX512VBMI_MICROKERNEL_SRCS",
84*da0073e9SAndroid Build Coastguard Worker    "PROD_AVX512VNNI_MICROKERNEL_SRCS",
85*da0073e9SAndroid Build Coastguard Worker    "PROD_AVX512VNNIGFNI_MICROKERNEL_SRCS",
86*da0073e9SAndroid Build Coastguard Worker    "PROD_RVV_MICROKERNEL_SRCS",
87*da0073e9SAndroid Build Coastguard Worker    "PROD_AVXVNNI_MICROKERNEL_SRCS",
88*da0073e9SAndroid Build Coastguard Worker    "AARCH32_ASM_MICROKERNEL_SRCS",
89*da0073e9SAndroid Build Coastguard Worker    "AARCH64_ASM_MICROKERNEL_SRCS",
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker    # add non-prod microkernel sources here:
92*da0073e9SAndroid Build Coastguard Worker}
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Workerdef handle_singleline_parse(line):
95*da0073e9SAndroid Build Coastguard Worker    start_index = line.find("(")
96*da0073e9SAndroid Build Coastguard Worker    end_index = line.find(")")
97*da0073e9SAndroid Build Coastguard Worker    line = line[start_index+1:end_index]
98*da0073e9SAndroid Build Coastguard Worker    key_val = line.split(" ")
99*da0073e9SAndroid Build Coastguard Worker    return key_val[0], [x[4:] for x in key_val[1:]]
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Workerdef update_sources(xnnpack_path, cmakefile = "XNNPACK/CMakeLists.txt"):
102*da0073e9SAndroid Build Coastguard Worker    sources = collections.defaultdict(list)
103*da0073e9SAndroid Build Coastguard Worker    with open(os.path.join(xnnpack_path, cmakefile)) as cmake:
104*da0073e9SAndroid Build Coastguard Worker        lines = cmake.readlines()
105*da0073e9SAndroid Build Coastguard Worker        i = 0
106*da0073e9SAndroid Build Coastguard Worker        while i < len(lines):
107*da0073e9SAndroid Build Coastguard Worker            line = lines[i]
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker            if lines[i].startswith("SET") and "src/" in lines[i]:
110*da0073e9SAndroid Build Coastguard Worker                name, val = handle_singleline_parse(line)
111*da0073e9SAndroid Build Coastguard Worker                sources[name].extend(val)
112*da0073e9SAndroid Build Coastguard Worker                i+=1
113*da0073e9SAndroid Build Coastguard Worker                continue
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker            if line.startswith("SET") and line.split('(')[1].strip(' \t\n\r') in set(WRAPPER_SRC_NAMES.keys()) | set(SRC_NAMES):
116*da0073e9SAndroid Build Coastguard Worker                name = line.split('(')[1].strip(' \t\n\r')
117*da0073e9SAndroid Build Coastguard Worker                i += 1
118*da0073e9SAndroid Build Coastguard Worker                while i < len(lines) and len(lines[i]) > 0 and ')' not in lines[i]:
119*da0073e9SAndroid Build Coastguard Worker                    # remove "src/" at the beginning, remove whitespaces and newline
120*da0073e9SAndroid Build Coastguard Worker                    value = lines[i].strip(' \t\n\r')
121*da0073e9SAndroid Build Coastguard Worker                    sources[name].append(value[4:])
122*da0073e9SAndroid Build Coastguard Worker                    i += 1
123*da0073e9SAndroid Build Coastguard Worker                if i < len(lines) and len(lines[i]) > 4:
124*da0073e9SAndroid Build Coastguard Worker                    # remove "src/" at the beginning, possibly ')' at the end
125*da0073e9SAndroid Build Coastguard Worker                    value = lines[i].strip(' \t\n\r)')
126*da0073e9SAndroid Build Coastguard Worker                    sources[name].append(value[4:])
127*da0073e9SAndroid Build Coastguard Worker            else:
128*da0073e9SAndroid Build Coastguard Worker                i += 1
129*da0073e9SAndroid Build Coastguard Worker    return sources
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Workerdef gen_wrappers(xnnpack_path):
132*da0073e9SAndroid Build Coastguard Worker    xnnpack_sources = collections.defaultdict(list)
133*da0073e9SAndroid Build Coastguard Worker    sources = update_sources(xnnpack_path)
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker    microkernels_sources = update_sources(xnnpack_path, "XNNPACK/cmake/microkernels.cmake")
136*da0073e9SAndroid Build Coastguard Worker    for key in  microkernels_sources:
137*da0073e9SAndroid Build Coastguard Worker        sources[key] = microkernels_sources[key]
138*da0073e9SAndroid Build Coastguard Worker
139*da0073e9SAndroid Build Coastguard Worker    for name in WRAPPER_SRC_NAMES:
140*da0073e9SAndroid Build Coastguard Worker        xnnpack_sources[WRAPPER_SRC_NAMES[name]].extend(sources[name])
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker    for condition, filenames in xnnpack_sources.items():
143*da0073e9SAndroid Build Coastguard Worker        print(condition)
144*da0073e9SAndroid Build Coastguard Worker        for filename in filenames:
145*da0073e9SAndroid Build Coastguard Worker            filepath = os.path.join(xnnpack_path, "xnnpack_wrappers", filename)
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker            if not os.path.isdir(os.path.dirname(filepath)):
148*da0073e9SAndroid Build Coastguard Worker                os.makedirs(os.path.dirname(filepath))
149*da0073e9SAndroid Build Coastguard Worker            with open(filepath, "w") as wrapper:
150*da0073e9SAndroid Build Coastguard Worker                print("/* {} */".format(BANNER), file=wrapper)
151*da0073e9SAndroid Build Coastguard Worker                print(file=wrapper)
152*da0073e9SAndroid Build Coastguard Worker
153*da0073e9SAndroid Build Coastguard Worker                # Architecture- or platform-dependent preprocessor flags can be
154*da0073e9SAndroid Build Coastguard Worker                # defined here. Note: platform_preprocessor_flags can't be used
155*da0073e9SAndroid Build Coastguard Worker                # because they are ignored by arc focus & buck project.
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker                if condition is None:
158*da0073e9SAndroid Build Coastguard Worker                    print("#include <%s>" % filename, file=wrapper)
159*da0073e9SAndroid Build Coastguard Worker                else:
160*da0073e9SAndroid Build Coastguard Worker                    # Include source file only if condition is satisfied
161*da0073e9SAndroid Build Coastguard Worker                    print("#if %s" % condition, file=wrapper)
162*da0073e9SAndroid Build Coastguard Worker                    print("#include <%s>" % filename, file=wrapper)
163*da0073e9SAndroid Build Coastguard Worker                    print("#endif /* %s */" % condition, file=wrapper)
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker    # update xnnpack_wrapper_defs.bzl file under the same folder
166*da0073e9SAndroid Build Coastguard Worker    with open(os.path.join(os.path.dirname(__file__), "xnnpack_wrapper_defs.bzl"), 'w') as wrapper_defs:
167*da0073e9SAndroid Build Coastguard Worker        print('"""', file=wrapper_defs)
168*da0073e9SAndroid Build Coastguard Worker        print(BANNER, file=wrapper_defs)
169*da0073e9SAndroid Build Coastguard Worker        print('"""', file=wrapper_defs)
170*da0073e9SAndroid Build Coastguard Worker        for name in WRAPPER_SRC_NAMES:
171*da0073e9SAndroid Build Coastguard Worker            print('\n' + name + ' = [', file=wrapper_defs)
172*da0073e9SAndroid Build Coastguard Worker            for file_name in sources[name]:
173*da0073e9SAndroid Build Coastguard Worker                print('    "xnnpack_wrappers/{}",'.format(file_name), file=wrapper_defs)
174*da0073e9SAndroid Build Coastguard Worker            print(']', file=wrapper_defs)
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker    # update xnnpack_src_defs.bzl file under the same folder
177*da0073e9SAndroid Build Coastguard Worker    with open(os.path.join(os.path.dirname(__file__), "xnnpack_src_defs.bzl"), 'w') as src_defs:
178*da0073e9SAndroid Build Coastguard Worker        print('"""', file=src_defs)
179*da0073e9SAndroid Build Coastguard Worker        print(BANNER, file=src_defs)
180*da0073e9SAndroid Build Coastguard Worker        print('"""', file=src_defs)
181*da0073e9SAndroid Build Coastguard Worker        for name in SRC_NAMES:
182*da0073e9SAndroid Build Coastguard Worker            print('\n' + name + ' = [', file=src_defs)
183*da0073e9SAndroid Build Coastguard Worker            for file_name in sources[name]:
184*da0073e9SAndroid Build Coastguard Worker                print('    "XNNPACK/src/{}",'.format(file_name), file=src_defs)
185*da0073e9SAndroid Build Coastguard Worker            print(']', file=src_defs)
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Workerdef main(argv):
189*da0073e9SAndroid Build Coastguard Worker    if argv is None or len(argv) == 0:
190*da0073e9SAndroid Build Coastguard Worker        gen_wrappers(".")
191*da0073e9SAndroid Build Coastguard Worker    else:
192*da0073e9SAndroid Build Coastguard Worker        gen_wrappers(argv[0])
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker# The first argument is the place where the "xnnpack_wrappers" folder will be created.
195*da0073e9SAndroid Build Coastguard Worker# Run it without arguments will generate "xnnpack_wrappers" in the current path.
196*da0073e9SAndroid Build Coastguard Worker# The two .bzl files will always be generated in the current path.
197*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
198*da0073e9SAndroid Build Coastguard Worker    main(sys.argv[1:])
199