xref: /XiangShan/src/main/scala/xiangshan/backend/fu/Multiplier.scala (revision 45f43e6e5f88874a7573ff096d1e5c2855bd16c7)
1/***************************************************************************************
2* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3* Copyright (c) 2020-2021 Peng Cheng Laboratory
4*
5* XiangShan is licensed under Mulan PSL v2.
6* You can use this software according to the terms and conditions of the Mulan PSL v2.
7* You may obtain a copy of Mulan PSL v2 at:
8*          http://license.coscl.org.cn/MulanPSL2
9*
10* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13*
14* See the Mulan PSL v2 for more details.
15***************************************************************************************/
16
17package xiangshan.backend.fu
18
19import org.chipsalliance.cde.config.Parameters
20import chisel3._
21import chisel3.util._
22import xiangshan._
23import utils._
24import utility._
25import xiangshan.backend.fu.util.{C22, C32, C53}
26
27class MulDivCtrl extends Bundle{
28  val sign = Bool()
29  val isW = Bool()
30  val isHi = Bool() // return hi bits of result ?
31}
32
33class AbstractMultiplier(len: Int)(implicit p: Parameters) extends FunctionUnit(
34  len
35){
36  val ctrl = IO(Input(new MulDivCtrl))
37}
38
39class NaiveMultiplier(len: Int, val latency: Int)(implicit p: Parameters)
40  extends AbstractMultiplier(len)
41  with HasPipelineReg
42{
43
44  val (src1, src2) = (io.in.bits.src(0), io.in.bits.src(1))
45
46  val mulRes = src1.asSInt * src2.asSInt
47
48  var dataVec = Seq(mulRes.asUInt)
49  var ctrlVec = Seq(ctrl)
50
51  for(i <- 1 to latency){
52    dataVec = dataVec :+ PipelineReg(i)(dataVec(i-1))
53    ctrlVec = ctrlVec :+ PipelineReg(i)(ctrlVec(i-1))
54  }
55
56  val xlen = io.out.bits.data.getWidth
57  val res = Mux(ctrlVec.last.isHi, dataVec.last(2*xlen-1, xlen), dataVec.last(xlen-1,0))
58  io.out.bits.data := Mux(ctrlVec.last.isW, SignExt(res(31,0),xlen), res)
59
60  XSDebug(p"validVec:${Binary(Cat(validVec))} flushVec:${Binary(Cat(flushVec))}\n")
61}
62
63class ArrayMulDataModule(len: Int) extends Module {
64  val io = IO(new Bundle() {
65    val a, b = Input(UInt(len.W))
66    val regEnables = Input(Vec(2, Bool()))
67    val result = Output(UInt((2 * len).W))
68  })
69  val (a, b) = (io.a, io.b)
70
71  val b_sext, bx2, neg_b, neg_bx2 = Wire(UInt((len+1).W))
72  b_sext := SignExt(b, len+1)
73  bx2 := b_sext << 1
74  neg_b := (~b_sext).asUInt
75  neg_bx2 := neg_b << 1
76
77  val columns: Array[Seq[Bool]] = Array.fill(2*len)(Seq())
78
79  var last_x = WireInit(0.U(3.W))
80  for(i <- Range(0, len, 2)){
81    val x = if(i==0) Cat(a(1,0), 0.U(1.W)) else if(i+1==len) SignExt(a(i, i-1), 3) else a(i+1, i-1)
82    val pp_temp = MuxLookup(x, 0.U)(Seq(
83      1.U -> b_sext,
84      2.U -> b_sext,
85      3.U -> bx2,
86      4.U -> neg_bx2,
87      5.U -> neg_b,
88      6.U -> neg_b
89    ))
90    val s = pp_temp(len)
91    val t = MuxLookup(last_x, 0.U(2.W))(Seq(
92      4.U -> 2.U(2.W),
93      5.U -> 1.U(2.W),
94      6.U -> 1.U(2.W)
95    ))
96    last_x = x
97    val (pp, weight) = i match {
98      case 0 =>
99        (Cat(~s, s, s, pp_temp), 0)
100      case n if (n==len-1) || (n==len-2) =>
101        (Cat(~s, pp_temp, t), i-2)
102      case _ =>
103        (Cat(1.U(1.W), ~s, pp_temp, t), i-2)
104    }
105    for(j <- columns.indices){
106      if(j >= weight && j < (weight + pp.getWidth)){
107        columns(j) = columns(j) :+ pp(j-weight)
108      }
109    }
110  }
111
112  def addOneColumn(col: Seq[Bool], cin: Seq[Bool]): (Seq[Bool], Seq[Bool], Seq[Bool]) = {
113    var sum = Seq[Bool]()
114    var cout1 = Seq[Bool]()
115    var cout2 = Seq[Bool]()
116    col.size match {
117      case 1 =>  // do nothing
118        sum = col ++ cin
119      case 2 =>
120        val c22 = Module(new C22)
121        c22.io.in := col
122        sum = c22.io.out(0).asBool +: cin
123        cout2 = Seq(c22.io.out(1).asBool)
124      case 3 =>
125        val c32 = Module(new C32)
126        c32.io.in := col
127        sum = c32.io.out(0).asBool +: cin
128        cout2 = Seq(c32.io.out(1).asBool)
129      case 4 =>
130        val c53 = Module(new C53)
131        for((x, y) <- c53.io.in.take(4) zip col){
132          x := y
133        }
134        c53.io.in.last := (if(cin.nonEmpty) cin.head else 0.U)
135        sum = Seq(c53.io.out(0).asBool) ++ (if(cin.nonEmpty) cin.drop(1) else Nil)
136        cout1 = Seq(c53.io.out(1).asBool)
137        cout2 = Seq(c53.io.out(2).asBool)
138      case n =>
139        val cin_1 = if(cin.nonEmpty) Seq(cin.head) else Nil
140        val cin_2 = if(cin.nonEmpty) cin.drop(1) else Nil
141        val (s_1, c_1_1, c_1_2) = addOneColumn(col take 4, cin_1)
142        val (s_2, c_2_1, c_2_2) = addOneColumn(col drop 4, cin_2)
143        sum = s_1 ++ s_2
144        cout1 = c_1_1 ++ c_2_1
145        cout2 = c_1_2 ++ c_2_2
146    }
147    (sum, cout1, cout2)
148  }
149
150  def max(in: Iterable[Int]): Int = in.reduce((a, b) => if(a>b) a else b)
151  def addAll(cols: Array[Seq[Bool]], depth: Int): (UInt, UInt) = {
152    if(max(cols.map(_.size)) <= 2){
153      val sum = Cat(cols.map(_(0)).reverse)
154      var k = 0
155      while(cols(k).size == 1) k = k+1
156      val carry = Cat(cols.drop(k).map(_(1)).reverse)
157      (sum, Cat(carry, 0.U(k.W)))
158    } else {
159      val columns_next = Array.fill(2*len)(Seq[Bool]())
160      var cout1, cout2 = Seq[Bool]()
161      for( i <- cols.indices){
162        val (s, c1, c2) = addOneColumn(cols(i), cout1)
163        columns_next(i) = s ++ cout2
164        cout1 = c1
165        cout2 = c2
166      }
167
168      val needReg = depth == 4
169      val toNextLayer = if(needReg)
170        columns_next.map(_.map(x => RegEnable(x, io.regEnables(1))))
171      else
172        columns_next
173
174      addAll(toNextLayer, depth+1)
175    }
176  }
177
178  val columns_reg = columns.map(col => col.map(b => RegEnable(b, io.regEnables(0))))
179  val (sum, carry) = addAll(cols = columns_reg, depth = 0)
180
181  io.result := sum + carry
182}
183
184class ArrayMultiplier(len: Int)(implicit p: Parameters)
185  extends AbstractMultiplier(len) with HasPipelineReg {
186
187  override def latency = 2
188
189  val mulDataModule = Module(new ArrayMulDataModule(len))
190  mulDataModule.io.a := io.in.bits.src(0)
191  mulDataModule.io.b := io.in.bits.src(1)
192  mulDataModule.io.regEnables := VecInit((1 to latency) map (i => regEnable(i)))
193  val result = mulDataModule.io.result
194
195  var ctrlVec = Seq(ctrl)
196  for(i <- 1 to latency){
197    ctrlVec = ctrlVec :+ PipelineReg(i)(ctrlVec(i-1))
198  }
199  val xlen = len - 1
200  val res = Mux(ctrlVec.last.isHi, result(2*xlen-1, xlen), result(xlen-1,0))
201
202  io.out.bits.data := Mux(ctrlVec.last.isW, SignExt(res(31,0),xlen), res)
203
204  XSDebug(p"validVec:${Binary(Cat(validVec))} flushVec:${Binary(Cat(flushVec))}\n")
205}
206