xref: /XiangShan/src/main/scala/xiangshan/backend/fu/Multiplier.scala (revision e3da8bad334fc71ba0d72f0607e2e93245ddaece)
1c6d43980SLemover/***************************************************************************************
2c6d43980SLemover* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3f320e0f0SYinan Xu* Copyright (c) 2020-2021 Peng Cheng Laboratory
4c6d43980SLemover*
5c6d43980SLemover* XiangShan is licensed under Mulan PSL v2.
6c6d43980SLemover* You can use this software according to the terms and conditions of the Mulan PSL v2.
7c6d43980SLemover* You may obtain a copy of Mulan PSL v2 at:
8c6d43980SLemover*          http://license.coscl.org.cn/MulanPSL2
9c6d43980SLemover*
10c6d43980SLemover* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11c6d43980SLemover* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12c6d43980SLemover* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13c6d43980SLemover*
14c6d43980SLemover* See the Mulan PSL v2 for more details.
15c6d43980SLemover***************************************************************************************/
16c6d43980SLemover
17cafb3558SLinJiaweipackage xiangshan.backend.fu
18cafb3558SLinJiawei
198891a219SYinan Xuimport org.chipsalliance.cde.config.Parameters
20cafb3558SLinJiaweiimport chisel3._
21cafb3558SLinJiaweiimport chisel3.util._
223c02ee8fSwakafaimport utility._
233b739f49SXuan Huimport utils._
243b739f49SXuan Huimport xiangshan._
257f1506e3SLinJiaweiimport xiangshan.backend.fu.util.{C22, C32, C53}
26cafb3558SLinJiawei
27cafb3558SLinJiaweiclass MulDivCtrl extends Bundle{
28cafb3558SLinJiawei  val sign = Bool()
29cafb3558SLinJiawei  val isW = Bool()
30cafb3558SLinJiawei  val isHi = Bool() // return hi bits of result ?
31cafb3558SLinJiawei}
32cafb3558SLinJiawei
33c3d7991bSJiawei Linclass ArrayMulDataModule(len: Int) extends Module {
34e2203130SLinJiawei  val io = IO(new Bundle() {
35e2203130SLinJiawei    val a, b = Input(UInt(len.W))
36c3d7991bSJiawei Lin    val regEnables = Input(Vec(2, Bool()))
37e2203130SLinJiawei    val result = Output(UInt((2 * len).W))
38e2203130SLinJiawei  })
39e2203130SLinJiawei  val (a, b) = (io.a, io.b)
408a4dc19aSLinJiawei
418a4dc19aSLinJiawei  val b_sext, bx2, neg_b, neg_bx2 = Wire(UInt((len+1).W))
428a4dc19aSLinJiawei  b_sext := SignExt(b, len+1)
438a4dc19aSLinJiawei  bx2 := b_sext << 1
44935edac4STang Haojin  neg_b := (~b_sext).asUInt
458a4dc19aSLinJiawei  neg_bx2 := neg_b << 1
468a4dc19aSLinJiawei
478a4dc19aSLinJiawei  val columns: Array[Seq[Bool]] = Array.fill(2*len)(Seq())
488a4dc19aSLinJiawei
498a4dc19aSLinJiawei  var last_x = WireInit(0.U(3.W))
508a4dc19aSLinJiawei  for(i <- Range(0, len, 2)){
518a4dc19aSLinJiawei    val x = if(i==0) Cat(a(1,0), 0.U(1.W)) else if(i+1==len) SignExt(a(i, i-1), 3) else a(i+1, i-1)
5245f43e6eSTang Haojin    val pp_temp = MuxLookup(x, 0.U)(Seq(
538a4dc19aSLinJiawei      1.U -> b_sext,
548a4dc19aSLinJiawei      2.U -> b_sext,
558a4dc19aSLinJiawei      3.U -> bx2,
568a4dc19aSLinJiawei      4.U -> neg_bx2,
578a4dc19aSLinJiawei      5.U -> neg_b,
588a4dc19aSLinJiawei      6.U -> neg_b
598a4dc19aSLinJiawei    ))
608a4dc19aSLinJiawei    val s = pp_temp(len)
6145f43e6eSTang Haojin    val t = MuxLookup(last_x, 0.U(2.W))(Seq(
628a4dc19aSLinJiawei      4.U -> 2.U(2.W),
638a4dc19aSLinJiawei      5.U -> 1.U(2.W),
648a4dc19aSLinJiawei      6.U -> 1.U(2.W)
658a4dc19aSLinJiawei    ))
668a4dc19aSLinJiawei    last_x = x
678a4dc19aSLinJiawei    val (pp, weight) = i match {
688a4dc19aSLinJiawei      case 0 =>
698a4dc19aSLinJiawei        (Cat(~s, s, s, pp_temp), 0)
708a4dc19aSLinJiawei      case n if (n==len-1) || (n==len-2) =>
718a4dc19aSLinJiawei        (Cat(~s, pp_temp, t), i-2)
728a4dc19aSLinJiawei      case _ =>
738a4dc19aSLinJiawei        (Cat(1.U(1.W), ~s, pp_temp, t), i-2)
748a4dc19aSLinJiawei    }
758a4dc19aSLinJiawei    for(j <- columns.indices){
768a4dc19aSLinJiawei      if(j >= weight && j < (weight + pp.getWidth)){
778a4dc19aSLinJiawei        columns(j) = columns(j) :+ pp(j-weight)
788a4dc19aSLinJiawei      }
798a4dc19aSLinJiawei    }
808a4dc19aSLinJiawei  }
818a4dc19aSLinJiawei
828a4dc19aSLinJiawei  def addOneColumn(col: Seq[Bool], cin: Seq[Bool]): (Seq[Bool], Seq[Bool], Seq[Bool]) = {
838a4dc19aSLinJiawei    var sum = Seq[Bool]()
848a4dc19aSLinJiawei    var cout1 = Seq[Bool]()
858a4dc19aSLinJiawei    var cout2 = Seq[Bool]()
868a4dc19aSLinJiawei    col.size match {
878a4dc19aSLinJiawei      case 1 =>  // do nothing
888a4dc19aSLinJiawei        sum = col ++ cin
898a4dc19aSLinJiawei      case 2 =>
908a4dc19aSLinJiawei        val c22 = Module(new C22)
918a4dc19aSLinJiawei        c22.io.in := col
92935edac4STang Haojin        sum = c22.io.out(0).asBool +: cin
93935edac4STang Haojin        cout2 = Seq(c22.io.out(1).asBool)
948a4dc19aSLinJiawei      case 3 =>
958a4dc19aSLinJiawei        val c32 = Module(new C32)
968a4dc19aSLinJiawei        c32.io.in := col
97935edac4STang Haojin        sum = c32.io.out(0).asBool +: cin
98935edac4STang Haojin        cout2 = Seq(c32.io.out(1).asBool)
998a4dc19aSLinJiawei      case 4 =>
1008a4dc19aSLinJiawei        val c53 = Module(new C53)
1018a4dc19aSLinJiawei        for((x, y) <- c53.io.in.take(4) zip col){
1028a4dc19aSLinJiawei          x := y
1038a4dc19aSLinJiawei        }
1048a4dc19aSLinJiawei        c53.io.in.last := (if(cin.nonEmpty) cin.head else 0.U)
105935edac4STang Haojin        sum = Seq(c53.io.out(0).asBool) ++ (if(cin.nonEmpty) cin.drop(1) else Nil)
106935edac4STang Haojin        cout1 = Seq(c53.io.out(1).asBool)
107935edac4STang Haojin        cout2 = Seq(c53.io.out(2).asBool)
1088a4dc19aSLinJiawei      case n =>
1098a4dc19aSLinJiawei        val cin_1 = if(cin.nonEmpty) Seq(cin.head) else Nil
1108a4dc19aSLinJiawei        val cin_2 = if(cin.nonEmpty) cin.drop(1) else Nil
1118a4dc19aSLinJiawei        val (s_1, c_1_1, c_1_2) = addOneColumn(col take 4, cin_1)
1128a4dc19aSLinJiawei        val (s_2, c_2_1, c_2_2) = addOneColumn(col drop 4, cin_2)
1138a4dc19aSLinJiawei        sum = s_1 ++ s_2
1148a4dc19aSLinJiawei        cout1 = c_1_1 ++ c_2_1
1158a4dc19aSLinJiawei        cout2 = c_1_2 ++ c_2_2
1168a4dc19aSLinJiawei    }
1178a4dc19aSLinJiawei    (sum, cout1, cout2)
1188a4dc19aSLinJiawei  }
1198a4dc19aSLinJiawei
1208a4dc19aSLinJiawei  def max(in: Iterable[Int]): Int = in.reduce((a, b) => if(a>b) a else b)
121*e3da8badSTang Haojin  def addAll(cols: Seq[Seq[Bool]], depth: Int): (UInt, UInt) = {
1228a4dc19aSLinJiawei    if(max(cols.map(_.size)) <= 2){
1238a4dc19aSLinJiawei      val sum = Cat(cols.map(_(0)).reverse)
1248a4dc19aSLinJiawei      var k = 0
1258a4dc19aSLinJiawei      while(cols(k).size == 1) k = k+1
1268a4dc19aSLinJiawei      val carry = Cat(cols.drop(k).map(_(1)).reverse)
1278a4dc19aSLinJiawei      (sum, Cat(carry, 0.U(k.W)))
1288a4dc19aSLinJiawei    } else {
1298a4dc19aSLinJiawei      val columns_next = Array.fill(2*len)(Seq[Bool]())
1308a4dc19aSLinJiawei      var cout1, cout2 = Seq[Bool]()
1318a4dc19aSLinJiawei      for( i <- cols.indices){
1328a4dc19aSLinJiawei        val (s, c1, c2) = addOneColumn(cols(i), cout1)
1338a4dc19aSLinJiawei        columns_next(i) = s ++ cout2
1348a4dc19aSLinJiawei        cout1 = c1
1358a4dc19aSLinJiawei        cout2 = c2
1368a4dc19aSLinJiawei      }
1378a4dc19aSLinJiawei
138c3d7991bSJiawei Lin      val needReg = depth == 4
1398a4dc19aSLinJiawei      val toNextLayer = if(needReg)
140c3d7991bSJiawei Lin        columns_next.map(_.map(x => RegEnable(x, io.regEnables(1))))
1418a4dc19aSLinJiawei      else
1428a4dc19aSLinJiawei        columns_next
1438a4dc19aSLinJiawei
144*e3da8badSTang Haojin      addAll(toNextLayer.toSeq, depth+1)
1458a4dc19aSLinJiawei    }
1468a4dc19aSLinJiawei  }
1478a4dc19aSLinJiawei
148c3d7991bSJiawei Lin  val columns_reg = columns.map(col => col.map(b => RegEnable(b, io.regEnables(0))))
149*e3da8badSTang Haojin  val (sum, carry) = addAll(cols = columns_reg.toSeq, depth = 0)
150c3d7991bSJiawei Lin
151e2203130SLinJiawei  io.result := sum + carry
152e2203130SLinJiawei}
153