xref: /aosp_15_r20/external/pytorch/buckbuild.bzl (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# NOTE: This file is shared by internal and OSS BUCK build.
2# These load paths point to different files in internal and OSS environment
3
4load("@bazel_skylib//lib:paths.bzl", "paths")
5load("//tools/build_defs:fb_native_wrapper.bzl", "fb_native")
6load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library")
7load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule")
8load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode")
9load("//tools/build_defs:glob_defs.bzl", "subdir_glob")
10load("//tools/build_defs:platform_defs.bzl", "APPLETVOS", "IOS", "MACOSX")
11load("//tools/build_defs:type_defs.bzl", "is_list", "is_string")
12load("//tools/build_defs/android:build_mode_defs.bzl", is_production_build_android = "is_production_build")
13load("//tools/build_defs/apple:build_mode_defs.bzl", is_production_build_ios = "is_production_build")
14load("//tools/build_defs/windows:windows_flag_map.bzl", "windows_convert_gcc_clang_flags")
15load(
16    ":build_variables.bzl",
17    "aten_cpu_source_list",
18    "aten_native_source_list",
19    "core_sources_common",
20    "core_sources_full_mobile_no_backend_interface_xplat",
21    "core_trainer_sources",
22    "jit_core_headers",
23    "jit_core_sources",
24    "libtorch_profiler_sources",
25    "torch_cpp_srcs",
26    "torch_mobile_tracer_sources",
27)
28load(
29    ":pt_ops.bzl",
30    "USED_PT_BACKENDS",
31)
32load(
33    ":pt_template_srcs.bzl",
34    "METAL_MASKRCNN_SOURCE_LIST",
35    "METAL_SOURCE_LIST",
36    "TEMPLATE_MASKRCNN_SOURCE_LIST",
37    "TEMPLATE_SOURCE_LIST",
38    "aten_ufunc_generated_all_cpu_sources",
39    "get_gen_oplist_outs",
40    "get_generate_code_bin_outs",
41    "get_metal_registration_files_outs",
42    "get_metal_registration_files_outs_windows",
43    "get_metal_source_dict",
44    "get_template_registration_file_rules",
45    "get_template_registration_files_outs",
46    "get_template_source_dict",
47)
48load(
49    ":ufunc_defs.bzl",
50    "aten_ufunc_generated_cpu_kernel_sources",
51    "aten_ufunc_generated_cpu_sources",
52    "aten_ufunc_generated_cuda_sources",
53)
54
55def read_bool(section, field, default, required = True):
56    val = read_config(section, field)
57    if val != None:
58        if val in ["true", "True", "1"]:
59            return True
60        elif val in ["false", "False", "0"]:
61            return False
62        else:
63            fail(
64                "`{}:{}`: must be one of (0, 1, true, false, True, False), but was {}".format(section, field, val),
65            )
66    elif default != None:
67        return default
68    elif not required:
69        return None
70    else:
71        fail("`{}:{}`: no value set".format(section, field))
72
73def _is_build_mode_dev():
74    if is_production_build_android():
75        # Android Prod builds
76        return False
77    if is_production_build_ios():
78        # iOS Prod builds
79        return False
80
81    return True
82
83def _get_enable_lightweight_dispatch():
84    return read_bool("pt", "enable_lightweight_dispatch", False)
85
86def _get_enable_record_kernel_dtype():
87    return read_bool("pt", "enable_record_kernel_dtype", False)
88
89def get_enable_mobile_dispatch_keys_trimming():
90    return read_bool("pt", "enable_mobile_dispatch_keys_trimming", False)
91
92def get_disable_per_op_profiling():
93    return read_bool("pt", "disable_per_op_profiling", True)
94
95def get_strip_error_messages():
96    if IS_OSS:
97        return True  # always strip in OSS CI to expose potential issues
98    return read_bool("pt", "strip_error_messages", not _is_build_mode_dev())
99
100def get_disable_warn():
101    return read_bool("pt", "disable_warn", False)
102
103def get_enable_eager_symbolication():
104    return read_bool("pt", "enable_eager_symbolication", default = False, required = False)
105
106def get_static_dispatch_backend():
107    static_dispatch_backend = native.read_config("pt", "static_dispatch_backend", None)
108    if static_dispatch_backend == None:
109        return []
110    return static_dispatch_backend.split(";")
111
112def get_glsl_image_format():
113    if read_config("pt", "vulkan_full_precision", "0") == "0":
114        return "rgba16f"
115    return "rgba32f"
116
117def get_glsl_paths():
118    paths = [
119        "//xplat/caffe2:aten_vulkan_glsl_src_path",
120        "aten/src/ATen/native/vulkan/glsl",
121    ] + [
122        p
123        for p in read_config("gen_vulkan_spv", "additional_glsl_paths", "").split(" ")
124        if p
125    ]
126
127    if len(paths) % 2 != 0:
128        fail(
129            "gen_vulkan_spv.additional_glsl_paths must contain an even number of elements",
130        )
131
132    return " ".join(
133        [
134            "$(location {})/{}".format(
135                paths[i],
136                paths[i + 1],
137            )
138            for i in range(
139                0,
140                len(paths),
141                2,
142            )
143        ],
144    )
145
146def spv_shader_library():
147    pass
148
149IS_OSS = read_config("pt", "is_oss", "0") == "1"  # True for OSS BUCK build, and False for internal BUCK build
150
151NOT_OSS = not IS_OSS
152
153# for targets in caffe2 root path
154ROOT = "//" if IS_OSS else "//xplat/caffe2"
155
156# for targets in subfolders
157ROOT_PATH = "//" if IS_OSS else "//xplat/caffe2/"
158
159C10 = "//c10:c10" if IS_OSS else "//xplat/caffe2/c10:c10"
160
161# a dictionary maps third party library name to fbsource and oss target
162THIRD_PARTY_LIBS = {
163    "FP16": ["//xplat/third-party/FP16:FP16", "//third_party:FP16"],
164    "FXdiv": ["//xplat/third-party/FXdiv:FXdiv", "//third_party:FXdiv"],
165    "XNNPACK": ["//xplat/third-party/XNNPACK:XNNPACK", "//third_party:XNNPACK"],
166    "clog": ["//xplat/third-party/clog:clog", "//third_party:clog"],
167    "cpuinfo": ["//third-party/cpuinfo:cpuinfo", "//third_party:cpuinfo"],
168    "flatbuffers-api": ["//third-party/flatbuffers/fbsource_namespace:flatbuffers-api", "//third_party:flatbuffers-api"],
169    "flatc": ["//third-party/flatbuffers/fbsource_namespace:flatc", "//third_party:flatc"],
170    "fmt": ["//third-party/fmt:fmt", "//third_party:fmt"],
171    "glog": ["//third-party/glog:glog", "//third_party:glog"],
172    "gmock": ["//third-party/googletest:gmock_main", "//third_party:gmock"],
173    "gtest": ["//third-party/googletest:gtest_main", "//third_party:gtest"],
174    "kineto": ["//xplat/kineto/libkineto:libkineto", "//third_party:libkineto"],
175    "libkineto_headers": ["//xplat/kineto/libkineto:libkineto_headers", "//third_party:libkineto_headers"],
176    "omp": ["//xplat/third-party/linker_lib:omp", "//third_party:no-op"],
177    "pocketfft": ["//third-party/pocket_fft:pocketfft", "//third_party:pocketfft_header"],
178    "psimd": ["//xplat/third-party/psimd:psimd", "//third_party:psimd"],
179    "pthreadpool": ["//xplat/third-party/pthreadpool:pthreadpool", "//third_party:pthreadpool"],
180    "pthreadpool_header": ["//xplat/third-party/pthreadpool:pthreadpool_header", "//third_party:pthreadpool_header"],
181    "pyyaml": ["//third-party/pyyaml:pyyaml", "//third_party:pyyaml"],
182    "rt": ["//xplat/third-party/linker_lib:rt", "//third_party:rt"],
183    "ruy": ["//third-party/ruy:ruy_xplat_lib", "//third_party:ruy_lib"],
184    "sleef_arm": ["//third-party/sleef:sleef_arm", "//third_party:sleef_arm"],
185    "typing-extensions": ["//third-party/typing-extensions:typing-extensions", "//third_party:typing-extensions"],
186}
187
188def third_party(name):
189    if name not in THIRD_PARTY_LIBS:
190        fail("Cannot find third party library " + name + ", please register it in THIRD_PARTY_LIBS first!")
191    return THIRD_PARTY_LIBS[name][1] if IS_OSS else THIRD_PARTY_LIBS[name][0]
192
193def get_pt_compiler_flags():
194    return select({
195        "DEFAULT": _PT_COMPILER_FLAGS,
196        "ovr_config//compiler:cl": windows_convert_gcc_clang_flags(_PT_COMPILER_FLAGS),
197    })
198
199_PT_COMPILER_FLAGS = [
200    "-fexceptions",
201    "-frtti",
202    "-Os",
203    "-Wno-unknown-pragmas",
204    "-Wno-write-strings",
205    "-Wno-unused-variable",
206    "-Wno-unused-function",
207    "-Wno-deprecated-declarations",
208    "-Wno-shadow",
209    "-Wno-global-constructors",
210    "-Wno-missing-prototypes",
211]
212
213ATEN_COMPILER_FLAGS = [
214    "-fexceptions",
215    "-frtti",
216    "-fPIC",
217    "-Os",
218    "-Wno-absolute-value",
219    "-Wno-deprecated-declarations",
220    "-Wno-macro-redefined",
221    "-Wno-tautological-constant-out-of-range-compare",
222    "-Wno-unknown-pragmas",
223    "-Wno-unknown-warning-option",
224    "-Wno-unused-function",
225    "-Wno-unused-variable",
226    "-Wno-pass-failed",
227    "-Wno-shadow",
228]
229
230def get_aten_compiler_flags():
231    return ATEN_COMPILER_FLAGS
232
233_COMMON_PREPROCESSOR_FLAGS = [
234    "-DC10_MOBILE",
235    "-DNO_EXPORT",
236] + (
237    ["-DC10_MOBILE_TRIM_DISPATCH_KEYS"] if get_enable_mobile_dispatch_keys_trimming() else []
238) + (
239    ["-DSTRIP_ERROR_MESSAGES"] if get_strip_error_messages() else []
240) + (
241    ["-DDISABLE_WARN"] if get_disable_warn() else []
242)
243
244def get_aten_preprocessor_flags():
245    # read_config is not allowed outside of function in Starlark
246    ATEN_PREPROCESSOR_FLAGS = _COMMON_PREPROCESSOR_FLAGS + [
247        "-DCPU_CAPABILITY_DEFAULT",
248        "-DCPU_CAPABILITY=DEFAULT",
249        "-DCAFFE2_USE_LITE_PROTO",
250        "-DATEN_CUDNN_ENABLED_FBXPLAT=0",
251        "-DATEN_MKLDNN_ENABLED_FBXPLAT=0",
252        "-DATEN_MKLDNN_ACL_ENABLED_FBXPLAT=0",
253        "-DATEN_NNPACK_ENABLED_FBXPLAT=0",
254        "-DATEN_MKL_ENABLED_FBXPLAT=0",
255        "-DATEN_MKL_SEQUENTIAL_FBXPLAT=0",
256        "-DUSE_PYTORCH_METAL",
257        "-DUSE_PYTORCH_QNNPACK",
258        "-DUSE_XNNPACK",
259        "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION",
260        "-DAT_PARALLEL_OPENMP_FBXPLAT=0",
261        "-DAT_PARALLEL_NATIVE_FBXPLAT=1",
262        "-DUSE_LAPACK_FBXPLAT=0",
263        "-DAT_BLAS_F2C_FBXPLAT=0",
264        "-DAT_BLAS_USE_CBLAS_DOT_FBXPLAT=0",
265        "-DUSE_RUY_QMATMUL",
266    ]
267    if get_disable_per_op_profiling():
268        ATEN_PREPROCESSOR_FLAGS.append("-DPYTORCH_DISABLE_PER_OP_PROFILING")
269    if _get_enable_record_kernel_dtype():
270        ATEN_PREPROCESSOR_FLAGS.append("-DENABLE_RECORD_KERNEL_FUNCTION_DTYPE")
271    return ATEN_PREPROCESSOR_FLAGS
272
273def get_pt_preprocessor_flags():
274    # read_config is not allowed outside of function in Starlark
275    PT_PREPROCESSOR_FLAGS = _COMMON_PREPROCESSOR_FLAGS + [
276        "-D_THP_CORE",
277        "-DUSE_SCALARS",
278        "-DNO_CUDNN_DESTROY_HANDLE",
279    ]
280
281    if _is_build_mode_dev():
282        PT_PREPROCESSOR_FLAGS.append("-DENABLE_PYTORCH_NON_PRODUCTION_BUILDS")
283    return PT_PREPROCESSOR_FLAGS
284
285# This needs to be kept in sync with https://github.com/pytorch/pytorch/blob/release/1.9/torchgen/gen.py#L892
286PT_BACKEND_HEADERS = [
287    "CPU",
288    "CUDA",
289    "CompositeExplicitAutograd",
290    "CompositeExplicitAutogradNonFunctional",
291    "CompositeImplicitAutograd",
292    "CompositeImplicitAutogradNestedTensor",
293    "Meta",
294]
295
296def get_aten_static_dispatch_backend_headers(existing_headers):
297    static_backends = get_static_dispatch_backend()
298    for backend in static_backends:
299        if backend != "CPU":
300            existing_headers["{}Functions.h".format(backend)] = ":gen_aten[{}Functions.h]".format(backend)
301            existing_headers["{}Functions_inl.h".format(backend)] = ":gen_aten[{}Functions_inl.h]".format(backend)
302    return existing_headers
303
304def get_aten_codegen_extra_params(backends):
305    extra_params = {
306        "force_schema_registration": True,
307    }
308    static_backends = get_static_dispatch_backend()
309    if static_backends:
310        extra_params["static_dispatch_backend"] = static_backends
311        extra_params["enabled_backends"] = static_backends
312    else:
313        extra_params["enabled_backends"] = backends
314    return extra_params
315
316def get_jit_codegen_params():
317    return []
318
319def get_unboxing_generated_files():
320    srcs = []
321    if _get_enable_lightweight_dispatch():
322        srcs = [
323            "UnboxingFunctions.h",
324            "UnboxingFunctions_0.cpp",
325            "UnboxingFunctions_1.cpp",
326            "UnboxingFunctions_2.cpp",
327            "UnboxingFunctions_3.cpp",
328            "UnboxingFunctions_4.cpp",
329            "RegisterCodegenUnboxedKernels_0.cpp",
330            "RegisterCodegenUnboxedKernels_1.cpp",
331            "RegisterCodegenUnboxedKernels_2.cpp",
332            "RegisterCodegenUnboxedKernels_3.cpp",
333            "RegisterCodegenUnboxedKernels_4.cpp",
334            "RegisterCodegenUnboxedKernels_5.cpp",
335            "RegisterCodegenUnboxedKernels_6.cpp",
336            "RegisterCodegenUnboxedKernels_7.cpp",
337            "RegisterCodegenUnboxedKernels_8.cpp",
338            "RegisterCodegenUnboxedKernels_9.cpp",
339        ]
340    res = {}
341    for file_name in srcs:
342        res[file_name] = [file_name]
343    return res
344
345def get_aten_generated_files(enabled_backends):
346    # NB: RegisterMeta counts as an optionally enabled backend,
347    # and is intentionally omitted from here
348    src_files = [
349        "RegisterBackendSelect.cpp",
350        "RegisterCompositeImplicitAutograd.cpp",
351        "RegisterCompositeImplicitAutogradNestedTensor.cpp",
352        "RegisterCompositeExplicitAutograd.cpp",
353        "RegisterCompositeExplicitAutogradNonFunctional.cpp",
354        "CompositeViewCopyKernels.cpp",
355        "RegisterSchema.cpp",
356        "Declarations.yaml",
357        "Functions.cpp",
358        "Functions.h",
359        "RedispatchFunctions.h",
360        "NativeFunctions.h",
361        "NativeMetaFunctions.h",
362        "MethodOperators.h",
363        "FunctionalInverses.h",
364        "Operators.h",
365        "Operators_0.cpp",
366        "Operators_1.cpp",
367        "Operators_2.cpp",
368        "Operators_3.cpp",
369        "Operators_4.cpp",
370        "CompositeImplicitAutogradFunctions.h",
371        "CompositeImplicitAutogradFunctions_inl.h",
372        "CompositeImplicitAutogradNestedTensorFunctions.h",
373        "CompositeImplicitAutogradNestedTensorFunctions_inl.h",
374        "CompositeExplicitAutogradFunctions.h",
375        "CompositeExplicitAutogradFunctions_inl.h",
376        "CompositeExplicitAutogradNonFunctionalFunctions.h",
377        "CompositeExplicitAutogradNonFunctionalFunctions_inl.h",
378        "VmapGeneratedPlumbing.h",
379        "core/ATenOpList.cpp",
380        "core/TensorBody.h",
381        "core/TensorMethods.cpp",
382        "core/aten_interned_strings.h",
383        "core/enum_tag.h",
384        "torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.cpp",
385    ] + get_aten_derived_type_srcs(enabled_backends)
386
387    # This is tiresome.  A better strategy would be to unconditionally
388    # generate these files, and then only actually COMPILE them depended
389    # on the generated set.  C'est la vie...
390    if "CPU" in enabled_backends:
391        src_files.extend(aten_ufunc_generated_cpu_sources())
392        src_files.extend(aten_ufunc_generated_cpu_kernel_sources())
393    if "CUDA" in enabled_backends:
394        # Cannot unconditionally include this, because in the Edge selective
395        # build CUDA is not enabled and thus the ufunc codegen for CUDA gets
396        # skipped
397        src_files.extend(aten_ufunc_generated_cuda_sources())
398
399    res = {}
400    for file_name in src_files:
401        res[file_name] = [file_name]
402    return res
403
404def get_aten_derived_type_src_rules(aten_rule_name, enabled_backends):
405    return [
406        ":{}[{}]".format(aten_rule_name, "Register" + backend + ".cpp")
407        for backend in enabled_backends
408    ]
409
410def get_aten_selective_cpp_rules(aten_rule_name, enabled_backends):
411    return [
412        ":{}[{}]".format(aten_rule_name, f)
413        for f in ["RegisterCompositeImplicitAutograd.cpp", "RegisterCompositeImplicitAutogradNestedTensor.cpp", "RegisterCompositeExplicitAutograd.cpp", "RegisterCompositeExplicitAutogradNonFunctional.cpp", "RegisterSchema.cpp", "RegisterBackendSelect.cpp", "CompositeViewCopyKernels.cpp"]
414    ] + get_aten_derived_type_src_rules(aten_rule_name, enabled_backends)
415
416def get_aten_derived_type_srcs(enabled_backends):
417    return [
418        "Register" + derived_type + ".cpp"
419        for derived_type in enabled_backends
420    ] + [
421        derived_type + "Functions.h"
422        for derived_type in enabled_backends
423        if derived_type in PT_BACKEND_HEADERS or derived_type in get_static_dispatch_backend()
424    ] + [
425        derived_type + "Functions_inl.h"
426        for derived_type in enabled_backends
427        if derived_type in PT_BACKEND_HEADERS or derived_type in get_static_dispatch_backend()
428    ]
429
430def gen_aten_files(
431        name,
432        extra_flags = {},
433        visibility = [],
434        compatible_with = [],
435        apple_sdks = None):
436    extra_params = []
437    force_schema_registration = extra_flags.get("force_schema_registration", False)
438    op_registration_allowlist = extra_flags.get("op_registration_allowlist", None)
439    op_selection_yaml_path = extra_flags.get("op_selection_yaml_path", None)
440    enabled_backends = extra_flags.get("enabled_backends", None)
441    static_dispatch_backend = extra_flags.get("static_dispatch_backend", None)
442
443    if force_schema_registration:
444        extra_params.append("--force_schema_registration")
445    if op_registration_allowlist != None and is_string(op_registration_allowlist):
446        extra_params.append("--op_registration_whitelist")
447        extra_params.append(op_registration_allowlist)
448    if op_selection_yaml_path != None and is_string(op_selection_yaml_path):
449        extra_params.append("--op_selection_yaml_path")
450        extra_params.append(op_selection_yaml_path)
451    if enabled_backends != None and is_list(enabled_backends):
452        extra_params.append("--backend_whitelist")
453        extra_params.extend(enabled_backends)
454    if _get_enable_lightweight_dispatch():
455        extra_params.append("--skip_dispatcher_op_registration")
456    if static_dispatch_backend:
457        extra_params.append("--static_dispatch_backend")
458        extra_params.extend(static_dispatch_backend)
459        backends = static_dispatch_backend
460    else:
461        backends = enabled_backends
462    fb_xplat_genrule(
463        name = name,
464        default_outs = ["."],
465        outs = get_aten_generated_files(backends),
466        cmd = "$(exe {}torchgen:gen) ".format(ROOT_PATH) + " ".join([
467            "--source-path $(location {}:aten_src_path)/aten/src/ATen".format(ROOT),
468            "--install_dir $OUT",
469            "--aoti_install_dir $OUT/torch/csrc/inductor/aoti_torch/generated"
470        ] + extra_params),
471        visibility = visibility,
472        compatible_with = compatible_with,
473        apple_sdks = apple_sdks,
474    )
475
476def gen_aten_unboxing_files(
477        genrule_name,
478        extra_flags = {}):
479    extra_params = []
480    op_selection_yaml_path = extra_flags.get("op_selection_yaml_path", None)
481    op_registration_allowlist = extra_flags.get("op_registration_allowlist", None)
482    if op_selection_yaml_path != None and is_string(op_selection_yaml_path):
483        extra_params.append("--op_selection_yaml_path")
484        extra_params.append(op_selection_yaml_path)
485    if op_registration_allowlist != None and is_string(op_registration_allowlist):
486        extra_params.append("--op_registration_allowlist")
487        extra_params.append(op_registration_allowlist)
488
489    fb_xplat_genrule(
490        name = genrule_name,
491        default_outs = ["."],
492        outs = get_unboxing_generated_files(),
493        cmd = "$(exe {}tools:gen_unboxing_bin) ".format(ROOT_PATH) + " ".join([
494            "--source-path $(location {}:aten_src_path)/aten/src/ATen".format(ROOT),
495            "--install_dir $OUT",
496        ] + extra_params),
497        visibility = ["PUBLIC"],
498    )
499
500def copy_template_registration_files(name, apple_sdks = None):
501    cmd = []
502    cmd_exe = []
503
504    template_source_dict = get_template_source_dict()
505
506    # Ideally, we would run one copy command for a single source directory along
507    # with all its child directories, but it's somewhat hard to know if a directory
508    # is a child of another just bu looking at the metadata (directory relative
509    # path) that we currently have since 1 directory could look like a parent of
510    # another and yet come from a different filegroup() rule.
511    #
512    for (path_prefix, file_paths) in template_source_dict.items():
513        cmd.append("mkdir -p $OUT/{}".format(path_prefix))
514        cmd_exe.append("md $OUT/{}".format(path_prefix))
515
516        # Adding *.cpp is a workaround to prevent cp from thrown an error when it
517        # encounters a directory (since -r was not specified). If files with an
518        # extension other than .cpp need to be copied, then the command below
519        # will not work and will need to be updated.
520        #
521        cmd.append("cp -f $(location {0}:templated_selective_build_srcs)/{1}/*.cpp $OUT/{1}/".format(ROOT, path_prefix))
522        cmd_exe.append("robocopy /E $(location {0}:templated_selective_build_srcs)/{1} $OUT/{1}".format(ROOT, path_prefix))
523
524    if NOT_OSS:
525        for file_path in TEMPLATE_MASKRCNN_SOURCE_LIST:
526            maskrcnn_file = "$(location //xplat/caffe2/fb/custom_ops/maskrcnn:templated_selective_build_srcs)/" + file_path
527            cmd.append("cp -f " + maskrcnn_file + " $OUT")
528            cmd_exe.append("copy " + maskrcnn_file + " $OUT")
529
530    cmd.append("mkdir -p $OUT/aten/src/ATen")
531    cmd_exe.append("md $OUT/aten/src/ATen")
532
533    # NB: CUDA is skipped here because this is selective build and CUDA is not
534    # supported for selective build
535    for ufunc_file in aten_ufunc_generated_all_cpu_sources("$(location " + ROOT + ":gen_aten[{}])"):
536        cmd.append("cp -f " + ufunc_file + " $OUT/aten/src/ATen")
537        cmd_exe.append("copy " + ufunc_file + " $OUT/aten/src/ATen")
538
539    if NOT_OSS:
540        pvd_batch_box_cox_file = "$(location //xplat/caffe2/fb/custom_ops/batch_box_cox:templated_selective_build_srcs)/register_batch_box_cox_ops.cpp"
541        cmd.append("cp -f " + pvd_batch_box_cox_file + " $OUT")
542        cmd_exe.append("copy " + pvd_batch_box_cox_file + " $OUT")
543
544    fb_xplat_genrule(
545        name = name,
546        cmd = " && ".join(cmd),
547        cmd_exe = "@powershell -Command " + ("; ".join(cmd_exe)),
548        outs = get_template_registration_files_outs(IS_OSS),
549        default_outs = ["."],
550        apple_sdks = apple_sdks,
551    )
552
553def get_feature_tracer_source_list():
554    """
555    Return just the Feature specific handlers used in the model tracer.
556    """
557    sources = []
558    for s in torch_mobile_tracer_sources:
559        if s.endswith("Tracer.cpp"):
560            sources.append(s)
561    return sources
562
563def pt_operator_query_codegen(
564        name,
565        deps = [],
566        train = False,
567        enforce_traced_op_list = False,
568        pt_allow_forced_schema_registration = True,
569        compatible_with = [],
570        apple_sdks = None):
571    oplist_dir_name = name + "_pt_oplist"
572
573    # @lint-ignore BUCKLINT
574    fb_native.genrule(
575        name = oplist_dir_name,
576        cmd = ("$(exe {}tools:gen_oplist) ".format(ROOT_PATH) +
577               "--model_file_list_path $(@query_outputs 'attrfilter(labels, pt_operator_library, deps(set({deps})))') " +
578               ("" if enforce_traced_op_list else "--allow_include_all_overloads ") +
579               "--output_dir $OUT ").format(deps = " ".join(["\"{}\"".format(d) for d in deps])),
580        outs = get_gen_oplist_outs(),
581        default_outs = ["."],
582        compatible_with = compatible_with,
583    )
584
585    # Aten files
586    aten_genrule = name + "_aten"
587    extra_flags = {
588        "enabled_backends": USED_PT_BACKENDS,
589        "op_selection_yaml_path": "$(location :{}[selected_operators.yaml])".format(oplist_dir_name),
590    }
591
592    if train and pt_allow_forced_schema_registration:
593        extra_flags["force_schema_registration"] = True
594
595    unboxing_genrule = name + "_unboxing"
596    if _get_enable_lightweight_dispatch():
597        gen_aten_unboxing_files(
598            unboxing_genrule,
599            extra_flags = extra_flags,
600        )
601
602    static_dispatch_backend = get_static_dispatch_backend()
603    if static_dispatch_backend:
604        extra_flags["static_dispatch_backend"] = static_dispatch_backend
605
606    gen_aten_files(
607        aten_genrule,
608        extra_flags = extra_flags,
609        compatible_with = compatible_with,
610        apple_sdks = apple_sdks,
611    )
612
613    # unboxing_wrappers files
614    extra_params = [
615        "--operators_yaml_path",
616        "$(location :" + oplist_dir_name + "[selected_operators.yaml])",
617    ]
618    unboxing_and_autograd_genrule = name + "_unboxing_and_autograd"
619    gen_aten_libtorch_files(
620        unboxing_and_autograd_genrule,
621        extra_params,
622        compatible_with,
623        apple_sdks = apple_sdks,
624    )
625
626    # Template runtime files (prim ops, etc)
627    template_registration_genrule = name + "_template_registration"
628    copy_template_registration_files(template_registration_genrule, apple_sdks = apple_sdks)
629
630    # Files needed for metal
631    if NOT_OSS:
632        metal_genrule = name + "_metal"
633        copy_metal(metal_genrule, apple_sdks = apple_sdks)
634
635    srcs = get_aten_selective_cpp_rules(
636        aten_genrule,
637        static_dispatch_backend if static_dispatch_backend else USED_PT_BACKENDS,
638    ) + get_template_registration_file_rules(template_registration_genrule, IS_OSS) + ([
639        ":{}[autograd/generated/VariableType_0.cpp]".format(unboxing_and_autograd_genrule),
640        ":{}[autograd/generated/VariableType_1.cpp]".format(unboxing_and_autograd_genrule),
641        ":{}[autograd/generated/VariableType_2.cpp]".format(unboxing_and_autograd_genrule),
642        ":{}[autograd/generated/VariableType_3.cpp]".format(unboxing_and_autograd_genrule),
643        ":{}[autograd/generated/VariableType_4.cpp]".format(unboxing_and_autograd_genrule),
644        ":{}[autograd/generated/ADInplaceOrViewType_0.cpp]".format(unboxing_and_autograd_genrule),
645        ":{}[autograd/generated/ADInplaceOrViewType_1.cpp]".format(unboxing_and_autograd_genrule),
646    ] if train else []) + ([
647        ":{}[SupportedMobileModelsRegistration.cpp]".format(oplist_dir_name),
648    ] if NOT_OSS else [])
649
650    headers = {
651        "selected_mobile_ops.h": ":{}[selected_mobile_ops.h]".format(oplist_dir_name),
652    }
653
654    if _get_enable_lightweight_dispatch():
655        srcs.extend([
656            ":{}[UnboxingFunctions_0.cpp]".format(unboxing_genrule),
657            ":{}[UnboxingFunctions_1.cpp]".format(unboxing_genrule),
658            ":{}[UnboxingFunctions_2.cpp]".format(unboxing_genrule),
659            ":{}[UnboxingFunctions_3.cpp]".format(unboxing_genrule),
660            ":{}[UnboxingFunctions_4.cpp]".format(unboxing_genrule),
661            ":{}[RegisterCodegenUnboxedKernels_0.cpp]".format(unboxing_genrule),
662            ":{}[RegisterCodegenUnboxedKernels_1.cpp]".format(unboxing_genrule),
663            ":{}[RegisterCodegenUnboxedKernels_2.cpp]".format(unboxing_genrule),
664            ":{}[RegisterCodegenUnboxedKernels_3.cpp]".format(unboxing_genrule),
665            ":{}[RegisterCodegenUnboxedKernels_4.cpp]".format(unboxing_genrule),
666            ":{}[RegisterCodegenUnboxedKernels_5.cpp]".format(unboxing_genrule),
667            ":{}[RegisterCodegenUnboxedKernels_6.cpp]".format(unboxing_genrule),
668            ":{}[RegisterCodegenUnboxedKernels_7.cpp]".format(unboxing_genrule),
669            ":{}[RegisterCodegenUnboxedKernels_8.cpp]".format(unboxing_genrule),
670            ":{}[RegisterCodegenUnboxedKernels_9.cpp]".format(unboxing_genrule),
671        ])
672        headers["UnboxingFunctions.h"] = ":{}[UnboxingFunctions.h]".format(unboxing_genrule)
673    return {"headers": headers, "srcs": srcs}
674
675def gen_aten_libtorch_files(name, extra_params = [], compatible_with = [], apple_sdks = None):
676    fb_xplat_genrule(
677        name = name,
678        outs = get_generate_code_bin_outs(),
679        default_outs = ["."],
680        bash = "mkdir -p tools && " +
681               "$(exe {}tools:generate_code_bin) ".format(ROOT_PATH) + " ".join(
682            # Mobile build only needs libtorch - skip python bindings for now, except
683            # for ovrsource, which needs Python bindings.
684            (["--subset libtorch"] if not is_arvr_mode() else []) + [
685                "--native-functions-path $(location {}:aten_src_path)/aten/src/ATen/native/native_functions.yaml".format(ROOT),
686                "--tags-path $(location {}:aten_src_path)/aten/src/ATen/native/tags.yaml".format(ROOT),
687                "--install_dir $OUT",
688            ] + extra_params,
689        ),
690        cmd_exe = "@powershell -Command New-Item -Path tools -ItemType Directory -Force; " +
691                  "$(exe {}tools:generate_code_bin) ".format(ROOT_PATH) + " ".join(
692            # Mobile build only needs libtorch - skip python bindings for now, except
693            # for ovrsource, which needs Python bindings.
694            (["--subset libtorch"] if not is_arvr_mode() else []) + [
695                "--native-functions-path $(location {}:aten_src_path)/aten/src/ATen/native/native_functions.yaml".format(ROOT),
696                "--tags-path $(location {}:aten_src_path)/aten/src/ATen/native/tags.yaml".format(ROOT),
697                "--install_dir $OUT",
698            ] + extra_params,
699        ),
700        compatible_with = compatible_with,
701        apple_sdks = apple_sdks,
702    )
703
704def vulkan_spv_shader_library(name, spv_filegroup):
705    genrule_cmd = [
706        "$(exe //xplat/caffe2/tools:gen_aten_vulkan_spv_bin)",
707        "--glsl-paths $(location {})".format(spv_filegroup),
708        "--output-path $OUT --env FLOAT_IMAGE_FORMAT={}".format(get_glsl_image_format()),
709        "--glslc-path=$(exe //xplat/caffe2/fb/vulkan/dotslash:glslc)",
710        "--tmp-dir-path=$TMP",
711    ]
712
713    genrule_name = "gen_{}_cpp".format(name)
714    fb_xplat_genrule(
715        name = "gen_{}_cpp".format(name),
716        outs = {
717            "{}.cpp".format(name): ["spv.cpp"],
718        },
719        cmd = " ".join(genrule_cmd),
720        default_outs = ["."],
721        labels = ["uses_dotslash"],
722    )
723
724    fb_xplat_cxx_library(
725        name = name,
726        srcs = [
727            ":{}[{}.cpp]".format(genrule_name, name),
728        ],
729        # Static initialization is used to register shaders to the global shader registry,
730        # therefore link_whole must be True to make sure unused symbols are not discarded.
731        # @lint-ignore BUCKLINT: Avoid `link_whole=True`
732        link_whole = True,
733        # Define a soname that can be used for dynamic loading in Java, Python, etc.
734        soname = "lib{}.$(ext)".format(name),
735        visibility = ["PUBLIC"],
736        exported_deps = [
737            "//xplat/caffe2:torch_vulkan_api",
738        ],
739    )
740
741def copy_metal(name, apple_sdks = None):
742    cmd = []
743    cmd_exe = []
744    metal_source_dict = get_metal_source_dict()
745
746    # Copy all source files over to bring them into the per app build
747    for path_prefix in sorted(metal_source_dict.keys()):
748        cmd.append("mkdir -p $OUT/{}".format(path_prefix))
749        cmd_exe.append("mkdir -Force $OUT/{0}".format(path_prefix))
750
751        # Not every directory has a mm or cpp file so '2>/dev/null || :' are tricks to suppress the error messages and codes.
752        cmd.append("cp -f {0}/{1}/*.mm $OUT/{1}/ 2>/dev/null || :".format("$(location //xplat/caffe2:metal_build_srcs)", path_prefix))
753        cmd.append("cp -f {0}/{1}/*.cpp $OUT/{1}/ 2>/dev/null || :".format("$(location //xplat/caffe2:metal_build_srcs)", path_prefix))
754
755        # Robocopy has a default success code of 1 which buck treats as failure so the echo masks that problem
756        cmd_exe.append("(robocopy /E /NFL /NDL /NJH /NJS {0}/{1} $OUT/{1}) || ECHO robocopy failed".format("$(location //xplat/caffe2:metal_build_srcs)", path_prefix))
757
758    # Metal custom ops currently have to be brought into selective build because they directly reference metal ops instead of
759    # going through the dispatcher. There is some weird issues with the genrule and these files locations on windows though, so
760    # for now we simply skip building them for windows where they very likely arent needed anyway.
761    # Metal MaskRCNN custom op
762    for full_path in METAL_MASKRCNN_SOURCE_LIST:
763        path_prefix = paths.dirname(full_path)
764        cmd.append("mkdir -p $OUT/{}".format(path_prefix))
765        cmd.append("cp -f {0}/{1}/*.mm $OUT/{1}/ 2>/dev/null || :".format("$(location //xplat/caffe2/fb/metal:metal_maskrcnn_sources)", path_prefix))
766
767    # Unet Metal Prepack Custom op
768    unet_metal_prepack_file = "$(location //xplat/caffe2/fb/custom_ops/unet_metal_prepack:unet_metal_prepack_sources)"
769    cmd.append("cp -f " + unet_metal_prepack_file + "/unet_metal_prepack.cpp" + " $OUT")
770    cmd.append("cp -f " + unet_metal_prepack_file + "/unet_metal_prepack.mm" + " $OUT")
771
772    fb_xplat_genrule(
773        name = name,
774        cmd = " && ".join(cmd),
775        cmd_exe = "@powershell -Command " + ("; ".join(cmd_exe)),
776        # due to an obscure bug certain custom ops werent being copied correctly on windows. ARVR also sometimes builds android targets on windows,
777        # so we just exclude those targets from being copied for those platforms (They end up uncompiled anyway).
778        outs = select({
779            "DEFAULT": get_metal_registration_files_outs(),
780            "ovr_config//os:android": get_metal_registration_files_outs_windows(),
781            "ovr_config//os:windows": get_metal_registration_files_outs_windows(),
782        }),
783        default_outs = ["."],
784        apple_sdks = apple_sdks,
785    )
786
787def get_pt_operator_registry_dict(
788        name,
789        deps = [],
790        train = False,
791        labels = [],
792        env = [],
793        template_select = True,
794        enforce_traced_op_list = False,
795        pt_allow_forced_schema_registration = True,
796        enable_flatbuffer = False,
797        **kwargs):
798    code_gen_files = pt_operator_query_codegen(
799        name,
800        deps = deps,
801        train = train,
802        enforce_traced_op_list = enforce_traced_op_list,
803        pt_allow_forced_schema_registration = pt_allow_forced_schema_registration,
804        compatible_with = kwargs.get("compatible_with", []),
805        apple_sdks = kwargs.get("apple_sdks"),
806    )
807
808    return dict(
809        srcs = code_gen_files["srcs"],
810        linker_flags = [
811            "-Wl,--no-as-needed",
812        ],
813        # @lint-ignore BUCKLINT link_whole
814        link_whole = True,
815        soname = "libtorch-code-gen.$(ext)",
816        header_namespace = "ATen",
817        compiler_flags = get_aten_compiler_flags(),
818        exported_headers = code_gen_files["headers"],
819        exported_preprocessor_flags = get_aten_preprocessor_flags() + (["-DTEMPLATE_SELECTIVE_BUILD"] if template_select else []),
820        headers = kwargs.pop("headers", []),
821        labels = kwargs.pop("labels", []) + [
822            # This library has multiple sources with the same file name
823            # and does not work with Buck filegroup used in bad practices.
824            # Opt out of the bad practices check with the below label.
825            "bad_practices_ignore_override",
826            "pt_operator_registry",
827        ],
828        deps = [
829            # need absolute path here
830            ROOT + ":torch_mobile_core",
831            ROOT + ":aten_cpu",
832            ROOT + ":aten_metal_prepack_header",
833            third_party("glog"),
834            C10,
835        ] + ([ROOT + ":torch_mobile_train"] if train else []),
836        **kwargs
837    )
838
839# these targets are shared by internal and OSS BUCK
840def define_buck_targets(
841        aten_default_args = dict(),
842        pt_xplat_cxx_library = fb_xplat_cxx_library,
843        c2_fbandroid_xplat_compiler_flags = [],
844        labels = []):
845    # @lint-ignore BUCKLINT
846    fb_native.filegroup(
847        name = "metal_build_srcs",
848        srcs = glob(METAL_SOURCE_LIST),
849        visibility = [
850            "PUBLIC",
851        ],
852    )
853
854    # @lint-ignore BUCKLINT
855    fb_native.filegroup(
856        name = "templated_selective_build_srcs",
857        # NB: no glob here, there are generated targets in this list!
858        srcs = glob(TEMPLATE_SOURCE_LIST) + aten_ufunc_generated_all_cpu_sources(":gen_aten[{}]"),
859        visibility = [
860            "PUBLIC",
861        ],
862    )
863
864    fb_xplat_cxx_library(
865        name = "aten_header",
866        header_namespace = "",
867        exported_headers = subdir_glob([
868            # ATen Core
869            ("aten/src", "ATen/core/**/*.h"),
870            ("aten/src", "ATen/ops/*.h"),
871            # ATen Base
872            ("aten/src", "ATen/*.h"),
873            ("aten/src", "ATen/cpu/**/*.h"),
874            ("aten/src", "ATen/detail/*.h"),
875            ("aten/src", "ATen/functorch/**/*.h"),
876            ("aten/src", "ATen/quantized/*.h"),
877            ("aten/src", "ATen/vulkan/*.h"),
878            ("aten/src", "ATen/metal/*.h"),
879            ("aten/src", "ATen/nnapi/*.h"),
880            # ATen Native
881            ("aten/src", "ATen/native/*.h"),
882            ("aten/src", "ATen/native/ao_sparse/quantized/cpu/*.h"),
883            ("aten/src", "ATen/native/cpu/**/*.h"),
884            ("aten/src", "ATen/native/sparse/*.h"),
885            ("aten/src", "ATen/native/nested/*.h"),
886            ("aten/src", "ATen/native/quantized/*.h"),
887            ("aten/src", "ATen/native/quantized/cpu/*.h"),
888            ("aten/src", "ATen/native/transformers/*.h"),
889            ("aten/src", "ATen/native/ufunc/*.h"),
890            ("aten/src", "ATen/native/utils/*.h"),
891            ("aten/src", "ATen/native/vulkan/ops/*.h"),
892            ("aten/src", "ATen/native/xnnpack/*.h"),
893            ("aten/src", "ATen/mps/*.h"),
894            ("aten/src", "ATen/native/mps/*.h"),
895            # Remove the following after modifying codegen for mobile.
896            ("aten/src", "ATen/mkl/*.h"),
897            ("aten/src", "ATen/native/mkl/*.h"),
898            ("aten/src", "ATen/native/mkldnn/*.h"),
899        ]),
900        visibility = ["PUBLIC"],
901        labels = labels,
902    )
903
904    fb_xplat_cxx_library(
905        name = "aten_vulkan_header",
906        header_namespace = "",
907        exported_headers = subdir_glob([
908            ("aten/src", "ATen/native/vulkan/*.h"),
909            ("aten/src", "ATen/native/vulkan/ops/*.h"),
910            ("aten/src", "ATen/vulkan/*.h"),
911        ]),
912        labels = labels,
913        visibility = ["PUBLIC"],
914    )
915
916    fb_xplat_cxx_library(
917        name = "jit_core_headers",
918        header_namespace = "",
919        exported_headers = subdir_glob([("", x) for x in jit_core_headers]),
920        labels = labels,
921    )
922
923    fb_xplat_cxx_library(
924        name = "torch_headers",
925        header_namespace = "",
926        exported_headers = subdir_glob(
927            [
928                ("torch/csrc/api/include", "torch/**/*.h"),
929                ("", "torch/csrc/**/*.h"),
930                ("", "torch/script.h"),
931                ("", "torch/library.h"),
932                ("", "torch/custom_class.h"),
933                ("", "torch/custom_class_detail.h"),
934                # Add again due to namespace difference from aten_header.
935                ("", "aten/src/ATen/*.h"),
936                ("", "aten/src/ATen/functorch/**/*.h"),
937                ("", "aten/src/ATen/quantized/*.h"),
938            ],
939            exclude = [
940                # Don't need on mobile.
941                "torch/csrc/Exceptions.h",
942                "torch/csrc/python_headers.h",
943                "torch/csrc/jit/serialization/mobile_bytecode_generated.h",
944            ],
945        ),
946        labels = labels,
947        visibility = ["PUBLIC"],
948        deps = [
949            ":generated-version-header",
950        ],
951    )
952
953    fb_xplat_cxx_library(
954        name = "aten_test_header",
955        header_namespace = "",
956        exported_headers = subdir_glob([
957            ("aten/src", "ATen/test/*.h"),
958        ]),
959    )
960
961    fb_xplat_cxx_library(
962        name = "aten_metal_prepack_header",
963        header_namespace = "",
964        exported_headers = subdir_glob([
965            ("aten/src", "ATen/native/metal/MetalPrepackOpContext.h"),
966        ]),
967        labels = labels,
968        visibility = ["PUBLIC"],
969    )
970
971    fb_xplat_cxx_library(
972        name = "torch_mobile_headers",
973        header_namespace = "",
974        exported_headers = subdir_glob(
975            [
976                ("", "torch/csrc/jit/mobile/*.h"),
977            ],
978        ),
979        labels = labels,
980        visibility = ["PUBLIC"],
981    )
982
983    fb_xplat_cxx_library(
984        name = "generated_aten_config_header",
985        header_namespace = "ATen",
986        exported_headers = {
987            "Config.h": ":generate_aten_config[Config.h]",
988        },
989        labels = labels,
990    )
991
992    fb_xplat_cxx_library(
993        name = "generated-autograd-headers",
994        header_namespace = "torch/csrc/autograd/generated",
995        exported_headers = {
996            "Functions.h": ":gen_aten_libtorch[autograd/generated/Functions.h]",
997            "VariableType.h": ":gen_aten_libtorch[autograd/generated/VariableType.h]",
998            "variable_factories.h": ":gen_aten_libtorch[autograd/generated/variable_factories.h]",
999            "ViewFuncs.h": ":gen_aten_libtorch[autograd/generated/ViewFuncs.h]",
1000            # Don't build python bindings on mobile.
1001            #"python_functions.h",
1002        },
1003        labels = labels,
1004        visibility = ["PUBLIC"],
1005    )
1006
1007    fb_xplat_cxx_library(
1008        name = "generated-version-header",
1009        header_namespace = "torch",
1010        exported_headers = {
1011            "version.h": ":generate-version-header[version.h]",
1012        },
1013        labels = labels,
1014    )
1015
1016    # @lint-ignore BUCKLINT
1017    fb_native.genrule(
1018        name = "generate-version-header",
1019        srcs = [
1020            "torch/csrc/api/include/torch/version.h.in",
1021            "version.txt",
1022        ],
1023        cmd = "$(exe {}tools:gen-version-header) ".format(ROOT_PATH) + " ".join([
1024            "--template-path",
1025            "torch/csrc/api/include/torch/version.h.in",
1026            "--version-path",
1027            "version.txt",
1028            "--output-path",
1029            "$OUT/version.h",
1030        ]),
1031        outs = {
1032            "version.h": ["version.h"],
1033        },
1034        default_outs = ["."],
1035    )
1036
1037    # @lint-ignore BUCKLINT
1038    fb_native.filegroup(
1039        name = "aten_src_path",
1040        srcs = [
1041            "aten/src/ATen/native/native_functions.yaml",
1042            "aten/src/ATen/native/tags.yaml",
1043        ] + glob(["aten/src/ATen/templates/*"]),
1044        visibility = [
1045            "PUBLIC",
1046        ],
1047    )
1048
1049    fb_xplat_cxx_library(
1050        name = "common_core",
1051        srcs = [
1052            "caffe2/core/common.cc",
1053        ],
1054        apple_sdks = (IOS, MACOSX, APPLETVOS),
1055        compiler_flags = get_pt_compiler_flags(),
1056        labels = labels,
1057        # @lint-ignore BUCKLINT link_whole
1058        link_whole = True,
1059        visibility = ["PUBLIC"],
1060        windows_preferred_linkage = "static" if is_arvr_mode() else None,
1061        deps = [
1062            ":caffe2_headers",
1063            C10,
1064        ],
1065    )
1066
1067    # @lint-ignore BUCKLINT
1068    fb_native.genrule(
1069        name = "generate_aten_config",
1070        srcs = [
1071            "aten/src/ATen/Config.h.in",
1072        ],
1073        cmd = "$(exe {}tools:substitute) ".format(ROOT_PATH) + " ".join([
1074            "--install_dir",
1075            "$OUT",
1076            "--input-file",
1077            "aten/src/ATen/Config.h.in",
1078            "--output-file",
1079            "Config.h",
1080            "--replace",
1081            "@AT_MKLDNN_ENABLED@",
1082            "ATEN_MKLDNN_ENABLED_FBXPLAT",
1083            "--replace",
1084            "@AT_MKLDNN_ACL_ENABLED@",
1085            "ATEN_MKLDNN_ACL_ENABLED_FBXPLAT",
1086            "--replace",
1087            "@AT_MKL_ENABLED@",
1088            "ATEN_MKL_ENABLED_FBXPLAT",
1089            "--replace",
1090            "@AT_MKL_SEQUENTIAL@",
1091            "ATEN_MKL_SEQUENTIAL_FBXPLAT",
1092            "--replace",
1093            "@AT_POCKETFFT_ENABLED@",
1094            "1",
1095            "--replace",
1096            "@AT_NNPACK_ENABLED@",
1097            "ATEN_NNPACK_ENABLED_FBXPLAT",
1098            "--replace",
1099            "@CAFFE2_STATIC_LINK_CUDA_INT@",
1100            "CAFFE2_STATIC_LINK_CUDA_FBXPLAT",
1101            "--replace",
1102            "@AT_BUILD_WITH_BLAS@",
1103            "USE_BLAS_FBXPLAT",
1104            "--replace",
1105            "@AT_PARALLEL_OPENMP@",
1106            "AT_PARALLEL_OPENMP_FBXPLAT",
1107            "--replace",
1108            "@AT_PARALLEL_NATIVE@",
1109            "AT_PARALLEL_NATIVE_FBXPLAT",
1110            "--replace",
1111            "@AT_BUILD_WITH_LAPACK@",
1112            "USE_LAPACK_FBXPLAT",
1113            "--replace",
1114            "@AT_BLAS_F2C@",
1115            "AT_BLAS_F2C_FBXPLAT",
1116            "--replace",
1117            "@AT_BLAS_USE_CBLAS_DOT@",
1118            "AT_BLAS_USE_CBLAS_DOT_FBXPLAT",
1119        ]),
1120        outs = {
1121            "Config.h": ["Config.h"],
1122        },
1123        default_outs = ["."],
1124    )
1125
1126    gen_aten_files(
1127        name = "gen_aten",
1128        extra_flags = get_aten_codegen_extra_params(USED_PT_BACKENDS),
1129        visibility = ["PUBLIC"],
1130    )
1131
1132    gen_aten_libtorch_files(name = "gen_aten_libtorch")
1133
1134    gen_aten_libtorch_files(
1135        name = "gen_aten_libtorch_lite",
1136        extra_params = get_jit_codegen_params(),
1137    )
1138
1139    fb_xplat_cxx_library(
1140        name = "generated_aten_headers_cpu",
1141        header_namespace = "ATen",
1142        exported_headers = get_aten_static_dispatch_backend_headers({
1143            "CPUFunctions.h": ":gen_aten[CPUFunctions.h]",
1144            "CPUFunctions_inl.h": ":gen_aten[CPUFunctions_inl.h]",
1145            "CompositeExplicitAutogradFunctions.h": ":gen_aten[CompositeExplicitAutogradFunctions.h]",
1146            "CompositeExplicitAutogradFunctions_inl.h": ":gen_aten[CompositeExplicitAutogradFunctions_inl.h]",
1147            "CompositeExplicitAutogradNonFunctionalFunctions.h": ":gen_aten[CompositeExplicitAutogradNonFunctionalFunctions.h]",
1148            "CompositeExplicitAutogradNonFunctionalFunctions_inl.h": ":gen_aten[CompositeExplicitAutogradNonFunctionalFunctions_inl.h]",
1149            "CompositeImplicitAutogradFunctions.h": ":gen_aten[CompositeImplicitAutogradFunctions.h]",
1150            "CompositeImplicitAutogradFunctions_inl.h": ":gen_aten[CompositeImplicitAutogradFunctions_inl.h]",
1151            "CompositeImplicitAutogradNestedTensorFunctions.h": ":gen_aten[CompositeImplicitAutogradNestedTensorFunctions.h]",
1152            "CompositeImplicitAutogradNestedTensorFunctions_inl.h": ":gen_aten[CompositeImplicitAutogradNestedTensorFunctions_inl.h]",
1153            "FunctionalInverses.h": ":gen_aten[FunctionalInverses.h]",
1154            "Functions.h": ":gen_aten[Functions.h]",
1155            "MethodOperators.h": ":gen_aten[MethodOperators.h]",
1156            "NativeFunctions.h": ":gen_aten[NativeFunctions.h]",
1157            "NativeMetaFunctions.h": ":gen_aten[NativeMetaFunctions.h]",
1158            "Operators.h": ":gen_aten[Operators.h]",
1159            "RedispatchFunctions.h": ":gen_aten[RedispatchFunctions.h]",
1160            "core/TensorBody.h": ":gen_aten[core/TensorBody.h]",
1161            "core/aten_interned_strings.h": ":gen_aten[core/aten_interned_strings.h]",
1162            "core/enum_tag.h": ":gen_aten[core/enum_tag.h]",
1163        }),
1164        labels = labels,
1165    )
1166
1167    fb_xplat_cxx_library(
1168        name = "torch_mobile_observer",
1169        srcs = [
1170            "torch/csrc/jit/mobile/observer.cpp",
1171        ] + ([] if IS_OSS else ["torch/fb/observers/MobileObserverUtil.cpp"]),
1172        compiler_flags = ["-fexceptions"],
1173        header_namespace = "",
1174        exported_headers = subdir_glob(
1175            [
1176                ("", "torch/csrc/jit/mobile/observer.h"),
1177            ] + ([] if IS_OSS else [
1178                ("", "torch/fb/observers/ObserverUtil.h"),
1179                ("", "torch/fb/observers/MobileObserverUtil.h"),
1180            ]),
1181        ),
1182        fbobjc_compiler_flags = [
1183            "-Wno-missing-prototypes",
1184        ],
1185        labels = labels,
1186        visibility = ["PUBLIC"],
1187        deps = [
1188            C10,
1189        ],
1190    )
1191
1192    # Base library shared by lite-interpreter and full-jit.
1193    pt_xplat_cxx_library(
1194        name = "torch_common",
1195        srcs = core_sources_common,
1196        compiler_flags = get_pt_compiler_flags(),
1197        exported_preprocessor_flags = get_pt_preprocessor_flags(),
1198        # @lint-ignore BUCKLINT link_whole
1199        link_whole = True,
1200        visibility = ["PUBLIC"],
1201        deps = [
1202            ":aten_cpu",
1203            ":generated-autograd-headers",
1204            ":torch_headers",
1205            C10,
1206            third_party("libkineto_headers"),
1207        ],
1208    )
1209
1210    pt_xplat_cxx_library(
1211        name = "torch_mobile_deserialize_common",
1212        srcs = [
1213            "torch/csrc/jit/mobile/parse_bytecode.cpp",
1214            "torch/csrc/jit/mobile/parse_operators.cpp",
1215            "torch/csrc/jit/mobile/upgrader_mobile.cpp",
1216            "torch/csrc/jit/serialization/import_read.cpp",
1217            "torch/csrc/jit/serialization/unpickler.cpp",
1218        ],
1219        header_namespace = "",
1220        exported_headers = [
1221            "torch/csrc/jit/serialization/import_read.h",
1222            "torch/csrc/jit/serialization/unpickler.h",
1223        ],
1224        compiler_flags = get_pt_compiler_flags(),
1225        exported_preprocessor_flags = get_pt_preprocessor_flags(),
1226        extra_flags = {
1227            "fbandroid_compiler_flags": ["-frtti"],
1228        },
1229        # torch_mobile_deserialize brings in sources neccessary to read a module
1230        # which depends on mobile module definition
1231        # link_whole is enable so that all symbols neccessary for mobile module are compiled
1232        # instead of only symbols used while loading; this prevents symbol
1233        # found definied in runtime
1234        # @lint-ignore BUCKLINT link_whole
1235        link_whole = True,
1236        linker_flags = ["-Wl,--no-as-needed"],
1237        visibility = ["PUBLIC"],
1238        exported_deps = [
1239            ":aten_cpu",
1240            ":caffe2_headers",
1241            ":caffe2_serialize",
1242            ":torch_common",
1243            ":torch_headers",
1244            ":torch_mobile_headers",
1245            ":torch_mobile_module",
1246            ":torch_mobile_observer",
1247            C10,
1248        ],
1249    )
1250
1251    pt_xplat_cxx_library(
1252        name = "torch_mobile_module",
1253        srcs = [
1254            "torch/csrc/jit/mobile/function.cpp",
1255            "torch/csrc/jit/mobile/interpreter.cpp",
1256            "torch/csrc/jit/mobile/module.cpp",
1257        ],
1258        header_namespace = "",
1259        exported_headers = [
1260        ],
1261        compiler_flags = get_pt_compiler_flags(),
1262        exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DSYMBOLICATE_MOBILE_DEBUG_HANDLE"] if get_enable_eager_symbolication() else []),
1263        extra_flags = {
1264            "fbandroid_compiler_flags": ["-frtti"],
1265        },
1266        # @lint-ignore BUCKLINT link_whole
1267        link_whole = True,
1268        linker_flags = [
1269            "-Wl,--no-as-needed",
1270        ],
1271        visibility = ["PUBLIC"],
1272        exported_deps = [
1273            ":aten_cpu",
1274            ":caffe2_headers",
1275            ":torch_common",
1276            ":torch_headers",
1277            ":torch_mobile_headers",
1278            ":torch_mobile_observer",
1279            C10,
1280        ],
1281    )
1282
1283    pt_xplat_cxx_library(
1284        name = "torch_mobile_debug_symbolication",
1285        srcs = [
1286            # included in aten_cpu "torch/csrc/jit/frontend/source_range.cpp",
1287            "torch/csrc/jit/ir/scope.cpp",
1288            "torch/csrc/jit/mobile/debug_info.cpp",
1289            "torch/csrc/jit/serialization/callstack_debug_info_serialization.cpp",
1290            "torch/csrc/jit/serialization/source_range_serialization.cpp",
1291            "torch/csrc/jit/serialization/pickle.cpp",
1292            # pickler.cpp doesn't seem to be needed.
1293            # "torch/csrc/jit/serialization/pickler.cpp",
1294            # included in core_sources_common "torch/csrc/jit/serialization/unpickler.cpp",
1295        ],
1296        compiler_flags = get_pt_compiler_flags(),
1297        exported_preprocessor_flags = get_pt_preprocessor_flags(),
1298        header_namespace = "",
1299        # @lint-ignore BUCKLINT link_whole
1300        link_whole = True,
1301        linker_flags = [
1302            "-Wl,--no-as-needed",
1303        ],
1304        visibility = ["PUBLIC"],
1305        deps = [
1306            ":torch_mobile_deserialize",
1307        ],
1308        exported_deps = [
1309            ":torch_common",
1310        ],
1311    )
1312
1313    pt_xplat_cxx_library(
1314        name = "torch_model_tracer",
1315        srcs = [
1316            "torch/csrc/jit/mobile/model_tracer/TracerRunner.cpp",
1317        ] + get_feature_tracer_source_list(),
1318        header_namespace = "",
1319        compiler_flags = get_pt_compiler_flags(),
1320        exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DSYMBOLICATE_MOBILE_DEBUG_HANDLE"] if get_enable_eager_symbolication() else []),
1321        # @lint-ignore BUCKLINT link_whole
1322        link_whole = True,
1323        linker_flags = [
1324            "-Wl,--no-as-needed",
1325        ],
1326        visibility = ["PUBLIC"],
1327        deps = [
1328            ":generated-autograd-headers",
1329            ":torch_mobile_deserialize",
1330            ":torch_mobile_headers",
1331            ":torch_mobile_observer",
1332        ] + ([] if IS_OSS else ["//xplat/folly:molly"]),
1333        exported_deps = [
1334            ":aten_cpu",
1335            ":torch_common",
1336        ] + ([] if IS_OSS else [
1337            "//xplat/caffe2/fb/custom_ops/batch_box_cox:batch_box_cox",
1338            "//xplat/caffe2/fb/custom_ops/maskrcnn:maskrcnn",
1339        ]),
1340    )
1341
1342    pt_xplat_cxx_library(
1343        name = "torch_mobile_deserialize",
1344        srcs = [
1345            "torch/csrc/jit/mobile/import.cpp",
1346            "torch/csrc/jit/mobile/flatbuffer_loader.cpp",
1347        ],
1348        compiler_flags = get_pt_compiler_flags(),
1349        exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DFB_XPLAT_BUILD"] if not IS_OSS else []),
1350        header_namespace = "",
1351        exported_headers = [
1352            "torch/csrc/jit/mobile/import.h",
1353            "torch/csrc/jit/mobile/flatbuffer_loader.h",
1354        ],
1355        # torch_mobile_deserialize brings in sources neccessary to read a module
1356        # which depends on mobile module definition
1357        # link_whole is enable so that all symbols neccessary for mobile module are compiled
1358        # instead of only symbols used while loading; this prevents symbol
1359        # found definied in runtime
1360        # @lint-ignore BUCKLINT link_whole
1361        link_whole = True,
1362        linker_flags = [
1363            "-Wl,--no-as-needed",
1364        ],
1365        visibility = ["PUBLIC"],
1366        exported_deps = [
1367            ":aten_cpu",
1368            ":caffe2_headers",
1369            ":caffe2_serialize",
1370            ":torch_common",
1371            ":torch_headers",
1372            ":torch_mobile_headers",
1373            ":torch_mobile_module",
1374            ":torch_mobile_observer",
1375            ":torch_mobile_deserialize_common",
1376            ":mobile_bytecode",
1377            C10,
1378        ],
1379    )
1380
1381    pt_xplat_cxx_library(
1382        name = "torch_mobile_core",
1383        srcs = [],
1384        header_namespace = "",
1385        exported_headers = [],
1386        compiler_flags = get_pt_compiler_flags(),
1387        exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DSYMBOLICATE_MOBILE_DEBUG_HANDLE"] if get_enable_eager_symbolication() else []),
1388        # torch_mobile_core brings in sources neccessary to read and run a module
1389        # link_whole is enabled so that all symbols linked
1390        # operators, registerations and other few symbols are need in runtime
1391        # @lint-ignore BUCKLINT link_whole
1392        link_whole = True,
1393        linker_flags = [
1394            "-Wl,--no-as-needed",
1395        ],
1396        visibility = ["PUBLIC"],
1397        deps = [
1398            ":generated-autograd-headers",
1399            ":torch_mobile_headers",
1400            ":torch_mobile_observer",
1401        ],
1402        exported_deps = [
1403            ":aten_cpu",
1404            ":torch_common",
1405            ":torch_mobile_deserialize",
1406            ":torch_supported_mobile_models",
1407        ],
1408    )
1409
1410    pt_xplat_cxx_library(
1411        name = "torch_mobile_core_pickle_and_flatbuffer",
1412        compiler_flags = get_pt_compiler_flags(),
1413        exported_preprocessor_flags = get_pt_preprocessor_flags(),
1414        visibility = ["PUBLIC"],
1415        exported_deps = [
1416            ":flatbuffers_mobile",
1417            ":torch_mobile_core",
1418        ],
1419    )
1420
1421    pt_xplat_cxx_library(
1422        name = "torch_cpp_cpu",
1423        srcs = torch_cpp_srcs,
1424        headers = native.glob(["torch/csrc/api/include/**/*.h"]) + ["torch/script.h"],
1425        compiler_flags = get_pt_compiler_flags(),
1426        exported_preprocessor_flags = get_pt_preprocessor_flags(),
1427        visibility = ["PUBLIC"],
1428        exported_deps = [
1429            ":torch",
1430            ":torch_mobile_deserialize_common",  # for torch/csrc/api/src/serialize/input-archive.cpp
1431        ],
1432    )
1433
1434    pt_xplat_cxx_library(
1435        name = "torch_core",
1436        srcs = core_sources_full_mobile_no_backend_interface_xplat,
1437        compiler_flags = get_pt_compiler_flags(),
1438        exported_preprocessor_flags = get_pt_preprocessor_flags(),
1439        visibility = [
1440            "//xplat/caffe2/android/...",
1441            "//xplat/caffe2/fb/...",
1442            "//xplat/caffe2/fb/model_tracer/...",
1443        ],
1444        deps = [
1445            ":aten_cpu",
1446            ":backend_interface_lib",
1447            ":generated-autograd-headers",
1448            ":torch_headers",
1449            ":torch_mobile_deserialize",
1450            third_party("glog"),
1451            third_party("rt"),
1452            C10,
1453        ] + ([] if IS_OSS else [
1454            "//xplat/caffe2/fb/custom_ops/batch_box_cox:batch_box_cox",
1455            "//xplat/caffe2/fb/custom_ops/maskrcnn:maskrcnn",
1456        ]),
1457        exported_deps = [
1458            ":torch_common",
1459            ":torch_mobile_train",
1460        ],
1461    )
1462
1463    pt_xplat_cxx_library(
1464        name = "torch_train",
1465        srcs = [
1466            "torch/csrc/api/src/data/samplers/random.cpp",
1467            "torch/csrc/api/src/data/samplers/sequential.cpp",
1468            "torch/csrc/api/src/optim/optimizer.cpp",
1469            "torch/csrc/api/src/optim/serialize.cpp",
1470            "torch/csrc/api/src/optim/sgd.cpp",
1471            "torch/csrc/api/src/serialize/input-archive.cpp",
1472            "torch/csrc/api/src/serialize/output-archive.cpp",
1473            "torch/csrc/jit/api/module_save.cpp",
1474        ],
1475        compiler_flags = get_pt_compiler_flags(),
1476        exported_preprocessor_flags = get_pt_preprocessor_flags(),
1477        visibility = ["PUBLIC"],
1478        deps = [
1479            ":aten_cpu",
1480            ":torch_headers",
1481            ":torch",
1482            ":torch_core",
1483            ":torch_mobile_deserialize",
1484            ":torch_mobile_train",
1485            ":jit_module_saving",
1486            C10,
1487        ],
1488    )
1489
1490    pt_xplat_cxx_library(
1491        name = "torch_mobile_train",
1492        srcs = core_trainer_sources + [
1493            "torch/csrc/autograd/VariableTypeManual.cpp",
1494            "torch/csrc/autograd/FunctionsManual.cpp",
1495            "torch/csrc/api/src/data/datasets/mnist.cpp",
1496            "torch/csrc/jit/mobile/quantization.cpp",
1497            "torch/csrc/jit/mobile/train/export_data.cpp",
1498            "torch/csrc/jit/mobile/train/optim/sgd.cpp",
1499            "torch/csrc/jit/mobile/train/random.cpp",
1500            "torch/csrc/jit/mobile/train/sequential.cpp",
1501            ":gen_aten_libtorch[autograd/generated/Functions.cpp]",
1502            ":gen_aten_libtorch[autograd/generated/ViewFuncs.cpp]",
1503        ],
1504        compiler_flags = get_pt_compiler_flags(),
1505        exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DUSE_MOBILE_CLASSTYPE"],
1506        # torch_mobile_train brings in sources neccessary to read and run a mobile
1507        # and save and load mobile params along with autograd
1508        # link_whole is enabled so that all symbols linked
1509        # operators, registerations and autograd related symbols are need in runtime
1510        # @lint-ignore BUCKLINT link_whole
1511        link_whole = True,
1512        visibility = ["PUBLIC"],
1513        deps = [
1514            ":aten_cpu",
1515            ":generated-autograd-headers",
1516            ":torch_headers",
1517            ":torch_mobile_deserialize",
1518            ":flatbuffers_serializer_mobile",
1519            C10,
1520        ],
1521    )
1522
1523    pt_xplat_cxx_library(
1524        name = "torch",
1525        srcs = [
1526            "torch/csrc/jit/runtime/register_c10_ops.cpp",
1527            "torch/csrc/jit/runtime/register_prim_ops_fulljit.cpp",
1528        ],
1529        compiler_flags = get_pt_compiler_flags(),
1530        exported_preprocessor_flags = get_pt_preprocessor_flags(),
1531        # torch brings in all sources neccessary to read and run a mobile module/jit module
1532        # link_whole is enabled so that all symbols linked
1533        # operators, registerations and other few symbols are need in runtime
1534        # @lint-ignore BUCKLINT link_whole
1535        link_whole = True,
1536        visibility = ["PUBLIC"],
1537        deps = [
1538            # This is to have autograd profiler available
1539            # in xplat/caffe2:torch which some builds are using
1540            # notable xplate/facegen:testsAndroid
1541            ":torch_headers",
1542            ":torch_kineto_profiling",
1543        ],
1544        exported_deps = [
1545            ":aten_cpu",
1546            ":torch_core",
1547            C10,
1548        ],
1549    )
1550
1551    pt_xplat_cxx_library(
1552        name = "torch_mobile_train_import_data",
1553        srcs = [
1554            "torch/csrc/jit/mobile/import_data.cpp",
1555        ],
1556        compiler_flags = get_pt_compiler_flags(),
1557        exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DUSE_MOBILE_CLASSTYPE"],
1558        # torch_mobile_train_import_data brings in sources neccessary to read a mobile module
1559        # link_whole is enabled so that all symbols linked
1560        # operators other few symbols are need in runtime
1561        # @lint-ignore BUCKLINT link_whole
1562        link_whole = True,
1563        visibility = ["PUBLIC"],
1564        deps = [
1565            ":torch_headers",
1566            ":torch_mobile_observer",
1567            ":torch_mobile_core",
1568            ":torch_mobile_train",
1569        ],
1570    )
1571
1572    fb_xplat_cxx_library(
1573        name = "torch_mobile_compatibility",
1574        srcs = [
1575            # These .cpp brought in through core_sources_common
1576            # "torch/csrc/jit/mobile/compatibility/runtime_compatibility.cpp",
1577            # "torch/csrc/jit/serialization/unpickler.cpp",
1578            "torch/csrc/jit/mobile/compatibility/model_compatibility.cpp",
1579        ],
1580        header_namespace = "",
1581        exported_headers = [
1582            "torch/csrc/jit/mobile/compatibility/backport.h",
1583            "torch/csrc/jit/mobile/compatibility/backport_manager.h",
1584            "torch/csrc/jit/mobile/compatibility/model_compatibility.h",
1585            "torch/csrc/jit/mobile/compatibility/runtime_compatibility.h",
1586        ],
1587        compiler_flags = [
1588            "-fexceptions",
1589            "-frtti",
1590            "-Wno-deprecated-declarations",
1591            "-Wno-global-constructors",
1592        ],
1593        labels = labels,
1594        visibility = ["PUBLIC"],
1595        deps = [
1596            ":torch_mobile_deserialize",
1597        ],
1598    )
1599
1600    pt_xplat_cxx_library(
1601        name = "jit_module_saving",
1602        srcs = [
1603            "torch/csrc/jit/api/module_save.cpp",
1604            "torch/csrc/jit/serialization/export_bytecode.cpp",
1605            "torch/csrc/jit/serialization/export_module.cpp",
1606        ],
1607        compiler_flags = get_pt_compiler_flags(),
1608        exported_preprocessor_flags = get_pt_preprocessor_flags() +
1609                                      (["-DFB_XPLAT_BUILD"] if not IS_OSS else []),
1610        exported_headers = [
1611            "torch/csrc/jit/serialization/export.h",
1612        ],
1613        visibility = ["PUBLIC"],
1614        deps = [
1615            ":torch",
1616            ":torch_mobile_core",
1617            ":flatbuffers_serializer_mobile",
1618        ],
1619    )
1620
1621    pt_xplat_cxx_library(
1622        name = "torch_mobile_model_tracer",
1623        srcs = [
1624            "torch/csrc/jit/mobile/model_tracer/MobileModelRunner.cpp",
1625            "torch/csrc/jit/mobile/model_tracer/TensorUtils.cpp",
1626        ],
1627        headers = [
1628            "torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h",
1629            "torch/csrc/jit/mobile/model_tracer/TensorUtils.h",
1630        ],
1631        header_namespace = "",
1632        exported_headers = [
1633            "torch/csrc/jit/mobile/model_tracer/MobileModelRunner.h",
1634        ],
1635        compiler_flags = get_pt_compiler_flags(),
1636        exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DSYMBOLICATE_MOBILE_DEBUG_HANDLE"] if get_enable_eager_symbolication() else []),
1637        # torch_mobile_model_tracer brings in sources neccessary to read and run a jit module
1638        # and trace the ops
1639        # link_whole is enabled so that all symbols linked
1640        # operators, registerations and other few symbols are need in runtime
1641        # @lint-ignore BUCKLINT link_whole
1642        link_whole = True,
1643        linker_flags = [
1644            "-Wl,--no-as-needed",
1645        ],
1646        visibility = ["PUBLIC"],
1647        deps = [
1648            ":caffe2_serialize",
1649            ":generated-autograd-headers",
1650            ":torch_mobile_headers",
1651            ":torch_mobile_observer",
1652            ":torch_mobile_core",
1653        ] + ([] if IS_OSS else ["//xplat/folly:molly"]),
1654        exported_deps = [
1655            ":aten_cpu",
1656            ":torch_common",
1657        ] + ([] if IS_OSS else [
1658            "//xplat/caffe2/fb/custom_ops/batch_box_cox:batch_box_cox",
1659            "//xplat/caffe2/fb/custom_ops/maskrcnn:maskrcnn",
1660            "//xplat/caffe2/fb/custom_ops/sparsenn:sparsenn-all",
1661        ]),
1662    )
1663
1664    #TODO(qihan) delete
1665    pt_xplat_cxx_library(
1666        name = "torch_mobile_core_flatbuffer",
1667        srcs = [],
1668        header_namespace = "",
1669        exported_headers = [],
1670        compiler_flags = get_pt_compiler_flags(),
1671        exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DSYMBOLICATE_MOBILE_DEBUG_HANDLE"] if get_enable_eager_symbolication() else []),
1672        # @lint-ignore BUCKLINT link_whole
1673        link_whole = True,
1674        linker_flags = [
1675            "-Wl,--no-as-needed",
1676        ],
1677        visibility = ["PUBLIC"],
1678        deps = [
1679            ":generated-autograd-headers",
1680            ":torch_mobile_headers",
1681            ":torch_mobile_observer",
1682        ],
1683        exported_deps = [
1684            ":aten_cpu",
1685            ":torch_common",
1686        ],
1687    )
1688
1689    fb_xplat_cxx_library(
1690        name = "backend_interface_lib",
1691        srcs = [
1692            "torch/csrc/jit/backends/backend_debug_info.cpp",
1693            "torch/csrc/jit/backends/backend_interface.cpp",
1694        ],
1695        compiler_flags = get_pt_compiler_flags(),
1696        fbandroid_compiler_flags = c2_fbandroid_xplat_compiler_flags,
1697        # @lint-ignore BUCKLINT link_whole
1698        link_whole = True,
1699        linker_flags = [
1700            "-Wl,--no-as-needed",
1701        ],
1702        visibility = ["PUBLIC"],
1703        exported_deps = [
1704            ":aten_cpu",
1705            ":torch_common",
1706        ],
1707    )
1708
1709    pt_xplat_cxx_library(
1710        name = "torch_kineto_profiling",
1711        srcs = libtorch_profiler_sources,
1712        compiler_flags = get_pt_compiler_flags() + ["-Wno-error"],
1713        exported_preprocessor_flags = get_pt_preprocessor_flags() + [
1714            "-DUSE_KINETO",
1715            # Need this otherwise USE_KINETO is undefed
1716            # for mobile
1717            "-DEDGE_PROFILER_USE_KINETO",
1718        ],
1719        # @lint-ignore BUCKLINT link_whole
1720        link_whole = True,
1721        linker_flags = [
1722            "-Wl,--no-as-needed",
1723        ],
1724        visibility = ["PUBLIC"],
1725        deps = [
1726            third_party("glog"),
1727            third_party("kineto"),
1728        ],
1729        exported_deps = [
1730            ":aten_cpu",
1731            ":torch_common",
1732        ],
1733    )
1734
1735    pt_xplat_cxx_library(
1736        name = "torch_edge_profiling",
1737        srcs = ["torch/csrc/jit/mobile/profiler_edge.cpp"],
1738        compiler_flags = get_pt_compiler_flags() + ["-Wno-error"],
1739        exported_preprocessor_flags = get_pt_preprocessor_flags() + [
1740            "-DUSE_KINETO",
1741            "-DEDGE_PROFILER_USE_KINETO",
1742        ],
1743        # @lint-ignore BUCKLINT link_whole
1744        link_whole = True,
1745        linker_flags = [
1746            "-Wl,--no-as-needed",
1747        ],
1748        visibility = ["PUBLIC"],
1749        exported_deps = [
1750            ":torch_common",
1751            ":torch_kineto_profiling",
1752            ":torch_mobile_core",
1753        ],
1754    )
1755
1756    fb_xplat_genrule(
1757        name = "mobile_bytecode_header",
1758        srcs = [
1759            "torch/csrc/jit/serialization/mobile_bytecode.fbs",
1760        ],
1761        outs = {
1762            "mobile_bytecode_generated_fbsource.h": ["mobile_bytecode_generated.h"],
1763        },
1764        cmd = "$(exe {})".format(third_party("flatc")) +
1765              " --cpp --gen-mutable --scoped-enums -o ${OUT} ${SRCS}",
1766        default_outs = ["."],
1767        visibility = [
1768            "{}:mobile_bytecode".format(ROOT),
1769        ],
1770    )
1771
1772    # Users of this target will need to add third_party("flatbuffers-api") as a
1773    # dep.
1774    fb_xplat_cxx_library(
1775        name = "mobile_bytecode",
1776        header_namespace = "",
1777        exported_headers = {
1778            ("torch/csrc/jit/serialization/mobile_bytecode_generated.h" if IS_OSS else "torch/csrc/jit/serialization/mobile_bytecode_generated_fbsource.h"): ":mobile_bytecode_header[mobile_bytecode_generated_fbsource.h]",
1779        },
1780        # Avoid leaking implementation details by only exposing this header to
1781        # the internals of the loader/serializer layer.
1782        visibility = [
1783            "{}:flatbuffer_loader".format(ROOT),
1784            "{}:flatbuffers_serializer_mobile".format(ROOT),
1785        ],
1786        exported_deps = [
1787            third_party("flatbuffers-api"),
1788        ],
1789    )
1790
1791    fb_xplat_cxx_library(
1792        name = "flatbuffers_serializer_mobile",
1793        srcs = ["torch/csrc/jit/serialization/flatbuffer_serializer.cpp"],
1794        exported_headers = [
1795            "torch/csrc/jit/serialization/flatbuffer_serializer.h",
1796        ],
1797        compiler_flags = [
1798            "-g0",
1799            "-O3",
1800            "-fexceptions",
1801            "-frtti",
1802            "-Wno-deprecated-declarations",
1803        ] + (["-DFB_XPLAT_BUILD"] if not IS_OSS else []),
1804        visibility = ["PUBLIC"],
1805        deps = [
1806            ":mobile_bytecode",
1807            ":torch_mobile_module",
1808            C10,
1809        ],
1810        exported_deps = [
1811            ":torch_mobile_deserialize",
1812            ":mobile_bytecode",
1813        ],
1814    )
1815
1816    # TODO (qihan) delete
1817    pt_xplat_cxx_library(
1818        name = "flatbuffer_loader",
1819        srcs = [
1820        ],
1821        exported_headers = [
1822            "torch/csrc/jit/mobile/flatbuffer_loader.h",
1823        ],
1824        compiler_flags = get_pt_compiler_flags() + ["-Wno-error"],
1825        exported_preprocessor_flags = get_pt_preprocessor_flags() + [
1826            "-DUSE_KINETO",
1827            # Need this otherwise USE_KINETO is undefed
1828            # for mobile
1829            "-DEDGE_PROFILER_USE_KINETO",
1830        ] + (["-DFB_XPLAT_BUILD"] if not IS_OSS else []),
1831        extra_flags = {
1832            "fbandroid_compiler_flags": ["-frtti"],
1833        },
1834        # torch_mobile_deserialize brings in sources neccessary to read a module
1835        # which depends on mobile module definition
1836        # link_whole is enable so that all symbols neccessary for mobile module are compiled
1837        # instead of only symbols used while loading; this prevents symbol
1838        # found definied in runtime
1839        # @lint-ignore BUCKLINT link_whole
1840        link_whole = True,
1841        linker_flags = [
1842            "-Wl,--no-as-needed",
1843        ],
1844        visibility = ["PUBLIC"],
1845        deps = [
1846            ":mobile_bytecode",
1847        ],
1848        exported_deps = [
1849            C10,
1850        ],
1851    )
1852
1853    # TODO(qihan) delete
1854    fb_xplat_cxx_library(
1855        name = "flatbuffers_serializer_jit",
1856        compiler_flags = [
1857            "-g0",
1858            "-O3",
1859            "-fexceptions",
1860            "-frtti",
1861            "-Wno-deprecated-declarations",
1862        ],
1863        headers = [
1864            "torch/csrc/jit/serialization/flatbuffer_serializer_jit.h",
1865        ],
1866        srcs = [
1867            "torch/csrc/jit/serialization/flatbuffer_serializer_jit.cpp",
1868        ],
1869        linker_flags = [
1870            "-Wl,--no-as-needed",
1871        ],
1872        visibility = ["PUBLIC"],
1873        deps = [
1874            ":flatbuffer_loader",
1875            ":flatbuffers_serializer_mobile",
1876            ":torch_core",
1877            ":torch_mobile_module",
1878            C10,
1879        ],
1880    )
1881
1882    fb_xplat_cxx_library(
1883        name = "flatbuffers_jit",
1884        visibility = ["PUBLIC"],
1885        exported_deps = [
1886            ":flatbuffer_loader",
1887            ":flatbuffers_serializer_mobile",
1888            ":flatbuffers_serializer_jit",
1889        ],
1890    )
1891
1892    fb_xplat_cxx_library(
1893        name = "flatbuffers_mobile",
1894        visibility = ["PUBLIC"],
1895        exported_deps = [
1896            ":flatbuffer_loader",
1897            ":flatbuffers_serializer_mobile",
1898            ":torch_mobile_train",
1899        ],
1900    )
1901
1902    pt_xplat_cxx_library(
1903        name = "torch_supported_mobile_models",
1904        srcs = [
1905            "fb/supported_mobile_models/SupportedMobileModels.cpp",
1906        ] if NOT_OSS else [],
1907        header_namespace = "",
1908        exported_headers = ["fb/supported_mobile_models/SupportedMobileModels.h"] if NOT_OSS else [],
1909        compiler_flags = get_pt_compiler_flags() + ["-Wno-error"],
1910        exported_preprocessor_flags = get_pt_preprocessor_flags() + (["-DSYMBOLICATE_MOBILE_DEBUG_HANDLE"] if get_enable_eager_symbolication() else []),
1911        # @lint-ignore BUCKLINT link_whole
1912        link_whole = True,
1913        linker_flags = [
1914            "-Wl,--no-as-needed",
1915        ],
1916        visibility = ["PUBLIC"],
1917        deps = [],
1918        exported_deps = [
1919            "//xplat/caffe2/fb/custom_ops/batch_box_cox:batch_box_cox",
1920            "//xplat/caffe2/fb/custom_ops/maskrcnn:maskrcnn",
1921        ] if NOT_OSS else [],
1922    )
1923
1924    fb_xplat_cxx_library(
1925        name = "static_runtime",
1926        srcs = [
1927            "torch/csrc/jit/runtime/static/fusion.cpp",
1928            "torch/csrc/jit/runtime/static/generated_ops.cpp",
1929            "torch/csrc/jit/runtime/static/impl.cpp",
1930            "torch/csrc/jit/runtime/static/memory_planner.cpp",
1931            "torch/csrc/jit/runtime/static/native_ops.cpp",
1932            "torch/csrc/jit/runtime/static/ops.cpp",
1933            "torch/csrc/jit/runtime/static/passes.cpp",
1934            "torch/csrc/jit/runtime/static/te_wrapper.cpp",
1935        ],
1936        compiler_flags = ["-fexceptions"],
1937        labels = labels,
1938        # @lint-ignore BUCKLINT link_whole
1939        link_whole = True,
1940        visibility = ["PUBLIC"],
1941        windows_preferred_linkage = "static" if is_arvr_mode() else None,
1942        deps = [
1943            ":aten_cpu",
1944            ":caffe2_headers",
1945            ":torch_core",
1946            C10,
1947        ],
1948    )
1949
1950    # aten_cpu and aten_native_cpu
1951    for name, srcs in [
1952        ("aten_cpu", jit_core_sources + aten_cpu_source_list + [
1953            # Generated
1954            ":gen_aten[Functions.cpp]",
1955            ":gen_aten[Operators_0.cpp]",
1956            ":gen_aten[Operators_1.cpp]",
1957            ":gen_aten[Operators_2.cpp]",
1958            ":gen_aten[Operators_3.cpp]",
1959            ":gen_aten[Operators_4.cpp]",
1960            ":gen_aten[core/ATenOpList.cpp]",
1961            ":gen_aten[core/TensorMethods.cpp]",
1962            # Needed by ATen/native/EmbeddingBag.cpp
1963            "caffe2/perfkernels/embedding_lookup_idx.cc",
1964        ]),
1965        ("aten_native_cpu", aten_native_source_list),
1966    ]:
1967        fb_xplat_cxx_library(
1968            name = name,
1969            srcs = srcs,
1970            header_namespace = "",
1971            # @lint-ignore BUCKLINT
1972            link_whole = True,
1973            visibility = ["PUBLIC"],
1974            deps = [
1975                third_party("omp"),
1976                third_party("cpuinfo"),
1977                third_party("glog"),
1978                third_party("XNNPACK"),
1979                third_party("pocketfft"),
1980            ] + select({
1981                "DEFAULT": [],
1982                "ovr_config//runtime:fbcode-arm64": [
1983                    third_party("sleef_arm"),
1984                ],
1985            }),
1986            compiler_flags = get_aten_compiler_flags(),
1987            exported_preprocessor_flags = get_aten_preprocessor_flags(),
1988            exported_deps = [
1989                ":aten_header",
1990                ":caffe2_headers",
1991                ":common_core",
1992                ":generated_aten_config_header",
1993                ":generated_aten_headers_cpu",
1994                ":jit_core_headers",
1995                ":pthreadpool",
1996                third_party("fmt"),
1997                third_party("ruy"),
1998                C10,
1999                ROOT_PATH + "aten/src/ATen/native/quantized/cpu/qnnpack:pytorch_qnnpack",
2000            ],
2001            labels = labels,
2002            **aten_default_args
2003        )
2004
2005    fb_xplat_cxx_library(
2006        name = "lean_runtime_with_flatbuffer",
2007        srcs = [
2008            "aten/src/ATen/core/DeprecatedTypePropertiesRegistry.cpp",
2009            "torch/csrc/jit/mobile/import.cpp",
2010            "torch/csrc/jit/mobile/module.cpp",
2011            "torch/csrc/jit/mobile/observer.cpp",
2012            "torch/csrc/jit/serialization/import_read.cpp",
2013        ],
2014        header_namespace = "",
2015        exported_headers = subdir_glob(
2016            [
2017                ("", "torch/csrc/jit/ir/*.h"),
2018                ("", "caffe2/serialize/*.h"),
2019                ("", "caffe2/utils/*.h"),
2020                ("", "caffe2/core/*.h"),
2021                ("", "torch/csrc/*.h"),
2022                ("", "torch/csrc/api/include/torch/*.h"),
2023                ("", "torch/csrc/autograd/*.h"),
2024                ("", "torch/csrc/autograd/*/*.h"),
2025                ("", "torch/csrc/jit/api/*.h"),
2026                ("", "torch/csrc/jit/backends/*.h"),
2027                ("", "torch/csrc/jit/mobile/*.h"),
2028                ("", "torch/csrc/jit/runtime/*.h"),
2029                ("", "torch/csrc/jit/passes/*.h"),
2030                ("", "torch/csrc/jit/python/*.h"),
2031                ("", "torch/csrc/jit/frontend/*.h"),
2032                ("", "torch/csrc/jit/serialization/*.h"),
2033                ("", "torch/csrc/profiler/**/*.h"),
2034                ("", "torch/csrc/utils/*.h"),
2035                ("", "aten/src/ATen/quantized/*.h"),
2036            ] + ([
2037                ("third_party/miniz-2.1.0", "*.h"),
2038            ] if NOT_OSS else []),
2039            exclude = [
2040                "torch/csrc/jit/serialization/mobile_bytecode_generated.h",
2041            ],
2042        ),
2043        compiler_flags = get_pt_compiler_flags() + select({
2044            "DEFAULT": [],
2045            "ovr_config//os:xtensa-xos": [
2046                "-fdata-sections",
2047                "-ffunction-sections",
2048            ],
2049        }),
2050        exported_preprocessor_flags = get_pt_preprocessor_flags() + [
2051            "-DMIN_EDGE_RUNTIME",
2052        ],
2053        linker_flags = [
2054            "-Wl,--no-as-needed",
2055        ] + select({
2056            "DEFAULT": [],
2057            "ovr_config//os:macos": [
2058                "-dead_strip",
2059            ],
2060            "ovr_config//os:xtensa-xos": [
2061                "-Wl,--gc-sections",
2062            ],
2063        }),
2064        visibility = ["PUBLIC"],
2065        exported_deps = [
2066            ":lean_runtime_with_tensor",
2067        ],
2068    )
2069
2070    pt_xplat_cxx_library(
2071        name = "lean_runtime_with_tensor",
2072        srcs = [
2073            "aten/src/ATen/Context.cpp",
2074            "aten/src/ATen/EmptyTensor.cpp",
2075            "aten/src/ATen/Utils.cpp",
2076            "aten/src/ATen/detail/CUDAHooksInterface.cpp",
2077            "aten/src/ATen/detail/PrivateUse1HooksInterface.cpp",
2078            ":gen_aten[Operators_0.cpp]",
2079            ":gen_aten[Operators_1.cpp]",
2080            ":gen_aten[Operators_2.cpp]",
2081            ":gen_aten[Operators_3.cpp]",
2082            ":gen_aten[Operators_4.cpp]",
2083            ":gen_aten[core/TensorMethods.cpp]",
2084        ],
2085        header_namespace = "",
2086        exported_headers = [
2087            "torch/csrc/jit/runtime/custom_operator.h",
2088            ":gen_aten[core/TensorBody.h]",
2089        ],
2090        compiler_flags = get_pt_compiler_flags() + select({
2091            "DEFAULT": [],
2092            "ovr_config//os:xtensa-xos": [
2093                "-fdata-sections",
2094                "-ffunction-sections",
2095            ],
2096        }),
2097        exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DMIN_EDGE_RUNTIME"] + select({
2098            "DEFAULT": [],
2099            "ovr_config//os:xtensa-xos": [
2100                "-Dthread_local=",
2101            ],
2102        }),
2103        # @lint-ignore BUCKLINT link_whole
2104        link_whole = True,
2105        linker_flags = [
2106            "-Wl,--no-as-needed",
2107        ],
2108        visibility = ["PUBLIC"],
2109        exported_deps = [
2110            ":generated_aten_config_header",
2111            ":lean_runtime_with_op",
2112            ":aten_header",
2113            C10,
2114        ] + (["//xplat/caffe2/fb/embedded:experimental"] if NOT_OSS else []),
2115    )
2116
2117    pt_xplat_cxx_library(
2118        name = "lean_runtime_with_op",
2119        srcs = [
2120            "aten/src/ATen/SequenceNumber.cpp",
2121            "aten/src/ATen/core/boxing/KernelFunction.cpp",
2122            "aten/src/ATen/core/custom_class.cpp",
2123            "aten/src/ATen/core/dispatch/DispatchKeyExtractor.cpp",
2124            "aten/src/ATen/core/dispatch/Dispatcher.cpp",
2125            "aten/src/ATen/core/dispatch/ObservedOperators.cpp",
2126            "aten/src/ATen/core/dispatch/OperatorEntry.cpp",
2127            "aten/src/ATen/core/PythonOpRegistrationTrampoline.cpp",
2128            "aten/src/ATen/core/interned_strings.cpp",
2129            "aten/src/ATen/core/library.cpp",
2130            "aten/src/ATen/core/op_registration/infer_schema.cpp",
2131            "aten/src/ATen/core/function_schema.cpp",
2132            "aten/src/ATen/core/operator_name.cpp",
2133            "aten/src/ATen/core/register_symbols.cpp",
2134            "aten/src/ATen/core/tensor_type.cpp",
2135            "aten/src/ATen/core/union_type.cpp",
2136            "aten/src/ATen/record_function.cpp",
2137            "torch/csrc/jit/frontend/edit_distance.cpp",
2138            "torch/csrc/jit/frontend/error_report.cpp",
2139            "torch/csrc/jit/frontend/function_schema_parser.cpp",
2140            "torch/csrc/jit/frontend/lexer.cpp",
2141            "torch/csrc/jit/frontend/schema_type_parser.cpp",
2142            "torch/csrc/jit/frontend/source_range.cpp",
2143            "torch/csrc/jit/frontend/strtod.cpp",
2144            "torch/csrc/jit/mobile/parse_operators.cpp",
2145            "torch/csrc/jit/mobile/prim_ops_registery.cpp",
2146            "torch/csrc/jit/runtime/operator.cpp",
2147            "torch/csrc/jit/runtime/slice_indices_adjust.cpp",
2148        ],
2149        header_namespace = "",
2150        exported_headers = [
2151            "torch/csrc/jit/frontend/edit_distance.h",
2152            "torch/csrc/jit/runtime/slice_indices_adjust.h",
2153        ],
2154        compiler_flags = get_pt_compiler_flags() + select({
2155            "DEFAULT": [],
2156            "ovr_config//os:xtensa-xos": [
2157                "-fdata-sections",
2158                "-ffunction-sections",
2159            ],
2160        }),
2161        exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DMIN_EDGE_RUNTIME"] + select({
2162            "DEFAULT": [],
2163            "ovr_config//os:xtensa-xos": [
2164                "-Dthread_local=",
2165            ],
2166        }),
2167        # @lint-ignore BUCKLINT link_whole
2168        link_whole = True,
2169        linker_flags = [
2170            "-Wl,--no-as-needed",
2171        ],
2172        visibility = ["PUBLIC"],
2173        exported_deps = [
2174            ":min_runtime_lib",
2175            C10,
2176        ],
2177    )
2178
2179    pt_xplat_cxx_library(
2180        name = "min_runtime_lib",
2181        srcs = [
2182            "aten/src/ATen/ScalarOps.cpp",
2183            "aten/src/ATen/core/Dict.cpp",
2184            "aten/src/ATen/core/List.cpp",
2185            "aten/src/ATen/core/class_type.cpp",
2186            "aten/src/ATen/core/dynamic_type.cpp",
2187            "aten/src/ATen/core/ivalue.cpp",
2188            "aten/src/ATen/core/type.cpp",
2189            "aten/src/ATen/core/type_factory.cpp",
2190            "aten/src/ATen/native/prim_native_functions.cpp",
2191            "torch/csrc/jit/mobile/function.cpp",
2192            "torch/csrc/jit/mobile/interpreter.cpp",
2193            "torch/csrc/jit/mobile/parse_bytecode.cpp",
2194            "torch/csrc/jit/mobile/promoted_prim_ops.cpp",
2195            "torch/csrc/jit/mobile/register_ops_common_utils.cpp",
2196            "torch/csrc/jit/mobile/type_parser.cpp",
2197            "torch/csrc/jit/runtime/instruction.cpp",
2198            "torch/csrc/jit/runtime/jit_exception.cpp",
2199            "torch/csrc/jit/runtime/vararg_functions.cpp",
2200        ],
2201        header_namespace = "",
2202        exported_headers = [
2203            "caffe2/serialize/versions.h",
2204            "torch/csrc/jit/backends/backend_exception.h",
2205            "torch/csrc/jit/mobile/register_ops_common_utils.h",
2206            "torch/csrc/jit/runtime/instruction.h",
2207            "torch/csrc/jit/runtime/jit_exception.h",
2208            "torch/csrc/jit/runtime/operator.h",
2209            "torch/csrc/jit/runtime/operator_options.h",
2210            "torch/csrc/jit/runtime/vararg_functions.h",
2211            "torch/csrc/jit/serialization/import_export_constants.h",
2212            "torch/csrc/jit/serialization/import_export_functions.h",
2213        ],
2214        compiler_flags = get_pt_compiler_flags() + select({
2215            "DEFAULT": [],
2216            "ovr_config//os:xtensa-xos": [
2217                "-fexceptions",
2218                "-fdata-sections",
2219                "-ffunction-sections",
2220            ],
2221        }),
2222        exported_preprocessor_flags = get_pt_preprocessor_flags() + ["-DMIN_EDGE_RUNTIME"] + select({
2223            "DEFAULT": [],
2224            "ovr_config//os:xtensa-xos": [
2225                "-Dthread_local=",
2226            ],
2227        }),
2228        # @lint-ignore BUCKLINT link_whole
2229        link_whole = True,
2230        linker_flags = [
2231            "-Wl,--no-as-needed",
2232        ],
2233        visibility = ["PUBLIC"],
2234        exported_deps = [
2235            ":aten_header",
2236            ":generated_aten_headers_cpu",
2237            ":jit_core_headers",
2238            ":torch_mobile_headers",
2239            C10,
2240        ],
2241    )
2242