xref: /XiangShan/src/main/scala/xiangshan/backend/fu/wrapper/VCVT.scala (revision ac5be754bbbd6fb4720d11aad404ec3d1c6117b7)
1package xiangshan.backend.fu.wrapper
2
3import org.chipsalliance.cde.config.Parameters
4import chisel3._
5import chisel3.util._
6import chisel3.util.experimental.decode._
7import utility.XSError
8import xiangshan.backend.fu.FuConfig
9import xiangshan.backend.fu.vector.{Mgu, VecPipedFuncUnit}
10import xiangshan.ExceptionNO
11import xiangshan.FuOpType
12import yunsuan.VfpuType
13import yunsuan.vector.VectorConvert.VectorCvt
14import yunsuan.util._
15
16
17class VCVT(cfg: FuConfig)(implicit p: Parameters) extends VecPipedFuncUnit(cfg) {
18  XSError(io.in.valid && io.in.bits.ctrl.fuOpType === VfpuType.dummy, "Vfcvt OpType not supported")
19
20  // params alias
21  private val dataWidth = cfg.destDataBits
22  private val dataWidthOfDataModule = 64
23  private val numVecModule = dataWidth / dataWidthOfDataModule
24
25  // io alias
26  private val opcode = fuOpType(8, 0)
27  private val sew = vsew
28
29  private val isRtz = opcode(2) & opcode(1)
30  private val isRod = opcode(2) & !opcode(1) & opcode(0)
31  private val isFrm = !isRtz && !isRod
32  private val vfcvtRm = Mux1H(
33    Seq(isRtz, isRod, isFrm),
34    Seq(1.U, 6.U, rm)
35  )
36
37  private val lmul = vlmul // -3->3 => 1/8 ->8
38
39  val widen = opcode(4, 3) // 0->single 1->widen 2->norrow => width of result
40  val isSingleCvt = !widen(1) & !widen(0)
41  val isWidenCvt = !widen(1) & widen(0)
42  val isNarrowCvt = widen(1) & !widen(0)
43  val fire = io.in.valid
44  val fireReg = GatedValidRegNext(fire)
45
46  // output width 8, 16, 32, 64
47  val output1H = Wire(UInt(4.W))
48  output1H := chisel3.util.experimental.decode.decoder(
49    widen ## sew,
50    TruthTable(
51      Seq(
52        BitPat("b00_01") -> BitPat("b0010"), // 16
53        BitPat("b00_10") -> BitPat("b0100"), // 32
54        BitPat("b00_11") -> BitPat("b1000"), // 64
55
56        BitPat("b01_00") -> BitPat("b0010"), // 16
57        BitPat("b01_01") -> BitPat("b0100"), // 32
58        BitPat("b01_10") -> BitPat("b1000"), // 64
59
60        BitPat("b10_00") -> BitPat("b0001"), // 8
61        BitPat("b10_01") -> BitPat("b0010"), // 16
62        BitPat("b10_10") -> BitPat("b0100"), // 32
63      ),
64      BitPat.N(4)
65    )
66  )
67  if(backendParams.debugEn) {
68    dontTouch(output1H)
69  }
70  val outputWidth1H = output1H
71  val outIs32bits = RegNext(RegNext(outputWidth1H(2)))
72  val outIsInt = !outCtrl.fuOpType(6)
73
74  // May be useful in the future.
75  // val outIsMvInst = outCtrl.fuOpType === FuOpType.FMVXF
76  val outIsMvInst = false.B
77
78  val outEew = RegEnable(RegEnable(Mux1H(output1H, Seq(0,1,2,3).map(i => i.U)), fire), fireReg)
79  private val needNoMask = outVecCtrl.fpu.isFpToVecInst
80  val maskToMgu = Mux(needNoMask, allMaskTrue, outSrcMask)
81
82  // modules
83  private val vfcvt = Module(new VectorCvtTop(dataWidth, dataWidthOfDataModule))
84  private val mgu = Module(new Mgu(dataWidth))
85
86  val vs2Vec = Wire(Vec(numVecModule, UInt(dataWidthOfDataModule.W)))
87  vs2Vec := vs2.asTypeOf(vs2Vec)
88
89  /**
90   * [[vfcvt]]'s in connection
91   */
92  vfcvt.uopIdx := vuopIdx(0)
93  vfcvt.src := vs2Vec
94  vfcvt.opType := opcode(7,0)
95  vfcvt.sew := sew
96  vfcvt.rm := vfcvtRm
97  vfcvt.outputWidth1H := outputWidth1H
98  vfcvt.isWiden := isWidenCvt
99  vfcvt.isNarrow := isNarrowCvt
100  vfcvt.fire := fire
101  vfcvt.isFpToVecInst := vecCtrl.fpu.isFpToVecInst
102  val vfcvtResult = vfcvt.io.result
103  val vfcvtFflags = vfcvt.io.fflags
104
105  /** fflags:
106   */
107  val eNum1H = chisel3.util.experimental.decode.decoder(sew ## (isWidenCvt || isNarrowCvt),
108    TruthTable(
109      Seq(                     // 8, 4, 2, 1
110        BitPat("b001") -> BitPat("b1000"), //8
111        BitPat("b010") -> BitPat("b1000"), //8
112        BitPat("b011") -> BitPat("b0100"), //4
113        BitPat("b100") -> BitPat("b0100"), //4
114        BitPat("b101") -> BitPat("b0010"), //2
115        BitPat("b110") -> BitPat("b0010"), //2
116      ),
117      BitPat.N(4)
118    )
119  )
120  val eNum1HEffect = Mux(isWidenCvt || isNarrowCvt, eNum1H << 1, eNum1H)
121  val eNumMax1H = Mux(lmul.head(1).asBool, eNum1HEffect >> ((~lmul.tail(1)).asUInt +1.U), eNum1HEffect << lmul.tail(1)).asUInt(6, 0)
122  val eNumMax = Mux1H(eNumMax1H, Seq(1,2,4,8,16,32,64).map(i => i.U)) //only for cvt intr, don't exist 128 in cvt
123  val vlForFflags = Mux(vecCtrl.fpu.isFpToVecInst, 1.U, vl)
124  val eNumEffectIdx = Mux(vlForFflags > eNumMax, eNumMax, vlForFflags)
125
126  val eNum = Mux1H(eNum1H, Seq(1, 2, 4, 8).map(num =>num.U))
127  val eStart = vuopIdx * eNum
128  val maskForFflags = Mux(vecCtrl.fpu.isFpToVecInst, allMaskTrue, srcMask)
129  val maskPart = maskForFflags >> eStart
130  val mask =  Mux1H(eNum1H, Seq(1, 2, 4, 8).map(num => maskPart(num-1, 0)))
131  val fflagsEn = Wire(Vec(4 * numVecModule, Bool()))
132
133  fflagsEn := mask.asBools.zipWithIndex.map{case(mask, i) => mask & (eNumEffectIdx > eStart + i.U) }
134
135  val fflagsEnCycle2 = RegEnable(RegEnable(fflagsEn, fire), fireReg)
136  val fflagsAll = Wire(Vec(8, UInt(5.W)))
137  fflagsAll := vfcvtFflags.asTypeOf(fflagsAll)
138  val fflags = fflagsEnCycle2.zip(fflagsAll).map{case(en, fflag) => Mux(en, fflag, 0.U(5.W))}.reduce(_ | _)
139  io.out.bits.res.fflags.get := Mux(outIsMvInst, 0.U, fflags)
140
141
142  /**
143   * [[mgu]]'s in connection
144   */
145  val resultDataUInt = Wire(UInt(dataWidth.W))
146  resultDataUInt := vfcvtResult
147
148  private val narrow = RegEnable(RegEnable(isNarrowCvt, fire), fireReg)
149  private val narrowNeedCat = outVecCtrl.vuopIdx(0).asBool && narrow
150  private val outNarrowVd = Mux(narrowNeedCat, Cat(resultDataUInt(dataWidth / 2 - 1, 0), outOldVd(dataWidth / 2 - 1, 0)),
151                                               Cat(outOldVd(dataWidth - 1, dataWidth / 2), resultDataUInt(dataWidth / 2 - 1, 0)))
152
153  // mgu.io.in.vd := resultDataUInt
154  mgu.io.in.vd := Mux(narrow, outNarrowVd, resultDataUInt)
155  mgu.io.in.oldVd := outOldVd
156  mgu.io.in.mask := maskToMgu
157  mgu.io.in.info.ta := outVecCtrl.vta
158  mgu.io.in.info.ma := outVecCtrl.vma
159  mgu.io.in.info.vl := Mux(outVecCtrl.fpu.isFpToVecInst, 1.U, outVl)
160  mgu.io.in.info.vlmul := outVecCtrl.vlmul
161  mgu.io.in.info.valid := io.out.valid
162  mgu.io.in.info.vstart := Mux(outVecCtrl.fpu.isFpToVecInst, 0.U, outVecCtrl.vstart)
163  mgu.io.in.info.eew := outEew
164  mgu.io.in.info.vsew := outVecCtrl.vsew
165  mgu.io.in.info.vdIdx := outVecCtrl.vuopIdx
166  mgu.io.in.info.narrow := narrow
167  mgu.io.in.info.dstMask := outVecCtrl.isDstMask
168  mgu.io.in.isIndexedVls := false.B
169
170  // for scalar f2i cvt inst
171  val isFp2VecForInt = outVecCtrl.fpu.isFpToVecInst && outIs32bits && outIsInt
172  // for f2i mv inst
173  val result = Mux(outIsMvInst, RegNext(RegNext(vs2.tail(64))), mgu.io.out.vd)
174
175  io.out.bits.res.data := Mux(isFp2VecForInt,
176    Fill(32, result(31)) ## result(31, 0),
177    result
178  )
179  io.out.bits.ctrl.exceptionVec.get(ExceptionNO.illegalInstr) := mgu.io.out.illegal
180}
181
182class VectorCvtTopIO(vlen: Int, xlen: Int) extends Bundle{
183  val fire = Input(Bool())
184  val uopIdx = Input(Bool())
185  val src = Input(Vec(vlen / xlen, UInt(xlen.W)))
186  val opType = Input(UInt(8.W))
187  val sew = Input(UInt(2.W))
188  val rm = Input(UInt(3.W))
189  val outputWidth1H = Input(UInt(4.W))
190  val isWiden = Input(Bool())
191  val isNarrow = Input(Bool())
192  val isFpToVecInst = Input(Bool())
193
194  val result = Output(UInt(vlen.W))
195  val fflags = Output(UInt((vlen/16*5).W))
196}
197
198
199
200//according to uopindex, 1: high64 0:low64
201class VectorCvtTop(vlen: Int, xlen: Int) extends Module{
202  val io = IO(new VectorCvtTopIO(vlen, xlen))
203
204  val (fire, uopIdx, src, opType, sew, rm, outputWidth1H, isWiden, isNarrow, isFpToVecInst) = (
205    io.fire, io.uopIdx, io.src, io.opType, io.sew, io.rm, io.outputWidth1H, io.isWiden, io.isNarrow, io.isFpToVecInst
206  )
207  val fireReg = GatedValidRegNext(fire)
208
209  val in0 = Mux(isWiden && !isFpToVecInst,
210    Mux(uopIdx, src(1).tail(32), src(0).tail(32)),
211    src(0)
212  )
213
214  val in1 = Mux(isWiden,
215    Mux(uopIdx, src(1).head(32), src(0).head(32)),
216    src(1)
217  )
218
219  val vectorCvt0 = Module(new VectorCvt(xlen))
220  vectorCvt0.fire := fire
221  vectorCvt0.src := in0
222  vectorCvt0.opType := opType
223  vectorCvt0.sew := sew
224  vectorCvt0.rm := rm
225  vectorCvt0.isFpToVecInst := isFpToVecInst
226  vectorCvt0.isFround := 0.U
227  vectorCvt0.isFcvtmod := false.B
228
229  val vectorCvt1 = Module(new VectorCvt(xlen))
230  vectorCvt1.fire := fire
231  vectorCvt1.src := in1
232  vectorCvt1.opType := opType
233  vectorCvt1.sew := sew
234  vectorCvt1.rm := rm
235  vectorCvt1.isFpToVecInst := isFpToVecInst
236  vectorCvt1.isFround := 0.U
237  vectorCvt1.isFcvtmod := false.B
238
239  val isNarrowCycle2 = RegEnable(RegEnable(isNarrow, fire), fireReg)
240  val outputWidth1HCycle2 = RegEnable(RegEnable(outputWidth1H, fire), fireReg)
241
242  //cycle2
243  io.result := Mux(isNarrowCycle2,
244    vectorCvt1.io.result.tail(32) ## vectorCvt0.io.result.tail(32),
245    vectorCvt1.io.result ## vectorCvt0.io.result)
246
247  io.fflags := Mux1H(outputWidth1HCycle2, Seq(
248    vectorCvt1.io.fflags ## vectorCvt0.io.fflags,
249    Mux(isNarrowCycle2, vectorCvt1.io.fflags.tail(10) ## vectorCvt0.io.fflags.tail(10), vectorCvt1.io.fflags ## vectorCvt0.io.fflags),
250    Mux(isNarrowCycle2, vectorCvt1.io.fflags(4,0) ## vectorCvt0.io.fflags(4,0), vectorCvt1.io.fflags.tail(10) ## vectorCvt0.io.fflags.tail(10)),
251    vectorCvt1.io.fflags(4,0) ## vectorCvt0.io.fflags(4,0)
252  ))
253}
254
255
256