1package xiangshan.backend.fu.wrapper 2 3import chipsalliance.rocketchip.config.Parameters 4import chisel3._ 5import chisel3.util._ 6import chisel3.util.experimental.decode._ 7import utils.XSError 8import xiangshan.backend.fu.FuConfig 9import xiangshan.backend.fu.vector.{Mgu, VecPipedFuncUnit} 10import yunsuan.VfpuType 11import yunsuan.vector.VectorConvert.VectorCvt 12 13 14class VCVT(cfg: FuConfig)(implicit p: Parameters) extends VecPipedFuncUnit(cfg) { 15 XSError(io.in.valid && io.in.bits.ctrl.fuOpType === VfpuType.dummy, "Vfcvt OpType not supported") 16 17 // params alias 18 private val dataWidth = cfg.dataBits 19 private val dataWidthOfDataModule = 64 20 private val numVecModule = dataWidth / dataWidthOfDataModule 21 22 // io alias 23 private val opcode = fuOpType(7, 0) 24 private val sew = vsew 25 26 private val isRtz = opcode(2) & opcode(1) 27 private val isRod = opcode(2) & !opcode(1) & opcode(0) 28 private val isFrm = !isRtz && !isRod 29 private val rm = Mux1H( 30 Seq(isRtz, isRod, isFrm), 31 Seq(1.U, 6.U, frm) 32 ) 33 34 private val lmul = vlmul // -3->3 => 1/8 ->8 35 36 val widen = opcode(4, 3) // 0->single 1->widen 2->norrow => width of result 37 val isSingleCvt = !widen(1) & !widen(0) 38 val isWidenCvt = !widen(1) & widen(0) 39 val isNarrowCvt = widen(1) & !widen(0) 40 41 // output width 8, 16, 32, 64 42 val output1H = Wire(UInt(4.W)) 43 output1H := chisel3.util.experimental.decode.decoder( 44 widen ## sew, 45 TruthTable( 46 Seq( 47 BitPat("b00_01") -> BitPat("b0010"), // 16 48 BitPat("b00_10") -> BitPat("b0100"), // 32 49 BitPat("b00_11") -> BitPat("b1000"), // 64 50 51 BitPat("b01_00") -> BitPat("b0010"), // 16 52 BitPat("b01_01") -> BitPat("b0100"), // 32 53 BitPat("b01_10") -> BitPat("b1000"), // 64 54 55 BitPat("b10_00") -> BitPat("b0001"), // 8 56 BitPat("b10_01") -> BitPat("b0010"), // 16 57 BitPat("b10_10") -> BitPat("b0100"), // 32 58 ), 59 BitPat.N(4) 60 ) 61 ) 62 dontTouch(output1H) 63 val outputWidth1H = output1H 64 65 val outEew = RegNext(RegNext(Mux1H(output1H, Seq(0,1,2,3).map(i => i.U)))) 66 private val needNoMask = outVecCtrl.fpu.isFpToVecInst 67 val maskToMgu = Mux(needNoMask, allMaskTrue, outSrcMask) 68 69 // modules 70 private val vfcvt = Module(new VectorCvtTop(dataWidth, dataWidthOfDataModule)) 71 private val mgu = Module(new Mgu(dataWidth)) //看看那mgu里干了什么活 72 73 val vs2Vec = Wire(Vec(numVecModule, UInt(dataWidthOfDataModule.W))) 74 vs2Vec := vs2.asTypeOf(vs2Vec) 75 76 /** 77 * [[vfcvt]]'s in connection 78 */ 79 vfcvt.uopIdx := vuopIdx(0) 80 vfcvt.src := vs2Vec 81 vfcvt.opType := opcode 82 vfcvt.sew := sew 83 vfcvt.rm := rm 84 vfcvt.outputWidth1H := outputWidth1H 85 vfcvt.isWiden := isWidenCvt 86 vfcvt.isNarrow := isNarrowCvt 87 val vfcvtResult = vfcvt.io.result 88 val vfcvtFflags = vfcvt.io.fflags 89 90 /** fflags: 91 */ 92 //uopidx:每一个uopidex:都有相应的元素个数以及对应的mask的位置 93 //vl: 决定每个向量寄存器组里有多少个元素参与运算 94 // 128/(sew+1)*8 95 // val num = 16.U >> sew 96 // val eNum = Mux(isWiden || isNorrow , num, num >> 1).asUInt //每个uopidx的向量元素的个数 97 // num of vector element in each uopidx 98 99 // 每个uopidx最大的元素个数,即一个向量寄存器所能容纳的Max(inwidth, outwidth)的元素的个数 100 // single/narrow的话是一个向量寄存器所容纳输入元素的个数, widen的话是是一个向量寄存器输出元素的的个数 101 // todo: 为什么用4bit,虽然一个Vector reg中的元素个数至少是2,但是存在lmul=1/2时 64->64的single,方便下面计算eNum的max(可能为1) 102 val eNum1H = chisel3.util.experimental.decode.decoder(sew ## (isWidenCvt || isNarrowCvt), 103 TruthTable( 104 Seq( // 8, 4, 2, 1 105 BitPat("b001") -> BitPat("b1000"), //8 106 BitPat("b010") -> BitPat("b1000"), //8 107 BitPat("b011") -> BitPat("b0100"), //4 108 BitPat("b100") -> BitPat("b0100"), //4 109 BitPat("b101") -> BitPat("b0010"), //2 110 BitPat("b110") -> BitPat("b0010"), //2 111 ), 112 BitPat.N(4) 113 ) 114 ) 115 val eNumMax1H = Mux(lmul.head(1).asBool, eNum1H >> ((~lmul.tail(1)).asUInt +1.U), eNum1H << lmul.tail(1)).asUInt(6, 0) 116 val eNumMax = Mux1H(eNumMax1H, Seq(1,2,4,8,16,32,64).map(i => i.U)) //only for cvt intr, don't exist 128 in cvt 117 val eNumEffectIdx = Mux(vl > eNumMax, eNumMax, vl) 118 119 // mask,vl和lmul => 某个输入的向量元素是否有效 120 // val numGrounp = lmul * 128/eew 121 // val mask = Mux1H(eNum1H, Seq(1, 2, 4, 8).map(num => (srcMask >> vuopIdx * num.U)(num-1, 0))) //mask是从每个uop的最大元素个数来的 122 val eNum = Mux1H(eNum1H, Seq(1, 2, 4, 8).map(num =>num.U)) 123 val eStart = vuopIdx * eNum 124 val maskPart = srcMask >> eStart 125 val mask = Mux1H(eNum1H, Seq(1, 2, 4, 8).map(num => maskPart(num-1, 0))) 126 val fflagsEn = Wire(Vec(4 * numVecModule, Bool())) 127 128 // fflagsEn := mask.asBools.zipWithIndex.map { case (mask, i) => 129 // mask & (eNumEffect > Mux1H(eNum1H, Seq(1, 2, 4, 8).map(num => vuopIdx * num.U + i.U))) 130 // } 131 // vl: [0, vl) eNumMax: [0, eNumMax) => eNumEffect为索引加1即length, 其右边的都为索引即length-1,所以 > 132 fflagsEn := mask.asBools.zipWithIndex.map{case(mask, i) => mask & (eNumEffectIdx > eStart + i.U) } //被vl和lmul所约束 133 134 val fflagsEnCycle2 = RegNext(RegNext(fflagsEn)) 135 val fflagsAll = Wire(Vec(8, UInt(5.W))) 136 fflagsAll := vfcvtFflags.asTypeOf(fflagsAll) 137 val fflags = fflagsEnCycle2.zip(fflagsAll).map{case(en, fflag) => Mux(en, fflag, 0.U(5.W))}.reduce(_ | _) 138 io.out.bits.res.fflags.get := fflags 139 140 141 /** 142 * [[mgu]]'s in connection 后面看看是否必须要这个mgu 143 */ 144 val resultDataUInt = Wire(UInt(dataWidth.W)) //todo: from vfcvt 145 resultDataUInt := vfcvtResult 146 147 mgu.io.in.vd := resultDataUInt 148 mgu.io.in.oldVd := outOldVd 149 mgu.io.in.mask := maskToMgu 150 mgu.io.in.info.ta := outVecCtrl.vta 151 mgu.io.in.info.ma := outVecCtrl.vma 152 mgu.io.in.info.vl := Mux(outVecCtrl.fpu.isFpToVecInst, 1.U, outVl) 153 mgu.io.in.info.vlmul := outVecCtrl.vlmul 154 mgu.io.in.info.valid := io.out.valid 155 mgu.io.in.info.vstart := Mux(outVecCtrl.fpu.isFpToVecInst, 0.U, outVecCtrl.vstart) 156 mgu.io.in.info.eew := outEew 157 mgu.io.in.info.vsew := outVecCtrl.vsew 158 mgu.io.in.info.vdIdx := outVecCtrl.vuopIdx 159 mgu.io.in.info.narrow := RegNext(RegNext(isNarrowCvt)) 160 mgu.io.in.info.dstMask := outVecCtrl.isDstMask 161 162 io.out.bits.res.data := mgu.io.out.vd 163} 164 165class VectorCvtTopIO(vlen :Int, xlen: Int) extends Bundle{ 166 val uopIdx = Input(Bool()) 167 val src = Input(Vec(vlen / xlen, UInt(xlen.W))) 168 val opType = Input(UInt(8.W)) 169 val sew = Input(UInt(2.W)) 170 val rm = Input(UInt(3.W)) 171 val outputWidth1H = Input(UInt(4.W)) 172 val isWiden = Input(Bool()) 173 val isNarrow = Input(Bool()) 174 175 val result = Output(UInt(vlen.W)) 176 val fflags = Output(UInt((vlen/16*5).W)) 177} 178 179 180 181//according to uopindex, 1: high64 0:low64 182class VectorCvtTop(vlen :Int, xlen: Int) extends Module{ 183 val io = IO(new VectorCvtTopIO(vlen, xlen)) 184 185 val (uopIdx, src, opType, sew, rm, outputWidth1H, isWiden, isNarrow) = ( 186 io.uopIdx, io.src, io.opType, io.sew, io.rm, io.outputWidth1H, io.isWiden, io.isNarrow 187 ) 188 189 val in0 = Mux(isWiden, 190 Mux(uopIdx, src(1).tail(32), src(0).tail(32)), 191 src(0) 192 ) 193 194 val in1 = Mux(isWiden, 195 Mux(uopIdx, src(1).head(32), src(0).head(32)), 196 src(1) 197 ) 198 199 val vectorCvt0 = Module(new VectorCvt(xlen)) 200 vectorCvt0.src := in0 201 vectorCvt0.opType := opType 202 vectorCvt0.sew := sew 203 vectorCvt0.rm := rm 204 205 val vectorCvt1 = Module(new VectorCvt(xlen)) 206 vectorCvt1.src := in1 207 vectorCvt1.opType := opType 208 vectorCvt1.sew := sew 209 vectorCvt1.rm := rm 210 211 val isNarrowCycle2 = RegNext(RegNext(isNarrow)) 212 val outputWidth1HCycle2 = RegNext(RegNext(outputWidth1H)) 213 214 //cycle2 215 io.result := Mux(isNarrowCycle2, 216 vectorCvt1.io.result.tail(32) ## vectorCvt0.io.result.tail(32), 217 vectorCvt1.io.result ## vectorCvt0.io.result) 218 219 io.fflags := Mux1H(outputWidth1HCycle2, Seq( // todo: map between fflags and result 220 vectorCvt1.io.fflags ## vectorCvt0.io.fflags, 221 Mux(isNarrowCycle2, vectorCvt1.io.fflags.tail(10) ## vectorCvt0.io.fflags.tail(10), vectorCvt1.io.fflags ## vectorCvt0.io.fflags), 222 Mux(isNarrowCycle2, vectorCvt1.io.fflags(4,0) ## vectorCvt0.io.fflags(4,0), vectorCvt1.io.fflags.tail(10) ## vectorCvt0.io.fflags.tail(10)), 223 vectorCvt1.io.fflags(4,0) ## vectorCvt0.io.fflags(4,0) 224 )) 225} 226 227 228