xref: /XiangShan/src/main/scala/xiangshan/backend/fu/wrapper/VCVT.scala (revision 8d081717cfe3863e6c0d38305abd6327c65a653a)
1package xiangshan.backend.fu.wrapper
2
3import org.chipsalliance.cde.config.Parameters
4import chisel3._
5import chisel3.util._
6import chisel3.util.experimental.decode._
7import utils.XSError
8import xiangshan.backend.fu.FuConfig
9import xiangshan.backend.fu.vector.{Mgu, VecPipedFuncUnit}
10import xiangshan.ExceptionNO
11import yunsuan.VfpuType
12import yunsuan.vector.VectorConvert.VectorCvt
13
14
15class VCVT(cfg: FuConfig)(implicit p: Parameters) extends VecPipedFuncUnit(cfg) {
16  XSError(io.in.valid && io.in.bits.ctrl.fuOpType === VfpuType.dummy, "Vfcvt OpType not supported")
17
18  // params alias
19  private val dataWidth = cfg.dataBits
20  private val dataWidthOfDataModule = 64
21  private val numVecModule = dataWidth / dataWidthOfDataModule
22
23  // io alias
24  private val opcode = fuOpType(7, 0)
25  private val sew = vsew
26
27  private val isRtz = opcode(2) & opcode(1)
28  private val isRod = opcode(2) & !opcode(1) & opcode(0)
29  private val isFrm = !isRtz && !isRod
30  private val rm = Mux1H(
31    Seq(isRtz, isRod, isFrm),
32    Seq(1.U, 6.U, frm)
33  )
34
35  private val lmul = vlmul // -3->3 => 1/8 ->8
36
37  val widen = opcode(4, 3) // 0->single 1->widen 2->norrow => width of result
38  val isSingleCvt = !widen(1) & !widen(0)
39  val isWidenCvt = !widen(1) & widen(0)
40  val isNarrowCvt = widen(1) & !widen(0)
41
42  // output width 8, 16, 32, 64
43  val output1H = Wire(UInt(4.W))
44  output1H := chisel3.util.experimental.decode.decoder(
45    widen ## sew,
46    TruthTable(
47      Seq(
48        BitPat("b00_01") -> BitPat("b0010"), // 16
49        BitPat("b00_10") -> BitPat("b0100"), // 32
50        BitPat("b00_11") -> BitPat("b1000"), // 64
51
52        BitPat("b01_00") -> BitPat("b0010"), // 16
53        BitPat("b01_01") -> BitPat("b0100"), // 32
54        BitPat("b01_10") -> BitPat("b1000"), // 64
55
56        BitPat("b10_00") -> BitPat("b0001"), // 8
57        BitPat("b10_01") -> BitPat("b0010"), // 16
58        BitPat("b10_10") -> BitPat("b0100"), // 32
59      ),
60      BitPat.N(4)
61    )
62  )
63  if(backendParams.debugEn) {
64    dontTouch(output1H)
65  }
66  val outputWidth1H = output1H
67
68  val outEew = RegNext(RegNext(Mux1H(output1H, Seq(0,1,2,3).map(i => i.U))))
69  private val needNoMask = outVecCtrl.fpu.isFpToVecInst
70  val maskToMgu = Mux(needNoMask, allMaskTrue, outSrcMask)
71
72  // modules
73  private val vfcvt = Module(new VectorCvtTop(dataWidth, dataWidthOfDataModule))
74  private val mgu = Module(new Mgu(dataWidth))
75
76  val vs2Vec = Wire(Vec(numVecModule, UInt(dataWidthOfDataModule.W)))
77  vs2Vec := vs2.asTypeOf(vs2Vec)
78
79  /**
80   * [[vfcvt]]'s in connection
81   */
82  vfcvt.uopIdx := vuopIdx(0)
83  vfcvt.src := vs2Vec
84  vfcvt.opType := opcode
85  vfcvt.sew := sew
86  vfcvt.rm := rm
87  vfcvt.outputWidth1H := outputWidth1H
88  vfcvt.isWiden := isWidenCvt
89  vfcvt.isNarrow := isNarrowCvt
90  val vfcvtResult = vfcvt.io.result
91  val vfcvtFflags = vfcvt.io.fflags
92
93  /** fflags:
94   */
95  val eNum1H = chisel3.util.experimental.decode.decoder(sew ## (isWidenCvt || isNarrowCvt),
96    TruthTable(
97      Seq(                     // 8, 4, 2, 1
98        BitPat("b001") -> BitPat("b1000"), //8
99        BitPat("b010") -> BitPat("b1000"), //8
100        BitPat("b011") -> BitPat("b0100"), //4
101        BitPat("b100") -> BitPat("b0100"), //4
102        BitPat("b101") -> BitPat("b0010"), //2
103        BitPat("b110") -> BitPat("b0010"), //2
104      ),
105      BitPat.N(4)
106    )
107  )
108  val eNumMax1H = Mux(lmul.head(1).asBool, eNum1H >> ((~lmul.tail(1)).asUInt +1.U), eNum1H << lmul.tail(1)).asUInt(6, 0)
109  val eNumMax = Mux1H(eNumMax1H, Seq(1,2,4,8,16,32,64).map(i => i.U)) //only for cvt intr, don't exist 128 in cvt
110  val eNumEffectIdx = Mux(vl > eNumMax, eNumMax, vl)
111
112  val eNum = Mux1H(eNum1H, Seq(1, 2, 4, 8).map(num =>num.U))
113  val eStart = vuopIdx * eNum
114  val maskPart = srcMask >> eStart
115  val mask =  Mux1H(eNum1H, Seq(1, 2, 4, 8).map(num => maskPart(num-1, 0)))
116  val fflagsEn = Wire(Vec(4 * numVecModule, Bool()))
117
118  fflagsEn := mask.asBools.zipWithIndex.map{case(mask, i) => mask & (eNumEffectIdx > eStart + i.U) }
119
120  val fflagsEnCycle2 = RegNext(RegNext(fflagsEn))
121  val fflagsAll = Wire(Vec(8, UInt(5.W)))
122  fflagsAll := vfcvtFflags.asTypeOf(fflagsAll)
123  val fflags = fflagsEnCycle2.zip(fflagsAll).map{case(en, fflag) => Mux(en, fflag, 0.U(5.W))}.reduce(_ | _)
124  io.out.bits.res.fflags.get := fflags
125
126
127  /**
128   * [[mgu]]'s in connection
129   */
130  val resultDataUInt = Wire(UInt(dataWidth.W))
131  resultDataUInt := vfcvtResult
132
133  private val narrow = RegNext(RegNext(isNarrowCvt))
134  private val narrowNeedCat = outVecCtrl.vuopIdx(0).asBool && narrow
135  private val outNarrowVd = Mux(narrowNeedCat, Cat(resultDataUInt(dataWidth / 2 - 1, 0), outOldVd(dataWidth / 2 - 1, 0)), resultDataUInt)
136
137  mgu.io.in.vd := resultDataUInt
138  mgu.io.in.vd := Mux(narrow, outNarrowVd, resultDataUInt)
139  mgu.io.in.oldVd := outOldVd
140  mgu.io.in.mask := maskToMgu
141  mgu.io.in.info.ta := outVecCtrl.vta
142  mgu.io.in.info.ma := outVecCtrl.vma
143  mgu.io.in.info.vl := Mux(outVecCtrl.fpu.isFpToVecInst, 1.U, outVl)
144  mgu.io.in.info.vlmul := outVecCtrl.vlmul
145  mgu.io.in.info.valid := io.out.valid
146  mgu.io.in.info.vstart := Mux(outVecCtrl.fpu.isFpToVecInst, 0.U, outVecCtrl.vstart)
147  mgu.io.in.info.eew := outEew
148  mgu.io.in.info.vsew := outVecCtrl.vsew
149  mgu.io.in.info.vdIdx := outVecCtrl.vuopIdx
150  mgu.io.in.info.narrow := narrow
151  mgu.io.in.info.dstMask := outVecCtrl.isDstMask
152  mgu.io.in.isIndexedVls := false.B
153
154  io.out.bits.res.data := mgu.io.out.vd
155  io.out.bits.ctrl.exceptionVec.get(ExceptionNO.illegalInstr) := mgu.io.out.illegal
156}
157
158class VectorCvtTopIO(vlen: Int, xlen: Int) extends Bundle{
159  val uopIdx = Input(Bool())
160  val src = Input(Vec(vlen / xlen, UInt(xlen.W)))
161  val opType = Input(UInt(8.W))
162  val sew = Input(UInt(2.W))
163  val rm = Input(UInt(3.W))
164  val outputWidth1H = Input(UInt(4.W))
165  val isWiden = Input(Bool())
166  val isNarrow = Input(Bool())
167
168  val result = Output(UInt(vlen.W))
169  val fflags = Output(UInt((vlen/16*5).W))
170}
171
172
173
174//according to uopindex, 1: high64 0:low64
175class VectorCvtTop(vlen: Int, xlen: Int) extends Module{
176  val io = IO(new VectorCvtTopIO(vlen, xlen))
177
178  val (uopIdx, src, opType, sew, rm, outputWidth1H, isWiden, isNarrow) = (
179    io.uopIdx, io.src, io.opType, io.sew, io.rm, io.outputWidth1H, io.isWiden, io.isNarrow
180  )
181
182  val in0 = Mux(isWiden,
183    Mux(uopIdx, src(1).tail(32), src(0).tail(32)),
184    src(0)
185  )
186
187  val in1 = Mux(isWiden,
188    Mux(uopIdx, src(1).head(32), src(0).head(32)),
189    src(1)
190  )
191
192  val vectorCvt0 = Module(new VectorCvt(xlen))
193  vectorCvt0.src := in0
194  vectorCvt0.opType := opType
195  vectorCvt0.sew := sew
196  vectorCvt0.rm := rm
197
198  val vectorCvt1 = Module(new VectorCvt(xlen))
199  vectorCvt1.src := in1
200  vectorCvt1.opType := opType
201  vectorCvt1.sew := sew
202  vectorCvt1.rm := rm
203
204  val isNarrowCycle2 = RegNext(RegNext(isNarrow))
205  val outputWidth1HCycle2 = RegNext(RegNext(outputWidth1H))
206
207  //cycle2
208  io.result := Mux(isNarrowCycle2,
209    vectorCvt1.io.result.tail(32) ## vectorCvt0.io.result.tail(32),
210    vectorCvt1.io.result ## vectorCvt0.io.result)
211
212  io.fflags := Mux1H(outputWidth1HCycle2, Seq(
213    vectorCvt1.io.fflags ## vectorCvt0.io.fflags,
214    Mux(isNarrowCycle2, vectorCvt1.io.fflags.tail(10) ## vectorCvt0.io.fflags.tail(10), vectorCvt1.io.fflags ## vectorCvt0.io.fflags),
215    Mux(isNarrowCycle2, vectorCvt1.io.fflags(4,0) ## vectorCvt0.io.fflags(4,0), vectorCvt1.io.fflags.tail(10) ## vectorCvt0.io.fflags.tail(10)),
216    vectorCvt1.io.fflags(4,0) ## vectorCvt0.io.fflags(4,0)
217  ))
218}
219
220
221