1/*************************************************************************************** 2* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences 3* Copyright (c) 2020-2021 Peng Cheng Laboratory 4* 5* XiangShan is licensed under Mulan PSL v2. 6* You can use this software according to the terms and conditions of the Mulan PSL v2. 7* You may obtain a copy of Mulan PSL v2 at: 8* http://license.coscl.org.cn/MulanPSL2 9* 10* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, 11* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, 12* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 13* 14* See the Mulan PSL v2 for more details. 15***************************************************************************************/ 16 17package xiangshan.backend.fu 18 19import chipsalliance.rocketchip.config.Parameters 20import chisel3._ 21import chisel3.util._ 22import xiangshan._ 23import utils._ 24import xiangshan.backend.fu.util.{C22, C32, C53} 25 26class MulDivCtrl extends Bundle{ 27 val sign = Bool() 28 val isW = Bool() 29 val isHi = Bool() // return hi bits of result ? 30} 31 32class AbstractMultiplier(len: Int)(implicit p: Parameters) extends FunctionUnit( 33 len 34){ 35 val ctrl = IO(Input(new MulDivCtrl)) 36} 37 38class NaiveMultiplier(len: Int, val latency: Int)(implicit p: Parameters) 39 extends AbstractMultiplier(len) 40 with HasPipelineReg 41{ 42 43 val (src1, src2) = (io.in.bits.src(0), io.in.bits.src(1)) 44 45 val mulRes = src1.asSInt() * src2.asSInt() 46 47 var dataVec = Seq(mulRes.asUInt()) 48 var ctrlVec = Seq(ctrl) 49 50 for(i <- 1 to latency){ 51 dataVec = dataVec :+ PipelineReg(i)(dataVec(i-1)) 52 ctrlVec = ctrlVec :+ PipelineReg(i)(ctrlVec(i-1)) 53 } 54 55 val xlen = io.out.bits.data.getWidth 56 val res = Mux(ctrlVec.last.isHi, dataVec.last(2*xlen-1, xlen), dataVec.last(xlen-1,0)) 57 io.out.bits.data := Mux(ctrlVec.last.isW, SignExt(res(31,0),xlen), res) 58 59 XSDebug(p"validVec:${Binary(Cat(validVec))} flushVec:${Binary(Cat(flushVec))}\n") 60} 61 62class ArrayMulDataModule(len: Int) extends Module { 63 val io = IO(new Bundle() { 64 val a, b = Input(UInt(len.W)) 65 val regEnables = Input(Vec(2, Bool())) 66 val result = Output(UInt((2 * len).W)) 67 }) 68 val (a, b) = (io.a, io.b) 69 70 val b_sext, bx2, neg_b, neg_bx2 = Wire(UInt((len+1).W)) 71 b_sext := SignExt(b, len+1) 72 bx2 := b_sext << 1 73 neg_b := (~b_sext).asUInt() 74 neg_bx2 := neg_b << 1 75 76 val columns: Array[Seq[Bool]] = Array.fill(2*len)(Seq()) 77 78 var last_x = WireInit(0.U(3.W)) 79 for(i <- Range(0, len, 2)){ 80 val x = if(i==0) Cat(a(1,0), 0.U(1.W)) else if(i+1==len) SignExt(a(i, i-1), 3) else a(i+1, i-1) 81 val pp_temp = MuxLookup(x, 0.U, Seq( 82 1.U -> b_sext, 83 2.U -> b_sext, 84 3.U -> bx2, 85 4.U -> neg_bx2, 86 5.U -> neg_b, 87 6.U -> neg_b 88 )) 89 val s = pp_temp(len) 90 val t = MuxLookup(last_x, 0.U(2.W), Seq( 91 4.U -> 2.U(2.W), 92 5.U -> 1.U(2.W), 93 6.U -> 1.U(2.W) 94 )) 95 last_x = x 96 val (pp, weight) = i match { 97 case 0 => 98 (Cat(~s, s, s, pp_temp), 0) 99 case n if (n==len-1) || (n==len-2) => 100 (Cat(~s, pp_temp, t), i-2) 101 case _ => 102 (Cat(1.U(1.W), ~s, pp_temp, t), i-2) 103 } 104 for(j <- columns.indices){ 105 if(j >= weight && j < (weight + pp.getWidth)){ 106 columns(j) = columns(j) :+ pp(j-weight) 107 } 108 } 109 } 110 111 def addOneColumn(col: Seq[Bool], cin: Seq[Bool]): (Seq[Bool], Seq[Bool], Seq[Bool]) = { 112 var sum = Seq[Bool]() 113 var cout1 = Seq[Bool]() 114 var cout2 = Seq[Bool]() 115 col.size match { 116 case 1 => // do nothing 117 sum = col ++ cin 118 case 2 => 119 val c22 = Module(new C22) 120 c22.io.in := col 121 sum = c22.io.out(0).asBool() +: cin 122 cout2 = Seq(c22.io.out(1).asBool()) 123 case 3 => 124 val c32 = Module(new C32) 125 c32.io.in := col 126 sum = c32.io.out(0).asBool() +: cin 127 cout2 = Seq(c32.io.out(1).asBool()) 128 case 4 => 129 val c53 = Module(new C53) 130 for((x, y) <- c53.io.in.take(4) zip col){ 131 x := y 132 } 133 c53.io.in.last := (if(cin.nonEmpty) cin.head else 0.U) 134 sum = Seq(c53.io.out(0).asBool()) ++ (if(cin.nonEmpty) cin.drop(1) else Nil) 135 cout1 = Seq(c53.io.out(1).asBool()) 136 cout2 = Seq(c53.io.out(2).asBool()) 137 case n => 138 val cin_1 = if(cin.nonEmpty) Seq(cin.head) else Nil 139 val cin_2 = if(cin.nonEmpty) cin.drop(1) else Nil 140 val (s_1, c_1_1, c_1_2) = addOneColumn(col take 4, cin_1) 141 val (s_2, c_2_1, c_2_2) = addOneColumn(col drop 4, cin_2) 142 sum = s_1 ++ s_2 143 cout1 = c_1_1 ++ c_2_1 144 cout2 = c_1_2 ++ c_2_2 145 } 146 (sum, cout1, cout2) 147 } 148 149 def max(in: Iterable[Int]): Int = in.reduce((a, b) => if(a>b) a else b) 150 def addAll(cols: Array[Seq[Bool]], depth: Int): (UInt, UInt) = { 151 if(max(cols.map(_.size)) <= 2){ 152 val sum = Cat(cols.map(_(0)).reverse) 153 var k = 0 154 while(cols(k).size == 1) k = k+1 155 val carry = Cat(cols.drop(k).map(_(1)).reverse) 156 (sum, Cat(carry, 0.U(k.W))) 157 } else { 158 val columns_next = Array.fill(2*len)(Seq[Bool]()) 159 var cout1, cout2 = Seq[Bool]() 160 for( i <- cols.indices){ 161 val (s, c1, c2) = addOneColumn(cols(i), cout1) 162 columns_next(i) = s ++ cout2 163 cout1 = c1 164 cout2 = c2 165 } 166 167 val needReg = depth == 4 168 val toNextLayer = if(needReg) 169 columns_next.map(_.map(x => RegEnable(x, io.regEnables(1)))) 170 else 171 columns_next 172 173 addAll(toNextLayer, depth+1) 174 } 175 } 176 177 val columns_reg = columns.map(col => col.map(b => RegEnable(b, io.regEnables(0)))) 178 val (sum, carry) = addAll(cols = columns_reg, depth = 0) 179 180 io.result := sum + carry 181} 182 183class ArrayMultiplier(len: Int)(implicit p: Parameters) 184 extends AbstractMultiplier(len) with HasPipelineReg { 185 186 override def latency = 2 187 188 val mulDataModule = Module(new ArrayMulDataModule(len)) 189 mulDataModule.io.a := io.in.bits.src(0) 190 mulDataModule.io.b := io.in.bits.src(1) 191 mulDataModule.io.regEnables := VecInit((1 to latency) map (i => regEnable(i))) 192 val result = mulDataModule.io.result 193 194 var ctrlVec = Seq(ctrl) 195 for(i <- 1 to latency){ 196 ctrlVec = ctrlVec :+ PipelineReg(i)(ctrlVec(i-1)) 197 } 198 val xlen = len - 1 199 val res = Mux(ctrlVec.last.isHi, result(2*xlen-1, xlen), result(xlen-1,0)) 200 201 io.out.bits.data := Mux(ctrlVec.last.isW, SignExt(res(31,0),xlen), res) 202 203 XSDebug(p"validVec:${Binary(Cat(validVec))} flushVec:${Binary(Cat(flushVec))}\n") 204}