xref: /XiangShan/src/main/scala/xiangshan/backend/fu/Alu.scala (revision 4aa9ed342654d307178fb17faf8226c0d6136b80)
1/***************************************************************************************
2* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3* Copyright (c) 2020-2021 Peng Cheng Laboratory
4*
5* XiangShan is licensed under Mulan PSL v2.
6* You can use this software according to the terms and conditions of the Mulan PSL v2.
7* You may obtain a copy of Mulan PSL v2 at:
8*          http://license.coscl.org.cn/MulanPSL2
9*
10* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13*
14* See the Mulan PSL v2 for more details.
15***************************************************************************************/
16
17package xiangshan.backend.fu
18
19import chipsalliance.rocketchip.config.Parameters
20import chisel3._
21import chisel3.util._
22import utility.{LookupTree, LookupTreeDefault, ParallelMux, SignExt, ZeroExt}
23import xiangshan._
24
25class VsetModule(implicit p: Parameters) extends XSModule {
26  val io = IO(new Bundle() {
27    val lsrc0 = Input(UInt(6.W))
28    val ldest = Input(UInt(6.W))
29    val src0  = Input(UInt(XLEN.W))
30    val src1  = Input(UInt(XLEN.W))
31    val func  = Input(FuOpType())
32    val vconfig = Input(UInt(16.W))
33
34    val res   = Output(UInt(XLEN.W))
35  })
36  val vtype = io.src1(7, 0)
37  val vlmul = vtype(2, 0)
38  val vsew = vtype(5, 3)
39
40  val avlImm = Cat(0.U(3.W), io.src1(14, 10))
41  val vlLast = io.vconfig(15, 8)
42
43  val rd = io.ldest
44  val rs1 = io.lsrc0
45  val vl = WireInit(0.U(XLEN.W))
46  val vconfig = WireInit(0.U(XLEN.W))
47
48  // vlen =  128
49  val vlmaxVec = (0 to 7).map(i => if(i < 4) (16 << i).U(8.W) else (16 >> (8 - i)).U(8.W))
50  val shamt = vlmul + (~vsew).asUInt + 1.U
51  val vlmax = ParallelMux((0 to 7).map(_.U).map(_ === shamt), vlmaxVec)
52
53  val isVsetivli = io.func === ALUOpType.vsetivli2 || io.func === ALUOpType.vsetivli1
54  val vlWhenRs1Not0 = Mux(isVsetivli, Mux(avlImm > vlmax, vlmax, avlImm),
55                                      Mux(io.src0 > vlmax, vlmax, io.src0))
56  vl := Mux(rs1 =/= 0.U, vlWhenRs1Not0,
57          Mux(rd === 0.U, Cat(0.U(56.W), vlLast), vlmax))
58
59  vconfig := Cat(0.U(48.W), vl(7, 0), vtype)
60
61  io.res := Mux(io.func === ALUOpType.vsetvli2 || io.func === ALUOpType.vsetvl2 || io.func === ALUOpType.vsetivli2, vl, vconfig)
62}
63
64class AddModule(implicit p: Parameters) extends XSModule {
65  val io = IO(new Bundle() {
66    val src = Vec(2, Input(UInt(XLEN.W)))
67    val srcw = Input(UInt((XLEN/2).W))
68    val add = Output(UInt(XLEN.W))
69    val addw = Output(UInt((XLEN/2).W))
70  })
71  io.add := io.src(0) + io.src(1)
72  // TODO: why this extra adder?
73  io.addw := io.srcw + io.src(1)(31,0)
74}
75
76class SubModule(implicit p: Parameters) extends XSModule {
77  val io = IO(new Bundle() {
78    val src = Vec(2, Input(UInt(XLEN.W)))
79    val sub = Output(UInt((XLEN+1).W))
80  })
81  io.sub := (io.src(0) +& (~io.src(1)).asUInt()) + 1.U
82}
83
84class LeftShiftModule(implicit p: Parameters) extends XSModule {
85  val io = IO(new Bundle() {
86    val shamt = Input(UInt(6.W))
87    val revShamt = Input(UInt(6.W))
88    val sllSrc = Input(UInt(XLEN.W))
89    val sll = Output(UInt(XLEN.W))
90    val revSll = Output(UInt(XLEN.W))
91  })
92  io.sll := io.sllSrc << io.shamt
93  io.revSll := io.sllSrc << io.revShamt
94}
95
96class LeftShiftWordModule(implicit p: Parameters) extends XSModule {
97  val io = IO(new Bundle() {
98    val shamt = Input(UInt(5.W))
99    val revShamt = Input(UInt(5.W))
100    val sllSrc = Input(UInt((XLEN/2).W))
101    val sllw = Output(UInt((XLEN/2).W))
102    val revSllw = Output(UInt((XLEN/2).W))
103  })
104  io.sllw := io.sllSrc << io.shamt
105  io.revSllw := io.sllSrc << io.revShamt
106}
107
108class RightShiftModule(implicit p: Parameters) extends XSModule {
109  val io = IO(new Bundle() {
110    val shamt = Input(UInt(6.W))
111    val revShamt = Input(UInt(6.W))
112    val srlSrc, sraSrc = Input(UInt(XLEN.W))
113    val srl, sra = Output(UInt(XLEN.W))
114    val revSrl = Output(UInt(XLEN.W))
115  })
116  io.srl  := io.srlSrc >> io.shamt
117  io.sra  := (io.sraSrc.asSInt() >> io.shamt).asUInt()
118  io.revSrl  := io.srlSrc >> io.revShamt
119}
120
121class RightShiftWordModule(implicit p: Parameters) extends XSModule {
122  val io = IO(new Bundle() {
123    val shamt = Input(UInt(5.W))
124    val revShamt = Input(UInt(5.W))
125    val srlSrc, sraSrc = Input(UInt((XLEN/2).W))
126    val srlw, sraw = Output(UInt((XLEN/2).W))
127    val revSrlw = Output(UInt((XLEN/2).W))
128  })
129
130  io.srlw := io.srlSrc >> io.shamt
131  io.sraw := (io.sraSrc.asSInt() >> io.shamt).asUInt()
132  io.revSrlw := io.srlSrc >> io.revShamt
133}
134
135
136class MiscResultSelect(implicit p: Parameters) extends XSModule {
137  val io = IO(new Bundle() {
138    val func = Input(UInt(6.W))
139    val and, or, xor, orcb, orh48, sextb, packh, sexth, packw, revb, rev8, pack = Input(UInt(XLEN.W))
140    val src = Input(UInt(XLEN.W))
141    val miscRes = Output(UInt(XLEN.W))
142  })
143
144  val logicRes = VecInit(Seq(
145    io.and,
146    io.or,
147    io.xor,
148    io.orcb
149  ))(io.func(2, 1))
150  val miscRes = VecInit(Seq(io.sextb, io.packh, io.sexth, io.packw))(io.func(1, 0))
151  val logicBase = Mux(io.func(3), miscRes, logicRes)
152
153  val revRes = VecInit(Seq(io.revb, io.rev8, io.pack, io.orh48))(io.func(1, 0))
154  val customRes = VecInit(Seq(
155    Cat(0.U(31.W), io.src(31, 0), 0.U(1.W)),
156    Cat(0.U(30.W), io.src(31, 0), 0.U(2.W)),
157    Cat(0.U(29.W), io.src(31, 0), 0.U(3.W)),
158    Cat(0.U(56.W), io.src(15, 8))))(io.func(1, 0))
159  val logicAdv = Mux(io.func(3), customRes, revRes)
160
161  val mask = Cat(Fill(15, io.func(0)), 1.U(1.W))
162  val maskedLogicRes = mask & logicRes
163
164  io.miscRes := Mux(io.func(5), maskedLogicRes, Mux(io.func(4), logicAdv, logicBase))
165}
166
167class ShiftResultSelect(implicit p: Parameters) extends XSModule {
168  val io = IO(new Bundle() {
169    val func = Input(UInt(4.W))
170    val sll, srl, sra, rol, ror, bclr, bset, binv, bext = Input(UInt(XLEN.W))
171    val shiftRes = Output(UInt(XLEN.W))
172  })
173
174  // val leftBit  = Mux(io.func(1), io.binv, Mux(io.func(0), io.bset, io.bclr))
175  // val leftRes  = Mux(io.func(2), leftBit, io.sll)
176  // val rightRes = Mux(io.func(1) && io.func(0), io.sra, Mux(io.func(1), io.bext, io.srl))
177  val resultSource = VecInit(Seq(
178    io.sll,
179    io.sll,
180    io.bclr,
181    io.bset,
182    io.binv,
183    io.srl,
184    io.bext,
185    io.sra
186  ))
187  val simple = resultSource(io.func(2, 0))
188
189  io.shiftRes := Mux(io.func(3), Mux(io.func(1), io.ror, io.rol), simple)
190}
191
192class WordResultSelect(implicit p: Parameters) extends XSModule {
193  val io = IO(new Bundle() {
194    val func = Input(UInt())
195    val sllw, srlw, sraw, rolw, rorw, addw, subw = Input(UInt((XLEN/2).W))
196    val wordRes = Output(UInt(XLEN.W))
197  })
198
199  val addsubRes = Mux(!io.func(2) && io.func(1), io.subw, io.addw)
200  val shiftRes = Mux(io.func(2), Mux(io.func(0), io.rorw, io.rolw),
201                  Mux(io.func(1), io.sraw, Mux(io.func(0), io.srlw, io.sllw)))
202  val wordRes = Mux(io.func(3), shiftRes, addsubRes)
203  io.wordRes := SignExt(wordRes, XLEN)
204}
205
206
207class AluResSel(implicit p: Parameters) extends XSModule {
208  val io = IO(new Bundle() {
209    val func = Input(UInt(4.W))
210    val addRes, shiftRes, miscRes, compareRes, wordRes, vsetRes = Input(UInt(XLEN.W))
211    val aluRes = Output(UInt(XLEN.W))
212  })
213
214  val res = Mux(io.func(3), io.vsetRes,
215              Mux(io.func(2, 1) === 0.U, Mux(io.func(0), io.wordRes, io.shiftRes),
216                Mux(!io.func(2), Mux(io.func(0), io.compareRes, io.addRes), io.miscRes)))
217  io.aluRes := res
218}
219
220class AluDataModule(implicit p: Parameters) extends XSModule {
221  val io = IO(new Bundle() {
222    val src = Vec(2, Input(UInt(XLEN.W)))
223    val func = Input(FuOpType())
224    val pred_taken, isBranch = Input(Bool())
225    val result = Output(UInt(XLEN.W))
226    val taken, mispredict = Output(Bool())
227    val lsrc0 = Input(UInt(6.W))
228    val ldest = Input(UInt(6.W))
229    val vconfig = Input(UInt(16.W))
230  })
231  val (src1, src2, func) = (io.src(0), io.src(1), io.func)
232
233  val shamt = src2(5, 0)
234  val revShamt = ~src2(5,0) + 1.U
235
236  // slliuw, sll
237  val leftShiftModule = Module(new LeftShiftModule)
238  val sll = leftShiftModule.io.sll
239  val revSll = leftShiftModule.io.revSll
240  leftShiftModule.io.sllSrc := Cat(Fill(32, func(0)), Fill(32, 1.U)) & src1
241  leftShiftModule.io.shamt := shamt
242  leftShiftModule.io.revShamt := revShamt
243
244  // bclr, bset, binv
245  val bitShift = 1.U << src2(5, 0)
246  val bclr = src1 & ~bitShift
247  val bset = src1 | bitShift
248  val binv = src1 ^ bitShift
249
250  // srl, sra, bext
251  val rightShiftModule = Module(new RightShiftModule)
252  val srl = rightShiftModule.io.srl
253  val revSrl = rightShiftModule.io.revSrl
254  val sra = rightShiftModule.io.sra
255  rightShiftModule.io.shamt := shamt
256  rightShiftModule.io.revShamt := revShamt
257  rightShiftModule.io.srlSrc := src1
258  rightShiftModule.io.sraSrc := src1
259  val bext = srl(0)
260
261  val rol = revSrl | sll
262  val ror = srl | revSll
263
264  // vset
265  val vsetModule = Module(new VsetModule)
266  vsetModule.io.lsrc0 := io.lsrc0
267  vsetModule.io.ldest := io.ldest
268  vsetModule.io.src0 := io.src(0)
269  vsetModule.io.src1 := io.src(1)
270  vsetModule.io.func := io.func
271  vsetModule.io.vconfig := io.vconfig
272
273  // addw
274  val addModule = Module(new AddModule)
275  addModule.io.srcw := Mux(!func(2) && func(0), ZeroExt(src1(0), XLEN), src1(31, 0))
276  val addwResultAll = VecInit(Seq(
277    ZeroExt(addModule.io.addw(0), XLEN),
278    ZeroExt(addModule.io.addw(7, 0), XLEN),
279    ZeroExt(addModule.io.addw(15, 0), XLEN),
280    SignExt(addModule.io.addw(15, 0), XLEN)
281  ))
282  val addw = Mux(func(2), addwResultAll(func(1, 0)), addModule.io.addw)
283
284  // subw
285  val subModule = Module(new SubModule)
286  val subw = subModule.io.sub
287
288  // sllw
289  val leftShiftWordModule = Module(new LeftShiftWordModule)
290  val sllw = leftShiftWordModule.io.sllw
291  val revSllw = leftShiftWordModule.io.revSllw
292  leftShiftWordModule.io.sllSrc := src1
293  leftShiftWordModule.io.shamt := shamt
294  leftShiftWordModule.io.revShamt := revShamt
295
296  val rightShiftWordModule = Module(new RightShiftWordModule)
297  val srlw = rightShiftWordModule.io.srlw
298  val revSrlw = rightShiftWordModule.io.revSrlw
299  val sraw = rightShiftWordModule.io.sraw
300  rightShiftWordModule.io.shamt := shamt
301  rightShiftWordModule.io.revShamt := revShamt
302  rightShiftWordModule.io.srlSrc := src1
303  rightShiftWordModule.io.sraSrc := src1
304
305  val rolw = revSrlw | sllw
306  val rorw = srlw | revSllw
307
308  // add
309  val wordMaskAddSource = Cat(Fill(32, func(0)), Fill(32, 1.U)) & src1
310  val shaddSource = VecInit(Seq(
311    Cat(wordMaskAddSource(62, 0), 0.U(1.W)),
312    Cat(wordMaskAddSource(61, 0), 0.U(2.W)),
313    Cat(wordMaskAddSource(60, 0), 0.U(3.W)),
314    Cat(wordMaskAddSource(59, 0), 0.U(4.W))
315  ))
316  val sraddSource = VecInit(Seq(
317    ZeroExt(src1(63, 29), XLEN),
318    ZeroExt(src1(63, 30), XLEN),
319    ZeroExt(src1(63, 31), XLEN),
320    ZeroExt(src1(63, 32), XLEN)
321  ))
322  // TODO: use decoder or other libraries to optimize timing
323  // Now we assume shadd has the worst timing.
324  addModule.io.src(0) := Mux(func(3), shaddSource(func(2, 1)),
325    Mux(func(2), sraddSource(func(1, 0)),
326    Mux(func(1), ZeroExt(src1(0), XLEN), wordMaskAddSource))
327  )
328  addModule.io.src(1) := src2
329  val add = addModule.io.add
330
331  // sub
332  val sub  = subModule.io.sub
333  subModule.io.src(0) := src1
334  subModule.io.src(1) := src2
335  val sltu    = !sub(XLEN)
336  val slt     = src1(XLEN - 1) ^ src2(XLEN - 1) ^ sltu
337  val maxMin  = Mux(slt ^ func(0), src2, src1)
338  val maxMinU = Mux(sltu ^ func(0), src2, src1)
339  val compareRes = Mux(func(2), Mux(func(1), maxMin, maxMinU), Mux(func(1), slt, Mux(func(0), sltu, sub)))
340
341  // logic
342  val logicSrc2 = Mux(!func(5) && func(0), ~src2, src2)
343  val and     = src1 & logicSrc2
344  val or      = src1 | logicSrc2
345  val xor     = src1 ^ logicSrc2
346  val orcb    = Cat((0 until 8).map(i => Fill(8, src1(i * 8 + 7, i * 8).orR)).reverse)
347  val orh48   = Cat(src1(63, 8), 0.U(8.W)) | src2
348
349  val sextb = SignExt(src1(7, 0), XLEN)
350  val packh = Cat(src2(7,0), src1(7,0))
351  val sexth = SignExt(src1(15, 0), XLEN)
352  val packw = SignExt(Cat(src2(15, 0), src1(15, 0)), XLEN)
353
354  val revb = Cat((0 until 8).map(i => Reverse(src1(8 * i + 7, 8 * i))).reverse)
355  val pack = Cat(src2(31, 0), src1(31, 0))
356  val rev8 = Cat((0 until 8).map(i => src1(8 * i + 7, 8 * i)))
357
358  // branch
359  val branchOpTable = List(
360    ALUOpType.getBranchType(ALUOpType.beq)  -> !xor.orR,
361    ALUOpType.getBranchType(ALUOpType.blt)  -> slt,
362    ALUOpType.getBranchType(ALUOpType.bltu) -> sltu
363  )
364  val taken = LookupTree(ALUOpType.getBranchType(func), branchOpTable) ^ ALUOpType.isBranchInvert(func)
365
366  // Result Select
367  val shiftResSel = Module(new ShiftResultSelect)
368  shiftResSel.io.func := func(3, 0)
369  shiftResSel.io.sll  := sll
370  shiftResSel.io.srl  := srl
371  shiftResSel.io.sra  := sra
372  shiftResSel.io.rol  := rol
373  shiftResSel.io.ror  := ror
374  shiftResSel.io.bclr := bclr
375  shiftResSel.io.binv := binv
376  shiftResSel.io.bset := bset
377  shiftResSel.io.bext := bext
378  val shiftRes = shiftResSel.io.shiftRes
379
380  val miscResSel = Module(new MiscResultSelect)
381  miscResSel.io.func    := func(5, 0)
382  miscResSel.io.and     := and
383  miscResSel.io.or      := or
384  miscResSel.io.xor     := xor
385  miscResSel.io.orcb    := orcb
386  miscResSel.io.orh48   := orh48
387  miscResSel.io.sextb   := sextb
388  miscResSel.io.packh   := packh
389  miscResSel.io.sexth   := sexth
390  miscResSel.io.packw   := packw
391  miscResSel.io.revb    := revb
392  miscResSel.io.rev8    := rev8
393  miscResSel.io.pack    := pack
394  miscResSel.io.src     := src1
395  val miscRes = miscResSel.io.miscRes
396
397  val wordResSel = Module(new WordResultSelect)
398  wordResSel.io.func := func
399  wordResSel.io.addw := addw
400  wordResSel.io.subw := subw
401  wordResSel.io.sllw := sllw
402  wordResSel.io.srlw := srlw
403  wordResSel.io.sraw := sraw
404  wordResSel.io.rolw := rolw
405  wordResSel.io.rorw := rorw
406  val wordRes = wordResSel.io.wordRes
407
408  val aluResSel = Module(new AluResSel)
409  aluResSel.io.func := func(7, 4)
410  aluResSel.io.addRes := add
411  aluResSel.io.compareRes := compareRes
412  aluResSel.io.shiftRes := shiftRes
413  aluResSel.io.miscRes := miscRes
414  aluResSel.io.wordRes := wordRes
415  aluResSel.io.vsetRes := vsetModule.io.res
416  val aluRes = aluResSel.io.aluRes
417
418  io.result := aluRes
419  io.taken := taken
420  io.mispredict := (io.pred_taken ^ taken) && io.isBranch
421}
422
423class Alu(implicit p: Parameters) extends FUWithRedirect {
424
425  val uop = io.in.bits.uop
426
427  val isBranch = ALUOpType.isBranch(io.in.bits.uop.ctrl.fuOpType)
428  val dataModule = Module(new AluDataModule)
429
430  dataModule.io.src := io.in.bits.src.take(2)
431  dataModule.io.func := io.in.bits.uop.ctrl.fuOpType
432  dataModule.io.pred_taken := uop.cf.pred_taken
433  dataModule.io.isBranch := isBranch
434  dataModule.io.lsrc0 := uop.ctrl.lsrc(0)
435  dataModule.io.ldest := uop.ctrl.ldest
436  dataModule.io.vconfig := uop.ctrl.vconfig
437
438  redirectOutValid := io.out.valid && isBranch
439  redirectOut := DontCare
440  redirectOut.level := RedirectLevel.flushAfter
441  redirectOut.robIdx := uop.robIdx
442  redirectOut.ftqIdx := uop.cf.ftqPtr
443  redirectOut.ftqOffset := uop.cf.ftqOffset
444  redirectOut.cfiUpdate.isMisPred := dataModule.io.mispredict
445  redirectOut.cfiUpdate.taken := dataModule.io.taken
446  redirectOut.cfiUpdate.predTaken := uop.cf.pred_taken
447
448  io.in.ready := io.out.ready
449  io.out.valid := io.in.valid
450  io.out.bits.uop <> io.in.bits.uop
451  io.out.bits.data := dataModule.io.result
452}
453