xref: /XiangShan/src/main/scala/xiangshan/backend/fu/wrapper/VIAluFix.scala (revision cdf8c16ccc1a06b1752d392919d932662936e61b)
1package xiangshan.backend.fu.wrapper
2
3import chipsalliance.rocketchip.config.Parameters
4import chisel3.{VecInit, _}
5import chisel3.util._
6import chisel3.util.experimental.decode.{QMCMinimizer, TruthTable, decoder}
7import utils.XSError
8import xiangshan.XSCoreParamsKey
9import xiangshan.backend.fu.vector.Bundles.{VConfig, VSew, ma}
10import xiangshan.backend.fu.vector.{Mgu, VecPipedFuncUnit}
11import xiangshan.backend.fu.vector.Utils.VecDataToMaskDataVec
12import xiangshan.backend.fu.vector.utils.VecDataSplitModule
13import xiangshan.backend.fu.{FuConfig, FuType}
14import yunsuan.{OpType, VialuFixType}
15import yunsuan.vector.alu.{VIntFixpAlu64b, VIntFixpDecode, VIntFixpTable}
16import yunsuan.encoding.{VdType, Vs1IntType, Vs2IntType}
17import yunsuan.encoding.Opcode.VialuOpcode
18import yunsuan.vector.SewOH
19
20class VIAluSrcTypeIO extends Bundle {
21  val in = Input(new Bundle {
22    val fuOpType: UInt = OpType()
23    val vsew: UInt = VSew()
24    val isReverse: Bool = Bool() // vrsub, vrdiv
25    val isExt: Bool = Bool()
26    val isDstMask: Bool = Bool() // vvm, vvvm, mmm
27    val isMove: Bool = Bool() // vmv.s.x, vmv.v.v, vmv.v.x, vmv.v.i
28  })
29  val out = Output(new Bundle {
30    val vs1Type: UInt = Vs1IntType()
31    val vs2Type: UInt = Vs2IntType()
32    val vdType: UInt = VdType()
33    val illegal: Bool = Bool()
34    val isVextF2: Bool = Bool()
35    val isVextF4: Bool = Bool()
36    val isVextF8: Bool = Bool()
37  })
38}
39
40class VIAluSrcTypeModule extends Module {
41  val io: VIAluSrcTypeIO = IO(new VIAluSrcTypeIO)
42
43  private val vsew = io.in.vsew
44  private val isExt = io.in.isExt
45  private val isDstMask = io.in.isDstMask
46
47  private val opcode = VialuFixType.getOpcode(io.in.fuOpType)
48  private val isSign = VialuFixType.isSigned(io.in.fuOpType)
49  private val format = VialuFixType.getFormat(io.in.fuOpType)
50
51  private val vsewX2 = vsew + 1.U
52  private val vsewF2 = vsew - 1.U
53  private val vsewF4 = vsew - 2.U
54  private val vsewF8 = vsew - 3.U
55
56  private val isAddSub = opcode === VialuOpcode.vadd || opcode === VialuOpcode.vsub
57  private val isShiftRight = Seq(VialuOpcode.vsrl, VialuOpcode.vsra, VialuOpcode.vssrl, VialuOpcode.vssra).map(fmt => fmt === format).reduce(_ || _)
58  private val isVext = opcode === VialuOpcode.vext
59
60  private val isWiden = isAddSub && Seq(VialuFixType.FMT.VVW, VialuFixType.FMT.WVW).map(fmt => fmt === format).reduce(_ || _)
61  private val isNarrow = isShiftRight && format === VialuFixType.FMT.WVV
62  private val isVextF2 = isVext && format === VialuFixType.FMT.VF2
63  private val isVextF4 = isVext && format === VialuFixType.FMT.VF4
64  private val isVextF8 = isVext && format === VialuFixType.FMT.VF8
65
66  // check illegal
67  private val widenIllegal = isWiden && vsewX2 === VSew.e8
68  private val narrowIllegal = isNarrow && vsewF2 === VSew.e64
69  private val vextIllegal = (isVextF2 && (vsewF2 === VSew.e64)) ||
70    (isVextF4 && (vsewF4 === VSew.e64)) ||
71    (isVextF8 && (vsewF8 === VSew.e64))
72  // Todo: use it
73  private val illegal = widenIllegal || narrowIllegal || vextIllegal
74
75  private val intType = Cat(0.U(1.W), isSign)
76
77  private class Vs2Vs1VdSew extends Bundle {
78    val vs2 = VSew()
79    val vs1 = VSew()
80    val vd = VSew()
81  }
82
83  private class Vs2Vs1VdType extends Bundle {
84    val vs2 = Vs2IntType()
85    val vs1 = Vs1IntType()
86    val vd = VdType()
87  }
88
89  private val addSubSews = Mux1H(Seq(
90    (format === VialuFixType.FMT.VVV) -> Cat(vsew, vsew, vsew),
91    (format === VialuFixType.FMT.VVW) -> Cat(vsew, vsew, vsewX2),
92    (format === VialuFixType.FMT.WVW) -> Cat(vsewX2, vsew, vsewX2),
93    (format === VialuFixType.FMT.WVV) -> Cat(vsewX2, vsew, vsew),
94  )).asTypeOf(new Vs2Vs1VdSew)
95
96  private val vextSews = Mux1H(Seq(
97    (format === VialuFixType.FMT.VF2) -> Cat(vsewF2, vsewF2, vsew),
98    (format === VialuFixType.FMT.VF4) -> Cat(vsewF4, vsewF4, vsew),
99    (format === VialuFixType.FMT.VF8) -> Cat(vsewF8, vsewF8, vsew),
100  )).asTypeOf(new Vs2Vs1VdSew)
101
102  private val maskTypes = Mux1H(Seq(
103    (format === VialuFixType.FMT.VVM) -> Cat(Cat(intType, vsew), Cat(intType, vsew), VdType.mask),
104    (format === VialuFixType.FMT.VVMM) -> Cat(Cat(intType, vsew), Cat(intType, vsew), VdType.mask),
105    (format === VialuFixType.FMT.MMM) -> Cat(Vs2IntType.mask, Vs1IntType.mask, VdType.mask),
106  )).asTypeOf(new Vs2Vs1VdType)
107
108  private val vs2Type = Mux1H(Seq(
109    isDstMask -> maskTypes.vs2,
110    isExt -> Cat(intType, vextSews.vs2),
111    (!isExt && !isDstMask) -> Cat(intType, addSubSews.vs2),
112  ))
113  private val vs1Type = Mux1H(Seq(
114    isDstMask -> maskTypes.vs1,
115    isExt -> Cat(intType, vextSews.vs1),
116    (!isExt && !isDstMask) -> Cat(intType, addSubSews.vs1),
117  ))
118  private val vdType = Mux1H(Seq(
119    isDstMask -> maskTypes.vd,
120    isExt -> Cat(intType, vextSews.vd),
121    (!isExt && !isDstMask) -> Cat(intType, addSubSews.vd),
122  ))
123
124  io.out.vs2Type := vs2Type
125  io.out.vs1Type := vs1Type
126  io.out.vdType := vdType
127  io.out.illegal := illegal
128  io.out.isVextF2 := isVextF2
129  io.out.isVextF4 := isVextF4
130  io.out.isVextF8 := isVextF8
131}
132
133class VIAluFix(cfg: FuConfig)(implicit p: Parameters) extends VecPipedFuncUnit(cfg) {
134  XSError(io.in.valid && io.in.bits.ctrl.fuOpType === VialuFixType.dummy, "VialuF OpType not supported")
135
136  // config params
137  private val dataWidth = cfg.dataBits
138  private val dataWidthOfDataModule = 64
139  private val numVecModule = dataWidth / dataWidthOfDataModule
140
141  // modules
142  private val typeMod = Module(new VIAluSrcTypeModule)
143  private val vs2Split = Module(new VecDataSplitModule(dataWidth, dataWidthOfDataModule))
144  private val vs1Split = Module(new VecDataSplitModule(dataWidth, dataWidthOfDataModule))
145  private val oldVdSplit = Module(new VecDataSplitModule(dataWidth, dataWidthOfDataModule))
146  private val vIntFixpAlus = Seq.fill(numVecModule)(Module(new VIntFixpAlu64b))
147  private val mgu = Module(new Mgu(dataWidth))
148
149  /**
150   * [[typeMod]]'s in connection
151   */
152  typeMod.io.in.fuOpType := fuOpType
153  typeMod.io.in.vsew := vsew
154  typeMod.io.in.isReverse := isReverse
155  typeMod.io.in.isExt := isExt
156  typeMod.io.in.isDstMask := vecCtrl.isDstMask
157  typeMod.io.in.isMove := isMove
158
159  private val vs2GroupedVec32b: Vec[UInt] = VecInit(vs2Split.io.outVec32b.zipWithIndex.groupBy(_._2 % 2).map(x => x._1 -> x._2.map(_._1)).values.map(x => Cat(x.reverse)).toSeq)
160  private val vs2GroupedVec16b: Vec[UInt] = VecInit(vs2Split.io.outVec16b.zipWithIndex.groupBy(_._2 % 2).map(x => x._1 -> x._2.map(_._1)).values.map(x => Cat(x.reverse)).toSeq)
161  private val vs2GroupedVec8b: Vec[UInt] = VecInit(vs2Split.io.outVec8b.zipWithIndex.groupBy(_._2 % 2).map(x => x._1 -> x._2.map(_._1)).values.map(x => Cat(x.reverse)).toSeq)
162  private val vs1GroupedVec: Vec[UInt] = VecInit(vs1Split.io.outVec32b.zipWithIndex.groupBy(_._2 % 2).map(x => x._1 -> x._2.map(_._1)).values.map(x => Cat(x.reverse)).toSeq)
163
164  /**
165   * In connection of [[vs2Split]], [[vs1Split]] and [[oldVdSplit]]
166   */
167  vs2Split.io.inVecData := vs2
168  vs1Split.io.inVecData := vs1
169  oldVdSplit.io.inVecData := oldVd
170
171  /**
172   * [[vIntFixpAlus]]'s in connection
173   */
174  private val opcode = VialuFixType.getOpcode(inCtrl.fuOpType).asTypeOf(vIntFixpAlus.head.io.opcode)
175  private val vs1Type = typeMod.io.out.vs1Type
176  private val vs2Type = typeMod.io.out.vs2Type
177  private val vdType = typeMod.io.out.vdType
178  private val isVextF2 = typeMod.io.out.isVextF2
179  private val isVextF4 = typeMod.io.out.isVextF4
180  private val isVextF8 = typeMod.io.out.isVextF8
181
182  private val truthTable = TruthTable(VIntFixpTable.table, VIntFixpTable.default)
183  private val decoderOut = decoder(QMCMinimizer, Cat(opcode.op), truthTable)
184  private val vIntFixpDecode = decoderOut.asTypeOf(new VIntFixpDecode)
185  private val isFixp = Mux(vIntFixpDecode.misc, opcode.isScalingShift, opcode.isSatAdd || opcode.isAvgAdd)
186  private val widen = opcode.isAddSub && vs1Type(1, 0) =/= vdType(1, 0)
187  private val widen_vs2 = widen && vs2Type(1, 0) =/= vdType(1, 0)
188  private val eewVs1 = SewOH(vs1Type(1, 0))
189  private val eewVd = SewOH(vdType(1, 0))
190
191  // Extension instructions
192  private val vf2 = isVextF2
193  private val vf4 = isVextF4
194  private val vf8 = isVextF8
195
196  private val vs1VecUsed: Vec[UInt] = Mux(widen || isNarrow, vs1GroupedVec, vs1Split.io.outVec64b)
197  private val vs2VecUsed = Wire(Vec(numVecModule, UInt(64.W)))
198  when(vf2 || widen_vs2) {
199    vs2VecUsed := vs2GroupedVec32b
200  }.elsewhen(vf4) {
201    vs2VecUsed := vs2GroupedVec16b
202  }.elsewhen(vf8) {
203    vs2VecUsed := vs2GroupedVec8b
204  }.otherwise {
205    vs2VecUsed := vs2Split.io.outVec64b
206  }
207
208  // mask
209  private val maskDataVec: Vec[UInt] = VecDataToMaskDataVec(srcMask, vsew)
210  private val maskIdx = Mux(isNarrow, (vuopIdx >> 1.U).asUInt, vuopIdx)
211  private val eewVd_is_1b = vdType === 15.U
212  private val maskUsed = splitMask(maskDataVec(maskIdx), Mux(eewVd_is_1b, eewVs1, eewVd))
213
214  private val oldVdUsed = splitMask(VecDataToMaskDataVec(oldVd, vs1Type(1, 0))(vuopIdx), eewVs1)
215
216  vIntFixpAlus.zipWithIndex.foreach {
217    case (mod, i) =>
218      mod.io.opcode := opcode
219
220      mod.io.info.vm := vm
221      mod.io.info.ma := vma
222      mod.io.info.ta := vta
223      mod.io.info.vlmul := vlmul
224      mod.io.info.vl := vl
225      mod.io.info.vstart := vstart
226      mod.io.info.uopIdx := vuopIdx
227      mod.io.info.vxrm := vxrm
228
229      mod.io.srcType(0) := vs2Type
230      mod.io.srcType(1) := vs1Type
231      mod.io.vdType := vdType
232      mod.io.narrow := isNarrow
233      mod.io.isSub := vIntFixpDecode.sub
234      mod.io.isMisc := vIntFixpDecode.misc
235      mod.io.isFixp := isFixp
236      mod.io.widen := widen
237      mod.io.widen_vs2 := widen_vs2
238      mod.io.vs1 := vs1VecUsed(i)
239      mod.io.vs2 := vs2VecUsed(i)
240      mod.io.vmask := maskUsed(i)
241      mod.io.oldVd := oldVdUsed(i)
242  }
243
244  /**
245   * [[mgu]]'s in connection
246   */
247  private val eewVs1S1 = RegNext(eewVs1)
248
249  private val outVd = Cat(vIntFixpAlus.reverse.map(_.io.vd))
250  private val outCmp = Mux1H(eewVs1S1.oneHot, Seq(8, 4, 2, 1).map(
251    k => Cat(vIntFixpAlus.reverse.map(_.io.cmpOut(k - 1, 0)))))
252  private val outNarrow = Cat(vIntFixpAlus.reverse.map(_.io.narrowVd))
253
254  /* insts whose mask is not used to generate 'agnosticEn' and 'keepEn' in mgu:
255   * vadc, vmadc...
256   * vmerge
257   */
258  private val needNoMask = VialuFixType.needNoMask(outCtrl.fuOpType)
259  private val maskToMgu = Mux(needNoMask, allMaskTrue, outSrcMask)
260
261  private val outFormat = VialuFixType.getFormat(outCtrl.fuOpType)
262  private val outWiden = (outFormat === VialuFixType.FMT.VVW | outFormat === VialuFixType.FMT.WVW) & !outVecCtrl.isExt & !outVecCtrl.isDstMask
263  private val narrow = outVecCtrl.isNarrow
264  private val dstMask = outVecCtrl.isDstMask
265
266  private val outEew = Mux(outWiden, outVecCtrl.vsew + 1.U, outVecCtrl.vsew)
267
268  mgu.io.in.vd := MuxCase(outVd, Seq(
269    narrow -> outNarrow,
270    dstMask -> outCmp,
271  ))
272  mgu.io.in.oldVd := outOldVd
273  mgu.io.in.mask := maskToMgu
274  mgu.io.in.info.ta := outVecCtrl.vta
275  mgu.io.in.info.ma := outVecCtrl.vma
276  mgu.io.in.info.vl := outVl
277  mgu.io.in.info.vstart := outVecCtrl.vstart
278  mgu.io.in.info.eew := outEew
279  mgu.io.in.info.vdIdx := outVecCtrl.vuopIdx
280  mgu.io.in.info.narrow := narrow
281  mgu.io.in.info.dstMask := dstMask
282
283  io.out.bits.res.data := mgu.io.out.vd
284  io.out.bits.res.vxsat.get := Cat(vIntFixpAlus.map(_.io.vxsat)).orR
285
286  // util function
287  def splitMask(maskIn: UInt, sew: SewOH): Vec[UInt] = {
288    val maskWidth = maskIn.getWidth
289    val result = Wire(Vec(maskWidth / 8, UInt(8.W)))
290    for ((resultData, i) <- result.zipWithIndex) {
291      resultData := Mux1H(Seq(
292        sew.is8 -> maskIn(i * 8 + 7, i * 8),
293        sew.is16 -> Cat(0.U(4.W), maskIn(i * 4 + 3, i * 4)),
294        sew.is32 -> Cat(0.U(6.W), maskIn(i * 2 + 1, i * 2)),
295        sew.is64 -> Cat(0.U(7.W), maskIn(i)),
296      ))
297    }
298    result
299  }
300
301}