xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/client/lib/tridiagonal_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/tridiagonal.h"
17 
18 #include <cstdint>
19 #include <tuple>
20 #include <vector>
21 
22 #include "tensorflow/compiler/xla/client/lib/slicing.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/status.h"
27 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
28 #include "tensorflow/compiler/xla/tests/test_macros.h"
29 
30 namespace xla {
31 namespace tridiagonal {
32 namespace {
33 
34 class TridiagonalTest
35     : public ClientLibraryTestBase,
36       public ::testing::WithParamInterface<std::tuple<int, int, int>> {};
37 
XLA_TEST_P(TridiagonalTest,SimpleTridiagonalMatMulOk)38 XLA_TEST_P(TridiagonalTest, SimpleTridiagonalMatMulOk) {
39   xla::XlaBuilder builder(TestName());
40 
41   // Since the last element ignored, it will be {{{34, 35, 0}}}
42   Array3D<float> upper_diagonal{{{34, 35, 999}}};
43   Array3D<float> main_diagonal{{{21, 22, 23}}};
44   // Since the first element ignored, it will be {{{0, 10, 100}}}
45   Array3D<float> lower_diagonal{{{999, 10, 100}}};
46   Array3D<float> rhs{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}};
47 
48   XlaOp upper_diagonal_xla;
49   XlaOp main_diagonal_xla;
50   XlaOp lower_diagonal_xla;
51   XlaOp rhs_xla;
52 
53   auto upper_diagonal_data = CreateR3Parameter<float>(
54       upper_diagonal, 0, "upper_diagonal", &builder, &upper_diagonal_xla);
55   auto main_diagonal_data = CreateR3Parameter<float>(
56       main_diagonal, 1, "main_diagonal", &builder, &main_diagonal_xla);
57   auto lower_diagonal_data = CreateR3Parameter<float>(
58       lower_diagonal, 2, "lower_diagonal", &builder, &lower_diagonal_xla);
59   auto rhs_data = CreateR3Parameter<float>(rhs, 3, "rhs", &builder, &rhs_xla);
60 
61   TF_ASSERT_OK_AND_ASSIGN(
62       XlaOp x, TridiagonalMatMul(upper_diagonal_xla, main_diagonal_xla,
63                                  lower_diagonal_xla, rhs_xla));
64 
65   ASSERT_EQ(x.builder()->first_error(), Status::OK());
66   ASSERT_TRUE(x.valid());
67 
68   std::vector<int64_t> expected_shape{1, 3, 4};
69   std::vector<float> expected_values{191, 246, 301, 356, 435, 502,
70                                      569, 636, 707, 830, 953, 1076};
71   TF_ASSERT_OK_AND_ASSIGN(
72       auto result,
73       ComputeAndTransfer(x.builder(),
74                          {upper_diagonal_data.get(), main_diagonal_data.get(),
75                           lower_diagonal_data.get(), rhs_data.get()}));
76   EXPECT_EQ(result.shape().dimensions(), expected_shape);
77   EXPECT_EQ(result.data<float>({}), expected_values);
78 }
79 
XLA_TEST_P(TridiagonalTest,TridiagonalMatMulWrongShape)80 XLA_TEST_P(TridiagonalTest, TridiagonalMatMulWrongShape) {
81   xla::XlaBuilder builder(TestName());
82 
83   Array<float> upper_diagonal = Array<float>({5, 3, 7}, 1);
84   Array<float> main_diagonal = Array<float>({5, 3, 7}, 1);
85   Array<float> lower_diagonal = Array<float>({5, 3, 7}, 1);
86   Array<float> rhs = Array<float>({5, 3, 7, 6}, 1);
87 
88   XlaOp upper_diagonal_xla;
89   XlaOp main_diagonal_xla;
90   XlaOp lower_diagonal_xla;
91   XlaOp rhs_xla;
92 
93   auto upper_diagonal_data = CreateParameter<float>(
94       upper_diagonal, 0, "upper_diagonal", &builder, &upper_diagonal_xla);
95   auto main_diagonal_data = CreateParameter<float>(
96       main_diagonal, 1, "main_diagonal", &builder, &main_diagonal_xla);
97   auto lower_diagonal_data = CreateParameter<float>(
98       lower_diagonal, 2, "lower_diagonal", &builder, &lower_diagonal_xla);
99   auto rhs_data = CreateParameter<float>(rhs, 3, "rhs", &builder, &rhs_xla);
100 
101   auto result = TridiagonalMatMul(upper_diagonal_xla, main_diagonal_xla,
102                                   lower_diagonal_xla, rhs_xla);
103   ASSERT_EQ(result.status(),
104             InvalidArgument(
105                 "superdiag must have same rank as rhs, but got 3 and 4."));
106 }
107 
XLA_TEST_P(TridiagonalTest,Solves)108 XLA_TEST_P(TridiagonalTest, Solves) {
109   const auto& spec = GetParam();
110   xla::XlaBuilder builder(TestName());
111 
112   // TODO(belletti): parametrize num_rhs.
113   const int64_t batch_size = std::get<0>(spec);
114   const int64_t num_eqs = std::get<1>(spec);
115   const int64_t num_rhs = std::get<2>(spec);
116 
117   Array3D<float> lower_diagonal(batch_size, 1, num_eqs);
118   Array3D<float> main_diagonal(batch_size, 1, num_eqs);
119   Array3D<float> upper_diagonal(batch_size, 1, num_eqs);
120   Array3D<float> rhs(batch_size, num_rhs, num_eqs);
121 
122   lower_diagonal.FillRandom(1.0, /*mean=*/0.0, /*seed=*/0);
123   main_diagonal.FillRandom(0.05, /*mean=*/1.0,
124                            /*seed=*/batch_size * num_eqs);
125   upper_diagonal.FillRandom(1.0, /*mean=*/0.0,
126                             /*seed=*/2 * batch_size * num_eqs);
127   rhs.FillRandom(1.0, /*mean=*/0.0, /*seed=*/3 * batch_size * num_eqs);
128 
129   XlaOp lower_diagonal_xla;
130   XlaOp main_diagonal_xla;
131   XlaOp upper_diagonal_xla;
132   XlaOp rhs_xla;
133 
134   auto lower_diagonal_data = CreateR3Parameter<float>(
135       lower_diagonal, 0, "lower_diagonal", &builder, &lower_diagonal_xla);
136   auto main_diagonal_data = CreateR3Parameter<float>(
137       main_diagonal, 1, "main_diagonal", &builder, &main_diagonal_xla);
138   auto upper_diagonal_data = CreateR3Parameter<float>(
139       upper_diagonal, 2, "upper_diagonal", &builder, &upper_diagonal_xla);
140   auto rhs_data = CreateR3Parameter<float>(rhs, 3, "rhs", &builder, &rhs_xla);
141 
142   TF_ASSERT_OK_AND_ASSIGN(
143       XlaOp x, TridiagonalSolver(kThomas, lower_diagonal_xla, main_diagonal_xla,
144                                  upper_diagonal_xla, rhs_xla));
145 
146   auto Coefficient = [](auto operand, auto i) {
147     return SliceInMinorDims(operand, /*start=*/{i}, /*end=*/{i + 1});
148   };
149 
150   std::vector<XlaOp> relative_errors(num_eqs);
151 
152   for (int64_t i = 0; i < num_eqs; i++) {
153     auto a_i = Coefficient(lower_diagonal_xla, i);
154     auto b_i = Coefficient(main_diagonal_xla, i);
155     auto c_i = Coefficient(upper_diagonal_xla, i);
156     auto d_i = Coefficient(rhs_xla, i);
157 
158     if (i == 0) {
159       relative_errors[i] =
160           (b_i * Coefficient(x, i) + c_i * Coefficient(x, i + 1) - d_i) / d_i;
161     } else if (i == num_eqs - 1) {
162       relative_errors[i] =
163           (a_i * Coefficient(x, i - 1) + b_i * Coefficient(x, i) - d_i) / d_i;
164     } else {
165       relative_errors[i] =
166           (a_i * Coefficient(x, i - 1) + b_i * Coefficient(x, i) +
167            c_i * Coefficient(x, i + 1) - d_i) /
168           d_i;
169     }
170   }
171   Abs(ConcatInDim(&builder, relative_errors, 2));
172 
173   TF_ASSERT_OK_AND_ASSIGN(
174       auto result,
175       ComputeAndTransfer(&builder,
176                          {lower_diagonal_data.get(), main_diagonal_data.get(),
177                           upper_diagonal_data.get(), rhs_data.get()}));
178 
179   auto result_data = result.data<float>({});
180   for (auto result_component : result_data) {
181     EXPECT_TRUE(result_component < 5e-3);
182   }
183 }
184 
185 INSTANTIATE_TEST_CASE_P(TridiagonalTestInstantiation, TridiagonalTest,
186                         ::testing::Combine(::testing::Values(1, 12),
187                                            ::testing::Values(4, 8),
188                                            ::testing::Values(1, 12)));
189 
190 }  // namespace
191 }  // namespace tridiagonal
192 }  // namespace xla
193