xref: /XiangShan/src/main/scala/xiangshan/backend/fu/Multiplier.scala (revision 2225d46ebbe2fd16b9b29963c27a7d0385a42709)
1package xiangshan.backend.fu
2
3import chipsalliance.rocketchip.config.Parameters
4import chisel3._
5import chisel3.util._
6import xiangshan._
7import utils._
8import xiangshan.backend.fu.util.{C22, C32, C53}
9
10class MulDivCtrl extends Bundle{
11  val sign = Bool()
12  val isW = Bool()
13  val isHi = Bool() // return hi bits of result ?
14}
15
16class AbstractMultiplier(len: Int)(implicit p: Parameters) extends FunctionUnit(
17  len
18){
19  val ctrl = IO(Input(new MulDivCtrl))
20}
21
22class NaiveMultiplier(len: Int, val latency: Int)(implicit p: Parameters)
23  extends AbstractMultiplier(len)
24  with HasPipelineReg
25{
26
27  val (src1, src2) = (io.in.bits.src(0), io.in.bits.src(1))
28
29  val mulRes = src1.asSInt() * src2.asSInt()
30
31  var dataVec = Seq(mulRes.asUInt())
32  var ctrlVec = Seq(ctrl)
33
34  for(i <- 1 to latency){
35    dataVec = dataVec :+ PipelineReg(i)(dataVec(i-1))
36    ctrlVec = ctrlVec :+ PipelineReg(i)(ctrlVec(i-1))
37  }
38
39  val xlen = io.out.bits.data.getWidth
40  val res = Mux(ctrlVec.last.isHi, dataVec.last(2*xlen-1, xlen), dataVec.last(xlen-1,0))
41  io.out.bits.data := Mux(ctrlVec.last.isW, SignExt(res(31,0),xlen), res)
42
43  XSDebug(p"validVec:${Binary(Cat(validVec))} flushVec:${Binary(Cat(flushVec))}\n")
44}
45
46class ArrayMulDataModule(len: Int, doReg: Seq[Int]) extends Module {
47  val io = IO(new Bundle() {
48    val a, b = Input(UInt(len.W))
49    val regEnables = Input(Vec(doReg.size, Bool()))
50    val result = Output(UInt((2 * len).W))
51  })
52  val (a, b) = (io.a, io.b)
53  val doRegSorted = doReg.sortWith(_ < _)
54
55  val b_sext, bx2, neg_b, neg_bx2 = Wire(UInt((len+1).W))
56  b_sext := SignExt(b, len+1)
57  bx2 := b_sext << 1
58  neg_b := (~b_sext).asUInt()
59  neg_bx2 := neg_b << 1
60
61  val columns: Array[Seq[Bool]] = Array.fill(2*len)(Seq())
62
63  var last_x = WireInit(0.U(3.W))
64  for(i <- Range(0, len, 2)){
65    val x = if(i==0) Cat(a(1,0), 0.U(1.W)) else if(i+1==len) SignExt(a(i, i-1), 3) else a(i+1, i-1)
66    val pp_temp = MuxLookup(x, 0.U, Seq(
67      1.U -> b_sext,
68      2.U -> b_sext,
69      3.U -> bx2,
70      4.U -> neg_bx2,
71      5.U -> neg_b,
72      6.U -> neg_b
73    ))
74    val s = pp_temp(len)
75    val t = MuxLookup(last_x, 0.U(2.W), Seq(
76      4.U -> 2.U(2.W),
77      5.U -> 1.U(2.W),
78      6.U -> 1.U(2.W)
79    ))
80    last_x = x
81    val (pp, weight) = i match {
82      case 0 =>
83        (Cat(~s, s, s, pp_temp), 0)
84      case n if (n==len-1) || (n==len-2) =>
85        (Cat(~s, pp_temp, t), i-2)
86      case _ =>
87        (Cat(1.U(1.W), ~s, pp_temp, t), i-2)
88    }
89    for(j <- columns.indices){
90      if(j >= weight && j < (weight + pp.getWidth)){
91        columns(j) = columns(j) :+ pp(j-weight)
92      }
93    }
94  }
95
96  def addOneColumn(col: Seq[Bool], cin: Seq[Bool]): (Seq[Bool], Seq[Bool], Seq[Bool]) = {
97    var sum = Seq[Bool]()
98    var cout1 = Seq[Bool]()
99    var cout2 = Seq[Bool]()
100    col.size match {
101      case 1 =>  // do nothing
102        sum = col ++ cin
103      case 2 =>
104        val c22 = Module(new C22)
105        c22.io.in := col
106        sum = c22.io.out(0).asBool() +: cin
107        cout2 = Seq(c22.io.out(1).asBool())
108      case 3 =>
109        val c32 = Module(new C32)
110        c32.io.in := col
111        sum = c32.io.out(0).asBool() +: cin
112        cout2 = Seq(c32.io.out(1).asBool())
113      case 4 =>
114        val c53 = Module(new C53)
115        for((x, y) <- c53.io.in.take(4) zip col){
116          x := y
117        }
118        c53.io.in.last := (if(cin.nonEmpty) cin.head else 0.U)
119        sum = Seq(c53.io.out(0).asBool()) ++ (if(cin.nonEmpty) cin.drop(1) else Nil)
120        cout1 = Seq(c53.io.out(1).asBool())
121        cout2 = Seq(c53.io.out(2).asBool())
122      case n =>
123        val cin_1 = if(cin.nonEmpty) Seq(cin.head) else Nil
124        val cin_2 = if(cin.nonEmpty) cin.drop(1) else Nil
125        val (s_1, c_1_1, c_1_2) = addOneColumn(col take 4, cin_1)
126        val (s_2, c_2_1, c_2_2) = addOneColumn(col drop 4, cin_2)
127        sum = s_1 ++ s_2
128        cout1 = c_1_1 ++ c_2_1
129        cout2 = c_1_2 ++ c_2_2
130    }
131    (sum, cout1, cout2)
132  }
133
134  def max(in: Iterable[Int]): Int = in.reduce((a, b) => if(a>b) a else b)
135  def addAll(cols: Array[Seq[Bool]], depth: Int): (UInt, UInt) = {
136    if(max(cols.map(_.size)) <= 2){
137      val sum = Cat(cols.map(_(0)).reverse)
138      var k = 0
139      while(cols(k).size == 1) k = k+1
140      val carry = Cat(cols.drop(k).map(_(1)).reverse)
141      (sum, Cat(carry, 0.U(k.W)))
142    } else {
143      val columns_next = Array.fill(2*len)(Seq[Bool]())
144      var cout1, cout2 = Seq[Bool]()
145      for( i <- cols.indices){
146        val (s, c1, c2) = addOneColumn(cols(i), cout1)
147        columns_next(i) = s ++ cout2
148        cout1 = c1
149        cout2 = c2
150      }
151
152      val needReg = doRegSorted.contains(depth)
153      val toNextLayer = if(needReg)
154        columns_next.map(_.map(x => RegEnable(x, io.regEnables(doRegSorted.indexOf(depth)))))
155      else
156        columns_next
157
158      addAll(toNextLayer, depth+1)
159    }
160  }
161
162  val (sum, carry) = addAll(cols = columns, depth = 0)
163  io.result := sum + carry
164}
165
166class ArrayMultiplier(len: Int, doReg: Seq[Int])(implicit p: Parameters)
167  extends AbstractMultiplier(len) with HasPipelineReg {
168
169  override def latency = doReg.size
170
171  val mulDataModule = Module(new ArrayMulDataModule(len, doReg))
172  mulDataModule.io.a := io.in.bits.src(0)
173  mulDataModule.io.b := io.in.bits.src(1)
174  mulDataModule.io.regEnables := VecInit((1 to doReg.size) map (i => regEnable(i)))
175  val result = mulDataModule.io.result
176
177  var ctrlVec = Seq(ctrl)
178  for(i <- 1 to latency){
179    ctrlVec = ctrlVec :+ PipelineReg(i)(ctrlVec(i-1))
180  }
181  val xlen = len - 1
182  val res = Mux(ctrlVec.last.isHi, result(2*xlen-1, xlen), result(xlen-1,0))
183
184  io.out.bits.data := Mux(ctrlVec.last.isW, SignExt(res(31,0),xlen), res)
185
186  XSDebug(p"validVec:${Binary(Cat(validVec))} flushVec:${Binary(Cat(flushVec))}\n")
187}