xref: /aosp_15_r20/external/ruy/example/parametrized_example.cc (revision bb86c7ed5fb1b98a7eac808e443a46cc8b90dfc0)
1*bb86c7edSAndroid Build Coastguard Worker #include <cstdint>
2*bb86c7edSAndroid Build Coastguard Worker #include <cstdio>
3*bb86c7edSAndroid Build Coastguard Worker #include <cstdlib>
4*bb86c7edSAndroid Build Coastguard Worker #include <cstring>
5*bb86c7edSAndroid Build Coastguard Worker #include <type_traits>
6*bb86c7edSAndroid Build Coastguard Worker 
7*bb86c7edSAndroid Build Coastguard Worker #include "ruy/context.h"
8*bb86c7edSAndroid Build Coastguard Worker #include "ruy/matrix.h"
9*bb86c7edSAndroid Build Coastguard Worker #include "ruy/mul_params.h"
10*bb86c7edSAndroid Build Coastguard Worker #include "ruy/ruy.h"
11*bb86c7edSAndroid Build Coastguard Worker 
12*bb86c7edSAndroid Build Coastguard Worker template <typename... Dst>
read_cmdline_args(bool help,int argc,char * argv[],const char * name,const char * format,const char * default_value,const char * allowed_values,Dst...dst)13*bb86c7edSAndroid Build Coastguard Worker void read_cmdline_args(bool help, int argc, char* argv[], const char* name,
14*bb86c7edSAndroid Build Coastguard Worker                        const char* format, const char* default_value,
15*bb86c7edSAndroid Build Coastguard Worker                        const char* allowed_values, Dst... dst) {
16*bb86c7edSAndroid Build Coastguard Worker   if (help) {
17*bb86c7edSAndroid Build Coastguard Worker     fprintf(stderr, "%-20s %-12s %-16s %s\n", name, format, default_value,
18*bb86c7edSAndroid Build Coastguard Worker             allowed_values ? allowed_values : "");
19*bb86c7edSAndroid Build Coastguard Worker     return;
20*bb86c7edSAndroid Build Coastguard Worker   }
21*bb86c7edSAndroid Build Coastguard Worker   const char* value = default_value;
22*bb86c7edSAndroid Build Coastguard Worker   for (int i = 1; i < argc; i++) {
23*bb86c7edSAndroid Build Coastguard Worker     if (std::strstr(argv[i], name) == argv[i]) {
24*bb86c7edSAndroid Build Coastguard Worker       const char* equal_sign = std::strchr(argv[i], '=');
25*bb86c7edSAndroid Build Coastguard Worker       if (equal_sign == argv[i] + std::strlen(name)) {
26*bb86c7edSAndroid Build Coastguard Worker         value = equal_sign + 1;
27*bb86c7edSAndroid Build Coastguard Worker       }
28*bb86c7edSAndroid Build Coastguard Worker       break;
29*bb86c7edSAndroid Build Coastguard Worker     }
30*bb86c7edSAndroid Build Coastguard Worker   }
31*bb86c7edSAndroid Build Coastguard Worker   if (allowed_values) {
32*bb86c7edSAndroid Build Coastguard Worker     if (!std::strstr(allowed_values, value)) {
33*bb86c7edSAndroid Build Coastguard Worker       fprintf(stderr, "Illegal value %s. The legal values are %s.\n", value,
34*bb86c7edSAndroid Build Coastguard Worker               allowed_values);
35*bb86c7edSAndroid Build Coastguard Worker       exit(1);
36*bb86c7edSAndroid Build Coastguard Worker     }
37*bb86c7edSAndroid Build Coastguard Worker   }
38*bb86c7edSAndroid Build Coastguard Worker   if (sizeof...(Dst) != sscanf(value, format, dst...)) {
39*bb86c7edSAndroid Build Coastguard Worker     fprintf(stderr, "Failed to parse %s\n", value);
40*bb86c7edSAndroid Build Coastguard Worker     exit(1);
41*bb86c7edSAndroid Build Coastguard Worker   }
42*bb86c7edSAndroid Build Coastguard Worker }
43*bb86c7edSAndroid Build Coastguard Worker 
44*bb86c7edSAndroid Build Coastguard Worker struct Params {
45*bb86c7edSAndroid Build Coastguard Worker   char types[100];
46*bb86c7edSAndroid Build Coastguard Worker   int m, k, n;  // matmul shape m*k*n
47*bb86c7edSAndroid Build Coastguard Worker   int paths;
48*bb86c7edSAndroid Build Coastguard Worker   int num_threads;
49*bb86c7edSAndroid Build Coastguard Worker   int repeat;
50*bb86c7edSAndroid Build Coastguard Worker   int lhs_cache_policy;
51*bb86c7edSAndroid Build Coastguard Worker   int rhs_cache_policy;
52*bb86c7edSAndroid Build Coastguard Worker   int lhs_stride;
53*bb86c7edSAndroid Build Coastguard Worker   int rhs_stride;
54*bb86c7edSAndroid Build Coastguard Worker   int dst_stride;
55*bb86c7edSAndroid Build Coastguard Worker   int lhs_zero_point;
56*bb86c7edSAndroid Build Coastguard Worker   int rhs_zero_point;
57*bb86c7edSAndroid Build Coastguard Worker   int dst_zero_point;
58*bb86c7edSAndroid Build Coastguard Worker   char lhs_order[100];
59*bb86c7edSAndroid Build Coastguard Worker   char rhs_order[100];
60*bb86c7edSAndroid Build Coastguard Worker   char dst_order[100];
61*bb86c7edSAndroid Build Coastguard Worker };
62*bb86c7edSAndroid Build Coastguard Worker 
63*bb86c7edSAndroid Build Coastguard Worker template <typename LhsType, typename RhsType, typename DstType>
run(const Params & params)64*bb86c7edSAndroid Build Coastguard Worker void run(const Params& params) {
65*bb86c7edSAndroid Build Coastguard Worker   using AccumType =
66*bb86c7edSAndroid Build Coastguard Worker       typename std::conditional<std::is_floating_point<DstType>::value, DstType,
67*bb86c7edSAndroid Build Coastguard Worker                                 std::int32_t>::type;
68*bb86c7edSAndroid Build Coastguard Worker 
69*bb86c7edSAndroid Build Coastguard Worker   ruy::Matrix<LhsType> lhs;
70*bb86c7edSAndroid Build Coastguard Worker   ruy::Matrix<RhsType> rhs;
71*bb86c7edSAndroid Build Coastguard Worker   ruy::Matrix<DstType> dst;
72*bb86c7edSAndroid Build Coastguard Worker 
73*bb86c7edSAndroid Build Coastguard Worker   auto parse_order = [](const char* name) {
74*bb86c7edSAndroid Build Coastguard Worker     if (!std::strcmp(name, "row-major")) {
75*bb86c7edSAndroid Build Coastguard Worker       return ruy::Order::kRowMajor;
76*bb86c7edSAndroid Build Coastguard Worker     } else if (!std::strcmp(name, "column-major")) {
77*bb86c7edSAndroid Build Coastguard Worker       return ruy::Order::kColMajor;
78*bb86c7edSAndroid Build Coastguard Worker     } else {
79*bb86c7edSAndroid Build Coastguard Worker       fprintf(stderr, "Failed to parse %s\n", name);
80*bb86c7edSAndroid Build Coastguard Worker       exit(1);
81*bb86c7edSAndroid Build Coastguard Worker     }
82*bb86c7edSAndroid Build Coastguard Worker   };
83*bb86c7edSAndroid Build Coastguard Worker 
84*bb86c7edSAndroid Build Coastguard Worker   auto make_layout = [](int rows, int cols, int stride, ruy::Order order,
85*bb86c7edSAndroid Build Coastguard Worker                         ruy::Layout* layout) {
86*bb86c7edSAndroid Build Coastguard Worker     layout->set_rows(rows);
87*bb86c7edSAndroid Build Coastguard Worker     layout->set_cols(cols);
88*bb86c7edSAndroid Build Coastguard Worker     layout->set_order(order);
89*bb86c7edSAndroid Build Coastguard Worker     int base_stride = order == ruy::Order::kRowMajor ? cols : rows;
90*bb86c7edSAndroid Build Coastguard Worker     layout->set_stride(stride ? stride : base_stride);
91*bb86c7edSAndroid Build Coastguard Worker   };
92*bb86c7edSAndroid Build Coastguard Worker 
93*bb86c7edSAndroid Build Coastguard Worker   make_layout(params.m, params.k, params.lhs_stride,
94*bb86c7edSAndroid Build Coastguard Worker               parse_order(params.lhs_order), lhs.mutable_layout());
95*bb86c7edSAndroid Build Coastguard Worker   make_layout(params.k, params.n, params.rhs_stride,
96*bb86c7edSAndroid Build Coastguard Worker               parse_order(params.rhs_order), rhs.mutable_layout());
97*bb86c7edSAndroid Build Coastguard Worker   make_layout(params.m, params.n, params.dst_stride,
98*bb86c7edSAndroid Build Coastguard Worker               parse_order(params.dst_order), dst.mutable_layout());
99*bb86c7edSAndroid Build Coastguard Worker 
100*bb86c7edSAndroid Build Coastguard Worker   lhs.set_zero_point(params.lhs_zero_point);
101*bb86c7edSAndroid Build Coastguard Worker   rhs.set_zero_point(params.rhs_zero_point);
102*bb86c7edSAndroid Build Coastguard Worker   dst.set_zero_point(params.dst_zero_point);
103*bb86c7edSAndroid Build Coastguard Worker 
104*bb86c7edSAndroid Build Coastguard Worker   lhs.set_cache_policy(static_cast<ruy::CachePolicy>(params.lhs_cache_policy));
105*bb86c7edSAndroid Build Coastguard Worker   rhs.set_cache_policy(static_cast<ruy::CachePolicy>(params.rhs_cache_policy));
106*bb86c7edSAndroid Build Coastguard Worker 
107*bb86c7edSAndroid Build Coastguard Worker   auto flat_size = [](const ruy::Layout& layout) {
108*bb86c7edSAndroid Build Coastguard Worker     int outer_size =
109*bb86c7edSAndroid Build Coastguard Worker         layout.order() == ruy::Order::kRowMajor ? layout.rows() : layout.cols();
110*bb86c7edSAndroid Build Coastguard Worker     return outer_size * layout.stride();
111*bb86c7edSAndroid Build Coastguard Worker   };
112*bb86c7edSAndroid Build Coastguard Worker 
113*bb86c7edSAndroid Build Coastguard Worker   std::vector<LhsType> lhs_buf(flat_size(lhs.layout()));
114*bb86c7edSAndroid Build Coastguard Worker   std::vector<RhsType> rhs_buf(flat_size(rhs.layout()));
115*bb86c7edSAndroid Build Coastguard Worker   std::vector<DstType> dst_buf(flat_size(dst.layout()));
116*bb86c7edSAndroid Build Coastguard Worker 
117*bb86c7edSAndroid Build Coastguard Worker   lhs.set_data(lhs_buf.data());
118*bb86c7edSAndroid Build Coastguard Worker   rhs.set_data(rhs_buf.data());
119*bb86c7edSAndroid Build Coastguard Worker   dst.set_data(dst_buf.data());
120*bb86c7edSAndroid Build Coastguard Worker 
121*bb86c7edSAndroid Build Coastguard Worker   ruy::Context context;
122*bb86c7edSAndroid Build Coastguard Worker   context.set_max_num_threads(params.num_threads);
123*bb86c7edSAndroid Build Coastguard Worker   context.set_runtime_enabled_paths(static_cast<ruy::Path>(params.paths));
124*bb86c7edSAndroid Build Coastguard Worker 
125*bb86c7edSAndroid Build Coastguard Worker   ruy::MulParams<AccumType, DstType> mul_params;
126*bb86c7edSAndroid Build Coastguard Worker   // Here an actual application might set some mul_params fields.
127*bb86c7edSAndroid Build Coastguard Worker   // Quantization multipliers, bias-vector, clamp bounds, etc.
128*bb86c7edSAndroid Build Coastguard Worker 
129*bb86c7edSAndroid Build Coastguard Worker   for (int r = 0; r < params.repeat; r++) {
130*bb86c7edSAndroid Build Coastguard Worker     ruy::Mul(lhs, rhs, mul_params, &context, &dst);
131*bb86c7edSAndroid Build Coastguard Worker   }
132*bb86c7edSAndroid Build Coastguard Worker }
133*bb86c7edSAndroid Build Coastguard Worker 
main(int argc,char * argv[])134*bb86c7edSAndroid Build Coastguard Worker int main(int argc, char* argv[]) {
135*bb86c7edSAndroid Build Coastguard Worker   bool help = argc == 1 || (argc == 2 && !strcmp(argv[1], "--help"));
136*bb86c7edSAndroid Build Coastguard Worker   if (help) {
137*bb86c7edSAndroid Build Coastguard Worker     fprintf(stderr, "Command-line flags (all in the form --flag=value):\n");
138*bb86c7edSAndroid Build Coastguard Worker     fprintf(stderr, "%-20s %-12s %-16s %s\n", "flag", "format", "default",
139*bb86c7edSAndroid Build Coastguard Worker             "allowed");
140*bb86c7edSAndroid Build Coastguard Worker   }
141*bb86c7edSAndroid Build Coastguard Worker   Params params;
142*bb86c7edSAndroid Build Coastguard Worker   const char* allowed_types =
143*bb86c7edSAndroid Build Coastguard Worker       "f32xf32->f32, i8xi8->i8, i8xi8->i16, i8xi8->i32, u8xu8->i16, u8xi8->u8, "
144*bb86c7edSAndroid Build Coastguard Worker       "i8xi16->i16, i16xi8->i16";
145*bb86c7edSAndroid Build Coastguard Worker   const char* allowed_orders = "row-major, column-major";
146*bb86c7edSAndroid Build Coastguard Worker   read_cmdline_args(help, argc, argv, "--types", "%s", "f32xf32->f32",
147*bb86c7edSAndroid Build Coastguard Worker                     allowed_types, &params.types);
148*bb86c7edSAndroid Build Coastguard Worker   read_cmdline_args(help, argc, argv, "--shape", "%dx%dx%d", "100x100x100",
149*bb86c7edSAndroid Build Coastguard Worker                     nullptr, &params.m, &params.k, &params.n);
150*bb86c7edSAndroid Build Coastguard Worker   read_cmdline_args(help, argc, argv, "--paths", "%x", "0", nullptr,
151*bb86c7edSAndroid Build Coastguard Worker                     &params.paths);
152*bb86c7edSAndroid Build Coastguard Worker   read_cmdline_args(help, argc, argv, "--num_threads", "%d", "1", nullptr,
153*bb86c7edSAndroid Build Coastguard Worker                     &params.num_threads);
154*bb86c7edSAndroid Build Coastguard Worker   read_cmdline_args(help, argc, argv, "--repeat", "%d", "1", nullptr,
155*bb86c7edSAndroid Build Coastguard Worker                     &params.repeat);
156*bb86c7edSAndroid Build Coastguard Worker   read_cmdline_args(help, argc, argv, "--lhs_cache_policy", "%d", "0",
157*bb86c7edSAndroid Build Coastguard Worker                     "0, 1, 2, 3", &params.lhs_cache_policy);
158*bb86c7edSAndroid Build Coastguard Worker   read_cmdline_args(help, argc, argv, "--rhs_cache_policy", "%d", "0",
159*bb86c7edSAndroid Build Coastguard Worker                     "0, 1, 2, 3", &params.rhs_cache_policy);
160*bb86c7edSAndroid Build Coastguard Worker   read_cmdline_args(help, argc, argv, "--lhs_stride", "%d", "0", nullptr,
161*bb86c7edSAndroid Build Coastguard Worker                     &params.lhs_stride);
162*bb86c7edSAndroid Build Coastguard Worker   read_cmdline_args(help, argc, argv, "--rhs_stride", "%d", "0", nullptr,
163*bb86c7edSAndroid Build Coastguard Worker                     &params.rhs_stride);
164*bb86c7edSAndroid Build Coastguard Worker   read_cmdline_args(help, argc, argv, "--dst_stride", "%d", "0", nullptr,
165*bb86c7edSAndroid Build Coastguard Worker                     &params.dst_stride);
166*bb86c7edSAndroid Build Coastguard Worker   read_cmdline_args(help, argc, argv, "--lhs_zero_point", "%d", "0", nullptr,
167*bb86c7edSAndroid Build Coastguard Worker                     &params.lhs_zero_point);
168*bb86c7edSAndroid Build Coastguard Worker   read_cmdline_args(help, argc, argv, "--rhs_zero_point", "%d", "0", nullptr,
169*bb86c7edSAndroid Build Coastguard Worker                     &params.rhs_zero_point);
170*bb86c7edSAndroid Build Coastguard Worker   read_cmdline_args(help, argc, argv, "--dst_zero_point", "%d", "0", nullptr,
171*bb86c7edSAndroid Build Coastguard Worker                     &params.dst_zero_point);
172*bb86c7edSAndroid Build Coastguard Worker   read_cmdline_args(help, argc, argv, "--lhs_order", "%s", "row-major",
173*bb86c7edSAndroid Build Coastguard Worker                     allowed_orders, &params.lhs_order);
174*bb86c7edSAndroid Build Coastguard Worker   read_cmdline_args(help, argc, argv, "--rhs_order", "%s", "row-major",
175*bb86c7edSAndroid Build Coastguard Worker                     allowed_orders, &params.rhs_order);
176*bb86c7edSAndroid Build Coastguard Worker   read_cmdline_args(help, argc, argv, "--dst_order", "%s", "row-major",
177*bb86c7edSAndroid Build Coastguard Worker                     allowed_orders, &params.dst_order);
178*bb86c7edSAndroid Build Coastguard Worker 
179*bb86c7edSAndroid Build Coastguard Worker   if (help) {
180*bb86c7edSAndroid Build Coastguard Worker     exit(1);
181*bb86c7edSAndroid Build Coastguard Worker   }
182*bb86c7edSAndroid Build Coastguard Worker 
183*bb86c7edSAndroid Build Coastguard Worker   if (!strcmp(params.types, "f32xf32->f32")) {
184*bb86c7edSAndroid Build Coastguard Worker     run<float, float, float>(params);
185*bb86c7edSAndroid Build Coastguard Worker   } else if (!strcmp(params.types, "i8xi8->i8")) {
186*bb86c7edSAndroid Build Coastguard Worker     run<std::int8_t, std::int8_t, std::int8_t>(params);
187*bb86c7edSAndroid Build Coastguard Worker   } else if (!strcmp(params.types, "i8xi8->i16")) {
188*bb86c7edSAndroid Build Coastguard Worker     run<std::int8_t, std::int8_t, std::int16_t>(params);
189*bb86c7edSAndroid Build Coastguard Worker   } else if (!strcmp(params.types, "i8xi8->i32")) {
190*bb86c7edSAndroid Build Coastguard Worker     run<std::int8_t, std::int8_t, std::int32_t>(params);
191*bb86c7edSAndroid Build Coastguard Worker   } else if (!strcmp(params.types, "u8xu8->i16")) {
192*bb86c7edSAndroid Build Coastguard Worker     run<std::uint8_t, std::uint8_t, std::int16_t>(params);
193*bb86c7edSAndroid Build Coastguard Worker   } else if (!strcmp(params.types, "u8xi8->u8")) {
194*bb86c7edSAndroid Build Coastguard Worker     run<std::uint8_t, std::int8_t, std::uint8_t>(params);
195*bb86c7edSAndroid Build Coastguard Worker   } else if (!strcmp(params.types, "i8xi16->i16")) {
196*bb86c7edSAndroid Build Coastguard Worker     run<std::int8_t, std::int16_t, std::int16_t>(params);
197*bb86c7edSAndroid Build Coastguard Worker   } else if (!strcmp(params.types, "i16xi8->i16")) {
198*bb86c7edSAndroid Build Coastguard Worker     run<std::int16_t, std::int8_t, std::int16_t>(params);
199*bb86c7edSAndroid Build Coastguard Worker   } else {
200*bb86c7edSAndroid Build Coastguard Worker     fprintf(stderr, "Unknown types: %s\n", params.types);
201*bb86c7edSAndroid Build Coastguard Worker     exit(1);
202*bb86c7edSAndroid Build Coastguard Worker   }
203*bb86c7edSAndroid Build Coastguard Worker }
204