xref: /XiangShan/src/main/scala/xiangshan/backend/fu/wrapper/VIAluFix.scala (revision 195ef4a53ab54326d879e884c4e1568f424f2668)
1package xiangshan.backend.fu.wrapper
2
3import org.chipsalliance.cde.config.Parameters
4import chisel3.{VecInit, _}
5import chisel3.util._
6import chisel3.util.experimental.decode.{QMCMinimizer, TruthTable, decoder}
7import utility.DelayN
8import utils.XSError
9import xiangshan.XSCoreParamsKey
10import xiangshan.backend.fu.vector.Bundles.{VConfig, VSew, ma}
11import xiangshan.backend.fu.vector.{Mgu, Mgtu, VecPipedFuncUnit}
12import xiangshan.backend.fu.vector.Utils.VecDataToMaskDataVec
13import xiangshan.backend.fu.vector.utils.VecDataSplitModule
14import xiangshan.backend.fu.{FuConfig, FuType}
15import xiangshan.ExceptionNO
16import yunsuan.{OpType, VialuFixType}
17import yunsuan.vector.alu.{VIntFixpAlu64b, VIntFixpDecode, VIntFixpTable}
18import yunsuan.encoding.{VdType, Vs1IntType, Vs2IntType}
19import yunsuan.encoding.Opcode.VialuOpcode
20import yunsuan.vector.SewOH
21
22class VIAluSrcTypeIO extends Bundle {
23  val in = Input(new Bundle {
24    val fuOpType: UInt = OpType()
25    val vsew: UInt = VSew()
26    val isReverse: Bool = Bool() // vrsub, vrdiv
27    val isExt: Bool = Bool()
28    val isDstMask: Bool = Bool() // vvm, vvvm, mmm
29    val isMove: Bool = Bool() // vmv.s.x, vmv.v.v, vmv.v.x, vmv.v.i
30  })
31  val out = Output(new Bundle {
32    val vs1Type: UInt = Vs1IntType()
33    val vs2Type: UInt = Vs2IntType()
34    val vdType: UInt = VdType()
35    val illegal: Bool = Bool()
36    val isVextF2: Bool = Bool()
37    val isVextF4: Bool = Bool()
38    val isVextF8: Bool = Bool()
39  })
40}
41
42class VIAluSrcTypeModule extends Module {
43  val io: VIAluSrcTypeIO = IO(new VIAluSrcTypeIO)
44
45  private val vsew = io.in.vsew
46  private val isExt = io.in.isExt
47  private val isDstMask = io.in.isDstMask
48
49  private val opcode = VialuFixType.getOpcode(io.in.fuOpType)
50  private val isSign = VialuFixType.isSigned(io.in.fuOpType)
51  private val format = VialuFixType.getFormat(io.in.fuOpType)
52
53  private val vsewX2 = vsew + 1.U
54  private val vsewF2 = vsew - 1.U
55  private val vsewF4 = vsew - 2.U
56  private val vsewF8 = vsew - 3.U
57
58  private val isAddSub = opcode === VialuOpcode.vadd || opcode === VialuOpcode.vsub
59  private val isShiftRight = Seq(VialuOpcode.vsrl, VialuOpcode.vsra, VialuOpcode.vssrl, VialuOpcode.vssra).map(fmt => fmt === format).reduce(_ || _)
60  private val isVext = opcode === VialuOpcode.vext
61
62  private val isWiden = isAddSub && Seq(VialuFixType.FMT.VVW, VialuFixType.FMT.WVW).map(fmt => fmt === format).reduce(_ || _)
63  private val isNarrow = isShiftRight && format === VialuFixType.FMT.WVV
64  private val isVextF2 = isVext && format === VialuFixType.FMT.VF2
65  private val isVextF4 = isVext && format === VialuFixType.FMT.VF4
66  private val isVextF8 = isVext && format === VialuFixType.FMT.VF8
67
68  // check illegal
69  private val widenIllegal = isWiden && vsewX2 === VSew.e8
70  private val narrowIllegal = isNarrow && vsewF2 === VSew.e64
71  private val vextIllegal = (isVextF2 && (vsewF2 === VSew.e64)) ||
72    (isVextF4 && (vsewF4 === VSew.e64)) ||
73    (isVextF8 && (vsewF8 === VSew.e64))
74  // Todo: use it
75  private val illegal = widenIllegal || narrowIllegal || vextIllegal
76
77  private val intType = Cat(0.U(1.W), isSign)
78
79  private class Vs2Vs1VdSew extends Bundle {
80    val vs2 = VSew()
81    val vs1 = VSew()
82    val vd = VSew()
83  }
84
85  private class Vs2Vs1VdType extends Bundle {
86    val vs2 = Vs2IntType()
87    val vs1 = Vs1IntType()
88    val vd = VdType()
89  }
90
91  private val addSubSews = Mux1H(Seq(
92    (format === VialuFixType.FMT.VVV) -> Cat(vsew, vsew, vsew),
93    (format === VialuFixType.FMT.VVW) -> Cat(vsew, vsew, vsewX2),
94    (format === VialuFixType.FMT.WVW) -> Cat(vsewX2, vsew, vsewX2),
95    (format === VialuFixType.FMT.WVV) -> Cat(vsewX2, vsew, vsew),
96  )).asTypeOf(new Vs2Vs1VdSew)
97
98  private val vextSews = Mux1H(Seq(
99    (format === VialuFixType.FMT.VF2) -> Cat(vsewF2, vsewF2, vsew),
100    (format === VialuFixType.FMT.VF4) -> Cat(vsewF4, vsewF4, vsew),
101    (format === VialuFixType.FMT.VF8) -> Cat(vsewF8, vsewF8, vsew),
102  )).asTypeOf(new Vs2Vs1VdSew)
103
104  private val maskTypes = Mux1H(Seq(
105    (format === VialuFixType.FMT.VVM) -> Cat(Cat(intType, vsew), Cat(intType, vsew), VdType.mask),
106    (format === VialuFixType.FMT.VVMM) -> Cat(Cat(intType, vsew), Cat(intType, vsew), VdType.mask),
107    (format === VialuFixType.FMT.MMM) -> Cat(Vs2IntType.mask, Vs1IntType.mask, VdType.mask),
108  )).asTypeOf(new Vs2Vs1VdType)
109
110  private val vs2Type = Mux1H(Seq(
111    isDstMask -> maskTypes.vs2,
112    isExt -> Cat(intType, vextSews.vs2),
113    (!isExt && !isDstMask) -> Cat(intType, addSubSews.vs2),
114  ))
115  private val vs1Type = Mux1H(Seq(
116    isDstMask -> maskTypes.vs1,
117    isExt -> Cat(intType, vextSews.vs1),
118    (!isExt && !isDstMask) -> Cat(intType, addSubSews.vs1),
119  ))
120  private val vdType = Mux1H(Seq(
121    isDstMask -> maskTypes.vd,
122    isExt -> Cat(intType, vextSews.vd),
123    (!isExt && !isDstMask) -> Cat(intType, addSubSews.vd),
124  ))
125
126  io.out.vs2Type := vs2Type
127  io.out.vs1Type := vs1Type
128  io.out.vdType := vdType
129  io.out.illegal := illegal
130  io.out.isVextF2 := isVextF2
131  io.out.isVextF4 := isVextF4
132  io.out.isVextF8 := isVextF8
133}
134
135class VIAluFix(cfg: FuConfig)(implicit p: Parameters) extends VecPipedFuncUnit(cfg) {
136  XSError(io.in.valid && io.in.bits.ctrl.fuOpType === VialuFixType.dummy, "VialuF OpType not supported")
137
138  // config params
139  private val dataWidth = cfg.destDataBits
140  private val dataWidthOfDataModule = 64
141  private val numVecModule = dataWidth / dataWidthOfDataModule
142
143  // modules
144  private val typeMod = Module(new VIAluSrcTypeModule)
145  private val vs2Split = Module(new VecDataSplitModule(dataWidth, dataWidthOfDataModule))
146  private val vs1Split = Module(new VecDataSplitModule(dataWidth, dataWidthOfDataModule))
147  private val oldVdSplit = Module(new VecDataSplitModule(dataWidth, dataWidthOfDataModule))
148  private val vIntFixpAlus = Seq.fill(numVecModule)(Module(new VIntFixpAlu64b))
149  private val mgu = Module(new Mgu(dataWidth))
150  private val mgtu = Module(new Mgtu(dataWidth))
151
152  /**
153   * [[typeMod]]'s in connection
154   */
155  typeMod.io.in.fuOpType := fuOpType
156  typeMod.io.in.vsew := vsew
157  typeMod.io.in.isReverse := isReverse
158  typeMod.io.in.isExt := isExt
159  typeMod.io.in.isDstMask := vecCtrl.isDstMask
160  typeMod.io.in.isMove := isMove
161
162  private val vs2GroupedVec32b: Vec[UInt] = VecInit(vs2Split.io.outVec32b.zipWithIndex.groupBy(_._2 % 2).map(x => x._1 -> x._2.map(_._1)).values.map(x => Cat(x.reverse)).toSeq)
163  private val vs2GroupedVec16b: Vec[UInt] = VecInit(vs2Split.io.outVec16b.zipWithIndex.groupBy(_._2 % 2).map(x => x._1 -> x._2.map(_._1)).values.map(x => Cat(x.reverse)).toSeq)
164  private val vs2GroupedVec8b : Vec[UInt] = VecInit(vs2Split.io.outVec8b .zipWithIndex.groupBy(_._2 % 2).map(x => x._1 -> x._2.map(_._1)).values.map(x => Cat(x.reverse)).toSeq)
165  private val vs1GroupedVec32b: Vec[UInt] = VecInit(vs1Split.io.outVec32b.zipWithIndex.groupBy(_._2 % 2).map(x => x._1 -> x._2.map(_._1)).values.map(x => Cat(x.reverse)).toSeq)
166  private val vs1GroupedVec16b: Vec[UInt] = VecInit(vs1Split.io.outVec16b.zipWithIndex.groupBy(_._2 % 2).map(x => x._1 -> x._2.map(_._1)).values.map(x => Cat(x.reverse)).toSeq)
167  private val vs1GroupedVec8b : Vec[UInt] = VecInit(vs1Split.io.outVec8b .zipWithIndex.groupBy(_._2 % 2).map(x => x._1 -> x._2.map(_._1)).values.map(x => Cat(x.reverse)).toSeq)
168
169  /**
170   * In connection of [[vs2Split]], [[vs1Split]] and [[oldVdSplit]]
171   */
172  vs2Split.io.inVecData := vs2
173  vs1Split.io.inVecData := vs1
174  oldVdSplit.io.inVecData := oldVd
175
176  /**
177   * [[vIntFixpAlus]]'s in connection
178   */
179  private val opcode = VialuFixType.getOpcode(inCtrl.fuOpType).asTypeOf(vIntFixpAlus.head.io.opcode)
180  private val vs1Type = typeMod.io.out.vs1Type
181  private val vs2Type = typeMod.io.out.vs2Type
182  private val vdType = typeMod.io.out.vdType
183  private val isVextF2 = typeMod.io.out.isVextF2
184  private val isVextF4 = typeMod.io.out.isVextF4
185  private val isVextF8 = typeMod.io.out.isVextF8
186
187  private val truthTable = TruthTable(VIntFixpTable.table, VIntFixpTable.default)
188  private val decoderOut = decoder(QMCMinimizer, Cat(opcode.op), truthTable)
189  private val vIntFixpDecode = decoderOut.asTypeOf(new VIntFixpDecode)
190  private val isFixp = Mux(vIntFixpDecode.misc, opcode.isScalingShift, opcode.isSatAdd || opcode.isAvgAdd)
191  private val widen = opcode.isAddSub && vs1Type(1, 0) =/= vdType(1, 0)
192  private val widen_vs2 = widen && vs2Type(1, 0) =/= vdType(1, 0)
193  private val eewVs1 = SewOH(vs1Type(1, 0))
194  private val eewVd = SewOH(vdType(1, 0))
195  private val isVwsll = opcode.isVwsll
196
197  // Extension instructions
198  private val vf2 = isVextF2
199  private val vf4 = isVextF4
200  private val vf8 = isVextF8
201
202  private val vs1VecUsed: Vec[UInt] = Wire(Vec(numVecModule, UInt(64.W)))
203  private val vs2VecUsed: Vec[UInt] = Wire(Vec(numVecModule, UInt(64.W)))
204  private val isVwsllEewVdIs64 = isVwsll && eewVd.is64
205  private val isVwsllEewVdIs32 = isVwsll && eewVd.is32
206  private val isVwsllEewVdIs16 = isVwsll && eewVd.is16
207  when(widen || isNarrow || isVwsllEewVdIs64) {
208    vs1VecUsed := vs1GroupedVec32b
209  }.elsewhen(isVwsllEewVdIs32) {
210    vs1VecUsed := vs1GroupedVec16b
211  }.elsewhen(isVwsllEewVdIs16) {
212    vs1VecUsed := vs1GroupedVec8b
213  }.otherwise {
214    vs1VecUsed := vs1Split.io.outVec64b
215  }
216  when(vf2 || isVwsllEewVdIs64) {
217    vs2VecUsed := vs2GroupedVec32b
218  }.elsewhen(vf4 || isVwsllEewVdIs32) {
219    vs2VecUsed := vs2GroupedVec16b
220  }.elsewhen(vf8 || isVwsllEewVdIs16) {
221    vs2VecUsed := vs2GroupedVec8b
222  }.otherwise {
223    vs2VecUsed := vs2Split.io.outVec64b
224  }
225
226  private val vs2Adder = Mux(widen_vs2, vs2GroupedVec32b, vs2Split.io.outVec64b)
227
228  // mask
229  private val maskDataVec: Vec[UInt] = VecDataToMaskDataVec(srcMask, vsew)
230  private val maskIdx = Mux(isNarrow, (vuopIdx >> 1.U).asUInt, vuopIdx)
231  private val eewVd_is_1b = vdType === VdType.mask
232  private val maskUsed = splitMask(maskDataVec(maskIdx), Mux(eewVd_is_1b, eewVs1, eewVd))
233
234  private val oldVdUsed = splitMask(VecDataToMaskDataVec(oldVd, vs1Type(1, 0))(vuopIdx), eewVs1)
235
236  vIntFixpAlus.zipWithIndex.foreach {
237    case (mod, i) =>
238      mod.io.fire := io.in.valid
239      mod.io.opcode := opcode
240
241      mod.io.info.vm := vm
242      mod.io.info.ma := vma
243      mod.io.info.ta := vta
244      mod.io.info.vlmul := vlmul
245      mod.io.info.vl := vl
246      mod.io.info.vstart := vstart
247      mod.io.info.uopIdx := vuopIdx
248      mod.io.info.vxrm := vxrm
249
250      mod.io.srcType(0) := vs2Type
251      mod.io.srcType(1) := vs1Type
252      mod.io.vdType := vdType
253      mod.io.narrow := isNarrow
254      mod.io.isSub := vIntFixpDecode.sub
255      mod.io.isMisc := vIntFixpDecode.misc
256      mod.io.isFixp := isFixp
257      mod.io.widen := widen
258      mod.io.widen_vs2 := widen_vs2
259      mod.io.vs1 := vs1VecUsed(i)
260      mod.io.vs2_adder := vs2Adder(i)
261      mod.io.vs2_misc := vs2VecUsed(i)
262      mod.io.vmask := maskUsed(i)
263      mod.io.oldVd := oldVdUsed(i)
264  }
265
266  /**
267   * [[mgu]]'s in connection
268   */
269  private val outIsVwsll = RegEnable(isVwsll, io.in.valid)
270  private val outIsVwsllEewVdIs64 = RegEnable(isVwsllEewVdIs64, io.in.valid)
271  private val outIsVwsllEewVdIs32 = RegEnable(isVwsllEewVdIs32, io.in.valid)
272  private val outIsVwsllEewVdIs16 = RegEnable(isVwsllEewVdIs16, io.in.valid)
273  //private val outEewVs1 = DelayN(eewVs1, latency)
274  private val outEewVs1 = SNReg(eewVs1, latency)
275
276  private val outVdTmp = Cat(vIntFixpAlus.reverse.map(_.io.vd))
277  private val outVd = Mux1H(Seq(
278    (outIsVwsllEewVdIs64 || !outIsVwsll) -> outVdTmp,
279    outIsVwsllEewVdIs32 -> Cat(outVdTmp(127,  96), outVdTmp(63, 32), outVdTmp( 95, 64), outVdTmp(31,  0)),
280    outIsVwsllEewVdIs16 -> Cat(outVdTmp(127, 112), outVdTmp(63, 48), outVdTmp(111, 96), outVdTmp(47, 32), outVdTmp(95, 80), outVdTmp(31, 16), outVdTmp(79, 64), outVdTmp(15,0)),
281  ))
282  private val outCmp = Mux1H(outEewVs1.oneHot, Seq(8, 4, 2, 1).map(
283    k => Cat(vIntFixpAlus.reverse.map(_.io.cmpOut(k - 1, 0)))))
284  private val outNarrow = Cat(vIntFixpAlus.reverse.map(_.io.narrowVd))
285  private val outOpcode = VialuFixType.getOpcode(outCtrl.fuOpType).asTypeOf(vIntFixpAlus.head.io.opcode)
286
287  private val numBytes = dataWidth / 8
288  private val maxMaskIdx = numBytes
289  private val maxVdIdx = 8
290  private val elementsInOneUop = Mux1H(outEewVs1.oneHot, Seq(1, 2, 4, 8).map(k => (numBytes / k).U(5.W)))
291  private val vdIdx = outVecCtrl.vuopIdx(2, 0)
292  private val elementsComputed = Mux1H(Seq.tabulate(maxVdIdx)(i => (vdIdx === i.U) -> (elementsInOneUop * i.U)))
293  val outCmpWithTail = Wire(Vec(maxMaskIdx, UInt(1.W)))
294  // set the bits in vd to 1 if the index is larger than vl and vta is true
295  for (i <- 0 until maxMaskIdx) {
296    when(elementsComputed +& i.U >= outVl) {
297      // always operate under a tail-agnostic policy
298      outCmpWithTail(i) := 1.U
299    }.otherwise {
300      outCmpWithTail(i) := outCmp(i)
301    }
302  }
303
304  /* insts whose mask is not used to generate 'agnosticEn' and 'activeEn' in mgu:
305   * vadc, vmadc...
306   * vmerge
307   */
308  private val needNoMask = VialuFixType.needNoMask(outCtrl.fuOpType)
309  private val maskToMgu = Mux(needNoMask, allMaskTrue, outSrcMask)
310
311  private val outFormat = VialuFixType.getFormat(outCtrl.fuOpType)
312  private val outWiden = (outFormat === VialuFixType.FMT.VVW | outFormat === VialuFixType.FMT.WVW) & !outVecCtrl.isExt & !outVecCtrl.isDstMask
313  private val narrow = outVecCtrl.isNarrow
314  private val dstMask = outVecCtrl.isDstMask
315  private val outVxsat = Mux(narrow, Cat(vIntFixpAlus.reverse.map(_.io.vxsat(3, 0))), Cat(vIntFixpAlus.reverse.map(_.io.vxsat)))
316
317  // the result of narrow inst which needs concat
318  private val narrowNeedCat = outVecCtrl.vuopIdx(0).asBool && narrow
319  private val outNarrowVd = Mux(narrowNeedCat, Cat(outNarrow, outOldVd(dataWidth / 2 - 1, 0)), Cat(outOldVd(dataWidth - 1, dataWidth / 2), outNarrow))
320  private val outVxsatReal = Mux(narrowNeedCat, Cat(outVxsat(numBytes / 2 - 1, 0), 0.U((numBytes / 2).W)), outVxsat)
321
322  private val outEew = Mux(outWiden, outVecCtrl.vsew + 1.U, outVecCtrl.vsew)
323
324  /*
325   * vl of vmv.x.s is 1
326   */
327  private val outIsVmvsx = outOpcode.isVmvsx
328
329  /*
330   * when vstart >= vl, no need to update vd, the old value should be kept
331   */
332  private val outVstartGeVl = outVstart >= outVl
333
334  mgu.io.in.vd := MuxCase(outVd, Seq(
335    narrow -> outNarrowVd,
336    dstMask -> outCmpWithTail.asUInt,
337  ))
338  mgu.io.in.oldVd := outOldVd
339  mgu.io.in.mask := maskToMgu
340  mgu.io.in.info.ta := outVecCtrl.vta
341  mgu.io.in.info.ma := outVecCtrl.vma
342  mgu.io.in.info.vl := Mux(outIsVmvsx, 1.U, outVl)
343  mgu.io.in.info.vlmul := outVecCtrl.vlmul
344  mgu.io.in.info.valid := validVec.last
345  mgu.io.in.info.vstart := outVecCtrl.vstart
346  mgu.io.in.info.eew := outEew
347  mgu.io.in.info.vsew := outVecCtrl.vsew
348  mgu.io.in.info.vdIdx := outVecCtrl.vuopIdx
349  mgu.io.in.info.narrow := narrow
350  mgu.io.in.info.dstMask := dstMask
351  mgu.io.in.isIndexedVls := false.B
352
353  /**
354   * [[mgtu]]'s in connection, for vmask instructions
355   */
356  mgtu.io.in.vd := Mux(dstMask && !outVecCtrl.isOpMask, mgu.io.out.vd, outVd)
357  mgtu.io.in.vl := outVl
358
359  io.out.bits.res.data := Mux(outVstartGeVl, outOldVd, Mux(dstMask, mgtu.io.out.vd, mgu.io.out.vd))
360  io.out.bits.res.vxsat.get := Mux(outVstartGeVl, false.B, (outVxsatReal & mgu.io.out.active).orR)
361  io.out.bits.ctrl.exceptionVec.get(ExceptionNO.illegalInstr) := mgu.io.out.illegal && !outVstartGeVl
362
363  // util function
364  def splitMask(maskIn: UInt, sew: SewOH): Vec[UInt] = {
365    val maskWidth = maskIn.getWidth
366    val result = Wire(Vec(maskWidth / 8, UInt(8.W)))
367    for ((resultData, i) <- result.zipWithIndex) {
368      resultData := Mux1H(Seq(
369        sew.is8 -> maskIn(i * 8 + 7, i * 8),
370        sew.is16 -> Cat(0.U((8 - 4).W), maskIn(i * 4 + 3, i * 4)),
371        sew.is32 -> Cat(0.U((8 - 2).W), maskIn(i * 2 + 1, i * 2)),
372        sew.is64 -> Cat(0.U((8 - 1).W), maskIn(i)),
373      ))
374    }
375    result
376  }
377
378}