xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/aot/BUILD (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
2load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library")
3load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test")
4load("//tensorflow/core/platform:build_config.bzl", "if_llvm_aarch64_available", "if_llvm_system_z_available")
5
6package(
7    default_visibility = ["//visibility:private"],
8    licenses = ["notice"],
9)
10
11# Don't depend on this directly; this is only used for the benchmark test
12# generated by tf_library.
13cc_library(
14    name = "tf_library_test_main",
15    testonly = 1,
16    visibility = ["//visibility:public"],
17    deps = ["//tensorflow/core:test_main"],
18)
19
20filegroup(
21    name = "quantize_header",
22    srcs = ["quantize.h"],
23    visibility = ["//visibility:public"],
24)
25
26cc_library(
27    name = "tfcompile_lib",
28    srcs = [
29        "codegen.cc",
30        "compile.cc",
31        "flags.cc",
32    ],
33    hdrs = [
34        "codegen.h",
35        "compile.h",
36        "flags.h",
37        "quantize.h",
38    ],
39    compatible_with = [],
40    defines = if_llvm_aarch64_available(["TF_LLVM_AARCH64_AVAILABLE=1"]) + if_llvm_system_z_available([
41        "TF_LLVM_S390X_AVAILABLE=1",
42    ]),
43    visibility = ["//tensorflow/python:__pkg__"],
44    deps = [
45        ":aot_only_var_handle_op",
46        ":embedded_protocol_buffers",
47        "@com_google_absl//absl/base",
48        "@com_google_absl//absl/memory",
49        "@com_google_absl//absl/strings",
50        "@com_google_absl//absl/types:span",
51        "//tensorflow/compiler/tf2xla",
52        "//tensorflow/compiler/tf2xla:mlir_tf2xla",
53        "//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
54        "//tensorflow/compiler/tf2xla:tf2xla_util",
55        "//tensorflow/compiler/tf2xla:xla_compiler",
56        "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops",
57        "//tensorflow/compiler/tf2xla/kernels:xla_ops",
58        "//tensorflow/compiler/xla:cpu_function_runtime",
59        "//tensorflow/compiler/xla:shape_util",
60        "//tensorflow/compiler/xla:statusor",
61        "//tensorflow/compiler/xla:util",
62        "//tensorflow/compiler/xla:xla_data_proto_cc",
63        "//tensorflow/compiler/xla/client:client_library",
64        "//tensorflow/compiler/xla/client:compile_only_client",
65        "//tensorflow/compiler/xla/client:xla_computation",
66        "//tensorflow/compiler/xla/service:compiler",
67        "//tensorflow/compiler/xla/service/cpu:buffer_info_util",
68        "//tensorflow/compiler/xla/service/cpu:cpu_compiler",
69        "//tensorflow/core:core_cpu_internal",
70        "//tensorflow/core:framework_internal",
71        "//tensorflow/core:lib",
72        "//tensorflow/core:lib_internal",
73        "//tensorflow/core:protos_all_cc",
74        "@llvm-project//llvm:ARMCodeGen",  # fixdeps: keep
75        "@llvm-project//llvm:PowerPCCodeGen",  # fixdeps: keep
76        "@llvm-project//llvm:Support",
77        "@llvm-project//llvm:Target",
78        "@llvm-project//llvm:X86CodeGen",  # fixdeps: keep
79        "//tensorflow/core/platform:regexp",
80    ] + if_llvm_system_z_available([
81        "@llvm-project//llvm:SystemZCodeGen",  # fixdeps: keep
82    ]) + if_llvm_aarch64_available([
83        "@llvm-project//llvm:AArch64CodeGen",  # fixdeps: keep
84    ]),
85)
86
87tf_cc_test(
88    name = "codegen_test",
89    srcs = ["codegen_test.cc"],
90    data = [
91        "codegen_test_h.golden",
92        "codegen_test_o.golden",
93    ],
94    deps = [
95        ":tfcompile_lib",
96        "//tensorflow/compiler/xla:cpu_function_runtime",
97        "//tensorflow/compiler/xla:shape_util",
98        "//tensorflow/core:lib",
99        "//tensorflow/core:protos_all_cc",
100        "//tensorflow/core:test",
101        "//tensorflow/core:test_main",
102        "//tensorflow/core/platform:resource_loader",
103        "@com_google_absl//absl/strings",
104        "@llvm-project//llvm:Support",  # fixdeps: keep
105        "@llvm-project//llvm:X86CodeGen",  # fixdeps: keep
106    ],
107)
108
109tf_cc_binary(
110    name = "tfcompile",
111    visibility = ["//visibility:public"],
112    deps = [":tfcompile_main"],
113)
114
115cc_library(
116    name = "llvm_targets",
117    visibility = ["//tensorflow/python:__pkg__"],
118    deps = [
119        "@llvm-project//llvm:ARMCodeGen",  # fixdeps: keep
120        "@llvm-project//llvm:PowerPCCodeGen",  # fixdeps: keep
121        "@llvm-project//llvm:Target",
122        "@llvm-project//llvm:X86CodeGen",  # fixdeps: keep
123    ] + if_llvm_system_z_available([
124        "@llvm-project//llvm:SystemZCodeGen",  # fixdeps: keep
125    ]) + if_llvm_aarch64_available([
126        "@llvm-project//llvm:AArch64CodeGen",  # fixdeps: keep
127    ]),
128)
129
130cc_library(
131    name = "tfcompile_main",
132    srcs = ["tfcompile_main.cc"],
133    compatible_with = [],
134    visibility = ["//visibility:public"],
135    deps = [
136        ":tfcompile_lib",
137        "//tensorflow/compiler/tf2xla:tf2xla_proto_cc",
138        "//tensorflow/compiler/tf2xla:tf2xla_util",
139        "//tensorflow/compiler/xla:debug_options_flags",
140        "//tensorflow/compiler/xla/service:compiler",
141        "//tensorflow/core:core_cpu",
142        "//tensorflow/core:core_cpu_internal",
143        "//tensorflow/core:framework",
144        "//tensorflow/core:framework_internal",
145        "//tensorflow/core:graph",
146        "//tensorflow/core:lib",
147        "//tensorflow/core:protos_all_cc",
148        "@com_google_absl//absl/strings",
149    ],
150)
151
152# NOTE: Most end-to-end tests are in the "tests" subdirectory, to ensure that
153# tfcompile.bzl correctly handles usage from outside of the package that it is
154# defined in.
155
156# A simple test of tf_library from a text protobuf, to enable benchmark_test.
157# This test uses an incompleted graph with a node that is not defined. The
158# compilation works because the undefined node is a feed node.
159tf_library(
160    name = "test_graph_tfadd",
161    testonly = 1,
162    config = "test_graph_tfadd.config.pbtxt",
163    cpp_class = "AddComp",
164    graph = "test_graph_tfadd.pbtxt",
165    mlir_components = "None",
166    tags = [
167        "manual",
168    ],
169)
170
171tf_library(
172    name = "test_graph_tfadd_mlir_bridge",
173    testonly = 1,
174    config = "test_graph_tfadd.config.pbtxt",
175    cpp_class = "AddComp",
176    graph = "test_graph_tfadd.pbtxt",
177    mlir_components = "Bridge",
178    tags = [
179        "manual",
180    ],
181)
182
183# A test of tf_library that includes a graph with an unknown op, but where
184# the compilation works because the node with the unknown op is not needed
185# for the fetches.
186tf_library(
187    name = "test_graph_tfunknownop",
188    testonly = 1,
189    config = "test_graph_tfunknownop.config.pbtxt",
190    cpp_class = "UnknownOpAddComp",
191    graph = "test_graph_tfunknownop.pbtxt",
192    mlir_components = "None",
193    tags = [
194        "manual",
195    ],
196)
197
198tf_library(
199    name = "test_graph_tfunknownop_mlir_bridge",
200    testonly = 1,
201    config = "test_graph_tfunknownop.config.pbtxt",
202    cpp_class = "UnknownOpAddComp",
203    graph = "test_graph_tfunknownop.pbtxt",
204    mlir_components = "Bridge",
205    tags = [
206        "manual",
207    ],
208)
209
210# A test of tf_library that includes a graph with an unknown op, but where
211# the compilation works because the node with the unknown op is only used as
212# an input of a feed node.
213tf_library(
214    name = "test_graph_tfunknownop2",
215    testonly = 1,
216    config = "test_graph_tfunknownop2.config.pbtxt",
217    cpp_class = "UnknownOpAddComp",
218    graph = "test_graph_tfunknownop.pbtxt",
219    mlir_components = "None",
220    tags = [
221        "manual",
222    ],
223)
224
225tf_library(
226    name = "test_graph_tfunknownop2_mlir_bridge",
227    testonly = 1,
228    config = "test_graph_tfunknownop2.config.pbtxt",
229    cpp_class = "UnknownOpAddComp",
230    graph = "test_graph_tfunknownop.pbtxt",
231    mlir_components = "Bridge",
232    tags = [
233        "manual",
234    ],
235)
236
237# A test of tf_library that includes a graph with an unknown op, but where
238# the compilation works because the node with the unknown op is a feed node.
239tf_library(
240    name = "test_graph_tfunknownop3",
241    testonly = 1,
242    config = "test_graph_tfunknownop3.config.pbtxt",
243    cpp_class = "UnknownOpAddComp",
244    graph = "test_graph_tfunknownop.pbtxt",
245    mlir_components = "None",
246    tags = [
247        "manual",
248    ],
249)
250
251tf_library(
252    name = "test_graph_tfunknownop3_mlir_bridge",
253    testonly = 1,
254    config = "test_graph_tfunknownop3.config.pbtxt",
255    cpp_class = "UnknownOpAddComp",
256    graph = "test_graph_tfunknownop.pbtxt",
257    mlir_components = "Bridge",
258    tags = [
259        "manual",
260    ],
261)
262
263# Utility library for benchmark binaries, used by the *_benchmark rules that are
264# added by the tfcompile bazel macro.
265cc_library(
266    name = "benchmark",
267    srcs = ["benchmark.cc"],
268    hdrs = ["benchmark.h"],
269    visibility = ["//visibility:public"],
270    deps = [
271        # The purpose of the benchmark library is to support building an aot
272        # binary with minimal dependencies, to demonstrate small binary sizes.
273        #
274        # KEEP THE DEPENDENCIES MINIMAL.
275        "//tensorflow/core:framework_lite",
276    ],
277)
278
279cc_library(
280    name = "benchmark_extra_android",
281    tags = [
282        "manual",
283    ],
284    visibility = ["//visibility:public"],
285)
286
287cc_library(
288    name = "embedded_protocol_buffers",
289    srcs = ["embedded_protocol_buffers.cc"],
290    hdrs = ["embedded_protocol_buffers.h"],
291    deps = [
292        "//tensorflow/compiler/xla:statusor",
293        "//tensorflow/compiler/xla:util",
294        "//tensorflow/compiler/xla/service/llvm_ir:llvm_type_conversion_util",
295        "//tensorflow/core:lib",
296        "@com_google_absl//absl/memory",
297        "@com_google_absl//absl/strings",
298        "@com_google_absl//absl/types:span",
299        "@llvm-project//llvm:Core",
300        "@llvm-project//llvm:MC",
301        "@llvm-project//llvm:Support",
302        "@llvm-project//llvm:Target",
303    ],
304)
305
306cc_library(
307    name = "aot_only_var_handle_op",
308    srcs = ["aot_only_var_handle_op.cc"],
309    hdrs = ["aot_only_var_handle_op.h"],
310    visibility = [
311        "//tensorflow/compiler/tf2xla:__pkg__",
312    ],
313    deps = [
314        "//tensorflow/compiler/tf2xla:xla_compiler",
315        "//tensorflow/compiler/tf2xla:xla_context",
316        "//tensorflow/compiler/tf2xla:xla_op_registry",
317        "//tensorflow/core:framework",
318    ],
319    alwayslink = 1,
320)
321
322tf_cc_test(
323    name = "benchmark_test",
324    srcs = ["benchmark_test.cc"],
325    tags = ["manual"],
326    deps = [
327        ":benchmark",
328        ":test_graph_tfadd",
329        "//tensorflow/core:test",
330        "//tensorflow/core:test_main",
331    ],
332)
333
334test_suite(
335    name = "all_tests",
336    tags = ["manual"],
337    tests = [
338        ":benchmark_test",
339        ":codegen_test",
340        ":test_graph_tfadd_mlir_bridge_test",
341        ":test_graph_tfadd_test",
342        ":test_graph_tfunknownop2_mlir_bridge_test",
343        ":test_graph_tfunknownop2_test",
344        ":test_graph_tfunknownop3_mlir_bridge_test",
345        ":test_graph_tfunknownop3_test",
346        ":test_graph_tfunknownop_mlir_bridge_test",
347        ":test_graph_tfunknownop_test",
348        "//tensorflow/compiler/aot/tests:all_tests",
349    ],
350)
351
352exports_files([
353    "benchmark_main.template",  # used by tf_library(...,gen_benchmark=True)
354    "test.cc",  # used by tf_library(...,gen_test=True)
355])
356