1cafb3558SLinJiaweipackage xiangshan.backend.fu 2cafb3558SLinJiawei 3cafb3558SLinJiaweiimport chisel3._ 4cafb3558SLinJiaweiimport chisel3.util._ 5cafb3558SLinJiaweiimport xiangshan._ 6b9fd1892SLinJiaweiimport utils._ 7*8a4dc19aSLinJiaweiimport xiangshan.backend.fu.fpu.util.{C22, C32, C53} 8cafb3558SLinJiawei 9cafb3558SLinJiaweiclass MulDivCtrl extends Bundle{ 10cafb3558SLinJiawei val sign = Bool() 11cafb3558SLinJiawei val isW = Bool() 12cafb3558SLinJiawei val isHi = Bool() // return hi bits of result ? 13cafb3558SLinJiawei} 14cafb3558SLinJiawei 15*8a4dc19aSLinJiaweiclass AbstractMultiplier(len: Int, latency: Int = 2) extends FunctionUnit( 16e18c367fSLinJiawei FuConfig(FuType.mul, 2, 0, writeIntRf = true, writeFpRf = false, hasRedirect = false, CertainLatency(latency)), 17e18c367fSLinJiawei len 18*8a4dc19aSLinJiawei){ 19*8a4dc19aSLinJiawei val ctrl = IO(Input(new MulDivCtrl)) 20*8a4dc19aSLinJiawei} 21*8a4dc19aSLinJiawei 22*8a4dc19aSLinJiaweiclass NaiveMultiplier(len: Int, latency: Int = 3) 23*8a4dc19aSLinJiawei extends AbstractMultiplier(len, latency) 24e18c367fSLinJiawei with HasPipelineReg 2512bb47ddSLinJiawei{ 263142d695SLinJiawei 273142d695SLinJiawei val (src1, src2) = (io.in.bits.src(0), io.in.bits.src(1)) 28cafb3558SLinJiawei 293142d695SLinJiawei val mulRes = src1.asSInt() * src2.asSInt() 30cafb3558SLinJiawei 31cafb3558SLinJiawei var dataVec = Seq(mulRes.asUInt()) 32e18c367fSLinJiawei var ctrlVec = Seq(ctrl) 33cafb3558SLinJiawei 34cafb3558SLinJiawei for(i <- 1 to latency){ 35cafb3558SLinJiawei dataVec = dataVec :+ PipelineReg(i)(dataVec(i-1)) 363142d695SLinJiawei ctrlVec = ctrlVec :+ PipelineReg(i)(ctrlVec(i-1)) 37cafb3558SLinJiawei } 38cafb3558SLinJiawei 39cafb3558SLinJiawei val xlen = io.out.bits.data.getWidth 40cafb3558SLinJiawei val res = Mux(ctrlVec.last.isHi, dataVec.last(2*xlen-1, xlen), dataVec.last(xlen-1,0)) 41cafb3558SLinJiawei io.out.bits.data := Mux(ctrlVec.last.isW, SignExt(res(31,0),xlen), res) 42cafb3558SLinJiawei 43e18c367fSLinJiawei XSDebug(p"validVec:${Binary(Cat(validVec))} flushVec:${Binary(Cat(flushVec))}\n") 44cafb3558SLinJiawei} 45*8a4dc19aSLinJiawei 46*8a4dc19aSLinJiaweiclass ArrayMultiplier(len: Int, doReg: Seq[Int]) extends AbstractMultiplier(len, doReg.size) with HasPipelineReg { 47*8a4dc19aSLinJiawei 48*8a4dc19aSLinJiawei val doRegSorted = doReg.sortWith(_ < _) 49*8a4dc19aSLinJiawei println(doRegSorted) 50*8a4dc19aSLinJiawei 51*8a4dc19aSLinJiawei val (a, b) = (io.in.bits.src(0), io.in.bits.src(1)) 52*8a4dc19aSLinJiawei 53*8a4dc19aSLinJiawei val b_sext, bx2, neg_b, neg_bx2 = Wire(UInt((len+1).W)) 54*8a4dc19aSLinJiawei b_sext := SignExt(b, len+1) 55*8a4dc19aSLinJiawei bx2 := b_sext << 1 56*8a4dc19aSLinJiawei neg_b := (~b_sext).asUInt() 57*8a4dc19aSLinJiawei neg_bx2 := neg_b << 1 58*8a4dc19aSLinJiawei 59*8a4dc19aSLinJiawei val columns: Array[Seq[Bool]] = Array.fill(2*len)(Seq()) 60*8a4dc19aSLinJiawei 61*8a4dc19aSLinJiawei var last_x = WireInit(0.U(3.W)) 62*8a4dc19aSLinJiawei for(i <- Range(0, len, 2)){ 63*8a4dc19aSLinJiawei val x = if(i==0) Cat(a(1,0), 0.U(1.W)) else if(i+1==len) SignExt(a(i, i-1), 3) else a(i+1, i-1) 64*8a4dc19aSLinJiawei val pp_temp = MuxLookup(x, 0.U, Seq( 65*8a4dc19aSLinJiawei 1.U -> b_sext, 66*8a4dc19aSLinJiawei 2.U -> b_sext, 67*8a4dc19aSLinJiawei 3.U -> bx2, 68*8a4dc19aSLinJiawei 4.U -> neg_bx2, 69*8a4dc19aSLinJiawei 5.U -> neg_b, 70*8a4dc19aSLinJiawei 6.U -> neg_b 71*8a4dc19aSLinJiawei )) 72*8a4dc19aSLinJiawei val s = pp_temp(len) 73*8a4dc19aSLinJiawei val t = MuxLookup(last_x, 0.U(2.W), Seq( 74*8a4dc19aSLinJiawei 4.U -> 2.U(2.W), 75*8a4dc19aSLinJiawei 5.U -> 1.U(2.W), 76*8a4dc19aSLinJiawei 6.U -> 1.U(2.W) 77*8a4dc19aSLinJiawei )) 78*8a4dc19aSLinJiawei last_x = x 79*8a4dc19aSLinJiawei val (pp, weight) = i match { 80*8a4dc19aSLinJiawei case 0 => 81*8a4dc19aSLinJiawei (Cat(~s, s, s, pp_temp), 0) 82*8a4dc19aSLinJiawei case n if (n==len-1) || (n==len-2) => 83*8a4dc19aSLinJiawei (Cat(~s, pp_temp, t), i-2) 84*8a4dc19aSLinJiawei case _ => 85*8a4dc19aSLinJiawei (Cat(1.U(1.W), ~s, pp_temp, t), i-2) 86*8a4dc19aSLinJiawei } 87*8a4dc19aSLinJiawei for(j <- columns.indices){ 88*8a4dc19aSLinJiawei if(j >= weight && j < (weight + pp.getWidth)){ 89*8a4dc19aSLinJiawei columns(j) = columns(j) :+ pp(j-weight) 90*8a4dc19aSLinJiawei } 91*8a4dc19aSLinJiawei } 92*8a4dc19aSLinJiawei } 93*8a4dc19aSLinJiawei 94*8a4dc19aSLinJiawei def addOneColumn(col: Seq[Bool], cin: Seq[Bool]): (Seq[Bool], Seq[Bool], Seq[Bool]) = { 95*8a4dc19aSLinJiawei var sum = Seq[Bool]() 96*8a4dc19aSLinJiawei var cout1 = Seq[Bool]() 97*8a4dc19aSLinJiawei var cout2 = Seq[Bool]() 98*8a4dc19aSLinJiawei col.size match { 99*8a4dc19aSLinJiawei case 1 => // do nothing 100*8a4dc19aSLinJiawei sum = col ++ cin 101*8a4dc19aSLinJiawei case 2 => 102*8a4dc19aSLinJiawei val c22 = Module(new C22) 103*8a4dc19aSLinJiawei c22.io.in := col 104*8a4dc19aSLinJiawei sum = c22.io.out(0).asBool() +: cin 105*8a4dc19aSLinJiawei cout2 = Seq(c22.io.out(1).asBool()) 106*8a4dc19aSLinJiawei case 3 => 107*8a4dc19aSLinJiawei val c32 = Module(new C32) 108*8a4dc19aSLinJiawei c32.io.in := col 109*8a4dc19aSLinJiawei sum = c32.io.out(0).asBool() +: cin 110*8a4dc19aSLinJiawei cout2 = Seq(c32.io.out(1).asBool()) 111*8a4dc19aSLinJiawei case 4 => 112*8a4dc19aSLinJiawei val c53 = Module(new C53) 113*8a4dc19aSLinJiawei for((x, y) <- c53.io.in.take(4) zip col){ 114*8a4dc19aSLinJiawei x := y 115*8a4dc19aSLinJiawei } 116*8a4dc19aSLinJiawei c53.io.in.last := (if(cin.nonEmpty) cin.head else 0.U) 117*8a4dc19aSLinJiawei sum = Seq(c53.io.out(0).asBool()) ++ (if(cin.nonEmpty) cin.drop(1) else Nil) 118*8a4dc19aSLinJiawei cout1 = Seq(c53.io.out(1).asBool()) 119*8a4dc19aSLinJiawei cout2 = Seq(c53.io.out(2).asBool()) 120*8a4dc19aSLinJiawei case n => 121*8a4dc19aSLinJiawei val cin_1 = if(cin.nonEmpty) Seq(cin.head) else Nil 122*8a4dc19aSLinJiawei val cin_2 = if(cin.nonEmpty) cin.drop(1) else Nil 123*8a4dc19aSLinJiawei val (s_1, c_1_1, c_1_2) = addOneColumn(col take 4, cin_1) 124*8a4dc19aSLinJiawei val (s_2, c_2_1, c_2_2) = addOneColumn(col drop 4, cin_2) 125*8a4dc19aSLinJiawei sum = s_1 ++ s_2 126*8a4dc19aSLinJiawei cout1 = c_1_1 ++ c_2_1 127*8a4dc19aSLinJiawei cout2 = c_1_2 ++ c_2_2 128*8a4dc19aSLinJiawei } 129*8a4dc19aSLinJiawei (sum, cout1, cout2) 130*8a4dc19aSLinJiawei } 131*8a4dc19aSLinJiawei 132*8a4dc19aSLinJiawei def max(in: Iterable[Int]): Int = in.reduce((a, b) => if(a>b) a else b) 133*8a4dc19aSLinJiawei def addAll(cols: Array[Seq[Bool]], depth: Int): (UInt, UInt) = { 134*8a4dc19aSLinJiawei if(max(cols.map(_.size)) <= 2){ 135*8a4dc19aSLinJiawei val sum = Cat(cols.map(_(0)).reverse) 136*8a4dc19aSLinJiawei var k = 0 137*8a4dc19aSLinJiawei while(cols(k).size == 1) k = k+1 138*8a4dc19aSLinJiawei val carry = Cat(cols.drop(k).map(_(1)).reverse) 139*8a4dc19aSLinJiawei (sum, Cat(carry, 0.U(k.W))) 140*8a4dc19aSLinJiawei } else { 141*8a4dc19aSLinJiawei val columns_next = Array.fill(2*len)(Seq[Bool]()) 142*8a4dc19aSLinJiawei var cout1, cout2 = Seq[Bool]() 143*8a4dc19aSLinJiawei for( i <- cols.indices){ 144*8a4dc19aSLinJiawei val (s, c1, c2) = addOneColumn(cols(i), cout1) 145*8a4dc19aSLinJiawei columns_next(i) = s ++ cout2 146*8a4dc19aSLinJiawei cout1 = c1 147*8a4dc19aSLinJiawei cout2 = c2 148*8a4dc19aSLinJiawei } 149*8a4dc19aSLinJiawei 150*8a4dc19aSLinJiawei val needReg = doRegSorted.contains(depth) 151*8a4dc19aSLinJiawei val toNextLayer = if(needReg) 152*8a4dc19aSLinJiawei columns_next.map(_.map(PipelineReg(doRegSorted.indexOf(depth) + 1)(_))) 153*8a4dc19aSLinJiawei else 154*8a4dc19aSLinJiawei columns_next 155*8a4dc19aSLinJiawei 156*8a4dc19aSLinJiawei addAll(toNextLayer, depth+1) 157*8a4dc19aSLinJiawei } 158*8a4dc19aSLinJiawei } 159*8a4dc19aSLinJiawei 160*8a4dc19aSLinJiawei val (sum, carry) = addAll(cols = columns, depth = 0) 161*8a4dc19aSLinJiawei val result = sum + carry 162*8a4dc19aSLinJiawei 163*8a4dc19aSLinJiawei var ctrlVec = Seq(ctrl) 164*8a4dc19aSLinJiawei for(i <- 1 to latency){ 165*8a4dc19aSLinJiawei ctrlVec = ctrlVec :+ PipelineReg(i)(ctrlVec(i-1)) 166*8a4dc19aSLinJiawei } 167*8a4dc19aSLinJiawei val xlen = io.out.bits.data.getWidth 168*8a4dc19aSLinJiawei val res = Mux(ctrlVec.last.isHi, result(2*xlen-1, xlen), result(xlen-1,0)) 169*8a4dc19aSLinJiawei 170*8a4dc19aSLinJiawei io.out.bits.data := Mux(ctrlVec.last.isW, SignExt(res(31,0),xlen), res) 171*8a4dc19aSLinJiawei 172*8a4dc19aSLinJiawei XSDebug(p"validVec:${Binary(Cat(validVec))} flushVec:${Binary(Cat(flushVec))}\n") 173*8a4dc19aSLinJiawei}