xref: /aosp_15_r20/external/mesa3d/src/intel/compiler/brw_fs_lower_dpas.cpp (revision 6104692788411f58d303aa86923a9ff6ecaded22)
1 /*
2  * Copyright 2023 Intel Corporation
3  * SPDX-License-Identifier: MIT
4  */
5 
6 #include "brw_fs.h"
7 #include "brw_fs_builder.h"
8 
9 using namespace brw;
10 
11 static void
f16_using_mac(const fs_builder & bld,fs_inst * inst)12 f16_using_mac(const fs_builder &bld, fs_inst *inst)
13 {
14    /* We only intend to support configurations where the destination and
15     * accumulator have the same type.
16     */
17    if (!inst->src[0].is_null())
18       assert(inst->dst.type == inst->src[0].type);
19 
20    assert(inst->src[1].type == BRW_TYPE_HF);
21    assert(inst->src[2].type == BRW_TYPE_HF);
22 
23    const brw_reg_type src0_type = inst->dst.type;
24    const brw_reg_type src1_type = BRW_TYPE_HF;
25    const brw_reg_type src2_type = BRW_TYPE_HF;
26 
27    const brw_reg dest = inst->dst;
28    brw_reg src0 = inst->src[0];
29    const brw_reg src1 = retype(inst->src[1], src1_type);
30    const brw_reg src2 = retype(inst->src[2], src2_type);
31 
32    const unsigned dest_stride =
33       dest.type == BRW_TYPE_HF ? REG_SIZE / 2 : REG_SIZE;
34 
35    for (unsigned r = 0; r < inst->rcount; r++) {
36       brw_reg temp = bld.vgrf(BRW_TYPE_HF);
37 
38       for (unsigned subword = 0; subword < 2; subword++) {
39          for (unsigned s = 0; s < inst->sdepth; s++) {
40             /* The first multiply of the dot-product operation has to
41              * explicitly write the accumulator register. The successive MAC
42              * instructions will implicitly read *and* write the
43              * accumulator. Those MAC instructions can also optionally
44              * explicitly write some other register.
45              *
46              * FINISHME: The accumulator can actually hold 16 HF values. On
47              * Gfx12 there are two accumulators. It should be possible to do
48              * this in SIMD16 or even SIMD32. I was unable to get this to work
49              * properly.
50              */
51             if (s == 0 && subword == 0) {
52                const unsigned acc_width = 8;
53                brw_reg acc = suboffset(retype(brw_acc_reg(inst->exec_size), BRW_TYPE_UD),
54                                       inst->group % acc_width);
55 
56                if (bld.shader->devinfo->verx10 >= 125) {
57                   acc = subscript(acc, BRW_TYPE_HF, subword);
58                } else {
59                   acc = retype(acc, BRW_TYPE_HF);
60                }
61 
62                bld.MUL(acc,
63                        subscript(retype(byte_offset(src1, s * REG_SIZE),
64                                         BRW_TYPE_UD),
65                                  BRW_TYPE_HF, subword),
66                        component(retype(byte_offset(src2, r * REG_SIZE),
67                                         BRW_TYPE_HF),
68                                  s * 2 + subword))
69                   ->writes_accumulator = true;
70 
71             } else {
72                brw_reg result;
73 
74                /* As mentioned above, the MAC had an optional, explicit
75                 * destination register. Various optimization passes are not
76                 * clever enough to understand the intricacies of this
77                 * instruction, so only write the result register on the final
78                 * MAC in the sequence.
79                 */
80                if ((s + 1) == inst->sdepth && subword == 1)
81                   result = temp;
82                else
83                   result = retype(bld.null_reg_ud(), BRW_TYPE_HF);
84 
85                bld.MAC(result,
86                        subscript(retype(byte_offset(src1, s * REG_SIZE),
87                                         BRW_TYPE_UD),
88                                  BRW_TYPE_HF, subword),
89                        component(retype(byte_offset(src2, r * REG_SIZE),
90                                         BRW_TYPE_HF),
91                                  s * 2 + subword))
92                   ->writes_accumulator = true;
93             }
94          }
95       }
96 
97       if (!src0.is_null()) {
98          if (src0_type != BRW_TYPE_HF) {
99             brw_reg temp2 = bld.vgrf(src0_type);
100 
101             bld.MOV(temp2, temp);
102 
103             bld.ADD(byte_offset(dest, r * dest_stride),
104                     temp2,
105                     byte_offset(src0, r * dest_stride));
106          } else {
107             bld.ADD(byte_offset(dest, r * dest_stride),
108                     temp,
109                     byte_offset(src0, r * dest_stride));
110          }
111       } else {
112          bld.MOV(byte_offset(dest, r * dest_stride), temp);
113       }
114    }
115 }
116 
117 static void
int8_using_dp4a(const fs_builder & bld,fs_inst * inst)118 int8_using_dp4a(const fs_builder &bld, fs_inst *inst)
119 {
120    /* We only intend to support configurations where the destination and
121     * accumulator have the same type.
122     */
123    if (!inst->src[0].is_null())
124       assert(inst->dst.type == inst->src[0].type);
125 
126    assert(inst->src[1].type == BRW_TYPE_B ||
127           inst->src[1].type == BRW_TYPE_UB);
128    assert(inst->src[2].type == BRW_TYPE_B ||
129           inst->src[2].type == BRW_TYPE_UB);
130 
131    const brw_reg_type src1_type = inst->src[1].type == BRW_TYPE_UB
132       ? BRW_TYPE_UD : BRW_TYPE_D;
133 
134    const brw_reg_type src2_type = inst->src[2].type == BRW_TYPE_UB
135       ? BRW_TYPE_UD : BRW_TYPE_D;
136 
137    brw_reg dest = inst->dst;
138    brw_reg src0 = inst->src[0];
139    const brw_reg src1 = retype(inst->src[1], src1_type);
140    const brw_reg src2 = retype(inst->src[2], src2_type);
141 
142    const unsigned dest_stride = reg_unit(bld.shader->devinfo) * REG_SIZE;
143 
144    for (unsigned r = 0; r < inst->rcount; r++) {
145       if (!src0.is_null()) {
146          bld.MOV(dest, src0);
147          src0 = byte_offset(src0, dest_stride);
148       } else {
149          bld.MOV(dest, retype(brw_imm_d(0), dest.type));
150       }
151 
152       for (unsigned s = 0; s < inst->sdepth; s++) {
153          bld.DP4A(dest,
154                   dest,
155                   byte_offset(src1, s * inst->exec_size * 4),
156                   component(byte_offset(src2, r * inst->sdepth * 4), s))
157             ->saturate = inst->saturate;
158       }
159 
160       dest = byte_offset(dest, dest_stride);
161    }
162 }
163 
164 static void
int8_using_mul_add(const fs_builder & bld,fs_inst * inst)165 int8_using_mul_add(const fs_builder &bld, fs_inst *inst)
166 {
167    /* We only intend to support configurations where the destination and
168     * accumulator have the same type.
169     */
170    if (!inst->src[0].is_null())
171       assert(inst->dst.type == inst->src[0].type);
172 
173    assert(inst->src[1].type == BRW_TYPE_B ||
174           inst->src[1].type == BRW_TYPE_UB);
175    assert(inst->src[2].type == BRW_TYPE_B ||
176           inst->src[2].type == BRW_TYPE_UB);
177 
178    const brw_reg_type src0_type = inst->dst.type;
179 
180    const brw_reg_type src1_type = inst->src[1].type == BRW_TYPE_UB
181       ? BRW_TYPE_UD : BRW_TYPE_D;
182 
183    const brw_reg_type src2_type = inst->src[2].type == BRW_TYPE_UB
184       ? BRW_TYPE_UD : BRW_TYPE_D;
185 
186    brw_reg dest = inst->dst;
187    brw_reg src0 = inst->src[0];
188    const brw_reg src1 = retype(inst->src[1], src1_type);
189    const brw_reg src2 = retype(inst->src[2], src2_type);
190 
191    const unsigned dest_stride = REG_SIZE;
192 
193    for (unsigned r = 0; r < inst->rcount; r++) {
194       if (!src0.is_null()) {
195          bld.MOV(dest, src0);
196          src0 = byte_offset(src0, dest_stride);
197       } else {
198          bld.MOV(dest, retype(brw_imm_d(0), dest.type));
199       }
200 
201       for (unsigned s = 0; s < inst->sdepth; s++) {
202          brw_reg temp1 = bld.vgrf(BRW_TYPE_UD);
203          brw_reg temp2 = bld.vgrf(BRW_TYPE_UD);
204          brw_reg temp3 = bld.vgrf(BRW_TYPE_UD, 2);
205          const brw_reg_type temp_type =
206             (inst->src[1].type == BRW_TYPE_B ||
207              inst->src[2].type == BRW_TYPE_B)
208             ? BRW_TYPE_W : BRW_TYPE_UW;
209 
210          /* Expand 8 dwords of packed bytes into 16 dwords of packed
211           * words.
212           *
213           * FINISHME: Gfx9 should not need this work around. Gfx11
214           * may be able to use integer MAD. Both platforms may be
215           * able to use MAC.
216           */
217          bld.group(32, 0).MOV(retype(temp3, temp_type),
218                               retype(byte_offset(src2, r * REG_SIZE),
219                                      inst->src[2].type));
220 
221          bld.MUL(subscript(temp1, temp_type, 0),
222                  subscript(retype(byte_offset(src1, s * REG_SIZE),
223                                   BRW_TYPE_UD),
224                            inst->src[1].type, 0),
225                  subscript(component(retype(temp3, BRW_TYPE_UD),
226                                      s * 2),
227                            temp_type, 0));
228 
229          bld.MUL(subscript(temp1, temp_type, 1),
230                  subscript(retype(byte_offset(src1, s * REG_SIZE),
231                                   BRW_TYPE_UD),
232                            inst->src[1].type, 1),
233                  subscript(component(retype(temp3, BRW_TYPE_UD),
234                                      s * 2),
235                            temp_type, 1));
236 
237          bld.MUL(subscript(temp2, temp_type, 0),
238                  subscript(retype(byte_offset(src1, s * REG_SIZE),
239                                   BRW_TYPE_UD),
240                            inst->src[1].type, 2),
241                  subscript(component(retype(temp3, BRW_TYPE_UD),
242                                      s * 2 + 1),
243                            temp_type, 0));
244 
245          bld.MUL(subscript(temp2, temp_type, 1),
246                  subscript(retype(byte_offset(src1, s * REG_SIZE),
247                                   BRW_TYPE_UD),
248                            inst->src[1].type, 3),
249                  subscript(component(retype(temp3, BRW_TYPE_UD),
250                                      s * 2 + 1),
251                            temp_type, 1));
252 
253          bld.ADD(subscript(temp1, src0_type, 0),
254                  subscript(temp1, temp_type, 0),
255                  subscript(temp1, temp_type, 1));
256 
257          bld.ADD(subscript(temp2, src0_type, 0),
258                  subscript(temp2, temp_type, 0),
259                  subscript(temp2, temp_type, 1));
260 
261          bld.ADD(retype(temp1, src0_type),
262                  retype(temp1, src0_type),
263                  retype(temp2, src0_type));
264 
265          bld.ADD(dest, dest, retype(temp1, src0_type))
266             ->saturate = inst->saturate;
267       }
268 
269       dest = byte_offset(dest, dest_stride);
270    }
271 }
272 
273 bool
brw_fs_lower_dpas(fs_visitor & v)274 brw_fs_lower_dpas(fs_visitor &v)
275 {
276    bool progress = false;
277 
278    foreach_block_and_inst_safe(block, fs_inst, inst, v.cfg) {
279       if (inst->opcode != BRW_OPCODE_DPAS)
280          continue;
281 
282       const unsigned exec_size = v.devinfo->ver >= 20 ? 16 : 8;
283       const fs_builder bld = fs_builder(&v, block, inst).group(exec_size, 0).exec_all();
284 
285       if (brw_type_is_float(inst->dst.type)) {
286          f16_using_mac(bld, inst);
287       } else {
288          if (v.devinfo->ver >= 12) {
289             int8_using_dp4a(bld, inst);
290          } else {
291             int8_using_mul_add(bld, inst);
292          }
293       }
294 
295       inst->remove(block);
296       progress = true;
297    }
298 
299    if (progress)
300       v.invalidate_analysis(DEPENDENCY_INSTRUCTIONS);
301 
302    return progress;
303 }
304