xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/aot/benchmark_main.template (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1// Generated by the tf_library build rule.  DO NOT EDIT!
2//
3// This file contains the main function and logic for benchmarking code
4// generated by tfcompile.  All tokens of the form `{{TFCOMPILE_*}}` must be
5// rewritten to real values before this file can be compiled.
6//
7//    TFCOMPILE_HEADER    : Path to the header file generated by tfcompile.
8//    TFCOMPILE_CPP_CLASS : Name of the C++ class generated by tfcompile.
9//
10// The tf_library bazel macro in tfcompile.bzl performs the token rewriting, and
11// generates a cc_binary rule for you.
12
13// These macros must be defined before eigen files are included.
14#define EIGEN_USE_THREADS
15#define EIGEN_USE_CUSTOM_THREAD_POOL
16
17// clang-format off
18#include "{{TFCOMPILE_HEADER}}"  // NOLINT(whitespace/braces)
19// clang-format on
20
21#include "tensorflow/compiler/aot/benchmark.h"
22#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23
24// Macros that expand to tokens based on the entry point name.
25// clang-format off
26#define CPP_CLASS {{TFCOMPILE_CPP_CLASS}}  // NOLINT(whitespace/braces)
27// clang-format on
28
29namespace tensorflow {
30namespace tfcompile {
31
32int Main(int argc, char** argv) {
33  Eigen::ThreadPool pool(1 /* num_threads */);
34  Eigen::ThreadPoolDevice device(&pool, pool.NumThreads());
35
36  CPP_CLASS computation;
37  computation.set_thread_pool(&device);
38
39  benchmark::Options options;
40  benchmark::Stats stats;
41  benchmark::Benchmark(options, [&] { computation.Run(); }, &stats);
42  benchmark::DumpStatsToStdout(stats);
43  return 0;
44}
45
46}  // namespace tfcompile
47}  // namespace tensorflow
48
49int main(int argc, char** argv) {
50  return tensorflow::tfcompile::Main(argc, argv);
51}
52