1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #define EIGEN_USE_THREADS
17
18 // TF_PIP_INTEGRATION_TEST is defined in the integration test for the support
19 // for AOT compilation in the PIP package. We don't have access to
20 // platform/logging, nor to platform/test, but we can use gtest.h instead.
21 // LINT.IfChange
22 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
23 #ifndef TF_PIP_INTEGRATION_TEST
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/test.h"
26 #else
27 #include "gtest/gtest.h"
28 #endif
29 #include "tensorflow/python/tools/aot_compiled_vars_and_arithmetic.h"
30 #include "tensorflow/python/tools/aot_compiled_vars_and_arithmetic_frozen.h"
31 #include "tensorflow/python/tools/aot_compiled_x_matmul_y_large.h"
32 #include "tensorflow/python/tools/aot_compiled_x_matmul_y_large_multithreaded.h"
33 #include "tensorflow/python/tools/aot_compiled_x_matmul_y_small.h"
34 #include "tensorflow/python/tools/aot_compiled_x_plus_y.h"
35 // LINT.ThenChange(//tensorflow/tools/pip_package/xla_build/pip_test/run_xla_aot_test.sh)
36
37 namespace tensorflow {
38 namespace {
TEST(AOTCompiledSavedModelTest,XPlusY)39 TEST(AOTCompiledSavedModelTest, XPlusY) {
40 XPlusY model;
41 // Calculation is: output_0 = x + y.
42 *model.arg_feed_x_data() = 3.0f;
43 *model.arg_feed_y_data() = 4.0f;
44 ASSERT_TRUE(model.Run());
45 ASSERT_NEAR(model.result_fetch_output_0(), 7.0f, /*abs_error=*/1e-6f);
46 }
47
TEST(AOTCompiledSavedModelTest,XMatmulYLarge)48 TEST(AOTCompiledSavedModelTest, XMatmulYLarge) {
49 XMatmulYLarge model;
50 // Calculation is: output_0 = x @ y.
51 EXPECT_EQ(model.arg_feed_x_count(), 3000 * 5000);
52 EXPECT_EQ(model.arg_feed_y_count(), 5000 * 4000);
53 EXPECT_EQ(model.result0_count(), 3000 * 4000);
54
55 Eigen::Tensor<float, 2, Eigen::RowMajor> arg_feed_x(3000, 5000);
56 Eigen::Tensor<float, 2, Eigen::RowMajor> arg_feed_y(5000, 4000);
57 arg_feed_x.setRandom();
58 arg_feed_y.setRandom();
59
60 // Set up dimensions for standard matmul.
61 const Eigen::array<Eigen::IndexPair<int>, 1> product_dims = {
62 Eigen::IndexPair<int>(1, 0)};
63 // Ground truth matmul.
64 const Eigen::Tensor<float, 2, Eigen::RowMajor> expected_output0 =
65 arg_feed_x.contract(arg_feed_y, product_dims);
66
67 model.set_arg_feed_x_data(arg_feed_x.data());
68 model.set_arg_feed_y_data(arg_feed_y.data());
69 ASSERT_TRUE(model.Run());
70 EXPECT_NEAR(model.result_fetch_output_0(0, 0), expected_output0(0, 0),
71 /*abs_error=*/1e-6f);
72 EXPECT_NEAR(model.result_fetch_output_0(2999, 3999),
73 expected_output0(2999, 3999),
74 /*abs_error=*/1e-6f);
75 }
76
TEST(AOTCompiledSavedModelTest,XMatmulYLargeMultithreaded)77 TEST(AOTCompiledSavedModelTest, XMatmulYLargeMultithreaded) {
78 XMatmulYLargeMultithreaded model;
79
80 Eigen::ThreadPool pool(2);
81 Eigen::ThreadPoolDevice device(&pool, pool.NumThreads());
82 model.set_thread_pool(&device);
83
84 // Calculation is: output_0 = x @ y.
85 EXPECT_EQ(model.arg_feed_x_count(), 3000 * 5000);
86 EXPECT_EQ(model.arg_feed_y_count(), 5000 * 4000);
87 EXPECT_EQ(model.result0_count(), 3000 * 4000);
88
89 Eigen::Tensor<float, 2, Eigen::RowMajor> arg_feed_x(3000, 5000);
90 Eigen::Tensor<float, 2, Eigen::RowMajor> arg_feed_y(5000, 4000);
91 arg_feed_x.setRandom();
92 arg_feed_y.setRandom();
93
94 // Set up dimensions for standard matmul.
95 const Eigen::array<Eigen::IndexPair<int>, 1> product_dims = {
96 Eigen::IndexPair<int>(1, 0)};
97 // Ground truth matmul.
98 const Eigen::Tensor<float, 2, Eigen::RowMajor> expected_output0 =
99 arg_feed_x.contract(arg_feed_y, product_dims);
100
101 model.set_arg_feed_x_data(arg_feed_x.data());
102 model.set_arg_feed_y_data(arg_feed_y.data());
103 ASSERT_TRUE(model.Run());
104 EXPECT_NEAR(model.result_fetch_output_0(0, 0), expected_output0(0, 0),
105 /*abs_error=*/1e-3f);
106 EXPECT_NEAR(model.result_fetch_output_0(2999, 3999),
107 expected_output0(2999, 3999),
108 /*abs_error=*/1e-3f);
109 }
110
TEST(AOTCompiledSavedModelTest,XMatmulYSmall)111 TEST(AOTCompiledSavedModelTest, XMatmulYSmall) {
112 XMatmulYSmall model;
113 // Calculation is: output_0 = x @ y.
114 EXPECT_EQ(model.arg_feed_x_count(), 3 * 5);
115 EXPECT_EQ(model.arg_feed_y_count(), 5 * 4);
116 EXPECT_EQ(model.result0_count(), 3 * 4);
117
118 Eigen::Tensor<float, 2, Eigen::RowMajor> arg_feed_x(3, 5);
119 Eigen::Tensor<float, 2, Eigen::RowMajor> arg_feed_y(5, 4);
120 arg_feed_x.setRandom();
121 arg_feed_y.setRandom();
122
123 // Set up dimensions for standard matmul.
124 const Eigen::array<Eigen::IndexPair<int>, 1> product_dims = {
125 Eigen::IndexPair<int>(1, 0)};
126 // Ground truth matmul.
127 const Eigen::Tensor<float, 2, Eigen::RowMajor> expected_output0 =
128 arg_feed_x.contract(arg_feed_y, product_dims);
129
130 model.set_arg_feed_x_data(arg_feed_x.data());
131 model.set_arg_feed_y_data(arg_feed_y.data());
132 ASSERT_TRUE(model.Run());
133 EXPECT_NEAR(model.result_fetch_output_0(0, 0), expected_output0(0, 0),
134 /*abs_error=*/1e-6f);
135 EXPECT_NEAR(model.result_fetch_output_0(2, 3), expected_output0(2, 3),
136 /*abs_error=*/1e-6f);
137 }
138
TEST(AOTCompiledSavedModelTest,VarsAndArithmetic)139 TEST(AOTCompiledSavedModelTest, VarsAndArithmetic) {
140 VarsAndArithmeticFrozen frozen_model;
141 // Calculation is:
142 // output_0 = [(a + variable_x) * (b + variable_y) / child_variable] + 5.0
143 // where {variable_x, variable_y, child_variable} = {1.0, 2.0, 3.0} when
144 // initialized (frozen).
145 *frozen_model.arg_feed_a_data() = 1.0f;
146 *frozen_model.arg_feed_b_data() = 2.0f;
147 ASSERT_TRUE(frozen_model.Run());
148 ASSERT_NEAR(frozen_model.result_fetch_output_0(),
149 (1.0f + 1.0f) * (2.0f + 2.0f) / 3.0f + 5.0f, /*abs_error=*/1e-6f);
150
151 VarsAndArithmetic nonfrozen_model;
152 *nonfrozen_model.arg_feed_a_data() = 1.0f;
153 *nonfrozen_model.arg_feed_b_data() = 2.0f;
154 // variable_x is no longer frozen. set it to 4.0;
155 float new_variable_x = 4.0f;
156 nonfrozen_model.set_var_param_variable_x_data(&new_variable_x);
157 ASSERT_TRUE(nonfrozen_model.Run());
158 ASSERT_NEAR(nonfrozen_model.result_fetch_output_0(),
159 (1.0f + 4.0f) * (2.0f + 2.0f) / 3.0f + 5.0f, /*abs_error=*/1e-6f);
160 }
161 } // namespace
162 } // namespace tensorflow
163