xref: /XiangShan/src/main/scala/xiangshan/backend/fu/Multiplier.scala (revision 7f1506e34f4f1556f09fd3d96108d0b558ad4881)
1cafb3558SLinJiaweipackage xiangshan.backend.fu
2cafb3558SLinJiawei
3cafb3558SLinJiaweiimport chisel3._
4cafb3558SLinJiaweiimport chisel3.util._
5cafb3558SLinJiaweiimport xiangshan._
6b9fd1892SLinJiaweiimport utils._
7*7f1506e3SLinJiaweiimport xiangshan.backend.fu.util.{C22, C32, C53}
8cafb3558SLinJiawei
9cafb3558SLinJiaweiclass MulDivCtrl extends Bundle{
10cafb3558SLinJiawei  val sign = Bool()
11cafb3558SLinJiawei  val isW = Bool()
12cafb3558SLinJiawei  val isHi = Bool() // return hi bits of result ?
13cafb3558SLinJiawei}
14cafb3558SLinJiawei
1552c3f215SLinJiaweiclass AbstractMultiplier(len: Int) extends FunctionUnit(
16e18c367fSLinJiawei  len
178a4dc19aSLinJiawei){
188a4dc19aSLinJiawei  val ctrl = IO(Input(new MulDivCtrl))
198a4dc19aSLinJiawei}
208a4dc19aSLinJiawei
2152c3f215SLinJiaweiclass NaiveMultiplier(len: Int, val latency: Int)
2252c3f215SLinJiawei  extends AbstractMultiplier(len)
23e18c367fSLinJiawei  with HasPipelineReg
2412bb47ddSLinJiawei{
253142d695SLinJiawei
263142d695SLinJiawei  val (src1, src2) = (io.in.bits.src(0), io.in.bits.src(1))
27cafb3558SLinJiawei
283142d695SLinJiawei  val mulRes = src1.asSInt() * src2.asSInt()
29cafb3558SLinJiawei
30cafb3558SLinJiawei  var dataVec = Seq(mulRes.asUInt())
31e18c367fSLinJiawei  var ctrlVec = Seq(ctrl)
32cafb3558SLinJiawei
33cafb3558SLinJiawei  for(i <- 1 to latency){
34cafb3558SLinJiawei    dataVec = dataVec :+ PipelineReg(i)(dataVec(i-1))
353142d695SLinJiawei    ctrlVec = ctrlVec :+ PipelineReg(i)(ctrlVec(i-1))
36cafb3558SLinJiawei  }
37cafb3558SLinJiawei
38cafb3558SLinJiawei  val xlen = io.out.bits.data.getWidth
39cafb3558SLinJiawei  val res = Mux(ctrlVec.last.isHi, dataVec.last(2*xlen-1, xlen), dataVec.last(xlen-1,0))
40cafb3558SLinJiawei  io.out.bits.data := Mux(ctrlVec.last.isW, SignExt(res(31,0),xlen), res)
41cafb3558SLinJiawei
42e18c367fSLinJiawei  XSDebug(p"validVec:${Binary(Cat(validVec))} flushVec:${Binary(Cat(flushVec))}\n")
43cafb3558SLinJiawei}
448a4dc19aSLinJiawei
4552c3f215SLinJiaweiclass ArrayMultiplier(len: Int, doReg: Seq[Int]) extends AbstractMultiplier(len) with HasPipelineReg {
4652c3f215SLinJiawei
47f64ff6e8SLinJiawei  override def latency = doReg.size
488a4dc19aSLinJiawei
498a4dc19aSLinJiawei  val doRegSorted = doReg.sortWith(_ < _)
508a4dc19aSLinJiawei  println(doRegSorted)
518a4dc19aSLinJiawei
528a4dc19aSLinJiawei  val (a, b) = (io.in.bits.src(0), io.in.bits.src(1))
538a4dc19aSLinJiawei
548a4dc19aSLinJiawei  val b_sext, bx2, neg_b, neg_bx2 = Wire(UInt((len+1).W))
558a4dc19aSLinJiawei  b_sext := SignExt(b, len+1)
568a4dc19aSLinJiawei  bx2 := b_sext << 1
578a4dc19aSLinJiawei  neg_b := (~b_sext).asUInt()
588a4dc19aSLinJiawei  neg_bx2 := neg_b << 1
598a4dc19aSLinJiawei
608a4dc19aSLinJiawei  val columns: Array[Seq[Bool]] = Array.fill(2*len)(Seq())
618a4dc19aSLinJiawei
628a4dc19aSLinJiawei  var last_x = WireInit(0.U(3.W))
638a4dc19aSLinJiawei  for(i <- Range(0, len, 2)){
648a4dc19aSLinJiawei    val x = if(i==0) Cat(a(1,0), 0.U(1.W)) else if(i+1==len) SignExt(a(i, i-1), 3) else a(i+1, i-1)
658a4dc19aSLinJiawei    val pp_temp = MuxLookup(x, 0.U, Seq(
668a4dc19aSLinJiawei      1.U -> b_sext,
678a4dc19aSLinJiawei      2.U -> b_sext,
688a4dc19aSLinJiawei      3.U -> bx2,
698a4dc19aSLinJiawei      4.U -> neg_bx2,
708a4dc19aSLinJiawei      5.U -> neg_b,
718a4dc19aSLinJiawei      6.U -> neg_b
728a4dc19aSLinJiawei    ))
738a4dc19aSLinJiawei    val s = pp_temp(len)
748a4dc19aSLinJiawei    val t = MuxLookup(last_x, 0.U(2.W), Seq(
758a4dc19aSLinJiawei      4.U -> 2.U(2.W),
768a4dc19aSLinJiawei      5.U -> 1.U(2.W),
778a4dc19aSLinJiawei      6.U -> 1.U(2.W)
788a4dc19aSLinJiawei    ))
798a4dc19aSLinJiawei    last_x = x
808a4dc19aSLinJiawei    val (pp, weight) = i match {
818a4dc19aSLinJiawei      case 0 =>
828a4dc19aSLinJiawei        (Cat(~s, s, s, pp_temp), 0)
838a4dc19aSLinJiawei      case n if (n==len-1) || (n==len-2) =>
848a4dc19aSLinJiawei        (Cat(~s, pp_temp, t), i-2)
858a4dc19aSLinJiawei      case _ =>
868a4dc19aSLinJiawei        (Cat(1.U(1.W), ~s, pp_temp, t), i-2)
878a4dc19aSLinJiawei    }
888a4dc19aSLinJiawei    for(j <- columns.indices){
898a4dc19aSLinJiawei      if(j >= weight && j < (weight + pp.getWidth)){
908a4dc19aSLinJiawei        columns(j) = columns(j) :+ pp(j-weight)
918a4dc19aSLinJiawei      }
928a4dc19aSLinJiawei    }
938a4dc19aSLinJiawei  }
948a4dc19aSLinJiawei
958a4dc19aSLinJiawei  def addOneColumn(col: Seq[Bool], cin: Seq[Bool]): (Seq[Bool], Seq[Bool], Seq[Bool]) = {
968a4dc19aSLinJiawei    var sum = Seq[Bool]()
978a4dc19aSLinJiawei    var cout1 = Seq[Bool]()
988a4dc19aSLinJiawei    var cout2 = Seq[Bool]()
998a4dc19aSLinJiawei    col.size match {
1008a4dc19aSLinJiawei      case 1 =>  // do nothing
1018a4dc19aSLinJiawei        sum = col ++ cin
1028a4dc19aSLinJiawei      case 2 =>
1038a4dc19aSLinJiawei        val c22 = Module(new C22)
1048a4dc19aSLinJiawei        c22.io.in := col
1058a4dc19aSLinJiawei        sum = c22.io.out(0).asBool() +: cin
1068a4dc19aSLinJiawei        cout2 = Seq(c22.io.out(1).asBool())
1078a4dc19aSLinJiawei      case 3 =>
1088a4dc19aSLinJiawei        val c32 = Module(new C32)
1098a4dc19aSLinJiawei        c32.io.in := col
1108a4dc19aSLinJiawei        sum = c32.io.out(0).asBool() +: cin
1118a4dc19aSLinJiawei        cout2 = Seq(c32.io.out(1).asBool())
1128a4dc19aSLinJiawei      case 4 =>
1138a4dc19aSLinJiawei        val c53 = Module(new C53)
1148a4dc19aSLinJiawei        for((x, y) <- c53.io.in.take(4) zip col){
1158a4dc19aSLinJiawei          x := y
1168a4dc19aSLinJiawei        }
1178a4dc19aSLinJiawei        c53.io.in.last := (if(cin.nonEmpty) cin.head else 0.U)
1188a4dc19aSLinJiawei        sum = Seq(c53.io.out(0).asBool()) ++ (if(cin.nonEmpty) cin.drop(1) else Nil)
1198a4dc19aSLinJiawei        cout1 = Seq(c53.io.out(1).asBool())
1208a4dc19aSLinJiawei        cout2 = Seq(c53.io.out(2).asBool())
1218a4dc19aSLinJiawei      case n =>
1228a4dc19aSLinJiawei        val cin_1 = if(cin.nonEmpty) Seq(cin.head) else Nil
1238a4dc19aSLinJiawei        val cin_2 = if(cin.nonEmpty) cin.drop(1) else Nil
1248a4dc19aSLinJiawei        val (s_1, c_1_1, c_1_2) = addOneColumn(col take 4, cin_1)
1258a4dc19aSLinJiawei        val (s_2, c_2_1, c_2_2) = addOneColumn(col drop 4, cin_2)
1268a4dc19aSLinJiawei        sum = s_1 ++ s_2
1278a4dc19aSLinJiawei        cout1 = c_1_1 ++ c_2_1
1288a4dc19aSLinJiawei        cout2 = c_1_2 ++ c_2_2
1298a4dc19aSLinJiawei    }
1308a4dc19aSLinJiawei    (sum, cout1, cout2)
1318a4dc19aSLinJiawei  }
1328a4dc19aSLinJiawei
1338a4dc19aSLinJiawei  def max(in: Iterable[Int]): Int = in.reduce((a, b) => if(a>b) a else b)
1348a4dc19aSLinJiawei  def addAll(cols: Array[Seq[Bool]], depth: Int): (UInt, UInt) = {
1358a4dc19aSLinJiawei    if(max(cols.map(_.size)) <= 2){
1368a4dc19aSLinJiawei      val sum = Cat(cols.map(_(0)).reverse)
1378a4dc19aSLinJiawei      var k = 0
1388a4dc19aSLinJiawei      while(cols(k).size == 1) k = k+1
1398a4dc19aSLinJiawei      val carry = Cat(cols.drop(k).map(_(1)).reverse)
1408a4dc19aSLinJiawei      (sum, Cat(carry, 0.U(k.W)))
1418a4dc19aSLinJiawei    } else {
1428a4dc19aSLinJiawei      val columns_next = Array.fill(2*len)(Seq[Bool]())
1438a4dc19aSLinJiawei      var cout1, cout2 = Seq[Bool]()
1448a4dc19aSLinJiawei      for( i <- cols.indices){
1458a4dc19aSLinJiawei        val (s, c1, c2) = addOneColumn(cols(i), cout1)
1468a4dc19aSLinJiawei        columns_next(i) = s ++ cout2
1478a4dc19aSLinJiawei        cout1 = c1
1488a4dc19aSLinJiawei        cout2 = c2
1498a4dc19aSLinJiawei      }
1508a4dc19aSLinJiawei
1518a4dc19aSLinJiawei      val needReg = doRegSorted.contains(depth)
1528a4dc19aSLinJiawei      val toNextLayer = if(needReg)
1538a4dc19aSLinJiawei        columns_next.map(_.map(PipelineReg(doRegSorted.indexOf(depth) + 1)(_)))
1548a4dc19aSLinJiawei      else
1558a4dc19aSLinJiawei        columns_next
1568a4dc19aSLinJiawei
1578a4dc19aSLinJiawei      addAll(toNextLayer, depth+1)
1588a4dc19aSLinJiawei    }
1598a4dc19aSLinJiawei  }
1608a4dc19aSLinJiawei
1618a4dc19aSLinJiawei  val (sum, carry) = addAll(cols = columns, depth = 0)
1628a4dc19aSLinJiawei  val result = sum + carry
1638a4dc19aSLinJiawei
1648a4dc19aSLinJiawei  var ctrlVec = Seq(ctrl)
1658a4dc19aSLinJiawei  for(i <- 1 to latency){
1668a4dc19aSLinJiawei    ctrlVec = ctrlVec :+ PipelineReg(i)(ctrlVec(i-1))
1678a4dc19aSLinJiawei  }
1688a4dc19aSLinJiawei  val xlen = io.out.bits.data.getWidth
1698a4dc19aSLinJiawei  val res = Mux(ctrlVec.last.isHi, result(2*xlen-1, xlen), result(xlen-1,0))
1708a4dc19aSLinJiawei
1718a4dc19aSLinJiawei  io.out.bits.data := Mux(ctrlVec.last.isW, SignExt(res(31,0),xlen), res)
1728a4dc19aSLinJiawei
1738a4dc19aSLinJiawei  XSDebug(p"validVec:${Binary(Cat(validVec))} flushVec:${Binary(Cat(flushVec))}\n")
1748a4dc19aSLinJiawei}