xref: /aosp_15_r20/external/executorch/kernels/optimized/test/libblas_test.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <gtest/gtest.h>
10 
11 #include <executorch/kernels/optimized/blas/CPUBlas.h>
12 #include <executorch/runtime/core/exec_aten/exec_aten.h>
13 
14 #include <vector>
15 
16 #define TEST_FORALL_SUPPORTED_CTYPES(_, N) \
17   _<double, N>();                          \
18   _<float, N>();                           \
19   _<int64_t, N>();                         \
20   _<uint8_t, N>();                         \
21   _<int32_t, N>();                         \
22   _<exec_aten::BFloat16, N>();
23 
24 namespace {
25 
26 // Fill a vector with a monotonic sequence of integer values
27 template <typename T>
fill_ones(std::vector<T> & arr)28 void fill_ones(std::vector<T>& arr) {
29   for (size_t i = 0; i < arr.size(); ++i) {
30     arr[i] = static_cast<T>(1);
31   }
32 }
33 
34 template <typename T>
check_all_equal_to(std::vector<T> & arr,const float value)35 bool check_all_equal_to(std::vector<T>& arr, const float value) {
36   for (size_t i = 0; i < arr.size(); ++i) {
37     if (arr[i] != static_cast<T>(value)) {
38       return false;
39     }
40   }
41   return true;
42 }
43 
44 } // namespace
45 
46 template <class CTYPE, int64_t N>
test_matmul_ones()47 void test_matmul_ones() {
48   using executorch::cpublas::TransposeType;
49 
50   std::vector<CTYPE> in_1(N * N);
51   fill_ones(in_1);
52   std::vector<CTYPE> in_2(N * N);
53   fill_ones(in_2);
54 
55   std::vector<CTYPE> out(N * N);
56 
57   const CTYPE* in_1_data = in_1.data();
58   const CTYPE* in_2_data = in_2.data();
59 
60   CTYPE* out_data = out.data();
61 
62   // clang-format off
63   executorch::cpublas::gemm(
64       TransposeType::NoTranspose, TransposeType::NoTranspose,
65       N, N, N,
66       static_cast<CTYPE>(1),
67       in_1_data, N,
68       in_2_data, N,
69       static_cast<CTYPE>(0),
70       out_data, N);
71   // clang-format on
72 
73   EXPECT_TRUE(check_all_equal_to(out, static_cast<float>(N)));
74 }
75 
TEST(BlasTest,MatmulOnes)76 TEST(BlasTest, MatmulOnes) {
77   TEST_FORALL_SUPPORTED_CTYPES(test_matmul_ones, 25);
78 }
79