1package xiangshan.backend.fu 2 3import chipsalliance.rocketchip.config.Parameters 4import chisel3._ 5import chisel3.util._ 6import xiangshan._ 7import utils._ 8import xiangshan.backend.fu.util.{C22, C32, C53} 9 10class MulDivCtrl extends Bundle{ 11 val sign = Bool() 12 val isW = Bool() 13 val isHi = Bool() // return hi bits of result ? 14} 15 16class AbstractMultiplier(len: Int)(implicit p: Parameters) extends FunctionUnit( 17 len 18){ 19 val ctrl = IO(Input(new MulDivCtrl)) 20} 21 22class NaiveMultiplier(len: Int, val latency: Int)(implicit p: Parameters) 23 extends AbstractMultiplier(len) 24 with HasPipelineReg 25{ 26 27 val (src1, src2) = (io.in.bits.src(0), io.in.bits.src(1)) 28 29 val mulRes = src1.asSInt() * src2.asSInt() 30 31 var dataVec = Seq(mulRes.asUInt()) 32 var ctrlVec = Seq(ctrl) 33 34 for(i <- 1 to latency){ 35 dataVec = dataVec :+ PipelineReg(i)(dataVec(i-1)) 36 ctrlVec = ctrlVec :+ PipelineReg(i)(ctrlVec(i-1)) 37 } 38 39 val xlen = io.out.bits.data.getWidth 40 val res = Mux(ctrlVec.last.isHi, dataVec.last(2*xlen-1, xlen), dataVec.last(xlen-1,0)) 41 io.out.bits.data := Mux(ctrlVec.last.isW, SignExt(res(31,0),xlen), res) 42 43 XSDebug(p"validVec:${Binary(Cat(validVec))} flushVec:${Binary(Cat(flushVec))}\n") 44} 45 46class ArrayMulDataModule(len: Int, doReg: Seq[Int]) extends Module { 47 val io = IO(new Bundle() { 48 val a, b = Input(UInt(len.W)) 49 val regEnables = Input(Vec(doReg.size, Bool())) 50 val result = Output(UInt((2 * len).W)) 51 }) 52 val (a, b) = (io.a, io.b) 53 val doRegSorted = doReg.sortWith(_ < _) 54 55 val b_sext, bx2, neg_b, neg_bx2 = Wire(UInt((len+1).W)) 56 b_sext := SignExt(b, len+1) 57 bx2 := b_sext << 1 58 neg_b := (~b_sext).asUInt() 59 neg_bx2 := neg_b << 1 60 61 val columns: Array[Seq[Bool]] = Array.fill(2*len)(Seq()) 62 63 var last_x = WireInit(0.U(3.W)) 64 for(i <- Range(0, len, 2)){ 65 val x = if(i==0) Cat(a(1,0), 0.U(1.W)) else if(i+1==len) SignExt(a(i, i-1), 3) else a(i+1, i-1) 66 val pp_temp = MuxLookup(x, 0.U, Seq( 67 1.U -> b_sext, 68 2.U -> b_sext, 69 3.U -> bx2, 70 4.U -> neg_bx2, 71 5.U -> neg_b, 72 6.U -> neg_b 73 )) 74 val s = pp_temp(len) 75 val t = MuxLookup(last_x, 0.U(2.W), Seq( 76 4.U -> 2.U(2.W), 77 5.U -> 1.U(2.W), 78 6.U -> 1.U(2.W) 79 )) 80 last_x = x 81 val (pp, weight) = i match { 82 case 0 => 83 (Cat(~s, s, s, pp_temp), 0) 84 case n if (n==len-1) || (n==len-2) => 85 (Cat(~s, pp_temp, t), i-2) 86 case _ => 87 (Cat(1.U(1.W), ~s, pp_temp, t), i-2) 88 } 89 for(j <- columns.indices){ 90 if(j >= weight && j < (weight + pp.getWidth)){ 91 columns(j) = columns(j) :+ pp(j-weight) 92 } 93 } 94 } 95 96 def addOneColumn(col: Seq[Bool], cin: Seq[Bool]): (Seq[Bool], Seq[Bool], Seq[Bool]) = { 97 var sum = Seq[Bool]() 98 var cout1 = Seq[Bool]() 99 var cout2 = Seq[Bool]() 100 col.size match { 101 case 1 => // do nothing 102 sum = col ++ cin 103 case 2 => 104 val c22 = Module(new C22) 105 c22.io.in := col 106 sum = c22.io.out(0).asBool() +: cin 107 cout2 = Seq(c22.io.out(1).asBool()) 108 case 3 => 109 val c32 = Module(new C32) 110 c32.io.in := col 111 sum = c32.io.out(0).asBool() +: cin 112 cout2 = Seq(c32.io.out(1).asBool()) 113 case 4 => 114 val c53 = Module(new C53) 115 for((x, y) <- c53.io.in.take(4) zip col){ 116 x := y 117 } 118 c53.io.in.last := (if(cin.nonEmpty) cin.head else 0.U) 119 sum = Seq(c53.io.out(0).asBool()) ++ (if(cin.nonEmpty) cin.drop(1) else Nil) 120 cout1 = Seq(c53.io.out(1).asBool()) 121 cout2 = Seq(c53.io.out(2).asBool()) 122 case n => 123 val cin_1 = if(cin.nonEmpty) Seq(cin.head) else Nil 124 val cin_2 = if(cin.nonEmpty) cin.drop(1) else Nil 125 val (s_1, c_1_1, c_1_2) = addOneColumn(col take 4, cin_1) 126 val (s_2, c_2_1, c_2_2) = addOneColumn(col drop 4, cin_2) 127 sum = s_1 ++ s_2 128 cout1 = c_1_1 ++ c_2_1 129 cout2 = c_1_2 ++ c_2_2 130 } 131 (sum, cout1, cout2) 132 } 133 134 def max(in: Iterable[Int]): Int = in.reduce((a, b) => if(a>b) a else b) 135 def addAll(cols: Array[Seq[Bool]], depth: Int): (UInt, UInt) = { 136 if(max(cols.map(_.size)) <= 2){ 137 val sum = Cat(cols.map(_(0)).reverse) 138 var k = 0 139 while(cols(k).size == 1) k = k+1 140 val carry = Cat(cols.drop(k).map(_(1)).reverse) 141 (sum, Cat(carry, 0.U(k.W))) 142 } else { 143 val columns_next = Array.fill(2*len)(Seq[Bool]()) 144 var cout1, cout2 = Seq[Bool]() 145 for( i <- cols.indices){ 146 val (s, c1, c2) = addOneColumn(cols(i), cout1) 147 columns_next(i) = s ++ cout2 148 cout1 = c1 149 cout2 = c2 150 } 151 152 val needReg = doRegSorted.contains(depth) 153 val toNextLayer = if(needReg) 154 columns_next.map(_.map(x => RegEnable(x, io.regEnables(doRegSorted.indexOf(depth))))) 155 else 156 columns_next 157 158 addAll(toNextLayer, depth+1) 159 } 160 } 161 162 val (sum, carry) = addAll(cols = columns, depth = 0) 163 io.result := sum + carry 164} 165 166class ArrayMultiplier(len: Int, doReg: Seq[Int])(implicit p: Parameters) 167 extends AbstractMultiplier(len) with HasPipelineReg { 168 169 override def latency = doReg.size 170 171 val mulDataModule = Module(new ArrayMulDataModule(len, doReg)) 172 mulDataModule.io.a := io.in.bits.src(0) 173 mulDataModule.io.b := io.in.bits.src(1) 174 mulDataModule.io.regEnables := VecInit((1 to doReg.size) map (i => regEnable(i))) 175 val result = mulDataModule.io.result 176 177 var ctrlVec = Seq(ctrl) 178 for(i <- 1 to latency){ 179 ctrlVec = ctrlVec :+ PipelineReg(i)(ctrlVec(i-1)) 180 } 181 val xlen = len - 1 182 val res = Mux(ctrlVec.last.isHi, result(2*xlen-1, xlen), result(xlen-1,0)) 183 184 io.out.bits.data := Mux(ctrlVec.last.isW, SignExt(res(31,0),xlen), res) 185 186 XSDebug(p"validVec:${Binary(Cat(validVec))} flushVec:${Binary(Cat(flushVec))}\n") 187}