xref: /XiangShan/src/main/scala/xiangshan/backend/fu/Multiplier.scala (revision 7f1506e34f4f1556f09fd3d96108d0b558ad4881)
1package xiangshan.backend.fu
2
3import chisel3._
4import chisel3.util._
5import xiangshan._
6import utils._
7import xiangshan.backend.fu.util.{C22, C32, C53}
8
9class MulDivCtrl extends Bundle{
10  val sign = Bool()
11  val isW = Bool()
12  val isHi = Bool() // return hi bits of result ?
13}
14
15class AbstractMultiplier(len: Int) extends FunctionUnit(
16  len
17){
18  val ctrl = IO(Input(new MulDivCtrl))
19}
20
21class NaiveMultiplier(len: Int, val latency: Int)
22  extends AbstractMultiplier(len)
23  with HasPipelineReg
24{
25
26  val (src1, src2) = (io.in.bits.src(0), io.in.bits.src(1))
27
28  val mulRes = src1.asSInt() * src2.asSInt()
29
30  var dataVec = Seq(mulRes.asUInt())
31  var ctrlVec = Seq(ctrl)
32
33  for(i <- 1 to latency){
34    dataVec = dataVec :+ PipelineReg(i)(dataVec(i-1))
35    ctrlVec = ctrlVec :+ PipelineReg(i)(ctrlVec(i-1))
36  }
37
38  val xlen = io.out.bits.data.getWidth
39  val res = Mux(ctrlVec.last.isHi, dataVec.last(2*xlen-1, xlen), dataVec.last(xlen-1,0))
40  io.out.bits.data := Mux(ctrlVec.last.isW, SignExt(res(31,0),xlen), res)
41
42  XSDebug(p"validVec:${Binary(Cat(validVec))} flushVec:${Binary(Cat(flushVec))}\n")
43}
44
45class ArrayMultiplier(len: Int, doReg: Seq[Int]) extends AbstractMultiplier(len) with HasPipelineReg {
46
47  override def latency = doReg.size
48
49  val doRegSorted = doReg.sortWith(_ < _)
50  println(doRegSorted)
51
52  val (a, b) = (io.in.bits.src(0), io.in.bits.src(1))
53
54  val b_sext, bx2, neg_b, neg_bx2 = Wire(UInt((len+1).W))
55  b_sext := SignExt(b, len+1)
56  bx2 := b_sext << 1
57  neg_b := (~b_sext).asUInt()
58  neg_bx2 := neg_b << 1
59
60  val columns: Array[Seq[Bool]] = Array.fill(2*len)(Seq())
61
62  var last_x = WireInit(0.U(3.W))
63  for(i <- Range(0, len, 2)){
64    val x = if(i==0) Cat(a(1,0), 0.U(1.W)) else if(i+1==len) SignExt(a(i, i-1), 3) else a(i+1, i-1)
65    val pp_temp = MuxLookup(x, 0.U, Seq(
66      1.U -> b_sext,
67      2.U -> b_sext,
68      3.U -> bx2,
69      4.U -> neg_bx2,
70      5.U -> neg_b,
71      6.U -> neg_b
72    ))
73    val s = pp_temp(len)
74    val t = MuxLookup(last_x, 0.U(2.W), Seq(
75      4.U -> 2.U(2.W),
76      5.U -> 1.U(2.W),
77      6.U -> 1.U(2.W)
78    ))
79    last_x = x
80    val (pp, weight) = i match {
81      case 0 =>
82        (Cat(~s, s, s, pp_temp), 0)
83      case n if (n==len-1) || (n==len-2) =>
84        (Cat(~s, pp_temp, t), i-2)
85      case _ =>
86        (Cat(1.U(1.W), ~s, pp_temp, t), i-2)
87    }
88    for(j <- columns.indices){
89      if(j >= weight && j < (weight + pp.getWidth)){
90        columns(j) = columns(j) :+ pp(j-weight)
91      }
92    }
93  }
94
95  def addOneColumn(col: Seq[Bool], cin: Seq[Bool]): (Seq[Bool], Seq[Bool], Seq[Bool]) = {
96    var sum = Seq[Bool]()
97    var cout1 = Seq[Bool]()
98    var cout2 = Seq[Bool]()
99    col.size match {
100      case 1 =>  // do nothing
101        sum = col ++ cin
102      case 2 =>
103        val c22 = Module(new C22)
104        c22.io.in := col
105        sum = c22.io.out(0).asBool() +: cin
106        cout2 = Seq(c22.io.out(1).asBool())
107      case 3 =>
108        val c32 = Module(new C32)
109        c32.io.in := col
110        sum = c32.io.out(0).asBool() +: cin
111        cout2 = Seq(c32.io.out(1).asBool())
112      case 4 =>
113        val c53 = Module(new C53)
114        for((x, y) <- c53.io.in.take(4) zip col){
115          x := y
116        }
117        c53.io.in.last := (if(cin.nonEmpty) cin.head else 0.U)
118        sum = Seq(c53.io.out(0).asBool()) ++ (if(cin.nonEmpty) cin.drop(1) else Nil)
119        cout1 = Seq(c53.io.out(1).asBool())
120        cout2 = Seq(c53.io.out(2).asBool())
121      case n =>
122        val cin_1 = if(cin.nonEmpty) Seq(cin.head) else Nil
123        val cin_2 = if(cin.nonEmpty) cin.drop(1) else Nil
124        val (s_1, c_1_1, c_1_2) = addOneColumn(col take 4, cin_1)
125        val (s_2, c_2_1, c_2_2) = addOneColumn(col drop 4, cin_2)
126        sum = s_1 ++ s_2
127        cout1 = c_1_1 ++ c_2_1
128        cout2 = c_1_2 ++ c_2_2
129    }
130    (sum, cout1, cout2)
131  }
132
133  def max(in: Iterable[Int]): Int = in.reduce((a, b) => if(a>b) a else b)
134  def addAll(cols: Array[Seq[Bool]], depth: Int): (UInt, UInt) = {
135    if(max(cols.map(_.size)) <= 2){
136      val sum = Cat(cols.map(_(0)).reverse)
137      var k = 0
138      while(cols(k).size == 1) k = k+1
139      val carry = Cat(cols.drop(k).map(_(1)).reverse)
140      (sum, Cat(carry, 0.U(k.W)))
141    } else {
142      val columns_next = Array.fill(2*len)(Seq[Bool]())
143      var cout1, cout2 = Seq[Bool]()
144      for( i <- cols.indices){
145        val (s, c1, c2) = addOneColumn(cols(i), cout1)
146        columns_next(i) = s ++ cout2
147        cout1 = c1
148        cout2 = c2
149      }
150
151      val needReg = doRegSorted.contains(depth)
152      val toNextLayer = if(needReg)
153        columns_next.map(_.map(PipelineReg(doRegSorted.indexOf(depth) + 1)(_)))
154      else
155        columns_next
156
157      addAll(toNextLayer, depth+1)
158    }
159  }
160
161  val (sum, carry) = addAll(cols = columns, depth = 0)
162  val result = sum + carry
163
164  var ctrlVec = Seq(ctrl)
165  for(i <- 1 to latency){
166    ctrlVec = ctrlVec :+ PipelineReg(i)(ctrlVec(i-1))
167  }
168  val xlen = io.out.bits.data.getWidth
169  val res = Mux(ctrlVec.last.isHi, result(2*xlen-1, xlen), result(xlen-1,0))
170
171  io.out.bits.data := Mux(ctrlVec.last.isW, SignExt(res(31,0),xlen), res)
172
173  XSDebug(p"validVec:${Binary(Cat(validVec))} flushVec:${Binary(Cat(flushVec))}\n")
174}