xref: /aosp_15_r20/external/executorch/extension/training/test/targets.bzl (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Workerload("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2*523fa7a6SAndroid Build Coastguard Worker
3*523fa7a6SAndroid Build Coastguard Workerdef define_common_targets(is_fbcode = False):
4*523fa7a6SAndroid Build Coastguard Worker    """Defines targets that should be shared between fbcode and xplat.
5*523fa7a6SAndroid Build Coastguard Worker
6*523fa7a6SAndroid Build Coastguard Worker    The directory containing this targets.bzl file should also contain both
7*523fa7a6SAndroid Build Coastguard Worker    TARGETS and BUCK files that call this function.
8*523fa7a6SAndroid Build Coastguard Worker    """
9*523fa7a6SAndroid Build Coastguard Worker
10*523fa7a6SAndroid Build Coastguard Worker    # TODO(dbort): Find a way to make these run for ANDROID/APPLE in xplat. The
11*523fa7a6SAndroid Build Coastguard Worker    # android and ios test determinators don't like the reference to the model
12*523fa7a6SAndroid Build Coastguard Worker    # file in fbcode. See https://fburl.com/9esapdmd
13*523fa7a6SAndroid Build Coastguard Worker    if not runtime.is_oss and is_fbcode:
14*523fa7a6SAndroid Build Coastguard Worker        modules_env = {
15*523fa7a6SAndroid Build Coastguard Worker            # The tests use this var to find the program file to load. This uses
16*523fa7a6SAndroid Build Coastguard Worker            # an fbcode target path because the authoring/export tools
17*523fa7a6SAndroid Build Coastguard Worker            # intentionally don't work in xplat (since they're host-only tools).
18*523fa7a6SAndroid Build Coastguard Worker            "ET_MODULE_SIMPLE_TRAIN_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleSimpleTrain.pte])",
19*523fa7a6SAndroid Build Coastguard Worker        }
20*523fa7a6SAndroid Build Coastguard Worker
21*523fa7a6SAndroid Build Coastguard Worker        runtime.cxx_test(
22*523fa7a6SAndroid Build Coastguard Worker            name = "training_loop_test",
23*523fa7a6SAndroid Build Coastguard Worker            srcs = [
24*523fa7a6SAndroid Build Coastguard Worker                "training_loop_test.cpp",
25*523fa7a6SAndroid Build Coastguard Worker            ],
26*523fa7a6SAndroid Build Coastguard Worker            deps = [
27*523fa7a6SAndroid Build Coastguard Worker                "//executorch/runtime/executor:program",
28*523fa7a6SAndroid Build Coastguard Worker                "//executorch/extension/data_loader:file_data_loader",
29*523fa7a6SAndroid Build Coastguard Worker                "//executorch/runtime/core/exec_aten/testing_util:tensor_util",
30*523fa7a6SAndroid Build Coastguard Worker                "//executorch/extension/evalue_util:print_evalue",
31*523fa7a6SAndroid Build Coastguard Worker                "//executorch/runtime/executor/test:managed_memory_manager",
32*523fa7a6SAndroid Build Coastguard Worker                "//executorch/extension/training/optimizer:sgd",
33*523fa7a6SAndroid Build Coastguard Worker                "//executorch/extension/training/module:training_module",
34*523fa7a6SAndroid Build Coastguard Worker                "//executorch/kernels/portable:generated_lib",
35*523fa7a6SAndroid Build Coastguard Worker            ],
36*523fa7a6SAndroid Build Coastguard Worker            env = modules_env,
37*523fa7a6SAndroid Build Coastguard Worker        )
38