xref: /XiangShan/src/main/scala/xiangshan/backend/fu/Multiplier.scala (revision 8a4dc19a5e2f0e8738191c5d75a9cf98f58c36dd)
1cafb3558SLinJiaweipackage xiangshan.backend.fu
2cafb3558SLinJiawei
3cafb3558SLinJiaweiimport chisel3._
4cafb3558SLinJiaweiimport chisel3.util._
5cafb3558SLinJiaweiimport xiangshan._
6b9fd1892SLinJiaweiimport utils._
7*8a4dc19aSLinJiaweiimport xiangshan.backend.fu.fpu.util.{C22, C32, C53}
8cafb3558SLinJiawei
9cafb3558SLinJiaweiclass MulDivCtrl extends Bundle{
10cafb3558SLinJiawei  val sign = Bool()
11cafb3558SLinJiawei  val isW = Bool()
12cafb3558SLinJiawei  val isHi = Bool() // return hi bits of result ?
13cafb3558SLinJiawei}
14cafb3558SLinJiawei
15*8a4dc19aSLinJiaweiclass AbstractMultiplier(len: Int, latency: Int = 2) extends FunctionUnit(
16e18c367fSLinJiawei  FuConfig(FuType.mul, 2, 0, writeIntRf = true, writeFpRf = false, hasRedirect = false, CertainLatency(latency)),
17e18c367fSLinJiawei  len
18*8a4dc19aSLinJiawei){
19*8a4dc19aSLinJiawei  val ctrl = IO(Input(new MulDivCtrl))
20*8a4dc19aSLinJiawei}
21*8a4dc19aSLinJiawei
22*8a4dc19aSLinJiaweiclass NaiveMultiplier(len: Int, latency: Int = 3)
23*8a4dc19aSLinJiawei  extends AbstractMultiplier(len, latency)
24e18c367fSLinJiawei  with HasPipelineReg
2512bb47ddSLinJiawei{
263142d695SLinJiawei
273142d695SLinJiawei  val (src1, src2) = (io.in.bits.src(0), io.in.bits.src(1))
28cafb3558SLinJiawei
293142d695SLinJiawei  val mulRes = src1.asSInt() * src2.asSInt()
30cafb3558SLinJiawei
31cafb3558SLinJiawei  var dataVec = Seq(mulRes.asUInt())
32e18c367fSLinJiawei  var ctrlVec = Seq(ctrl)
33cafb3558SLinJiawei
34cafb3558SLinJiawei  for(i <- 1 to latency){
35cafb3558SLinJiawei    dataVec = dataVec :+ PipelineReg(i)(dataVec(i-1))
363142d695SLinJiawei    ctrlVec = ctrlVec :+ PipelineReg(i)(ctrlVec(i-1))
37cafb3558SLinJiawei  }
38cafb3558SLinJiawei
39cafb3558SLinJiawei  val xlen = io.out.bits.data.getWidth
40cafb3558SLinJiawei  val res = Mux(ctrlVec.last.isHi, dataVec.last(2*xlen-1, xlen), dataVec.last(xlen-1,0))
41cafb3558SLinJiawei  io.out.bits.data := Mux(ctrlVec.last.isW, SignExt(res(31,0),xlen), res)
42cafb3558SLinJiawei
43e18c367fSLinJiawei  XSDebug(p"validVec:${Binary(Cat(validVec))} flushVec:${Binary(Cat(flushVec))}\n")
44cafb3558SLinJiawei}
45*8a4dc19aSLinJiawei
46*8a4dc19aSLinJiaweiclass ArrayMultiplier(len: Int, doReg: Seq[Int]) extends AbstractMultiplier(len, doReg.size) with HasPipelineReg {
47*8a4dc19aSLinJiawei
48*8a4dc19aSLinJiawei  val doRegSorted = doReg.sortWith(_ < _)
49*8a4dc19aSLinJiawei  println(doRegSorted)
50*8a4dc19aSLinJiawei
51*8a4dc19aSLinJiawei  val (a, b) = (io.in.bits.src(0), io.in.bits.src(1))
52*8a4dc19aSLinJiawei
53*8a4dc19aSLinJiawei  val b_sext, bx2, neg_b, neg_bx2 = Wire(UInt((len+1).W))
54*8a4dc19aSLinJiawei  b_sext := SignExt(b, len+1)
55*8a4dc19aSLinJiawei  bx2 := b_sext << 1
56*8a4dc19aSLinJiawei  neg_b := (~b_sext).asUInt()
57*8a4dc19aSLinJiawei  neg_bx2 := neg_b << 1
58*8a4dc19aSLinJiawei
59*8a4dc19aSLinJiawei  val columns: Array[Seq[Bool]] = Array.fill(2*len)(Seq())
60*8a4dc19aSLinJiawei
61*8a4dc19aSLinJiawei  var last_x = WireInit(0.U(3.W))
62*8a4dc19aSLinJiawei  for(i <- Range(0, len, 2)){
63*8a4dc19aSLinJiawei    val x = if(i==0) Cat(a(1,0), 0.U(1.W)) else if(i+1==len) SignExt(a(i, i-1), 3) else a(i+1, i-1)
64*8a4dc19aSLinJiawei    val pp_temp = MuxLookup(x, 0.U, Seq(
65*8a4dc19aSLinJiawei      1.U -> b_sext,
66*8a4dc19aSLinJiawei      2.U -> b_sext,
67*8a4dc19aSLinJiawei      3.U -> bx2,
68*8a4dc19aSLinJiawei      4.U -> neg_bx2,
69*8a4dc19aSLinJiawei      5.U -> neg_b,
70*8a4dc19aSLinJiawei      6.U -> neg_b
71*8a4dc19aSLinJiawei    ))
72*8a4dc19aSLinJiawei    val s = pp_temp(len)
73*8a4dc19aSLinJiawei    val t = MuxLookup(last_x, 0.U(2.W), Seq(
74*8a4dc19aSLinJiawei      4.U -> 2.U(2.W),
75*8a4dc19aSLinJiawei      5.U -> 1.U(2.W),
76*8a4dc19aSLinJiawei      6.U -> 1.U(2.W)
77*8a4dc19aSLinJiawei    ))
78*8a4dc19aSLinJiawei    last_x = x
79*8a4dc19aSLinJiawei    val (pp, weight) = i match {
80*8a4dc19aSLinJiawei      case 0 =>
81*8a4dc19aSLinJiawei        (Cat(~s, s, s, pp_temp), 0)
82*8a4dc19aSLinJiawei      case n if (n==len-1) || (n==len-2) =>
83*8a4dc19aSLinJiawei        (Cat(~s, pp_temp, t), i-2)
84*8a4dc19aSLinJiawei      case _ =>
85*8a4dc19aSLinJiawei        (Cat(1.U(1.W), ~s, pp_temp, t), i-2)
86*8a4dc19aSLinJiawei    }
87*8a4dc19aSLinJiawei    for(j <- columns.indices){
88*8a4dc19aSLinJiawei      if(j >= weight && j < (weight + pp.getWidth)){
89*8a4dc19aSLinJiawei        columns(j) = columns(j) :+ pp(j-weight)
90*8a4dc19aSLinJiawei      }
91*8a4dc19aSLinJiawei    }
92*8a4dc19aSLinJiawei  }
93*8a4dc19aSLinJiawei
94*8a4dc19aSLinJiawei  def addOneColumn(col: Seq[Bool], cin: Seq[Bool]): (Seq[Bool], Seq[Bool], Seq[Bool]) = {
95*8a4dc19aSLinJiawei    var sum = Seq[Bool]()
96*8a4dc19aSLinJiawei    var cout1 = Seq[Bool]()
97*8a4dc19aSLinJiawei    var cout2 = Seq[Bool]()
98*8a4dc19aSLinJiawei    col.size match {
99*8a4dc19aSLinJiawei      case 1 =>  // do nothing
100*8a4dc19aSLinJiawei        sum = col ++ cin
101*8a4dc19aSLinJiawei      case 2 =>
102*8a4dc19aSLinJiawei        val c22 = Module(new C22)
103*8a4dc19aSLinJiawei        c22.io.in := col
104*8a4dc19aSLinJiawei        sum = c22.io.out(0).asBool() +: cin
105*8a4dc19aSLinJiawei        cout2 = Seq(c22.io.out(1).asBool())
106*8a4dc19aSLinJiawei      case 3 =>
107*8a4dc19aSLinJiawei        val c32 = Module(new C32)
108*8a4dc19aSLinJiawei        c32.io.in := col
109*8a4dc19aSLinJiawei        sum = c32.io.out(0).asBool() +: cin
110*8a4dc19aSLinJiawei        cout2 = Seq(c32.io.out(1).asBool())
111*8a4dc19aSLinJiawei      case 4 =>
112*8a4dc19aSLinJiawei        val c53 = Module(new C53)
113*8a4dc19aSLinJiawei        for((x, y) <- c53.io.in.take(4) zip col){
114*8a4dc19aSLinJiawei          x := y
115*8a4dc19aSLinJiawei        }
116*8a4dc19aSLinJiawei        c53.io.in.last := (if(cin.nonEmpty) cin.head else 0.U)
117*8a4dc19aSLinJiawei        sum = Seq(c53.io.out(0).asBool()) ++ (if(cin.nonEmpty) cin.drop(1) else Nil)
118*8a4dc19aSLinJiawei        cout1 = Seq(c53.io.out(1).asBool())
119*8a4dc19aSLinJiawei        cout2 = Seq(c53.io.out(2).asBool())
120*8a4dc19aSLinJiawei      case n =>
121*8a4dc19aSLinJiawei        val cin_1 = if(cin.nonEmpty) Seq(cin.head) else Nil
122*8a4dc19aSLinJiawei        val cin_2 = if(cin.nonEmpty) cin.drop(1) else Nil
123*8a4dc19aSLinJiawei        val (s_1, c_1_1, c_1_2) = addOneColumn(col take 4, cin_1)
124*8a4dc19aSLinJiawei        val (s_2, c_2_1, c_2_2) = addOneColumn(col drop 4, cin_2)
125*8a4dc19aSLinJiawei        sum = s_1 ++ s_2
126*8a4dc19aSLinJiawei        cout1 = c_1_1 ++ c_2_1
127*8a4dc19aSLinJiawei        cout2 = c_1_2 ++ c_2_2
128*8a4dc19aSLinJiawei    }
129*8a4dc19aSLinJiawei    (sum, cout1, cout2)
130*8a4dc19aSLinJiawei  }
131*8a4dc19aSLinJiawei
132*8a4dc19aSLinJiawei  def max(in: Iterable[Int]): Int = in.reduce((a, b) => if(a>b) a else b)
133*8a4dc19aSLinJiawei  def addAll(cols: Array[Seq[Bool]], depth: Int): (UInt, UInt) = {
134*8a4dc19aSLinJiawei    if(max(cols.map(_.size)) <= 2){
135*8a4dc19aSLinJiawei      val sum = Cat(cols.map(_(0)).reverse)
136*8a4dc19aSLinJiawei      var k = 0
137*8a4dc19aSLinJiawei      while(cols(k).size == 1) k = k+1
138*8a4dc19aSLinJiawei      val carry = Cat(cols.drop(k).map(_(1)).reverse)
139*8a4dc19aSLinJiawei      (sum, Cat(carry, 0.U(k.W)))
140*8a4dc19aSLinJiawei    } else {
141*8a4dc19aSLinJiawei      val columns_next = Array.fill(2*len)(Seq[Bool]())
142*8a4dc19aSLinJiawei      var cout1, cout2 = Seq[Bool]()
143*8a4dc19aSLinJiawei      for( i <- cols.indices){
144*8a4dc19aSLinJiawei        val (s, c1, c2) = addOneColumn(cols(i), cout1)
145*8a4dc19aSLinJiawei        columns_next(i) = s ++ cout2
146*8a4dc19aSLinJiawei        cout1 = c1
147*8a4dc19aSLinJiawei        cout2 = c2
148*8a4dc19aSLinJiawei      }
149*8a4dc19aSLinJiawei
150*8a4dc19aSLinJiawei      val needReg = doRegSorted.contains(depth)
151*8a4dc19aSLinJiawei      val toNextLayer = if(needReg)
152*8a4dc19aSLinJiawei        columns_next.map(_.map(PipelineReg(doRegSorted.indexOf(depth) + 1)(_)))
153*8a4dc19aSLinJiawei      else
154*8a4dc19aSLinJiawei        columns_next
155*8a4dc19aSLinJiawei
156*8a4dc19aSLinJiawei      addAll(toNextLayer, depth+1)
157*8a4dc19aSLinJiawei    }
158*8a4dc19aSLinJiawei  }
159*8a4dc19aSLinJiawei
160*8a4dc19aSLinJiawei  val (sum, carry) = addAll(cols = columns, depth = 0)
161*8a4dc19aSLinJiawei  val result = sum + carry
162*8a4dc19aSLinJiawei
163*8a4dc19aSLinJiawei  var ctrlVec = Seq(ctrl)
164*8a4dc19aSLinJiawei  for(i <- 1 to latency){
165*8a4dc19aSLinJiawei    ctrlVec = ctrlVec :+ PipelineReg(i)(ctrlVec(i-1))
166*8a4dc19aSLinJiawei  }
167*8a4dc19aSLinJiawei  val xlen = io.out.bits.data.getWidth
168*8a4dc19aSLinJiawei  val res = Mux(ctrlVec.last.isHi, result(2*xlen-1, xlen), result(xlen-1,0))
169*8a4dc19aSLinJiawei
170*8a4dc19aSLinJiawei  io.out.bits.data := Mux(ctrlVec.last.isW, SignExt(res(31,0),xlen), res)
171*8a4dc19aSLinJiawei
172*8a4dc19aSLinJiawei  XSDebug(p"validVec:${Binary(Cat(validVec))} flushVec:${Binary(Cat(flushVec))}\n")
173*8a4dc19aSLinJiawei}