xref: /XiangShan/src/main/scala/xiangshan/backend/fu/wrapper/VCVT.scala (revision 66c730349685672f9cb0a0a90fdd90d67e233f71)
1package xiangshan.backend.fu.wrapper
2
3import chipsalliance.rocketchip.config.Parameters
4import chisel3._
5import chisel3.util._
6import chisel3.util.experimental.decode._
7import utils.XSError
8import xiangshan.backend.fu.FuConfig
9import xiangshan.backend.fu.vector.{Mgu, VecPipedFuncUnit}
10import yunsuan.VfpuType
11
12
13class VCVT(cfg: FuConfig)(implicit p: Parameters) extends VecPipedFuncUnit(cfg) {
14  XSError(io.in.valid && io.in.bits.ctrl.fuOpType === VfpuType.dummy, "Vfcvt OpType not supported")
15
16  // params alias
17  private val dataWidth = cfg.dataBits
18  private val dataWidthOfDataModule = 64
19  private val numVecModule = dataWidth / dataWidthOfDataModule
20
21  // io alias
22  private val opcode = fuOpType(7, 0)
23  private val sew = vsew
24  private val rm = frm
25  private val lmul = vlmul // -3->3 => 1/8 ->8
26
27  val widen = opcode(4, 3) // 0->single 1->widen 2->norrow => width of result
28  val isSingleCvt = !widen(1) & !widen(0)
29  val isWidenCvt = !widen(1) & widen(0)
30  val isNarrowCvt = widen(1) & !widen(0)
31
32  // output width 8, 16, 32, 64
33  val output1H = Wire(UInt(4.W))
34  output1H := chisel3.util.experimental.decode.decoder(
35    widen ## sew,
36    TruthTable(
37      Seq(
38        BitPat("b00_01") -> BitPat("b0010"), // 16
39        BitPat("b00_10") -> BitPat("b0100"), // 32
40        BitPat("b00_11") -> BitPat("b1000"), // 64
41
42        BitPat("b01_00") -> BitPat("b0010"), // 16
43        BitPat("b01_01") -> BitPat("b0100"), // 32
44        BitPat("b01_10") -> BitPat("b1000"), // 64
45
46        BitPat("b10_00") -> BitPat("b0001"), // 8
47        BitPat("b10_01") -> BitPat("b0010"), // 16
48        BitPat("b10_10") -> BitPat("b0100"), // 32
49      ),
50      BitPat("b0000") //f8, don't exist
51    )
52  )
53  dontTouch(output1H)
54  val outputWidth1H = output1H
55
56  val outEew = RegNext(RegNext(Mux1H(output1H, Seq(0,1,2,3).map(i => i.U))))
57  private val needNoMask = outVecCtrl.fpu.isFpToVecInst
58  val maskToMgu = Mux(needNoMask, allMaskTrue, outSrcMask)
59
60  // modules
61  private val vfcvt = Module(new VectorCvtTop(dataWidth, dataWidthOfDataModule))
62  private val mgu = Module(new Mgu(dataWidth)) //看看那mgu里干了什么活
63
64  /**
65   * [[vfcvt]]'s in connection
66   */
67  vfcvt.uopIdx := vuopIdx(0)
68  vfcvt.src := vs2
69  vfcvt.opType := opcode
70  vfcvt.sew := sew
71  vfcvt.rm := rm
72  vfcvt.outputWidth1H := outputWidth1H
73  vfcvt.isWiden := isWidenCvt
74  vfcvt.isNarrow := isNarrowCvt
75  val vfcvtResult = vfcvt.result
76  val vfcvtFflags = vfcvt.fflags
77
78  /** fflags:
79   */
80  //uopidx:每一个uopidex:都有相应的元素个数以及对应的mask的位置
81  //vl: 决定每个向量寄存器组里有多少个元素参与运算
82  //  128/(sew+1)*8
83  //  val num = 16.U >> sew
84  //  val eNum = Mux(isWiden || isNorrow , num, num >> 1).asUInt //每个uopidx的向量元素的个数
85  // num of vector element in each uopidx
86
87  // 每个uopidx最大的元素个数,即一个向量寄存器所能容纳的Max(inwidth, outwidth)的元素的个数
88  // single/narrow的话是一个向量寄存器所容纳输入元素的个数, widen的话是是一个向量寄存器输出元素的的个数
89  // todo: 为什么用4bit,虽然一个Vector reg中的元素个数至少是2,但是存在lmul=1/2时 64->64的single,方便下面计算eNum的max(可能为1)
90  val eNum1H = chisel3.util.experimental.decode.decoder(sew ## (isWidenCvt || isNarrowCvt),
91    TruthTable(
92      Seq(                     // 8, 4, 2, 1
93        BitPat("b000") -> BitPat("b1000"), //8,
94        BitPat("b001") -> BitPat("b1000"), //8
95        BitPat("b010") -> BitPat("b1000"), //8
96        BitPat("b011") -> BitPat("b0100"), //4
97        BitPat("b100") -> BitPat("b0100"), //4
98        BitPat("b101") -> BitPat("b0010"), //2
99        BitPat("b110") -> BitPat("b0010"), //2
100        BitPat("b111") -> BitPat("b0010")  //2
101      ),
102      BitPat.N(4)
103    )
104  )
105  val eNumMax1H = Mux(lmul.head(1).asBool, eNum1H >> ((~lmul.tail(1)).asUInt +1.U), eNum1H.tail(1) << lmul).asUInt(6, 0)
106  val eNumMax = Mux1H(eNumMax1H, Seq(1,2,4,8,16,32,64).map(i => i.U)) //only for cvt intr, don't exist 128 in cvt
107  val eNumEffect = Mux(vl > eNumMax, eNumMax, vl)
108  val fflagsAll = Vec(8, UInt(5.W))
109  fflagsAll := vfcvtFflags.asTypeOf(fflagsAll)
110  // mask,vl和lmul => 某个输入的向量元素是否有效
111  //  val numGrounp = lmul * 128/eew
112  val mask = Mux1H(eNum1H, Seq(1, 2, 4, 8).map(num => (srcMask >> vuopIdx * num.U)(num-1, 0))) //mask是从每个uop的最大元素个数来的
113  val fflagsEn = Wire(Vec(4 * numVecModule, Bool()))
114  fflagsEn := mask.asBools.zipWithIndex.map{case(mask, i) =>
115    mask & (eNumEffect >= Mux1H(eNum1H, Seq(1, 2, 4, 8).map(num =>vuopIdx * num.U + i.U))) } //被vl和lmul所约束
116
117  val fflagsEnCycle2 = RegNext(RegNext(fflagsEn))
118  val fflags = fflagsEnCycle2.zip(fflagsAll).map{case(en, fflag) => Mux(en, fflag, 0.U(5.W))}.reduce(_ | _)
119  io.out.bits.res.fflags.get := fflags
120
121
122  /**
123   * [[mgu]]'s in connection 后面看看是否必须要这个mgu
124   */
125  val resultDataUInt = Wire(UInt(dataWidth.W)) //todo: from vfcvt
126  resultDataUInt := vfcvtResult
127
128  mgu.io.in.vd := resultDataUInt
129  mgu.io.in.oldVd := outOldVd
130  mgu.io.in.mask := maskToMgu
131  mgu.io.in.info.ta := outVecCtrl.vta
132  mgu.io.in.info.ma := outVecCtrl.vma
133  mgu.io.in.info.vl := Mux(outVecCtrl.fpu.isFpToVecInst, 1.U, outVl)
134  mgu.io.in.info.vlmul := outVecCtrl.vlmul
135  mgu.io.in.info.valid := io.out.valid
136  mgu.io.in.info.vstart := Mux(outVecCtrl.fpu.isFpToVecInst, 0.U, outVecCtrl.vstart)
137  mgu.io.in.info.eew := outEew
138  mgu.io.in.info.vsew := outVecCtrl.vsew
139  mgu.io.in.info.vdIdx := outVecCtrl.vuopIdx
140  mgu.io.in.info.narrow := RegNext(RegNext(isNarrowCvt))
141  mgu.io.in.info.dstMask := outVecCtrl.isDstMask
142
143  io.out.bits.res.data := mgu.io.out.vd
144}
145
146//according to uopindex, 1: high64 0:low64
147class VectorCvtTop(vlen :Int, xlen: Int) extends Module{
148  val uopIdx = UInt(1.W)
149  val src = Input(Vec(vlen/xlen, UInt(xlen.W)))
150  val opType = Input(UInt(8.W))
151  val sew = Input(UInt(2.W))
152  val rm = Input(UInt(3.W))
153  val outputWidth1H = Input(UInt(4.W))
154  val isWiden = Input(Bool())
155  val isNarrow = Input(Bool())
156
157  val result = Output(Vec(vlen/xlen, UInt(xlen.W)))
158  val fflags = Output(Vec(vlen/xlen * (xlen / 16), UInt(5.W)))
159
160  val in0 = Mux(isWiden,
161    Mux(uopIdx(0), src(1).tail(32), src(0).tail(32)),
162    src(0)
163  )
164
165  val in1 = Mux(isWiden,
166    Mux(uopIdx(0), src(1).head(32), src(0).head(32)),
167    src(1)
168  )
169
170  val vectorCvt0 = Module(new VectorCvt(xlen))
171  vectorCvt0.src := in0
172  vectorCvt0.opType := opType
173  vectorCvt0.sew := sew
174  vectorCvt0.rm := rm
175
176  val vectorCvt1 = Module(new VectorCvt(xlen))
177  vectorCvt1.src := in1
178  vectorCvt1.opType := opType
179  vectorCvt1.sew := sew
180  vectorCvt1.rm := rm
181
182  val isNarrowCycle2 = RegNext(RegNext(isNarrow))
183  val outputWidth1HCycle2 = RegNext(RegNext(outputWidth1H))
184
185  //cycle2
186  result := Mux(isNarrowCycle2,
187    vectorCvt1.io.result.tail(32) ## vectorCvt0.io.result.tail(32),
188    vectorCvt1.io.result ## vectorCvt1.io.result)
189
190  fflags := Mux1H(outputWidth1HCycle2, Seq( // todo: map between fflags and result
191    vectorCvt1.io.fflags ## vectorCvt0.io.fflags,
192    Mux(isNarrowCycle2, vectorCvt1.io.fflags.tail(10) ## vectorCvt0.io.fflags.tail(10), vectorCvt1.io.fflags ## vectorCvt0.io.fflags),
193    Mux(isNarrowCycle2, vectorCvt1.io.fflags(4,0) ## vectorCvt0.io.fflags(4,0), vectorCvt1.io.fflags.tail(10) ## vectorCvt0.io.fflags.tail(10)),
194    vectorCvt1.io.fflags(4,0) ## vectorCvt0.io.fflags(4,0)
195  ))
196}
197
198
199