xref: /aosp_15_r20/external/ruy/example/example.cc (revision bb86c7ed5fb1b98a7eac808e443a46cc8b90dfc0)
1*bb86c7edSAndroid Build Coastguard Worker /* Copyright 2019 Google LLC. All Rights Reserved.
2*bb86c7edSAndroid Build Coastguard Worker 
3*bb86c7edSAndroid Build Coastguard Worker Licensed under the Apache License, Version 2.0 (the "License");
4*bb86c7edSAndroid Build Coastguard Worker you may not use this file except in compliance with the License.
5*bb86c7edSAndroid Build Coastguard Worker You may obtain a copy of the License at
6*bb86c7edSAndroid Build Coastguard Worker 
7*bb86c7edSAndroid Build Coastguard Worker     http://www.apache.org/licenses/LICENSE-2.0
8*bb86c7edSAndroid Build Coastguard Worker 
9*bb86c7edSAndroid Build Coastguard Worker Unless required by applicable law or agreed to in writing, software
10*bb86c7edSAndroid Build Coastguard Worker distributed under the License is distributed on an "AS IS" BASIS,
11*bb86c7edSAndroid Build Coastguard Worker WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12*bb86c7edSAndroid Build Coastguard Worker See the License for the specific language governing permissions and
13*bb86c7edSAndroid Build Coastguard Worker limitations under the License.
14*bb86c7edSAndroid Build Coastguard Worker ==============================================================================*/
15*bb86c7edSAndroid Build Coastguard Worker 
16*bb86c7edSAndroid Build Coastguard Worker #include <cstdint>
17*bb86c7edSAndroid Build Coastguard Worker #include <iostream>
18*bb86c7edSAndroid Build Coastguard Worker 
19*bb86c7edSAndroid Build Coastguard Worker #include "ruy/ruy.h"
20*bb86c7edSAndroid Build Coastguard Worker 
ExampleMulFloat(ruy::Context * context)21*bb86c7edSAndroid Build Coastguard Worker void ExampleMulFloat(ruy::Context *context) {
22*bb86c7edSAndroid Build Coastguard Worker   const float lhs_data[] = {1, 2, 3, 4};
23*bb86c7edSAndroid Build Coastguard Worker   const float rhs_data[] = {1, 2, 3, 4};
24*bb86c7edSAndroid Build Coastguard Worker   float dst_data[4];
25*bb86c7edSAndroid Build Coastguard Worker 
26*bb86c7edSAndroid Build Coastguard Worker   ruy::Matrix<float> lhs;
27*bb86c7edSAndroid Build Coastguard Worker   ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
28*bb86c7edSAndroid Build Coastguard Worker   lhs.set_data(lhs_data);
29*bb86c7edSAndroid Build Coastguard Worker   ruy::Matrix<float> rhs;
30*bb86c7edSAndroid Build Coastguard Worker   ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout());
31*bb86c7edSAndroid Build Coastguard Worker   rhs.set_data(rhs_data);
32*bb86c7edSAndroid Build Coastguard Worker   ruy::Matrix<float> dst;
33*bb86c7edSAndroid Build Coastguard Worker   ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout());
34*bb86c7edSAndroid Build Coastguard Worker   dst.set_data(dst_data);
35*bb86c7edSAndroid Build Coastguard Worker 
36*bb86c7edSAndroid Build Coastguard Worker   ruy::MulParams<float, float> mul_params;
37*bb86c7edSAndroid Build Coastguard Worker   ruy::Mul(lhs, rhs, mul_params, context, &dst);
38*bb86c7edSAndroid Build Coastguard Worker 
39*bb86c7edSAndroid Build Coastguard Worker   std::cout << "Example Mul, float:\n";
40*bb86c7edSAndroid Build Coastguard Worker   std::cout << "LHS:\n" << lhs;
41*bb86c7edSAndroid Build Coastguard Worker   std::cout << "RHS:\n" << rhs;
42*bb86c7edSAndroid Build Coastguard Worker   std::cout << "Result:\n" << dst << "\n";
43*bb86c7edSAndroid Build Coastguard Worker }
44*bb86c7edSAndroid Build Coastguard Worker 
ExampleMulFloatWithBiasAddAndClamp(ruy::Context * context)45*bb86c7edSAndroid Build Coastguard Worker void ExampleMulFloatWithBiasAddAndClamp(ruy::Context *context) {
46*bb86c7edSAndroid Build Coastguard Worker   const float lhs_data[] = {1, 2, 3, 4};
47*bb86c7edSAndroid Build Coastguard Worker   const float rhs_data[] = {1, 2, 3, 4};
48*bb86c7edSAndroid Build Coastguard Worker   const float bias_data[] = {1, 0};
49*bb86c7edSAndroid Build Coastguard Worker   float dst_data[4];
50*bb86c7edSAndroid Build Coastguard Worker 
51*bb86c7edSAndroid Build Coastguard Worker   ruy::Matrix<float> lhs;
52*bb86c7edSAndroid Build Coastguard Worker   ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
53*bb86c7edSAndroid Build Coastguard Worker   lhs.set_data(lhs_data);
54*bb86c7edSAndroid Build Coastguard Worker   ruy::Matrix<float> rhs;
55*bb86c7edSAndroid Build Coastguard Worker   ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout());
56*bb86c7edSAndroid Build Coastguard Worker   rhs.set_data(rhs_data);
57*bb86c7edSAndroid Build Coastguard Worker   ruy::Matrix<float> dst;
58*bb86c7edSAndroid Build Coastguard Worker   ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout());
59*bb86c7edSAndroid Build Coastguard Worker   dst.set_data(dst_data);
60*bb86c7edSAndroid Build Coastguard Worker 
61*bb86c7edSAndroid Build Coastguard Worker   ruy::MulParams<float, float> mul_params;
62*bb86c7edSAndroid Build Coastguard Worker   mul_params.set_bias(bias_data);
63*bb86c7edSAndroid Build Coastguard Worker   mul_params.set_clamp_min(0);
64*bb86c7edSAndroid Build Coastguard Worker   mul_params.set_clamp_max(15);
65*bb86c7edSAndroid Build Coastguard Worker   ruy::Mul(lhs, rhs, mul_params, context, &dst);
66*bb86c7edSAndroid Build Coastguard Worker 
67*bb86c7edSAndroid Build Coastguard Worker   std::cout << "Example Mul, float with bias addition and clamp:\n";
68*bb86c7edSAndroid Build Coastguard Worker   std::cout << "LHS:\n" << lhs;
69*bb86c7edSAndroid Build Coastguard Worker   std::cout << "RHS:\n" << rhs;
70*bb86c7edSAndroid Build Coastguard Worker   std::cout << "Result:\n" << dst << "\n";
71*bb86c7edSAndroid Build Coastguard Worker }
72*bb86c7edSAndroid Build Coastguard Worker 
ExampleMulUint8AsymmetricQuantized(ruy::Context * context)73*bb86c7edSAndroid Build Coastguard Worker void ExampleMulUint8AsymmetricQuantized(ruy::Context *context) {
74*bb86c7edSAndroid Build Coastguard Worker   const std::uint8_t lhs_data[] = {124, 125, 126, 127};
75*bb86c7edSAndroid Build Coastguard Worker   const std::uint8_t rhs_data[] = {129, 130, 131, 132};
76*bb86c7edSAndroid Build Coastguard Worker   std::uint8_t dst_data[4];
77*bb86c7edSAndroid Build Coastguard Worker 
78*bb86c7edSAndroid Build Coastguard Worker   ruy::Matrix<std::uint8_t> lhs;
79*bb86c7edSAndroid Build Coastguard Worker   ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
80*bb86c7edSAndroid Build Coastguard Worker   lhs.set_data(lhs_data);
81*bb86c7edSAndroid Build Coastguard Worker   lhs.set_zero_point(125);
82*bb86c7edSAndroid Build Coastguard Worker   ruy::Matrix<std::uint8_t> rhs;
83*bb86c7edSAndroid Build Coastguard Worker   ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout());
84*bb86c7edSAndroid Build Coastguard Worker   rhs.set_data(rhs_data);
85*bb86c7edSAndroid Build Coastguard Worker   rhs.set_zero_point(132);
86*bb86c7edSAndroid Build Coastguard Worker   ruy::Matrix<std::uint8_t> dst;
87*bb86c7edSAndroid Build Coastguard Worker   ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout());
88*bb86c7edSAndroid Build Coastguard Worker   dst.set_data(dst_data);
89*bb86c7edSAndroid Build Coastguard Worker   dst.set_zero_point(129);
90*bb86c7edSAndroid Build Coastguard Worker 
91*bb86c7edSAndroid Build Coastguard Worker   ruy::MulParams<std::int32_t, std::uint8_t> mul_params;
92*bb86c7edSAndroid Build Coastguard Worker   mul_params.set_multiplier_fixedpoint(1 << 30);
93*bb86c7edSAndroid Build Coastguard Worker 
94*bb86c7edSAndroid Build Coastguard Worker   mul_params.set_multiplier_exponent(0);
95*bb86c7edSAndroid Build Coastguard Worker   ruy::Mul(lhs, rhs, mul_params, context, &dst);
96*bb86c7edSAndroid Build Coastguard Worker 
97*bb86c7edSAndroid Build Coastguard Worker   std::cout << "Example Mul, uint8 quantized with asymmetric zero points:\n";
98*bb86c7edSAndroid Build Coastguard Worker   std::cout << "LHS:\n" << lhs;
99*bb86c7edSAndroid Build Coastguard Worker   std::cout << "RHS:\n" << rhs;
100*bb86c7edSAndroid Build Coastguard Worker   std::cout << "Result:\n" << dst << "\n";
101*bb86c7edSAndroid Build Coastguard Worker }
ExampleMulInt8PerChannelQuantized(ruy::Context * context)102*bb86c7edSAndroid Build Coastguard Worker void ExampleMulInt8PerChannelQuantized(ruy::Context *context) {
103*bb86c7edSAndroid Build Coastguard Worker   const std::int8_t lhs_data[] = {1, 2, 3, 4};
104*bb86c7edSAndroid Build Coastguard Worker   const std::int8_t rhs_data[] = {1, 2, 3, 4};
105*bb86c7edSAndroid Build Coastguard Worker   const std::int32_t multiplier_data[] = {3 << 28, 5 << 28};
106*bb86c7edSAndroid Build Coastguard Worker   const int exponent_data[] = {1, -2};
107*bb86c7edSAndroid Build Coastguard Worker   std::int8_t dst_data[4];
108*bb86c7edSAndroid Build Coastguard Worker 
109*bb86c7edSAndroid Build Coastguard Worker   ruy::Matrix<std::int8_t> lhs;
110*bb86c7edSAndroid Build Coastguard Worker   ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
111*bb86c7edSAndroid Build Coastguard Worker   lhs.set_data(lhs_data);
112*bb86c7edSAndroid Build Coastguard Worker   ruy::Matrix<std::int8_t> rhs;
113*bb86c7edSAndroid Build Coastguard Worker   ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout());
114*bb86c7edSAndroid Build Coastguard Worker   rhs.set_data(rhs_data);
115*bb86c7edSAndroid Build Coastguard Worker   ruy::Matrix<std::int8_t> dst;
116*bb86c7edSAndroid Build Coastguard Worker   ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout());
117*bb86c7edSAndroid Build Coastguard Worker   dst.set_data(dst_data);
118*bb86c7edSAndroid Build Coastguard Worker 
119*bb86c7edSAndroid Build Coastguard Worker   ruy::MulParams<std::int32_t, std::int8_t> mul_params;
120*bb86c7edSAndroid Build Coastguard Worker   mul_params.set_multiplier_fixedpoint_perchannel(multiplier_data);
121*bb86c7edSAndroid Build Coastguard Worker   mul_params.set_multiplier_exponent_perchannel(exponent_data);
122*bb86c7edSAndroid Build Coastguard Worker   ruy::Mul(lhs, rhs, mul_params, context, &dst);
123*bb86c7edSAndroid Build Coastguard Worker 
124*bb86c7edSAndroid Build Coastguard Worker   std::cout << "Example Mul, int8 quantized with per-channel multipliers\n";
125*bb86c7edSAndroid Build Coastguard Worker   std::cout << "LHS:\n" << lhs;
126*bb86c7edSAndroid Build Coastguard Worker   std::cout << "RHS:\n" << rhs;
127*bb86c7edSAndroid Build Coastguard Worker   std::cout << "Result:\n" << dst << "\n";
128*bb86c7edSAndroid Build Coastguard Worker }
129*bb86c7edSAndroid Build Coastguard Worker 
ExampleMulInt8GetRawAccumulators(ruy::Context * context)130*bb86c7edSAndroid Build Coastguard Worker void ExampleMulInt8GetRawAccumulators(ruy::Context *context) {
131*bb86c7edSAndroid Build Coastguard Worker   const std::int8_t lhs_data[] = {1, 2, 3, 4};
132*bb86c7edSAndroid Build Coastguard Worker   const std::int8_t rhs_data[] = {1, 2, 3, 4};
133*bb86c7edSAndroid Build Coastguard Worker   std::int32_t dst_data[4];
134*bb86c7edSAndroid Build Coastguard Worker 
135*bb86c7edSAndroid Build Coastguard Worker   ruy::Matrix<std::int8_t> lhs;
136*bb86c7edSAndroid Build Coastguard Worker   ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
137*bb86c7edSAndroid Build Coastguard Worker   lhs.set_data(lhs_data);
138*bb86c7edSAndroid Build Coastguard Worker   ruy::Matrix<std::int8_t> rhs;
139*bb86c7edSAndroid Build Coastguard Worker   ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout());
140*bb86c7edSAndroid Build Coastguard Worker   rhs.set_data(rhs_data);
141*bb86c7edSAndroid Build Coastguard Worker   ruy::Matrix<std::int32_t> dst;
142*bb86c7edSAndroid Build Coastguard Worker   ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout());
143*bb86c7edSAndroid Build Coastguard Worker   dst.set_data(dst_data);
144*bb86c7edSAndroid Build Coastguard Worker 
145*bb86c7edSAndroid Build Coastguard Worker   // When Dst is int32, mul_params is unused.
146*bb86c7edSAndroid Build Coastguard Worker   ruy::MulParams<std::int32_t, std::int32_t> mul_params;
147*bb86c7edSAndroid Build Coastguard Worker   ruy::Mul(lhs, rhs, mul_params, context, &dst);
148*bb86c7edSAndroid Build Coastguard Worker 
149*bb86c7edSAndroid Build Coastguard Worker   std::cout << "Example Mul, returning raw int32 accumulators:\n";
150*bb86c7edSAndroid Build Coastguard Worker   std::cout << "LHS:\n" << lhs;
151*bb86c7edSAndroid Build Coastguard Worker   std::cout << "RHS:\n" << rhs;
152*bb86c7edSAndroid Build Coastguard Worker   std::cout << "Result:\n" << dst << "\n";
153*bb86c7edSAndroid Build Coastguard Worker }
154*bb86c7edSAndroid Build Coastguard Worker 
ExampleMulInt8TimesInt16PerChannelQuantized(ruy::Context * context)155*bb86c7edSAndroid Build Coastguard Worker void ExampleMulInt8TimesInt16PerChannelQuantized(ruy::Context *context) {
156*bb86c7edSAndroid Build Coastguard Worker   const std::int8_t lhs_data[] = {1, 2, 3, 4};
157*bb86c7edSAndroid Build Coastguard Worker   const std::int16_t rhs_data[] = {1000, 2000, 3000, 4000};
158*bb86c7edSAndroid Build Coastguard Worker   const std::int32_t multiplier_data[] = {3 << 28, 5 << 28};
159*bb86c7edSAndroid Build Coastguard Worker   const int exponent_data[] = {1, -2};
160*bb86c7edSAndroid Build Coastguard Worker   std::int16_t dst_data[4];
161*bb86c7edSAndroid Build Coastguard Worker 
162*bb86c7edSAndroid Build Coastguard Worker   ruy::Matrix<std::int8_t> lhs;
163*bb86c7edSAndroid Build Coastguard Worker   ruy::MakeSimpleLayout(2, 2, ruy::Order::kRowMajor, lhs.mutable_layout());
164*bb86c7edSAndroid Build Coastguard Worker   lhs.set_data(lhs_data);
165*bb86c7edSAndroid Build Coastguard Worker   ruy::Matrix<std::int16_t> rhs;
166*bb86c7edSAndroid Build Coastguard Worker   ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, rhs.mutable_layout());
167*bb86c7edSAndroid Build Coastguard Worker   rhs.set_data(rhs_data);
168*bb86c7edSAndroid Build Coastguard Worker   ruy::Matrix<std::int16_t> dst;
169*bb86c7edSAndroid Build Coastguard Worker   ruy::MakeSimpleLayout(2, 2, ruy::Order::kColMajor, dst.mutable_layout());
170*bb86c7edSAndroid Build Coastguard Worker   dst.set_data(dst_data);
171*bb86c7edSAndroid Build Coastguard Worker 
172*bb86c7edSAndroid Build Coastguard Worker   ruy::MulParams<std::int32_t, std::int16_t> mul_params;
173*bb86c7edSAndroid Build Coastguard Worker   mul_params.set_multiplier_fixedpoint_perchannel(multiplier_data);
174*bb86c7edSAndroid Build Coastguard Worker   mul_params.set_multiplier_exponent_perchannel(exponent_data);
175*bb86c7edSAndroid Build Coastguard Worker   ruy::Mul(lhs, rhs, mul_params, context, &dst);
176*bb86c7edSAndroid Build Coastguard Worker 
177*bb86c7edSAndroid Build Coastguard Worker   std::cout << "Example Mul, int8 times int16 quantized with per-channel "
178*bb86c7edSAndroid Build Coastguard Worker                "multipliers\n";
179*bb86c7edSAndroid Build Coastguard Worker   std::cout << "LHS:\n" << lhs;
180*bb86c7edSAndroid Build Coastguard Worker   std::cout << "RHS:\n" << rhs;
181*bb86c7edSAndroid Build Coastguard Worker   std::cout << "Result:\n" << dst << "\n";
182*bb86c7edSAndroid Build Coastguard Worker }
183*bb86c7edSAndroid Build Coastguard Worker 
main()184*bb86c7edSAndroid Build Coastguard Worker int main() {
185*bb86c7edSAndroid Build Coastguard Worker   ruy::Context context;
186*bb86c7edSAndroid Build Coastguard Worker   ExampleMulFloat(&context);
187*bb86c7edSAndroid Build Coastguard Worker   ExampleMulFloatWithBiasAddAndClamp(&context);
188*bb86c7edSAndroid Build Coastguard Worker   ExampleMulUint8AsymmetricQuantized(&context);
189*bb86c7edSAndroid Build Coastguard Worker   ExampleMulInt8PerChannelQuantized(&context);
190*bb86c7edSAndroid Build Coastguard Worker   ExampleMulInt8GetRawAccumulators(&context);
191*bb86c7edSAndroid Build Coastguard Worker   ExampleMulInt8TimesInt16PerChannelQuantized(&context);
192*bb86c7edSAndroid Build Coastguard Worker }
193