xref: /XiangShan/src/main/scala/xiangshan/backend/fu/wrapper/MulUnit.scala (revision 730cfbc0bf03569aa07dd82ba3fb41eb7413e13c)
1package xiangshan.backend.fu.wrapper
2
3import chipsalliance.rocketchip.config.Parameters
4import chisel3._
5import utility.{LookupTree, SignExt, ZeroExt}
6import xiangshan.MULOpType
7import xiangshan.backend.fu.{ArrayMulDataModule, FuncUnit, HasPipelineReg, MulDivCtrl}
8import xiangshan.backend.fu.FuConfig
9
10class MulUnit(cfg: FuConfig)(implicit p: Parameters) extends FuncUnit(cfg)
11  with HasPipelineReg
12{
13  override def latency: Int = 2
14
15  private val xlen = cfg.dataBits
16
17  val func = io.in.bits.fuOpType
18  val src = io.in.bits.src
19
20  val op = MULOpType.getOp(func)
21
22  private val ctrl = Wire(new MulDivCtrl)
23  ctrl.isW := MULOpType.isW(func)
24  ctrl.isHi := MULOpType.isH(func)
25  ctrl.sign := DontCare
26
27  val sext = SignExt(_: UInt, xlen + 1)
28  val zext = ZeroExt(_: UInt, xlen + 1)
29  val mulInputCvtTable: Seq[(UInt, (UInt => UInt, UInt => UInt))] = List(
30    MULOpType.getOp(MULOpType.mul)    -> (zext, zext),
31    MULOpType.getOp(MULOpType.mulh)   -> (sext, sext),
32    MULOpType.getOp(MULOpType.mulhsu) -> (sext, zext),
33    MULOpType.getOp(MULOpType.mulhu)  -> (zext, zext),
34    MULOpType.getOp(MULOpType.mulw7)  -> (_.asUInt(6, 0), zext),
35  )
36
37  // len should be xlen + 1
38  private val len = cfg.dataBits + 1
39  private val mulDataModule = Module(new ArrayMulDataModule(len))
40
41  mulDataModule.io.a := LookupTree(
42    op,
43    mulInputCvtTable.map { case (k, v) => (k, v._1(src(0))) }
44  )
45
46  mulDataModule.io.b := LookupTree(
47    op,
48    mulInputCvtTable.map { case (k, v) => (k, v._2(src(1))) }
49  )
50
51  mulDataModule.io.regEnables := VecInit((1 to latency) map (i => regEnable(i)))
52  private val result = mulDataModule.io.result
53
54  private var ctrlVec = Seq(ctrl)
55  for (i <- 1 to latency) {
56    ctrlVec = ctrlVec :+ PipelineReg(i)(ctrlVec(i - 1))
57  }
58  private val res = Mux(ctrlVec.last.isHi, result(2 * xlen - 1, xlen), result(xlen - 1, 0))
59
60  io.out.bits.data := Mux(ctrlVec.last.isW, SignExt(res(31, 0), xlen), res)
61  connectNonPipedCtrlSingal // Todo: make it piped
62}
63