xref: /XiangShan/src/main/scala/xiangshan/backend/fu/wrapper/MulUnit.scala (revision bb2f3f51dd67f6e16e0cc1ffe43368c9fc7e4aef)
1package xiangshan.backend.fu.wrapper
2
3import org.chipsalliance.cde.config.Parameters
4import chisel3._
5import utility.{LookupTree, SignExt, ZeroExt}
6import xiangshan.MULOpType
7import xiangshan.backend.fu.{ArrayMulDataModule, FuncUnit, HasPipelineReg, MulDivCtrl}
8import xiangshan.backend.fu.FuConfig
9
10class MulUnit(cfg: FuConfig)(implicit p: Parameters) extends FuncUnit(cfg)
11  with HasPipelineReg
12{
13  override def latency: Int = 2
14
15  private val xlen = cfg.destDataBits
16
17  val func = io.in.bits.ctrl.fuOpType
18  val src = io.in.bits.data.src
19
20  val op = MULOpType.getOp(func)
21
22  private val mdCtrl = Wire(new MulDivCtrl)
23  mdCtrl.isW := MULOpType.isW(func)
24  mdCtrl.isHi := MULOpType.isH(func)
25  mdCtrl.sign := DontCare
26
27  val sext = SignExt(_: UInt, xlen + 1)
28  val zext = ZeroExt(_: UInt, xlen + 1)
29  val mulInputCvtTable: Seq[(UInt, (UInt => UInt, UInt => UInt))] = List(
30    MULOpType.getOp(MULOpType.mul)    -> (zext, zext),
31    MULOpType.getOp(MULOpType.mulh)   -> (sext, sext),
32    MULOpType.getOp(MULOpType.mulhsu) -> (sext, zext),
33    MULOpType.getOp(MULOpType.mulhu)  -> (zext, zext),
34    MULOpType.getOp(MULOpType.mulw7)  -> (_.asUInt(6, 0), zext),
35  )
36
37  // len should be xlen + 1
38  private val len = cfg.destDataBits + 1
39  private val mulDataModule = Module(new ArrayMulDataModule(len))
40
41  mulDataModule.io.a := LookupTree(
42    op,
43    mulInputCvtTable.map { case (k, v) => (k, v._1(src(0))) }
44  )
45
46  mulDataModule.io.b := LookupTree(
47    op,
48    mulInputCvtTable.map { case (k, v) => (k, v._2(src(1))) }
49  )
50
51  mulDataModule.io.regEnables := VecInit((1 to latency) map (i => regEnable(i)))
52  private val result = mulDataModule.io.result
53
54  private var mdCtrlVec = Seq(mdCtrl)
55  for (i <- 1 to latency) {
56    mdCtrlVec = mdCtrlVec :+ PipelineReg(i)(mdCtrlVec(i - 1))
57  }
58  private val res = Mux(mdCtrlVec.last.isHi, result(2 * xlen - 1, xlen), result(xlen - 1, 0))
59
60  io.out.bits.res.data := Mux(mdCtrlVec.last.isW, SignExt(res(31, 0), xlen), res)
61}
62