1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/client/lib/tridiagonal.h"
17
18 #include <cstdint>
19 #include <tuple>
20 #include <vector>
21
22 #include "tensorflow/compiler/xla/client/lib/slicing.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status.h"
27 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
28 #include "tensorflow/compiler/xla/tests/test_macros.h"
29
30 namespace xla {
31 namespace tridiagonal {
32 namespace {
33
34 class TridiagonalTest
35 : public ClientLibraryTestBase,
36 public ::testing::WithParamInterface<std::tuple<int, int, int>> {};
37
XLA_TEST_P(TridiagonalTest,SimpleTridiagonalMatMulOk)38 XLA_TEST_P(TridiagonalTest, SimpleTridiagonalMatMulOk) {
39 xla::XlaBuilder builder(TestName());
40
41 // Since the last element ignored, it will be {{{34, 35, 0}}}
42 Array3D<float> upper_diagonal{{{34, 35, 999}}};
43 Array3D<float> main_diagonal{{{21, 22, 23}}};
44 // Since the first element ignored, it will be {{{0, 10, 100}}}
45 Array3D<float> lower_diagonal{{{999, 10, 100}}};
46 Array3D<float> rhs{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}};
47
48 XlaOp upper_diagonal_xla;
49 XlaOp main_diagonal_xla;
50 XlaOp lower_diagonal_xla;
51 XlaOp rhs_xla;
52
53 auto upper_diagonal_data = CreateR3Parameter<float>(
54 upper_diagonal, 0, "upper_diagonal", &builder, &upper_diagonal_xla);
55 auto main_diagonal_data = CreateR3Parameter<float>(
56 main_diagonal, 1, "main_diagonal", &builder, &main_diagonal_xla);
57 auto lower_diagonal_data = CreateR3Parameter<float>(
58 lower_diagonal, 2, "lower_diagonal", &builder, &lower_diagonal_xla);
59 auto rhs_data = CreateR3Parameter<float>(rhs, 3, "rhs", &builder, &rhs_xla);
60
61 TF_ASSERT_OK_AND_ASSIGN(
62 XlaOp x, TridiagonalMatMul(upper_diagonal_xla, main_diagonal_xla,
63 lower_diagonal_xla, rhs_xla));
64
65 ASSERT_EQ(x.builder()->first_error(), Status::OK());
66 ASSERT_TRUE(x.valid());
67
68 std::vector<int64_t> expected_shape{1, 3, 4};
69 std::vector<float> expected_values{191, 246, 301, 356, 435, 502,
70 569, 636, 707, 830, 953, 1076};
71 TF_ASSERT_OK_AND_ASSIGN(
72 auto result,
73 ComputeAndTransfer(x.builder(),
74 {upper_diagonal_data.get(), main_diagonal_data.get(),
75 lower_diagonal_data.get(), rhs_data.get()}));
76 EXPECT_EQ(result.shape().dimensions(), expected_shape);
77 EXPECT_EQ(result.data<float>({}), expected_values);
78 }
79
XLA_TEST_P(TridiagonalTest,TridiagonalMatMulWrongShape)80 XLA_TEST_P(TridiagonalTest, TridiagonalMatMulWrongShape) {
81 xla::XlaBuilder builder(TestName());
82
83 Array<float> upper_diagonal = Array<float>({5, 3, 7}, 1);
84 Array<float> main_diagonal = Array<float>({5, 3, 7}, 1);
85 Array<float> lower_diagonal = Array<float>({5, 3, 7}, 1);
86 Array<float> rhs = Array<float>({5, 3, 7, 6}, 1);
87
88 XlaOp upper_diagonal_xla;
89 XlaOp main_diagonal_xla;
90 XlaOp lower_diagonal_xla;
91 XlaOp rhs_xla;
92
93 auto upper_diagonal_data = CreateParameter<float>(
94 upper_diagonal, 0, "upper_diagonal", &builder, &upper_diagonal_xla);
95 auto main_diagonal_data = CreateParameter<float>(
96 main_diagonal, 1, "main_diagonal", &builder, &main_diagonal_xla);
97 auto lower_diagonal_data = CreateParameter<float>(
98 lower_diagonal, 2, "lower_diagonal", &builder, &lower_diagonal_xla);
99 auto rhs_data = CreateParameter<float>(rhs, 3, "rhs", &builder, &rhs_xla);
100
101 auto result = TridiagonalMatMul(upper_diagonal_xla, main_diagonal_xla,
102 lower_diagonal_xla, rhs_xla);
103 ASSERT_EQ(result.status(),
104 InvalidArgument(
105 "superdiag must have same rank as rhs, but got 3 and 4."));
106 }
107
XLA_TEST_P(TridiagonalTest,Solves)108 XLA_TEST_P(TridiagonalTest, Solves) {
109 const auto& spec = GetParam();
110 xla::XlaBuilder builder(TestName());
111
112 // TODO(belletti): parametrize num_rhs.
113 const int64_t batch_size = std::get<0>(spec);
114 const int64_t num_eqs = std::get<1>(spec);
115 const int64_t num_rhs = std::get<2>(spec);
116
117 Array3D<float> lower_diagonal(batch_size, 1, num_eqs);
118 Array3D<float> main_diagonal(batch_size, 1, num_eqs);
119 Array3D<float> upper_diagonal(batch_size, 1, num_eqs);
120 Array3D<float> rhs(batch_size, num_rhs, num_eqs);
121
122 lower_diagonal.FillRandom(1.0, /*mean=*/0.0, /*seed=*/0);
123 main_diagonal.FillRandom(0.05, /*mean=*/1.0,
124 /*seed=*/batch_size * num_eqs);
125 upper_diagonal.FillRandom(1.0, /*mean=*/0.0,
126 /*seed=*/2 * batch_size * num_eqs);
127 rhs.FillRandom(1.0, /*mean=*/0.0, /*seed=*/3 * batch_size * num_eqs);
128
129 XlaOp lower_diagonal_xla;
130 XlaOp main_diagonal_xla;
131 XlaOp upper_diagonal_xla;
132 XlaOp rhs_xla;
133
134 auto lower_diagonal_data = CreateR3Parameter<float>(
135 lower_diagonal, 0, "lower_diagonal", &builder, &lower_diagonal_xla);
136 auto main_diagonal_data = CreateR3Parameter<float>(
137 main_diagonal, 1, "main_diagonal", &builder, &main_diagonal_xla);
138 auto upper_diagonal_data = CreateR3Parameter<float>(
139 upper_diagonal, 2, "upper_diagonal", &builder, &upper_diagonal_xla);
140 auto rhs_data = CreateR3Parameter<float>(rhs, 3, "rhs", &builder, &rhs_xla);
141
142 TF_ASSERT_OK_AND_ASSIGN(
143 XlaOp x, TridiagonalSolver(kThomas, lower_diagonal_xla, main_diagonal_xla,
144 upper_diagonal_xla, rhs_xla));
145
146 auto Coefficient = [](auto operand, auto i) {
147 return SliceInMinorDims(operand, /*start=*/{i}, /*end=*/{i + 1});
148 };
149
150 std::vector<XlaOp> relative_errors(num_eqs);
151
152 for (int64_t i = 0; i < num_eqs; i++) {
153 auto a_i = Coefficient(lower_diagonal_xla, i);
154 auto b_i = Coefficient(main_diagonal_xla, i);
155 auto c_i = Coefficient(upper_diagonal_xla, i);
156 auto d_i = Coefficient(rhs_xla, i);
157
158 if (i == 0) {
159 relative_errors[i] =
160 (b_i * Coefficient(x, i) + c_i * Coefficient(x, i + 1) - d_i) / d_i;
161 } else if (i == num_eqs - 1) {
162 relative_errors[i] =
163 (a_i * Coefficient(x, i - 1) + b_i * Coefficient(x, i) - d_i) / d_i;
164 } else {
165 relative_errors[i] =
166 (a_i * Coefficient(x, i - 1) + b_i * Coefficient(x, i) +
167 c_i * Coefficient(x, i + 1) - d_i) /
168 d_i;
169 }
170 }
171 Abs(ConcatInDim(&builder, relative_errors, 2));
172
173 TF_ASSERT_OK_AND_ASSIGN(
174 auto result,
175 ComputeAndTransfer(&builder,
176 {lower_diagonal_data.get(), main_diagonal_data.get(),
177 upper_diagonal_data.get(), rhs_data.get()}));
178
179 auto result_data = result.data<float>({});
180 for (auto result_component : result_data) {
181 EXPECT_TRUE(result_component < 5e-3);
182 }
183 }
184
185 INSTANTIATE_TEST_CASE_P(TridiagonalTestInstantiation, TridiagonalTest,
186 ::testing::Combine(::testing::Values(1, 12),
187 ::testing::Values(4, 8),
188 ::testing::Values(1, 12)));
189
190 } // namespace
191 } // namespace tridiagonal
192 } // namespace xla
193