1cafb3558SLinJiaweipackage xiangshan.backend.fu 2cafb3558SLinJiawei 3cafb3558SLinJiaweiimport chisel3._ 4cafb3558SLinJiaweiimport chisel3.util._ 5cafb3558SLinJiaweiimport xiangshan._ 6b9fd1892SLinJiaweiimport utils._ 7*7f1506e3SLinJiaweiimport xiangshan.backend.fu.util.{C22, C32, C53} 8cafb3558SLinJiawei 9cafb3558SLinJiaweiclass MulDivCtrl extends Bundle{ 10cafb3558SLinJiawei val sign = Bool() 11cafb3558SLinJiawei val isW = Bool() 12cafb3558SLinJiawei val isHi = Bool() // return hi bits of result ? 13cafb3558SLinJiawei} 14cafb3558SLinJiawei 1552c3f215SLinJiaweiclass AbstractMultiplier(len: Int) extends FunctionUnit( 16e18c367fSLinJiawei len 178a4dc19aSLinJiawei){ 188a4dc19aSLinJiawei val ctrl = IO(Input(new MulDivCtrl)) 198a4dc19aSLinJiawei} 208a4dc19aSLinJiawei 2152c3f215SLinJiaweiclass NaiveMultiplier(len: Int, val latency: Int) 2252c3f215SLinJiawei extends AbstractMultiplier(len) 23e18c367fSLinJiawei with HasPipelineReg 2412bb47ddSLinJiawei{ 253142d695SLinJiawei 263142d695SLinJiawei val (src1, src2) = (io.in.bits.src(0), io.in.bits.src(1)) 27cafb3558SLinJiawei 283142d695SLinJiawei val mulRes = src1.asSInt() * src2.asSInt() 29cafb3558SLinJiawei 30cafb3558SLinJiawei var dataVec = Seq(mulRes.asUInt()) 31e18c367fSLinJiawei var ctrlVec = Seq(ctrl) 32cafb3558SLinJiawei 33cafb3558SLinJiawei for(i <- 1 to latency){ 34cafb3558SLinJiawei dataVec = dataVec :+ PipelineReg(i)(dataVec(i-1)) 353142d695SLinJiawei ctrlVec = ctrlVec :+ PipelineReg(i)(ctrlVec(i-1)) 36cafb3558SLinJiawei } 37cafb3558SLinJiawei 38cafb3558SLinJiawei val xlen = io.out.bits.data.getWidth 39cafb3558SLinJiawei val res = Mux(ctrlVec.last.isHi, dataVec.last(2*xlen-1, xlen), dataVec.last(xlen-1,0)) 40cafb3558SLinJiawei io.out.bits.data := Mux(ctrlVec.last.isW, SignExt(res(31,0),xlen), res) 41cafb3558SLinJiawei 42e18c367fSLinJiawei XSDebug(p"validVec:${Binary(Cat(validVec))} flushVec:${Binary(Cat(flushVec))}\n") 43cafb3558SLinJiawei} 448a4dc19aSLinJiawei 4552c3f215SLinJiaweiclass ArrayMultiplier(len: Int, doReg: Seq[Int]) extends AbstractMultiplier(len) with HasPipelineReg { 4652c3f215SLinJiawei 47f64ff6e8SLinJiawei override def latency = doReg.size 488a4dc19aSLinJiawei 498a4dc19aSLinJiawei val doRegSorted = doReg.sortWith(_ < _) 508a4dc19aSLinJiawei println(doRegSorted) 518a4dc19aSLinJiawei 528a4dc19aSLinJiawei val (a, b) = (io.in.bits.src(0), io.in.bits.src(1)) 538a4dc19aSLinJiawei 548a4dc19aSLinJiawei val b_sext, bx2, neg_b, neg_bx2 = Wire(UInt((len+1).W)) 558a4dc19aSLinJiawei b_sext := SignExt(b, len+1) 568a4dc19aSLinJiawei bx2 := b_sext << 1 578a4dc19aSLinJiawei neg_b := (~b_sext).asUInt() 588a4dc19aSLinJiawei neg_bx2 := neg_b << 1 598a4dc19aSLinJiawei 608a4dc19aSLinJiawei val columns: Array[Seq[Bool]] = Array.fill(2*len)(Seq()) 618a4dc19aSLinJiawei 628a4dc19aSLinJiawei var last_x = WireInit(0.U(3.W)) 638a4dc19aSLinJiawei for(i <- Range(0, len, 2)){ 648a4dc19aSLinJiawei val x = if(i==0) Cat(a(1,0), 0.U(1.W)) else if(i+1==len) SignExt(a(i, i-1), 3) else a(i+1, i-1) 658a4dc19aSLinJiawei val pp_temp = MuxLookup(x, 0.U, Seq( 668a4dc19aSLinJiawei 1.U -> b_sext, 678a4dc19aSLinJiawei 2.U -> b_sext, 688a4dc19aSLinJiawei 3.U -> bx2, 698a4dc19aSLinJiawei 4.U -> neg_bx2, 708a4dc19aSLinJiawei 5.U -> neg_b, 718a4dc19aSLinJiawei 6.U -> neg_b 728a4dc19aSLinJiawei )) 738a4dc19aSLinJiawei val s = pp_temp(len) 748a4dc19aSLinJiawei val t = MuxLookup(last_x, 0.U(2.W), Seq( 758a4dc19aSLinJiawei 4.U -> 2.U(2.W), 768a4dc19aSLinJiawei 5.U -> 1.U(2.W), 778a4dc19aSLinJiawei 6.U -> 1.U(2.W) 788a4dc19aSLinJiawei )) 798a4dc19aSLinJiawei last_x = x 808a4dc19aSLinJiawei val (pp, weight) = i match { 818a4dc19aSLinJiawei case 0 => 828a4dc19aSLinJiawei (Cat(~s, s, s, pp_temp), 0) 838a4dc19aSLinJiawei case n if (n==len-1) || (n==len-2) => 848a4dc19aSLinJiawei (Cat(~s, pp_temp, t), i-2) 858a4dc19aSLinJiawei case _ => 868a4dc19aSLinJiawei (Cat(1.U(1.W), ~s, pp_temp, t), i-2) 878a4dc19aSLinJiawei } 888a4dc19aSLinJiawei for(j <- columns.indices){ 898a4dc19aSLinJiawei if(j >= weight && j < (weight + pp.getWidth)){ 908a4dc19aSLinJiawei columns(j) = columns(j) :+ pp(j-weight) 918a4dc19aSLinJiawei } 928a4dc19aSLinJiawei } 938a4dc19aSLinJiawei } 948a4dc19aSLinJiawei 958a4dc19aSLinJiawei def addOneColumn(col: Seq[Bool], cin: Seq[Bool]): (Seq[Bool], Seq[Bool], Seq[Bool]) = { 968a4dc19aSLinJiawei var sum = Seq[Bool]() 978a4dc19aSLinJiawei var cout1 = Seq[Bool]() 988a4dc19aSLinJiawei var cout2 = Seq[Bool]() 998a4dc19aSLinJiawei col.size match { 1008a4dc19aSLinJiawei case 1 => // do nothing 1018a4dc19aSLinJiawei sum = col ++ cin 1028a4dc19aSLinJiawei case 2 => 1038a4dc19aSLinJiawei val c22 = Module(new C22) 1048a4dc19aSLinJiawei c22.io.in := col 1058a4dc19aSLinJiawei sum = c22.io.out(0).asBool() +: cin 1068a4dc19aSLinJiawei cout2 = Seq(c22.io.out(1).asBool()) 1078a4dc19aSLinJiawei case 3 => 1088a4dc19aSLinJiawei val c32 = Module(new C32) 1098a4dc19aSLinJiawei c32.io.in := col 1108a4dc19aSLinJiawei sum = c32.io.out(0).asBool() +: cin 1118a4dc19aSLinJiawei cout2 = Seq(c32.io.out(1).asBool()) 1128a4dc19aSLinJiawei case 4 => 1138a4dc19aSLinJiawei val c53 = Module(new C53) 1148a4dc19aSLinJiawei for((x, y) <- c53.io.in.take(4) zip col){ 1158a4dc19aSLinJiawei x := y 1168a4dc19aSLinJiawei } 1178a4dc19aSLinJiawei c53.io.in.last := (if(cin.nonEmpty) cin.head else 0.U) 1188a4dc19aSLinJiawei sum = Seq(c53.io.out(0).asBool()) ++ (if(cin.nonEmpty) cin.drop(1) else Nil) 1198a4dc19aSLinJiawei cout1 = Seq(c53.io.out(1).asBool()) 1208a4dc19aSLinJiawei cout2 = Seq(c53.io.out(2).asBool()) 1218a4dc19aSLinJiawei case n => 1228a4dc19aSLinJiawei val cin_1 = if(cin.nonEmpty) Seq(cin.head) else Nil 1238a4dc19aSLinJiawei val cin_2 = if(cin.nonEmpty) cin.drop(1) else Nil 1248a4dc19aSLinJiawei val (s_1, c_1_1, c_1_2) = addOneColumn(col take 4, cin_1) 1258a4dc19aSLinJiawei val (s_2, c_2_1, c_2_2) = addOneColumn(col drop 4, cin_2) 1268a4dc19aSLinJiawei sum = s_1 ++ s_2 1278a4dc19aSLinJiawei cout1 = c_1_1 ++ c_2_1 1288a4dc19aSLinJiawei cout2 = c_1_2 ++ c_2_2 1298a4dc19aSLinJiawei } 1308a4dc19aSLinJiawei (sum, cout1, cout2) 1318a4dc19aSLinJiawei } 1328a4dc19aSLinJiawei 1338a4dc19aSLinJiawei def max(in: Iterable[Int]): Int = in.reduce((a, b) => if(a>b) a else b) 1348a4dc19aSLinJiawei def addAll(cols: Array[Seq[Bool]], depth: Int): (UInt, UInt) = { 1358a4dc19aSLinJiawei if(max(cols.map(_.size)) <= 2){ 1368a4dc19aSLinJiawei val sum = Cat(cols.map(_(0)).reverse) 1378a4dc19aSLinJiawei var k = 0 1388a4dc19aSLinJiawei while(cols(k).size == 1) k = k+1 1398a4dc19aSLinJiawei val carry = Cat(cols.drop(k).map(_(1)).reverse) 1408a4dc19aSLinJiawei (sum, Cat(carry, 0.U(k.W))) 1418a4dc19aSLinJiawei } else { 1428a4dc19aSLinJiawei val columns_next = Array.fill(2*len)(Seq[Bool]()) 1438a4dc19aSLinJiawei var cout1, cout2 = Seq[Bool]() 1448a4dc19aSLinJiawei for( i <- cols.indices){ 1458a4dc19aSLinJiawei val (s, c1, c2) = addOneColumn(cols(i), cout1) 1468a4dc19aSLinJiawei columns_next(i) = s ++ cout2 1478a4dc19aSLinJiawei cout1 = c1 1488a4dc19aSLinJiawei cout2 = c2 1498a4dc19aSLinJiawei } 1508a4dc19aSLinJiawei 1518a4dc19aSLinJiawei val needReg = doRegSorted.contains(depth) 1528a4dc19aSLinJiawei val toNextLayer = if(needReg) 1538a4dc19aSLinJiawei columns_next.map(_.map(PipelineReg(doRegSorted.indexOf(depth) + 1)(_))) 1548a4dc19aSLinJiawei else 1558a4dc19aSLinJiawei columns_next 1568a4dc19aSLinJiawei 1578a4dc19aSLinJiawei addAll(toNextLayer, depth+1) 1588a4dc19aSLinJiawei } 1598a4dc19aSLinJiawei } 1608a4dc19aSLinJiawei 1618a4dc19aSLinJiawei val (sum, carry) = addAll(cols = columns, depth = 0) 1628a4dc19aSLinJiawei val result = sum + carry 1638a4dc19aSLinJiawei 1648a4dc19aSLinJiawei var ctrlVec = Seq(ctrl) 1658a4dc19aSLinJiawei for(i <- 1 to latency){ 1668a4dc19aSLinJiawei ctrlVec = ctrlVec :+ PipelineReg(i)(ctrlVec(i-1)) 1678a4dc19aSLinJiawei } 1688a4dc19aSLinJiawei val xlen = io.out.bits.data.getWidth 1698a4dc19aSLinJiawei val res = Mux(ctrlVec.last.isHi, result(2*xlen-1, xlen), result(xlen-1,0)) 1708a4dc19aSLinJiawei 1718a4dc19aSLinJiawei io.out.bits.data := Mux(ctrlVec.last.isW, SignExt(res(31,0),xlen), res) 1728a4dc19aSLinJiawei 1738a4dc19aSLinJiawei XSDebug(p"validVec:${Binary(Cat(validVec))} flushVec:${Binary(Cat(flushVec))}\n") 1748a4dc19aSLinJiawei}