1/*************************************************************************************** 2* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences 3* Copyright (c) 2020-2021 Peng Cheng Laboratory 4* 5* XiangShan is licensed under Mulan PSL v2. 6* You can use this software according to the terms and conditions of the Mulan PSL v2. 7* You may obtain a copy of Mulan PSL v2 at: 8* http://license.coscl.org.cn/MulanPSL2 9* 10* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, 11* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, 12* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 13* 14* See the Mulan PSL v2 for more details. 15***************************************************************************************/ 16 17package xiangshan.backend.fu 18 19import org.chipsalliance.cde.config.Parameters 20import chisel3._ 21import chisel3.util._ 22import xiangshan._ 23import utils._ 24import utility._ 25import xiangshan.backend.fu.util.{C22, C32, C53} 26 27class MulDivCtrl extends Bundle{ 28 val sign = Bool() 29 val isW = Bool() 30 val isHi = Bool() // return hi bits of result ? 31} 32 33class AbstractMultiplier(len: Int)(implicit p: Parameters) extends FunctionUnit( 34 len 35){ 36 val ctrl = IO(Input(new MulDivCtrl)) 37} 38 39class NaiveMultiplier(len: Int, val latency: Int)(implicit p: Parameters) 40 extends AbstractMultiplier(len) 41 with HasPipelineReg 42{ 43 44 val (src1, src2) = (io.in.bits.src(0), io.in.bits.src(1)) 45 46 val mulRes = src1.asSInt * src2.asSInt 47 48 var dataVec = Seq(mulRes.asUInt) 49 var ctrlVec = Seq(ctrl) 50 51 for(i <- 1 to latency){ 52 dataVec = dataVec :+ PipelineReg(i)(dataVec(i-1)) 53 ctrlVec = ctrlVec :+ PipelineReg(i)(ctrlVec(i-1)) 54 } 55 56 val xlen = io.out.bits.data.getWidth 57 val res = Mux(ctrlVec.last.isHi, dataVec.last(2*xlen-1, xlen), dataVec.last(xlen-1,0)) 58 io.out.bits.data := Mux(ctrlVec.last.isW, SignExt(res(31,0),xlen), res) 59 60 XSDebug(p"validVec:${Binary(Cat(validVec))} flushVec:${Binary(Cat(flushVec))}\n") 61} 62 63class ArrayMulDataModule(len: Int) extends Module { 64 val io = IO(new Bundle() { 65 val a, b = Input(UInt(len.W)) 66 val regEnables = Input(Vec(2, Bool())) 67 val result = Output(UInt((2 * len).W)) 68 }) 69 val (a, b) = (io.a, io.b) 70 71 val b_sext, bx2, neg_b, neg_bx2 = Wire(UInt((len+1).W)) 72 b_sext := SignExt(b, len+1) 73 bx2 := b_sext << 1 74 neg_b := (~b_sext).asUInt 75 neg_bx2 := neg_b << 1 76 77 val columns: Array[Seq[Bool]] = Array.fill(2*len)(Seq()) 78 79 var last_x = WireInit(0.U(3.W)) 80 for(i <- Range(0, len, 2)){ 81 val x = if(i==0) Cat(a(1,0), 0.U(1.W)) else if(i+1==len) SignExt(a(i, i-1), 3) else a(i+1, i-1) 82 val pp_temp = MuxLookup(x, 0.U)(Seq( 83 1.U -> b_sext, 84 2.U -> b_sext, 85 3.U -> bx2, 86 4.U -> neg_bx2, 87 5.U -> neg_b, 88 6.U -> neg_b 89 )) 90 val s = pp_temp(len) 91 val t = MuxLookup(last_x, 0.U(2.W))(Seq( 92 4.U -> 2.U(2.W), 93 5.U -> 1.U(2.W), 94 6.U -> 1.U(2.W) 95 )) 96 last_x = x 97 val (pp, weight) = i match { 98 case 0 => 99 (Cat(~s, s, s, pp_temp), 0) 100 case n if (n==len-1) || (n==len-2) => 101 (Cat(~s, pp_temp, t), i-2) 102 case _ => 103 (Cat(1.U(1.W), ~s, pp_temp, t), i-2) 104 } 105 for(j <- columns.indices){ 106 if(j >= weight && j < (weight + pp.getWidth)){ 107 columns(j) = columns(j) :+ pp(j-weight) 108 } 109 } 110 } 111 112 def addOneColumn(col: Seq[Bool], cin: Seq[Bool]): (Seq[Bool], Seq[Bool], Seq[Bool]) = { 113 var sum = Seq[Bool]() 114 var cout1 = Seq[Bool]() 115 var cout2 = Seq[Bool]() 116 col.size match { 117 case 1 => // do nothing 118 sum = col ++ cin 119 case 2 => 120 val c22 = Module(new C22) 121 c22.io.in := col 122 sum = c22.io.out(0).asBool +: cin 123 cout2 = Seq(c22.io.out(1).asBool) 124 case 3 => 125 val c32 = Module(new C32) 126 c32.io.in := col 127 sum = c32.io.out(0).asBool +: cin 128 cout2 = Seq(c32.io.out(1).asBool) 129 case 4 => 130 val c53 = Module(new C53) 131 for((x, y) <- c53.io.in.take(4) zip col){ 132 x := y 133 } 134 c53.io.in.last := (if(cin.nonEmpty) cin.head else 0.U) 135 sum = Seq(c53.io.out(0).asBool) ++ (if(cin.nonEmpty) cin.drop(1) else Nil) 136 cout1 = Seq(c53.io.out(1).asBool) 137 cout2 = Seq(c53.io.out(2).asBool) 138 case n => 139 val cin_1 = if(cin.nonEmpty) Seq(cin.head) else Nil 140 val cin_2 = if(cin.nonEmpty) cin.drop(1) else Nil 141 val (s_1, c_1_1, c_1_2) = addOneColumn(col take 4, cin_1) 142 val (s_2, c_2_1, c_2_2) = addOneColumn(col drop 4, cin_2) 143 sum = s_1 ++ s_2 144 cout1 = c_1_1 ++ c_2_1 145 cout2 = c_1_2 ++ c_2_2 146 } 147 (sum, cout1, cout2) 148 } 149 150 def max(in: Iterable[Int]): Int = in.reduce((a, b) => if(a>b) a else b) 151 def addAll(cols: Array[Seq[Bool]], depth: Int): (UInt, UInt) = { 152 if(max(cols.map(_.size)) <= 2){ 153 val sum = Cat(cols.map(_(0)).reverse) 154 var k = 0 155 while(cols(k).size == 1) k = k+1 156 val carry = Cat(cols.drop(k).map(_(1)).reverse) 157 (sum, Cat(carry, 0.U(k.W))) 158 } else { 159 val columns_next = Array.fill(2*len)(Seq[Bool]()) 160 var cout1, cout2 = Seq[Bool]() 161 for( i <- cols.indices){ 162 val (s, c1, c2) = addOneColumn(cols(i), cout1) 163 columns_next(i) = s ++ cout2 164 cout1 = c1 165 cout2 = c2 166 } 167 168 val needReg = depth == 4 169 val toNextLayer = if(needReg) 170 columns_next.map(_.map(x => RegEnable(x, io.regEnables(1)))) 171 else 172 columns_next 173 174 addAll(toNextLayer, depth+1) 175 } 176 } 177 178 val columns_reg = columns.map(col => col.map(b => RegEnable(b, io.regEnables(0)))) 179 val (sum, carry) = addAll(cols = columns_reg, depth = 0) 180 181 io.result := sum + carry 182} 183 184class ArrayMultiplier(len: Int)(implicit p: Parameters) 185 extends AbstractMultiplier(len) with HasPipelineReg { 186 187 override def latency = 2 188 189 val mulDataModule = Module(new ArrayMulDataModule(len)) 190 mulDataModule.io.a := io.in.bits.src(0) 191 mulDataModule.io.b := io.in.bits.src(1) 192 mulDataModule.io.regEnables := VecInit((1 to latency) map (i => regEnable(i))) 193 val result = mulDataModule.io.result 194 195 var ctrlVec = Seq(ctrl) 196 for(i <- 1 to latency){ 197 ctrlVec = ctrlVec :+ PipelineReg(i)(ctrlVec(i-1)) 198 } 199 val xlen = len - 1 200 val res = Mux(ctrlVec.last.isHi, result(2*xlen-1, xlen), result(xlen-1,0)) 201 202 io.out.bits.data := Mux(ctrlVec.last.isW, SignExt(res(31,0),xlen), res) 203 204 XSDebug(p"validVec:${Binary(Cat(validVec))} flushVec:${Binary(Cat(flushVec))}\n") 205} 206