1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Generated by the tf_library build rule. DO NOT EDIT!
17 //
18 // This file contains a test and benchmark for the function generated by
19 // tfcompile. All tokens of the form `{{TFCOMPILE_*}}` must be rewritten to
20 // real values before this file can be compiled.
21 //
22 // TFCOMPILE_HEADER : Path to the header file generated by tfcompile.
23 // TFCOMPILE_CPP_CLASS : Name of the C++ class generated by tfcompile.
24 // TFCOMPILE_NAME : Name for tests and benchmarks.
25 //
26 // The tf_library bazel macro in tfcompile.bzl performs the token rewriting, and
27 // generates a cc_test rule for you.
28
29 // These macros must be defined before eigen files are included.
30 #define EIGEN_USE_THREADS
31 #define EIGEN_USE_CUSTOM_THREAD_POOL
32
33 // clang-format off
34 #include "{{TFCOMPILE_HEADER}}" // NOLINT(whitespace/braces)
35 // clang-format on
36
37 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
38 #include "tensorflow/core/platform/byte_order.h"
39 #include "tensorflow/core/platform/cpu_info.h"
40 #include "tensorflow/core/platform/test.h"
41 #include "tensorflow/core/platform/test_benchmark.h"
42
43 // Macros that expand to tokens based on the entry point name.
44 // clang-format off
45 #define CPP_CLASS {{TFCOMPILE_CPP_CLASS}} // NOLINT(whitespace/braces)
46 #define TEST_NAME {{TFCOMPILE_NAME}}Test // NOLINT(whitespace/braces)
47 #define BM_NAME BM_{{TFCOMPILE_NAME}} // NOLINT(whitespace/braces)
48 // clang-format on
49
50 namespace tensorflow {
51 namespace tfcompile {
52 namespace {
53
zero_buffers(XlaCompiledCpuFunction * computation)54 void zero_buffers(XlaCompiledCpuFunction* computation) {
55 for (int i = 0; i < computation->num_args(); ++i) {
56 memset(computation->arg_data(i), 0, computation->arg_size(i));
57 }
58 }
59
60 // Trivial test that runs the generated function to ensure it doesn't crash.
TEST(TEST_NAME,NoCrash)61 TEST(TEST_NAME, NoCrash) {
62 Eigen::ThreadPool pool(port::MaxParallelism());
63 Eigen::ThreadPoolDevice device(&pool, pool.NumThreads());
64
65 CPP_CLASS computation;
66 computation.set_thread_pool(&device);
67 zero_buffers(&computation);
68
69 EXPECT_TRUE(computation.Run());
70 }
71
72 // Simple benchmark that repeatedly runs the generated function.
BM_NAME(benchmark::State & state)73 void BM_NAME(benchmark::State& state) {
74 Eigen::ThreadPool pool(port::MaxParallelism());
75 Eigen::ThreadPoolDevice device(&pool, pool.NumThreads());
76
77 CPP_CLASS computation;
78 computation.set_thread_pool(&device);
79 zero_buffers(&computation);
80
81 for (auto s : state) {
82 computation.Run();
83 }
84 }
85 BENCHMARK(BM_NAME);
86
87 } // namespace
88 } // namespace tfcompile
89 } // namespace tensorflow
90