xref: /XiangShan/src/main/scala/xiangshan/backend/fu/wrapper/VCVT.scala (revision 9d3cebe77f43ab9001b88cc6e6fd8b0d98dc0737)
1package xiangshan.backend.fu.wrapper
2
3import chipsalliance.rocketchip.config.Parameters
4import chisel3._
5import chisel3.util._
6import chisel3.util.experimental.decode._
7import utils.XSError
8import xiangshan.backend.fu.FuConfig
9import xiangshan.backend.fu.vector.{Mgu, VecPipedFuncUnit}
10import yunsuan.VfpuType
11import yunsuan.vector.VectorConvert.VectorCvt
12
13
14class VCVT(cfg: FuConfig)(implicit p: Parameters) extends VecPipedFuncUnit(cfg) {
15  XSError(io.in.valid && io.in.bits.ctrl.fuOpType === VfpuType.dummy, "Vfcvt OpType not supported")
16
17  // params alias
18  private val dataWidth = cfg.dataBits
19  private val dataWidthOfDataModule = 64
20  private val numVecModule = dataWidth / dataWidthOfDataModule
21
22  // io alias
23  private val opcode = fuOpType(7, 0)
24  private val sew = vsew
25
26  private val isRtz = opcode(2) & opcode(1)
27  private val isRod = opcode(2) & !opcode(1) & opcode(0)
28  private val isFrm = !isRtz && !isRod
29  private val rm = Mux1H(
30    Seq(isRtz, isRod, isFrm),
31    Seq(1.U, 6.U, frm)
32  )
33
34  private val lmul = vlmul // -3->3 => 1/8 ->8
35
36  val widen = opcode(4, 3) // 0->single 1->widen 2->norrow => width of result
37  val isSingleCvt = !widen(1) & !widen(0)
38  val isWidenCvt = !widen(1) & widen(0)
39  val isNarrowCvt = widen(1) & !widen(0)
40
41  // output width 8, 16, 32, 64
42  val output1H = Wire(UInt(4.W))
43  output1H := chisel3.util.experimental.decode.decoder(
44    widen ## sew,
45    TruthTable(
46      Seq(
47        BitPat("b00_01") -> BitPat("b0010"), // 16
48        BitPat("b00_10") -> BitPat("b0100"), // 32
49        BitPat("b00_11") -> BitPat("b1000"), // 64
50
51        BitPat("b01_00") -> BitPat("b0010"), // 16
52        BitPat("b01_01") -> BitPat("b0100"), // 32
53        BitPat("b01_10") -> BitPat("b1000"), // 64
54
55        BitPat("b10_00") -> BitPat("b0001"), // 8
56        BitPat("b10_01") -> BitPat("b0010"), // 16
57        BitPat("b10_10") -> BitPat("b0100"), // 32
58      ),
59      BitPat.N(4)
60    )
61  )
62  dontTouch(output1H)
63  val outputWidth1H = output1H
64
65  val outEew = RegNext(RegNext(Mux1H(output1H, Seq(0,1,2,3).map(i => i.U))))
66  private val needNoMask = outVecCtrl.fpu.isFpToVecInst
67  val maskToMgu = Mux(needNoMask, allMaskTrue, outSrcMask)
68
69  // modules
70  private val vfcvt = Module(new VectorCvtTop(dataWidth, dataWidthOfDataModule))
71  private val mgu = Module(new Mgu(dataWidth)) //看看那mgu里干了什么活
72
73  val vs2Vec = Wire(Vec(numVecModule, UInt(dataWidthOfDataModule.W)))
74  vs2Vec := vs2.asTypeOf(vs2Vec)
75
76  /**
77   * [[vfcvt]]'s in connection
78   */
79  vfcvt.uopIdx := vuopIdx(0)
80  vfcvt.src := vs2Vec
81  vfcvt.opType := opcode
82  vfcvt.sew := sew
83  vfcvt.rm := rm
84  vfcvt.outputWidth1H := outputWidth1H
85  vfcvt.isWiden := isWidenCvt
86  vfcvt.isNarrow := isNarrowCvt
87  val vfcvtResult = vfcvt.io.result
88  val vfcvtFflags = vfcvt.io.fflags
89
90  /** fflags:
91   */
92  //uopidx:每一个uopidex:都有相应的元素个数以及对应的mask的位置
93  //vl: 决定每个向量寄存器组里有多少个元素参与运算
94  //  128/(sew+1)*8
95  //  val num = 16.U >> sew
96  //  val eNum = Mux(isWiden || isNorrow , num, num >> 1).asUInt //每个uopidx的向量元素的个数
97  // num of vector element in each uopidx
98
99  // 每个uopidx最大的元素个数,即一个向量寄存器所能容纳的Max(inwidth, outwidth)的元素的个数
100  // single/narrow的话是一个向量寄存器所容纳输入元素的个数, widen的话是是一个向量寄存器输出元素的的个数
101  // todo: 为什么用4bit,虽然一个Vector reg中的元素个数至少是2,但是存在lmul=1/2时 64->64的single,方便下面计算eNum的max(可能为1)
102  val eNum1H = chisel3.util.experimental.decode.decoder(sew ## (isWidenCvt || isNarrowCvt),
103    TruthTable(
104      Seq(                     // 8, 4, 2, 1
105        BitPat("b001") -> BitPat("b1000"), //8
106        BitPat("b010") -> BitPat("b1000"), //8
107        BitPat("b011") -> BitPat("b0100"), //4
108        BitPat("b100") -> BitPat("b0100"), //4
109        BitPat("b101") -> BitPat("b0010"), //2
110        BitPat("b110") -> BitPat("b0010"), //2
111      ),
112      BitPat.N(4)
113    )
114  )
115  val eNumMax1H = Mux(lmul.head(1).asBool, eNum1H >> ((~lmul.tail(1)).asUInt +1.U), eNum1H << lmul.tail(1)).asUInt(6, 0)
116  val eNumMax = Mux1H(eNumMax1H, Seq(1,2,4,8,16,32,64).map(i => i.U)) //only for cvt intr, don't exist 128 in cvt
117  val eNumEffectIdx = Mux(vl > eNumMax, eNumMax, vl)
118
119  //  mask,vl和lmul => 某个输入的向量元素是否有效
120  //  val numGrounp = lmul * 128/eew
121  // val mask = Mux1H(eNum1H, Seq(1, 2, 4, 8).map(num => (srcMask >> vuopIdx * num.U)(num-1, 0))) //mask是从每个uop的最大元素个数来的
122  val eNum = Mux1H(eNum1H, Seq(1, 2, 4, 8).map(num =>num.U))
123  val eStart = vuopIdx * eNum
124  val maskPart = srcMask >> eStart
125  val mask =  Mux1H(eNum1H, Seq(1, 2, 4, 8).map(num => maskPart(num-1, 0)))
126  val fflagsEn = Wire(Vec(4 * numVecModule, Bool()))
127
128  //  fflagsEn := mask.asBools.zipWithIndex.map { case (mask, i) =>
129  //    mask & (eNumEffect > Mux1H(eNum1H, Seq(1, 2, 4, 8).map(num => vuopIdx * num.U + i.U)))
130  //  }
131  //  vl: [0, vl)  eNumMax: [0, eNumMax) => eNumEffect为索引加1即length, 其右边的都为索引即length-1,所以 >
132  fflagsEn := mask.asBools.zipWithIndex.map{case(mask, i) => mask & (eNumEffectIdx > eStart + i.U) } //被vl和lmul所约束
133
134  val fflagsEnCycle2 = RegNext(RegNext(fflagsEn))
135  val fflagsAll = Wire(Vec(8, UInt(5.W)))
136  fflagsAll := vfcvtFflags.asTypeOf(fflagsAll)
137  val fflags = fflagsEnCycle2.zip(fflagsAll).map{case(en, fflag) => Mux(en, fflag, 0.U(5.W))}.reduce(_ | _)
138  io.out.bits.res.fflags.get := fflags
139
140
141  /**
142   * [[mgu]]'s in connection 后面看看是否必须要这个mgu
143   */
144  val resultDataUInt = Wire(UInt(dataWidth.W)) //todo: from vfcvt
145  resultDataUInt := vfcvtResult
146
147  mgu.io.in.vd := resultDataUInt
148  mgu.io.in.oldVd := outOldVd
149  mgu.io.in.mask := maskToMgu
150  mgu.io.in.info.ta := outVecCtrl.vta
151  mgu.io.in.info.ma := outVecCtrl.vma
152  mgu.io.in.info.vl := Mux(outVecCtrl.fpu.isFpToVecInst, 1.U, outVl)
153  mgu.io.in.info.vlmul := outVecCtrl.vlmul
154  mgu.io.in.info.valid := io.out.valid
155  mgu.io.in.info.vstart := Mux(outVecCtrl.fpu.isFpToVecInst, 0.U, outVecCtrl.vstart)
156  mgu.io.in.info.eew := outEew
157  mgu.io.in.info.vsew := outVecCtrl.vsew
158  mgu.io.in.info.vdIdx := outVecCtrl.vuopIdx
159  mgu.io.in.info.narrow := RegNext(RegNext(isNarrowCvt))
160  mgu.io.in.info.dstMask := outVecCtrl.isDstMask
161
162  io.out.bits.res.data := mgu.io.out.vd
163}
164
165class VectorCvtTopIO(vlen :Int, xlen: Int) extends Bundle{
166  val uopIdx = Input(Bool())
167  val src = Input(Vec(vlen / xlen, UInt(xlen.W)))
168  val opType = Input(UInt(8.W))
169  val sew = Input(UInt(2.W))
170  val rm = Input(UInt(3.W))
171  val outputWidth1H = Input(UInt(4.W))
172  val isWiden = Input(Bool())
173  val isNarrow = Input(Bool())
174
175  val result = Output(UInt(vlen.W))
176  val fflags = Output(UInt((vlen/16*5).W))
177}
178
179
180
181//according to uopindex, 1: high64 0:low64
182class VectorCvtTop(vlen :Int, xlen: Int) extends Module{
183  val io = IO(new VectorCvtTopIO(vlen, xlen))
184
185  val (uopIdx, src, opType, sew, rm, outputWidth1H, isWiden, isNarrow) = (
186    io.uopIdx, io.src, io.opType, io.sew, io.rm, io.outputWidth1H, io.isWiden, io.isNarrow
187  )
188
189  val in0 = Mux(isWiden,
190    Mux(uopIdx, src(1).tail(32), src(0).tail(32)),
191    src(0)
192  )
193
194  val in1 = Mux(isWiden,
195    Mux(uopIdx, src(1).head(32), src(0).head(32)),
196    src(1)
197  )
198
199  val vectorCvt0 = Module(new VectorCvt(xlen))
200  vectorCvt0.src := in0
201  vectorCvt0.opType := opType
202  vectorCvt0.sew := sew
203  vectorCvt0.rm := rm
204
205  val vectorCvt1 = Module(new VectorCvt(xlen))
206  vectorCvt1.src := in1
207  vectorCvt1.opType := opType
208  vectorCvt1.sew := sew
209  vectorCvt1.rm := rm
210
211  val isNarrowCycle2 = RegNext(RegNext(isNarrow))
212  val outputWidth1HCycle2 = RegNext(RegNext(outputWidth1H))
213
214  //cycle2
215  io.result := Mux(isNarrowCycle2,
216    vectorCvt1.io.result.tail(32) ## vectorCvt0.io.result.tail(32),
217    vectorCvt1.io.result ## vectorCvt0.io.result)
218
219  io.fflags := Mux1H(outputWidth1HCycle2, Seq( // todo: map between fflags and result
220    vectorCvt1.io.fflags ## vectorCvt0.io.fflags,
221    Mux(isNarrowCycle2, vectorCvt1.io.fflags.tail(10) ## vectorCvt0.io.fflags.tail(10), vectorCvt1.io.fflags ## vectorCvt0.io.fflags),
222    Mux(isNarrowCycle2, vectorCvt1.io.fflags(4,0) ## vectorCvt0.io.fflags(4,0), vectorCvt1.io.fflags.tail(10) ## vectorCvt0.io.fflags.tail(10)),
223    vectorCvt1.io.fflags(4,0) ## vectorCvt0.io.fflags(4,0)
224  ))
225}
226
227
228