xref: /aosp_15_r20/external/ComputeLibrary/examples/cl_sgemm.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1*c217d954SCole Faust /*
2*c217d954SCole Faust  * Copyright (c) 2017-2020 Arm Limited.
3*c217d954SCole Faust  *
4*c217d954SCole Faust  * SPDX-License-Identifier: MIT
5*c217d954SCole Faust  *
6*c217d954SCole Faust  * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust  * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust  * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust  * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust  * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust  *
13*c217d954SCole Faust  * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust  * copies or substantial portions of the Software.
15*c217d954SCole Faust  *
16*c217d954SCole Faust  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust  * SOFTWARE.
23*c217d954SCole Faust  */
24*c217d954SCole Faust #ifndef ARM_COMPUTE_CL /* Needed by Utils.cpp to handle OpenCL exceptions properly */
25*c217d954SCole Faust #error "This example needs to be built with -DARM_COMPUTE_CL"
26*c217d954SCole Faust #endif /* ARM_COMPUTE_CL */
27*c217d954SCole Faust 
28*c217d954SCole Faust #include "arm_compute/core/Types.h"
29*c217d954SCole Faust #include "arm_compute/runtime/CL/CLScheduler.h"
30*c217d954SCole Faust #include "arm_compute/runtime/CL/CLTuner.h"
31*c217d954SCole Faust #include "arm_compute/runtime/CL/functions/CLGEMM.h"
32*c217d954SCole Faust #include "utils/Utils.h"
33*c217d954SCole Faust 
34*c217d954SCole Faust #include <cstdlib>
35*c217d954SCole Faust 
36*c217d954SCole Faust using namespace arm_compute;
37*c217d954SCole Faust using namespace utils;
38*c217d954SCole Faust 
39*c217d954SCole Faust class CLSGEMMExample : public Example
40*c217d954SCole Faust {
41*c217d954SCole Faust public:
do_setup(int argc,char ** argv)42*c217d954SCole Faust     bool do_setup(int argc, char **argv) override
43*c217d954SCole Faust     {
44*c217d954SCole Faust         NPYLoader npy0;
45*c217d954SCole Faust         NPYLoader npy1;
46*c217d954SCole Faust         NPYLoader npy2;
47*c217d954SCole Faust         alpha = 1.0f;
48*c217d954SCole Faust         beta  = 0.0f;
49*c217d954SCole Faust 
50*c217d954SCole Faust         CLScheduler::get().default_init(&tuner);
51*c217d954SCole Faust 
52*c217d954SCole Faust         std::ifstream stream;
53*c217d954SCole Faust         if(argc > 1)
54*c217d954SCole Faust         {
55*c217d954SCole Faust             stream.open(argv[1], std::fstream::in);
56*c217d954SCole Faust         }
57*c217d954SCole Faust 
58*c217d954SCole Faust         if(argc < 3 || (argc < 4 && stream.bad()))
59*c217d954SCole Faust         {
60*c217d954SCole Faust             // Print help
61*c217d954SCole Faust             std::cout << "Usage: 1) ./build/cl_sgemm input_matrix_1.npy input_matrix_2.npy [input_matrix_3.npy] [alpha = 1] [beta = 0]\n";
62*c217d954SCole Faust             std::cout << "       2) ./build/cl_sgemm M N K [alpha = 1.0f] [beta = 0.0f]\n\n";
63*c217d954SCole Faust             std::cout << "Too few or no input_matrices provided. Using M=7, N=3, K=5, alpha=1.0f and beta=0.0f\n\n";
64*c217d954SCole Faust 
65*c217d954SCole Faust             src0.allocator()->init(TensorInfo(TensorShape(5U, 7U), 1, DataType::F32));
66*c217d954SCole Faust             src1.allocator()->init(TensorInfo(TensorShape(3U, 5U), 1, DataType::F32));
67*c217d954SCole Faust             src2.allocator()->init(TensorInfo(TensorShape(3U, 7U), 1, DataType::F32));
68*c217d954SCole Faust         }
69*c217d954SCole Faust         else
70*c217d954SCole Faust         {
71*c217d954SCole Faust             if(stream.good()) /* case file1.npy file2.npy [file3.npy] [alpha = 1.0f] [beta = 0.0f] */
72*c217d954SCole Faust             {
73*c217d954SCole Faust                 npy0.open(argv[1]);
74*c217d954SCole Faust                 npy0.init_tensor(src0, DataType::F32);
75*c217d954SCole Faust                 npy1.open(argv[2]);
76*c217d954SCole Faust                 npy1.init_tensor(src1, DataType::F32);
77*c217d954SCole Faust 
78*c217d954SCole Faust                 if(argc > 3)
79*c217d954SCole Faust                 {
80*c217d954SCole Faust                     stream.close();
81*c217d954SCole Faust                     stream.clear();
82*c217d954SCole Faust                     stream.open(argv[3], std::fstream::in);
83*c217d954SCole Faust                     if(stream.good()) /* case with third file */
84*c217d954SCole Faust                     {
85*c217d954SCole Faust                         npy2.open(argv[3]);
86*c217d954SCole Faust                         npy2.init_tensor(src2, DataType::F32);
87*c217d954SCole Faust 
88*c217d954SCole Faust                         if(argc > 4)
89*c217d954SCole Faust                         {
90*c217d954SCole Faust                             // Convert string to float
91*c217d954SCole Faust                             alpha = strtof(argv[4], nullptr);
92*c217d954SCole Faust 
93*c217d954SCole Faust                             if(argc > 5)
94*c217d954SCole Faust                             {
95*c217d954SCole Faust                                 // Convert string to float
96*c217d954SCole Faust                                 beta = strtof(argv[5], nullptr);
97*c217d954SCole Faust                             }
98*c217d954SCole Faust                         }
99*c217d954SCole Faust                     }
100*c217d954SCole Faust                     else /* case without third file */
101*c217d954SCole Faust                     {
102*c217d954SCole Faust                         alpha = strtof(argv[3], nullptr);
103*c217d954SCole Faust 
104*c217d954SCole Faust                         if(argc > 4)
105*c217d954SCole Faust                         {
106*c217d954SCole Faust                             beta = strtof(argv[4], nullptr);
107*c217d954SCole Faust                         }
108*c217d954SCole Faust                     }
109*c217d954SCole Faust                 }
110*c217d954SCole Faust             }
111*c217d954SCole Faust             else /* case M N K [alpha = 1.0f] [beta = 0.0f] */
112*c217d954SCole Faust             {
113*c217d954SCole Faust                 size_t M = strtol(argv[1], nullptr, 10);
114*c217d954SCole Faust                 size_t N = strtol(argv[2], nullptr, 10);
115*c217d954SCole Faust                 size_t K = strtol(argv[3], nullptr, 10);
116*c217d954SCole Faust 
117*c217d954SCole Faust                 src0.allocator()->init(TensorInfo(TensorShape(K, M), 1, DataType::F32));
118*c217d954SCole Faust                 src1.allocator()->init(TensorInfo(TensorShape(N, K), 1, DataType::F32));
119*c217d954SCole Faust                 src2.allocator()->init(TensorInfo(TensorShape(N, M), 1, DataType::F32));
120*c217d954SCole Faust 
121*c217d954SCole Faust                 if(argc > 4)
122*c217d954SCole Faust                 {
123*c217d954SCole Faust                     alpha = strtof(argv[4], nullptr);
124*c217d954SCole Faust 
125*c217d954SCole Faust                     if(argc > 5)
126*c217d954SCole Faust                     {
127*c217d954SCole Faust                         beta = strtof(argv[5], nullptr);
128*c217d954SCole Faust                     }
129*c217d954SCole Faust                 }
130*c217d954SCole Faust             }
131*c217d954SCole Faust         }
132*c217d954SCole Faust 
133*c217d954SCole Faust         init_sgemm_output(dst, src0, src1, DataType::F32);
134*c217d954SCole Faust 
135*c217d954SCole Faust         // Configure function
136*c217d954SCole Faust         sgemm.configure(&src0, &src1, (src2.info()->total_size() > 0) ? &src2 : nullptr, &dst, alpha, beta);
137*c217d954SCole Faust 
138*c217d954SCole Faust         // Allocate all the images
139*c217d954SCole Faust         src0.allocator()->allocate();
140*c217d954SCole Faust         src1.allocator()->allocate();
141*c217d954SCole Faust         dst.allocator()->allocate();
142*c217d954SCole Faust 
143*c217d954SCole Faust         // Fill the input images with either the data provided or random data
144*c217d954SCole Faust         if(npy0.is_open())
145*c217d954SCole Faust         {
146*c217d954SCole Faust             npy0.fill_tensor(src0);
147*c217d954SCole Faust             npy1.fill_tensor(src1);
148*c217d954SCole Faust 
149*c217d954SCole Faust             output_filename = "sgemm_out.npy";
150*c217d954SCole Faust             is_fortran      = npy0.is_fortran();
151*c217d954SCole Faust 
152*c217d954SCole Faust             if(npy2.is_open())
153*c217d954SCole Faust             {
154*c217d954SCole Faust                 src2.allocator()->allocate();
155*c217d954SCole Faust                 npy2.fill_tensor(src2);
156*c217d954SCole Faust             }
157*c217d954SCole Faust         }
158*c217d954SCole Faust         else
159*c217d954SCole Faust         {
160*c217d954SCole Faust             src2.allocator()->allocate();
161*c217d954SCole Faust 
162*c217d954SCole Faust             fill_random_tensor(src0, -1.f, 1.f);
163*c217d954SCole Faust             fill_random_tensor(src1, -1.f, 1.f);
164*c217d954SCole Faust             fill_random_tensor(src2, -1.f, 1.f);
165*c217d954SCole Faust         }
166*c217d954SCole Faust 
167*c217d954SCole Faust         // Dummy run for CLTuner
168*c217d954SCole Faust         sgemm.run();
169*c217d954SCole Faust 
170*c217d954SCole Faust         return true;
171*c217d954SCole Faust     }
do_run()172*c217d954SCole Faust     void do_run() override
173*c217d954SCole Faust     {
174*c217d954SCole Faust         // Execute the function
175*c217d954SCole Faust         sgemm.run();
176*c217d954SCole Faust 
177*c217d954SCole Faust         // Make sure all the OpenCL jobs are done executing:
178*c217d954SCole Faust         CLScheduler::get().sync();
179*c217d954SCole Faust     }
do_teardown()180*c217d954SCole Faust     void do_teardown() override
181*c217d954SCole Faust     {
182*c217d954SCole Faust         if(!output_filename.empty()) /* Save to .npy file */
183*c217d954SCole Faust         {
184*c217d954SCole Faust             save_to_npy(dst, output_filename, is_fortran);
185*c217d954SCole Faust         }
186*c217d954SCole Faust     }
187*c217d954SCole Faust 
188*c217d954SCole Faust private:
189*c217d954SCole Faust     CLTensor    src0{};
190*c217d954SCole Faust     CLTensor    src1{};
191*c217d954SCole Faust     CLTensor    src2{};
192*c217d954SCole Faust     CLTensor    dst{};
193*c217d954SCole Faust     CLGEMM      sgemm{};
194*c217d954SCole Faust     CLTuner     tuner{};
195*c217d954SCole Faust     float       alpha{}, beta{};
196*c217d954SCole Faust     bool        is_fortran{};
197*c217d954SCole Faust     std::string output_filename{};
198*c217d954SCole Faust };
199*c217d954SCole Faust 
200*c217d954SCole Faust /** Main program for sgemm test
201*c217d954SCole Faust  *
202*c217d954SCole Faust  * @param[in] argc Number of arguments
203*c217d954SCole Faust  * @param[in] argv Arguments ( [optional] Matrix A, [optional] Matrix B, [optional] Matrix C, [optional] alpha, [optional] beta )
204*c217d954SCole Faust  */
main(int argc,char ** argv)205*c217d954SCole Faust int main(int argc, char **argv)
206*c217d954SCole Faust {
207*c217d954SCole Faust     return utils::run_example<CLSGEMMExample>(argc, argv);
208*c217d954SCole Faust }
209