xref: /XiangShan/src/main/scala/xiangshan/mem/vector/VecCommon.scala (revision 82674533125d3d049f50148b1d9e215e1463f136)
1/***************************************************************************************
2  * Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3  * Copyright (c) 2020-2021 Peng Cheng Laboratory
4  *
5  * XiangShan is licensed under Mulan PSL v2.
6  * You can use this software according to the terms and conditions of the Mulan PSL v2.
7  * You may obtain a copy of Mulan PSL v2 at:
8  *          http://license.coscl.org.cn/MulanPSL2
9  *
10  * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11  * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12  * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13  *
14  * See the Mulan PSL v2 for more details.
15  ***************************************************************************************/
16
17package xiangshan.mem
18
19import org.chipsalliance.cde.config.Parameters
20import chisel3._
21import chisel3.util._
22import utils._
23import utility._
24import xiangshan._
25import xiangshan.backend.rob.RobPtr
26import xiangshan.backend.Bundles._
27import xiangshan.backend.fu.FuType
28
29/**
30  * Common used parameters or functions in vlsu
31  */
32trait VLSUConstants {
33  val VLEN = 128
34  //for pack unit-stride flow
35  val AlignedNum = 4 // 1/2/4/8
36  def VLENB = VLEN/8
37  def vOffsetBits = log2Up(VLENB) // bits-width to index offset inside a vector reg
38  lazy val vlmBindexBits = 8 //will be overrided later
39  lazy val vsmBindexBits = 8 // will be overrided later
40
41  def alignTypes = 5 // eew/sew = 1/2/4/8, last indicate 128 bit element
42  def alignTypeBits = log2Up(alignTypes)
43  def maxMUL = 8
44  def maxFields = 8
45  /**
46    * In the most extreme cases like a segment indexed instruction, eew=64, emul=8, sew=8, lmul=1,
47    * and nf=8, each data reg is mapped with 8 index regs and there are 8 data regs in total,
48    * each for a field. Therefore an instruction can be divided into 64 uops at most.
49    */
50  def maxUopNum = maxMUL * maxFields // 64
51  def maxFlowNum = 16
52  def maxElemNum = maxMUL * maxFlowNum // 128
53  // def uopIdxBits = log2Up(maxUopNum) // to index uop inside an robIdx
54  def elemIdxBits = log2Up(maxElemNum) + 1 // to index which element in an instruction
55  def flowIdxBits = log2Up(maxFlowNum) + 1 // to index which flow in a uop
56  def fieldBits = log2Up(maxFields) + 1 // 4-bits to indicate 1~8
57
58  def ewBits = 3 // bits-width of EEW/SEW
59  def mulBits = 3 // bits-width of emul/lmul
60
61  def getSlice(data: UInt, i: Int, alignBits: Int): UInt = {
62    require(data.getWidth >= (i+1) * alignBits)
63    data((i+1) * alignBits - 1, i * alignBits)
64  }
65  def getNoAlignedSlice(data: UInt, i: Int, alignBits: Int): UInt = {
66    data(i * 8 + alignBits - 1, i * 8)
67  }
68
69  def getByte(data: UInt, i: Int = 0) = getSlice(data, i, 8)
70  def getHalfWord(data: UInt, i: Int = 0) = getSlice(data, i, 16)
71  def getWord(data: UInt, i: Int = 0) = getSlice(data, i, 32)
72  def getDoubleWord(data: UInt, i: Int = 0) = getSlice(data, i, 64)
73  def getDoubleDoubleWord(data: UInt, i: Int = 0) = getSlice(data, i, 128)
74}
75
76trait HasVLSUParameters extends HasXSParameter with VLSUConstants {
77  override val VLEN = coreParams.VLEN
78  override lazy val vlmBindexBits = log2Up(coreParams.VlMergeBufferSize)
79  override lazy val vsmBindexBits = log2Up(coreParams.VsMergeBufferSize)
80  def isUnitStride(instType: UInt) = instType(1, 0) === "b00".U
81  def isStrided(instType: UInt) = instType(1, 0) === "b10".U
82  def isIndexed(instType: UInt) = instType(0) === "b1".U
83  def isNotIndexed(instType: UInt) = instType(0) === "b0".U
84  def isSegment(instType: UInt) = instType(2) === "b1".U
85  def is128Bit(alignedType: UInt) = alignedType(2) === "b1".U
86
87  def mergeDataWithMask(oldData: UInt, newData: UInt, mask: UInt): Vec[UInt] = {
88    require(oldData.getWidth == newData.getWidth)
89    require(oldData.getWidth == mask.getWidth * 8)
90    VecInit(mask.asBools.zipWithIndex.map { case (en, i) =>
91      Mux(en, getByte(newData, i), getByte(oldData, i))
92    })
93  }
94
95  // def asBytes(data: UInt) = {
96  //   require(data.getWidth % 8 == 0)
97  //   (0 until data.getWidth/8).map(i => getByte(data, i))
98  // }
99
100  def mergeDataWithElemIdx(
101    oldData: UInt,
102    newData: Seq[UInt],
103    alignedType: UInt,
104    elemIdx: Seq[UInt],
105    valids: Seq[Bool]
106  ): UInt = {
107    require(newData.length == elemIdx.length)
108    require(newData.length == valids.length)
109    LookupTree(alignedType, List(
110      "b00".U -> VecInit(elemIdx.map(e => UIntToOH(e(3, 0)).asBools).transpose.zipWithIndex.map { case (selVec, i) =>
111        ParallelPosteriorityMux(
112          true.B +: selVec.zip(valids).map(x => x._1 && x._2),
113          getByte(oldData, i) +: newData.map(getByte(_))
114        )}).asUInt,
115      "b01".U -> VecInit(elemIdx.map(e => UIntToOH(e(2, 0)).asBools).transpose.zipWithIndex.map { case (selVec, i) =>
116        ParallelPosteriorityMux(
117          true.B +: selVec.zip(valids).map(x => x._1 && x._2),
118          getHalfWord(oldData, i) +: newData.map(getHalfWord(_))
119        )}).asUInt,
120      "b10".U -> VecInit(elemIdx.map(e => UIntToOH(e(1, 0)).asBools).transpose.zipWithIndex.map { case (selVec, i) =>
121        ParallelPosteriorityMux(
122          true.B +: selVec.zip(valids).map(x => x._1 && x._2),
123          getWord(oldData, i) +: newData.map(getWord(_))
124        )}).asUInt,
125      "b11".U -> VecInit(elemIdx.map(e => UIntToOH(e(0)).asBools).transpose.zipWithIndex.map { case (selVec, i) =>
126        ParallelPosteriorityMux(
127          true.B +: selVec.zip(valids).map(x => x._1 && x._2),
128          getDoubleWord(oldData, i) +: newData.map(getDoubleWord(_))
129        )}).asUInt
130    ))
131  }
132
133  def mergeDataWithElemIdx(oldData: UInt, newData: UInt, alignedType: UInt, elemIdx: UInt): UInt = {
134    mergeDataWithElemIdx(oldData, Seq(newData), alignedType, Seq(elemIdx), Seq(true.B))
135  }
136  /**
137    * for merge 128-bits data of unit-stride
138    */
139  object mergeDataByByte{
140    def apply(oldData: UInt, newData: UInt, mask: UInt): UInt = {
141      val selVec = Seq(mask).map(_.asBools).transpose
142      VecInit(selVec.zipWithIndex.map{ case (selV, i) =>
143        ParallelPosteriorityMux(
144          true.B +: selV.map(x => x),
145          getByte(oldData, i) +: Seq(getByte(newData, i))
146        )}).asUInt
147    }
148  }
149
150  /**
151    * for merge Unit-Stride data to 256-bits
152    * merge 128-bits data to 256-bits
153    * if have 3 port,
154    *   if is port0, it is 6 to 1 Multiplexer -> (128'b0, data) or (data, 128'b0) or (data, port2data) or (port2data, data) or (data, port3data) or (port3data, data)
155    *   if is port1, it is 4 to 1 Multiplexer -> (128'b0, data) or (data, 128'b0) or (data, port3data) or (port3data, data)
156    *   if is port3, it is 2 to 1 Multiplexer -> (128'b0, data) or (data, 128'b0)
157    *
158    */
159  object mergeDataByIndex{
160    def apply(data:  Seq[UInt], mask: Seq[UInt], index: UInt, valids: Seq[Bool]): (UInt, UInt) = {
161      require(data.length == valids.length)
162      require(data.length == mask.length)
163      val muxLength = data.length
164      val selDataMatrix = Wire(Vec(muxLength, Vec(2, UInt((VLEN * 2).W)))) // 3 * 2 * 256
165      val selMaskMatrix = Wire(Vec(muxLength, Vec(2, UInt((VLENB * 2).W)))) // 3 * 2 * 16
166      dontTouch(selDataMatrix)
167      dontTouch(selMaskMatrix)
168      for(i <- 0 until muxLength){
169        if(i == 0){
170          selDataMatrix(i)(0) := Cat(0.U(VLEN.W), data(i))
171          selDataMatrix(i)(1) := Cat(data(i), 0.U(VLEN.W))
172          selMaskMatrix(i)(0) := Cat(0.U(VLENB.W), mask(i))
173          selMaskMatrix(i)(1) := Cat(mask(i), 0.U(VLENB.W))
174        }
175        else{
176          selDataMatrix(i)(0) := Cat(data(i), data(0))
177          selDataMatrix(i)(1) := Cat(data(0), data(i))
178          selMaskMatrix(i)(0) := Cat(mask(i), mask(0))
179          selMaskMatrix(i)(1) := Cat(mask(0), mask(i))
180        }
181      }
182      val selIdxVec = (0 until muxLength).map(_.U)
183      val selIdx    = PriorityMux(valids.reverse, selIdxVec.reverse)
184
185      val selData = Mux(index === 0.U,
186                        selDataMatrix(selIdx)(0),
187                        selDataMatrix(selIdx)(1))
188      val selMask = Mux(index === 0.U,
189                        selMaskMatrix(selIdx)(0),
190                        selMaskMatrix(selIdx)(1))
191      (selData, selMask)
192    }
193  }
194  def mergeDataByIndex(data:  UInt, mask: UInt, index: UInt): (UInt, UInt) = {
195    mergeDataByIndex(Seq(data), Seq(mask), index, Seq(true.B))
196  }
197}
198abstract class VLSUModule(implicit p: Parameters) extends XSModule
199  with HasVLSUParameters
200  with HasCircularQueuePtrHelper
201abstract class VLSUBundle(implicit p: Parameters) extends XSBundle
202  with HasVLSUParameters
203
204class VLSUBundleWithMicroOp(implicit p: Parameters) extends VLSUBundle {
205  val uop = new DynInst
206}
207
208class OnlyVecExuOutput(implicit p: Parameters) extends VLSUBundle {
209  val isvec = Bool()
210  val vecdata = UInt(VLEN.W)
211  val mask = UInt(VLENB.W)
212  // val rob_idx_valid = Vec(2, Bool())
213  // val inner_idx = Vec(2, UInt(3.W))
214  // val rob_idx = Vec(2, new RobPtr)
215  // val offset = Vec(2, UInt(4.W))
216  val reg_offset = UInt(vOffsetBits.W)
217  val vecActive = Bool() // 1: vector active element, 0: vector not active element
218  val is_first_ele = Bool()
219  val elemIdx = UInt(elemIdxBits.W) // element index
220  val elemIdxInsideVd = UInt(elemIdxBits.W) // element index in scope of vd
221  // val uopQueuePtr = new VluopPtr
222  // val flowPtr = new VlflowPtr
223}
224
225class VecExuOutput(implicit p: Parameters) extends MemExuOutput with HasVLSUParameters {
226  val vec = new OnlyVecExuOutput
227  val alignedType       = UInt(alignTypeBits.W)
228   // feedback
229  val vecFeedback       = Bool()
230}
231
232// class VecStoreExuOutput(implicit p: Parameters) extends MemExuOutput with HasVLSUParameters {
233//   val elemIdx = UInt(elemIdxBits.W)
234//   val uopQueuePtr = new VsUopPtr
235//   val fieldIdx = UInt(fieldBits.W)
236//   val segmentIdx = UInt(elemIdxBits.W)
237//   val vaddr = UInt(VAddrBits.W)
238//   // pack
239//   val isPackage         = Bool()
240//   val packageNum        = UInt((log2Up(VLENB) + 1).W)
241//   val originAlignedType = UInt(alignTypeBits.W)
242//   val alignedType       = UInt(alignTypeBits.W)
243// }
244
245class VecUopBundle(implicit p: Parameters) extends VLSUBundleWithMicroOp {
246  val flowMask       = UInt(VLENB.W) // each bit for a flow
247  val byteMask       = UInt(VLENB.W) // each bit for a byte
248  val data           = UInt(VLEN.W)
249  // val fof            = Bool() // fof is only used for vector loads
250  val excp_eew_index = UInt(elemIdxBits.W)
251  // val exceptionVec   = ExceptionVec() // uop has exceptionVec
252  val baseAddr = UInt(VAddrBits.W)
253  val stride = UInt(VLEN.W)
254  val flow_counter = UInt(flowIdxBits.W)
255
256  // instruction decode result
257  val flowNum = UInt(flowIdxBits.W) // # of flows in a uop
258  // val flowNumLog2 = UInt(log2Up(flowIdxBits).W) // log2(flowNum), for better timing of multiplication
259  val nfields = UInt(fieldBits.W) // NFIELDS
260  val vm = Bool() // whether vector masking is enabled
261  val usWholeReg = Bool() // unit-stride, whole register load
262  val usMaskReg = Bool() // unit-stride, masked store/load
263  val eew = UInt(ewBits.W) // size of memory elements
264  val sew = UInt(ewBits.W)
265  val emul = UInt(mulBits.W)
266  val lmul = UInt(mulBits.W)
267  val vlmax = UInt(elemIdxBits.W)
268  val instType = UInt(3.W)
269  val vd_last_uop = Bool()
270  val vd_first_uop = Bool()
271}
272
273class VecFlowBundle(implicit p: Parameters) extends VLSUBundleWithMicroOp {
274  val vaddr             = UInt(VAddrBits.W)
275  val mask              = UInt(VLENB.W)
276  val alignedType       = UInt(alignTypeBits.W)
277  val vecActive         = Bool()
278  val elemIdx           = UInt(elemIdxBits.W)
279  val is_first_ele      = Bool()
280
281  // pack
282  val isPackage         = Bool()
283  val packageNum        = UInt((log2Up(VLENB) + 1).W)
284  val originAlignedType = UInt(alignTypeBits.W)
285}
286
287class VecMemExuOutput(isVector: Boolean = false)(implicit p: Parameters) extends VLSUBundle{
288  val output = new MemExuOutput(isVector)
289  val vecFeedback = Bool()
290  val mmio = Bool()
291  val usSecondInv = Bool()
292  val elemIdx = UInt(elemIdxBits.W)
293  val alignedType = UInt(alignTypeBits.W)
294  val mbIndex     = UInt(vsmBindexBits.W)
295  val mask        = UInt(VLENB.W)
296  val vaddr       = UInt(VAddrBits.W)
297}
298
299object MulNum {
300  def apply (mul: UInt): UInt = { //mul means emul or lmul
301    (LookupTree(mul,List(
302      "b101".U -> 1.U , // 1/8
303      "b110".U -> 1.U , // 1/4
304      "b111".U -> 1.U , // 1/2
305      "b000".U -> 1.U , // 1
306      "b001".U -> 2.U , // 2
307      "b010".U -> 4.U , // 4
308      "b011".U -> 8.U   // 8
309    )))}
310}
311/**
312  * when emul is greater than or equal to 1, this means the entire register needs to be written;
313  * otherwise, only write the specified number of bytes */
314object MulDataSize {
315  def apply (mul: UInt): UInt = { //mul means emul or lmul
316    (LookupTree(mul,List(
317      "b101".U -> 2.U  , // 1/8
318      "b110".U -> 4.U  , // 1/4
319      "b111".U -> 8.U  , // 1/2
320      "b000".U -> 16.U , // 1
321      "b001".U -> 16.U , // 2
322      "b010".U -> 16.U , // 4
323      "b011".U -> 16.U   // 8
324    )))}
325}
326
327object OneRegNum {
328  def apply (eew: UInt): UInt = { //mul means emul or lmul
329    (LookupTree(eew,List(
330      "b000".U -> 16.U , // 1
331      "b101".U -> 8.U , // 2
332      "b110".U -> 4.U , // 4
333      "b111".U -> 2.U   // 8
334    )))}
335}
336
337//index inst read data byte
338object SewDataSize {
339  def apply (sew: UInt): UInt = {
340    (LookupTree(sew,List(
341      "b000".U -> 1.U , // 1
342      "b001".U -> 2.U , // 2
343      "b010".U -> 4.U , // 4
344      "b011".U -> 8.U   // 8
345    )))}
346}
347
348// strided inst read data byte
349object EewDataSize {
350  def apply (eew: UInt): UInt = {
351    (LookupTree(eew,List(
352      "b000".U -> 1.U , // 1
353      "b101".U -> 2.U , // 2
354      "b110".U -> 4.U , // 4
355      "b111".U -> 8.U   // 8
356    )))}
357}
358
359object loadDataSize {
360  def apply (instType: UInt, emul: UInt, eew: UInt, sew: UInt): UInt = {
361    (LookupTree(instType,List(
362      "b000".U ->  MulDataSize(emul), // unit-stride
363      "b010".U ->  EewDataSize(eew)  , // strided
364      "b001".U ->  SewDataSize(sew)  , // indexed-unordered
365      "b011".U ->  SewDataSize(sew)  , // indexed-ordered
366      "b100".U ->  EewDataSize(eew)  , // segment unit-stride
367      "b110".U ->  EewDataSize(eew)  , // segment strided
368      "b101".U ->  SewDataSize(sew)  , // segment indexed-unordered
369      "b111".U ->  SewDataSize(sew)    // segment indexed-ordered
370    )))}
371}
372
373object storeDataSize {
374  def apply (instType: UInt, eew: UInt, sew: UInt): UInt = {
375    (LookupTree(instType,List(
376      "b000".U ->  EewDataSize(eew)  , // unit-stride, do not use
377      "b010".U ->  EewDataSize(eew)  , // strided
378      "b001".U ->  SewDataSize(sew)  , // indexed-unordered
379      "b011".U ->  SewDataSize(sew)  , // indexed-ordered
380      "b100".U ->  EewDataSize(eew)  , // segment unit-stride
381      "b110".U ->  EewDataSize(eew)  , // segment strided
382      "b101".U ->  SewDataSize(sew)  , // segment indexed-unordered
383      "b111".U ->  SewDataSize(sew)    // segment indexed-ordered
384    )))}
385}
386
387object GenVecStoreMask {
388  def apply (instType: UInt, eew: UInt, sew: UInt): UInt = {
389    val mask = Wire(UInt(16.W))
390    mask := UIntToOH(storeDataSize(instType = instType, eew = eew, sew = sew)) - 1.U
391    mask
392  }
393}
394
395/**
396  * these are used to obtain immediate addresses for  index instruction */
397object EewEq8 {
398  def apply(index:UInt, flow_inner_idx: UInt): UInt = {
399    (LookupTree(flow_inner_idx,List(
400      0.U  -> index(7 ,0   ),
401      1.U  -> index(15,8   ),
402      2.U  -> index(23,16  ),
403      3.U  -> index(31,24  ),
404      4.U  -> index(39,32  ),
405      5.U  -> index(47,40  ),
406      6.U  -> index(55,48  ),
407      7.U  -> index(63,56  ),
408      8.U  -> index(71,64  ),
409      9.U  -> index(79,72  ),
410      10.U -> index(87,80  ),
411      11.U -> index(95,88  ),
412      12.U -> index(103,96 ),
413      13.U -> index(111,104),
414      14.U -> index(119,112),
415      15.U -> index(127,120)
416    )))}
417}
418
419object EewEq16 {
420  def apply(index: UInt, flow_inner_idx: UInt): UInt = {
421    (LookupTree(flow_inner_idx, List(
422      0.U -> index(15, 0),
423      1.U -> index(31, 16),
424      2.U -> index(47, 32),
425      3.U -> index(63, 48),
426      4.U -> index(79, 64),
427      5.U -> index(95, 80),
428      6.U -> index(111, 96),
429      7.U -> index(127, 112)
430    )))}
431}
432
433object EewEq32 {
434  def apply(index: UInt, flow_inner_idx: UInt): UInt = {
435    (LookupTree(flow_inner_idx, List(
436      0.U -> index(31, 0),
437      1.U -> index(63, 32),
438      2.U -> index(95, 64),
439      3.U -> index(127, 96)
440    )))}
441}
442
443object EewEq64 {
444  def apply (index: UInt, flow_inner_idx: UInt): UInt = {
445    (LookupTree(flow_inner_idx, List(
446      0.U -> index(63, 0),
447      1.U -> index(127, 64)
448    )))}
449}
450
451object IndexAddr {
452  def apply (index: UInt, flow_inner_idx: UInt, eew: UInt): UInt = {
453    (LookupTree(eew,List(
454      "b000".U -> EewEq8 (index = index, flow_inner_idx = flow_inner_idx ), // Imm is 1 Byte // TODO: index maybe cross register
455      "b101".U -> EewEq16(index = index, flow_inner_idx = flow_inner_idx ), // Imm is 2 Byte
456      "b110".U -> EewEq32(index = index, flow_inner_idx = flow_inner_idx ), // Imm is 4 Byte
457      "b111".U -> EewEq64(index = index, flow_inner_idx = flow_inner_idx )  // Imm is 8 Byte
458    )))}
459}
460
461object Log2Num {
462  def apply (num: UInt): UInt = {
463    (LookupTree(num,List(
464      16.U -> 4.U,
465      8.U  -> 3.U,
466      4.U  -> 2.U,
467      2.U  -> 1.U,
468      1.U  -> 0.U
469    )))}
470}
471
472object GenUopIdxInField {
473  /**
474   * Used in normal vector instruction
475   * */
476  def apply (instType: UInt, emul: UInt, lmul: UInt, uopIdx: UInt): UInt = {
477    val isIndexed = instType(0)
478    val mulInField = Mux(
479      isIndexed,
480      Mux(lmul.asSInt > emul.asSInt, lmul, emul),
481      emul
482    )
483    LookupTree(mulInField, List(
484      "b101".U -> 0.U,
485      "b110".U -> 0.U,
486      "b111".U -> 0.U,
487      "b000".U -> 0.U,
488      "b001".U -> uopIdx(0),
489      "b010".U -> uopIdx(1, 0),
490      "b011".U -> uopIdx(2, 0)
491    ))
492  }
493  /**
494   *  Only used in segment instruction.
495   * */
496  def apply (select: UInt, uopIdx: UInt): UInt = {
497    LookupTree(select, List(
498      "b101".U -> 0.U,
499      "b110".U -> 0.U,
500      "b111".U -> 0.U,
501      "b000".U -> 0.U,
502      "b001".U -> uopIdx(0),
503      "b010".U -> uopIdx(1, 0),
504      "b011".U -> uopIdx(2, 0)
505    ))
506  }
507}
508
509//eew decode
510object EewLog2 extends VLSUConstants {
511  // def apply (eew: UInt): UInt = {
512  //   (LookupTree(eew,List(
513  //     "b000".U -> "b000".U , // 1
514  //     "b101".U -> "b001".U , // 2
515  //     "b110".U -> "b010".U , // 4
516  //     "b111".U -> "b011".U   // 8
517  //   )))}
518  def apply(eew: UInt): UInt = ZeroExt(eew(1, 0), ewBits)
519}
520
521/**
522  * unit-stride instructions don't use this method;
523  * other instructions generate realFlowNum by EmulDataSize >> eew(1,0),
524  * EmulDataSize means the number of bytes that need to be written to the register,
525  * eew(1,0) means the number of bytes written at once*/
526object GenRealFlowNum {
527  def apply (instType: UInt, emul: UInt, lmul: UInt, eew: UInt, sew: UInt): UInt = {
528    require(instType.getWidth == 3, "The instType width must be 3, (isSegment, mop)")
529    (LookupTree(instType,List(
530      "b000".U ->  (MulDataSize(emul) >> eew(1,0)).asUInt, // store use, load do not use
531      "b010".U ->  (MulDataSize(emul) >> eew(1,0)).asUInt, // strided
532      "b001".U ->  Mux(emul.asSInt > lmul.asSInt, (MulDataSize(emul) >> eew(1,0)).asUInt, (MulDataSize(lmul) >> sew(1,0)).asUInt), // indexed-unordered
533      "b011".U ->  Mux(emul.asSInt > lmul.asSInt, (MulDataSize(emul) >> eew(1,0)).asUInt, (MulDataSize(lmul) >> sew(1,0)).asUInt), // indexed-ordered
534      "b100".U ->  (MulDataSize(emul) >> eew(1,0)).asUInt, // segment unit-stride
535      "b110".U ->  (MulDataSize(emul) >> eew(1,0)).asUInt, // segment strided
536      "b101".U ->  Mux(emul.asSInt > lmul.asSInt, (MulDataSize(emul) >> eew(1,0)).asUInt, (MulDataSize(lmul) >> sew(1,0)).asUInt), // segment indexed-unordered
537      "b111".U ->  Mux(emul.asSInt > lmul.asSInt, (MulDataSize(emul) >> eew(1,0)).asUInt, (MulDataSize(lmul) >> sew(1,0)).asUInt)  // segment indexed-ordered
538    )))}
539}
540
541/**
542  * GenRealFlowLog2 = Log2(GenRealFlowNum)
543  */
544object GenRealFlowLog2 extends VLSUConstants {
545  def apply(instType: UInt, emul: UInt, lmul: UInt, eew: UInt, sew: UInt): UInt = {
546    require(instType.getWidth == 3, "The instType width must be 3, (isSegment, mop)")
547    val emulLog2 = Mux(emul.asSInt >= 0.S, 0.U, emul)
548    val lmulLog2 = Mux(lmul.asSInt >= 0.S, 0.U, lmul)
549    val eewRealFlowLog2 = emulLog2 + log2Up(VLENB).U - eew(1, 0)
550    val sewRealFlowLog2 = lmulLog2 + log2Up(VLENB).U - sew(1, 0)
551    (LookupTree(instType, List(
552      "b000".U -> eewRealFlowLog2, // unit-stride
553      "b010".U -> eewRealFlowLog2, // strided
554      "b001".U -> Mux(emul.asSInt > lmul.asSInt, eewRealFlowLog2, sewRealFlowLog2), // indexed-unordered
555      "b011".U -> Mux(emul.asSInt > lmul.asSInt, eewRealFlowLog2, sewRealFlowLog2), // indexed-ordered
556      "b100".U -> eewRealFlowLog2, // segment unit-stride
557      "b110".U -> eewRealFlowLog2, // segment strided
558      "b101".U -> Mux(emul.asSInt > lmul.asSInt, eewRealFlowLog2, sewRealFlowLog2), // segment indexed-unordered
559      "b111".U -> Mux(emul.asSInt > lmul.asSInt, eewRealFlowLog2, sewRealFlowLog2), // segment indexed-ordered
560    )))
561  }
562}
563
564/**
565  * GenElemIdx generals an element index within an instruction, given a certain uopIdx and a known flowIdx
566  * inside the uop.
567  */
568object GenElemIdx extends VLSUConstants {
569  def apply(instType: UInt, emul: UInt, lmul: UInt, eew: UInt, sew: UInt,
570    uopIdx: UInt, flowIdx: UInt): UInt = {
571    val isIndexed = instType(0).asBool
572    val eewUopFlowsLog2 = Mux(emul.asSInt > 0.S, 0.U, emul) + log2Up(VLENB).U - eew(1, 0)
573    val sewUopFlowsLog2 = Mux(lmul.asSInt > 0.S, 0.U, lmul) + log2Up(VLENB).U - sew(1, 0)
574    val uopFlowsLog2 = Mux(
575      isIndexed,
576      Mux(emul.asSInt > lmul.asSInt, eewUopFlowsLog2, sewUopFlowsLog2),
577      eewUopFlowsLog2
578    )
579    LookupTree(uopFlowsLog2, List(
580      0.U -> uopIdx,
581      1.U -> uopIdx ## flowIdx(0),
582      2.U -> uopIdx ## flowIdx(1, 0),
583      3.U -> uopIdx ## flowIdx(2, 0),
584      4.U -> uopIdx ## flowIdx(3, 0)
585    ))
586  }
587}
588
589/**
590  * GenVLMAX calculates VLMAX, which equals MUL * ew
591  */
592object GenVLMAXLog2 extends VLSUConstants {
593  def apply(lmul: UInt, sew: UInt): UInt = lmul + log2Up(VLENB).U - sew
594}
595object GenVLMAX {
596  def apply(lmul: UInt, sew: UInt): UInt = 1.U << GenVLMAXLog2(lmul, sew)
597}
598/**
599 * generate mask base on vlmax
600 * example: vlmax = b100, max = b011
601 * */
602object GenVlMaxMask{
603  def apply(vlmax: UInt, length: Int): UInt = (vlmax - 1.U)(length-1, 0)
604}
605
606object GenUSWholeRegVL extends VLSUConstants {
607  def apply(nfields: UInt, eew: UInt): UInt = {
608    LookupTree(eew(1, 0), List(
609      "b00".U -> (nfields << (log2Up(VLENB) - 0)),
610      "b01".U -> (nfields << (log2Up(VLENB) - 1)),
611      "b10".U -> (nfields << (log2Up(VLENB) - 2)),
612      "b11".U -> (nfields << (log2Up(VLENB) - 3))
613    ))
614  }
615}
616object GenUSWholeEmul extends VLSUConstants{
617  def apply(nf: UInt): UInt={
618    LookupTree(nf,List(
619      "b000".U -> "b000".U(mulBits.W),
620      "b001".U -> "b001".U(mulBits.W),
621      "b011".U -> "b010".U(mulBits.W),
622      "b111".U -> "b011".U(mulBits.W)
623    ))
624  }
625}
626
627
628object GenUSMaskRegVL extends VLSUConstants {
629  def apply(vl: UInt): UInt = {
630    Mux(vl(2,0) === 0.U , (vl >> 3.U), ((vl >> 3.U) + 1.U))
631  }
632}
633
634object GenUopByteMask {
635  def apply(flowMask: UInt, alignedType: UInt): UInt = {
636    LookupTree(alignedType, List(
637      "b000".U -> flowMask,
638      "b001".U -> FillInterleaved(2, flowMask),
639      "b010".U -> FillInterleaved(4, flowMask),
640      "b011".U -> FillInterleaved(8, flowMask),
641      "b100".U -> FillInterleaved(16, flowMask)
642    ))
643  }
644}
645
646object GenVdIdxInField extends VLSUConstants {
647  def apply(instType: UInt, emul: UInt, lmul: UInt, uopIdx: UInt): UInt = {
648    val vdIdx = Wire(UInt(log2Up(maxMUL).W))
649    when (instType(1,0) === "b00".U || instType(1,0) === "b10".U || lmul.asSInt > emul.asSInt) {
650      // Unit-stride or Strided, or indexed with lmul >= emul
651      vdIdx := uopIdx
652    }.otherwise {
653      // Indexed with lmul <= emul
654      val multiple = emul - lmul
655      val uopIdxWidth = uopIdx.getWidth
656      vdIdx := LookupTree(multiple, List(
657        0.U -> uopIdx,
658        1.U -> (uopIdx >> 1),
659        2.U -> (uopIdx >> 2),
660        3.U -> (uopIdx >> 3)
661      ))
662    }
663    vdIdx
664  }
665}
666/**
667* Use start and vl to generate flow activative mask
668* mod = true fill 0
669* mod = false fill 1
670*/
671object GenFlowMask extends VLSUConstants {
672  def apply(elementMask: UInt, start: UInt, vl: UInt , mod: Boolean): UInt = {
673    val startMask = ~UIntToMask(start, VLEN)
674    val vlMask = UIntToMask(vl, VLEN)
675    val maskVlStart = vlMask & startMask
676    if(mod){
677      elementMask & maskVlStart
678    }
679    else{
680      (~elementMask).asUInt & maskVlStart
681    }
682  }
683}
684
685object CheckAligned extends VLSUConstants {
686  def apply(addr: UInt): UInt = {
687    val aligned_16 = (addr(0) === 0.U) // 16-bit
688    val aligned_32 = (addr(1,0) === 0.U) // 32-bit
689    val aligned_64 = (addr(2,0) === 0.U) // 64-bit
690    val aligned_128 = (addr(3,0) === 0.U) // 128-bit
691    Cat(true.B, aligned_16, aligned_32, aligned_64, aligned_128)
692  }
693}
694
695/**
696  search if mask have continue 'len' bit '1'
697  mask: source mask
698  len: search length
699*/
700object GenPackMask{
701  def leadX(mask: Seq[Bool], len: Int): Bool = {
702    if(len == 1){
703      mask.head
704    }
705    else{
706      leadX(mask.drop(1),len-1) & mask.head
707    }
708  }
709  def leadOneVec(shiftMask: Seq[Bool]): UInt = {
710    // max is 64-bit, so the max num of flow to pack is 8
711
712    val lead1 = leadX(shiftMask, 1) // continue 1 bit
713    val lead2 = leadX(shiftMask, 2) // continue 2 bit
714    val lead4 = leadX(shiftMask, 4) // continue 4 bit
715    val lead8 = leadX(shiftMask, 8) // continue 8 bit
716    val lead16 = leadX(shiftMask, 16) // continue 16 bit
717    Cat(lead1, lead2, lead4, lead8, lead16)
718  }
719
720  def apply(shiftMask: UInt) = {
721    // pack mask
722    val packMask = leadOneVec(shiftMask.asBools)
723    packMask
724  }
725}
726/**
727PackEnable = (LeadXVec >> eew) & alignedVec, where the 0th bit represents the ability to merge into a 64 bit flow, the second bit represents the ability to merge into a 32 bit flow, and so on.
728
729example:
730  addr = 0x0, activeMask = b00011100101111, flowIdx = 0, eew = 0(8-bit)
731
732  step 0 : addrAlignedVec = (1, 1, 1, 1) elemIdxAligned = (1, 1, 1, 1)
733  step 1 : activePackVec = (1, 1, 1, 0), inactivePackVec = (0, 0, 0, 0)
734  step 2 : activePackEnable = (1, 1, 1, 0), inactivePackVec = (0, 0, 0, 0)
735
736  we can package 4 8-bit activative flows into a 32-bit flow.
737*/
738object GenPackVec extends VLSUConstants{
739  def apply(addr: UInt, shiftMask: UInt, eew: UInt, elemIdx: UInt): UInt = {
740    val addrAlignedVec = CheckAligned(addr)
741    val elemIdxAligned = CheckAligned(elemIdx)
742    val packMask = GenPackMask(shiftMask)
743    // generate packVec
744    val packVec = addrAlignedVec & elemIdxAligned & (packMask.asUInt >> eew)
745
746    packVec
747  }
748}
749
750object GenPackAlignedType extends VLSUConstants{
751  def apply(packVec: UInt): UInt = {
752    val packAlignedType = PriorityMux(Seq(
753      packVec(0) -> "b100".U,
754      packVec(1) -> "b011".U,
755      packVec(2) -> "b010".U,
756      packVec(3) -> "b001".U,
757      packVec(4) -> "b000".U
758    ))
759    packAlignedType
760  }
761}
762
763object GenPackNum extends VLSUConstants{
764  def apply(alignedType: UInt, packAlignedType: UInt): UInt = {
765    (1.U << (packAlignedType - alignedType)).asUInt
766  }
767}
768
769object genVWmask128 {
770  def apply(addr: UInt, sizeEncode: UInt): UInt = {
771    (LookupTree(sizeEncode, List(
772      "b000".U -> 0x1.U, //0001 << addr(2:0)
773      "b001".U -> 0x3.U, //0011
774      "b010".U -> 0xf.U, //1111
775      "b011".U -> 0xff.U, //11111111
776      "b100".U -> 0xffff.U //1111111111111111
777    )) << addr(3, 0)).asUInt
778  }
779}
780/*
781* only use in max length is 128
782*/
783object genVWdata {
784  def apply(data: UInt, sizeEncode: UInt): UInt = {
785    LookupTree(sizeEncode, List(
786      "b000".U -> Fill(16, data(7, 0)),
787      "b001".U -> Fill(8, data(15, 0)),
788      "b010".U -> Fill(4, data(31, 0)),
789      "b011".U -> Fill(2, data(63,0)),
790      "b100".U -> data(127,0)
791    ))
792  }
793}
794
795object genUSSplitAddr{
796  def apply(addr: UInt, index: UInt): UInt = {
797    val tmpAddr = Cat(addr(38, 4), 0.U(4.W))
798    val nextCacheline = tmpAddr + 16.U
799    LookupTree(index, List(
800      0.U -> tmpAddr,
801      1.U -> nextCacheline
802    ))
803  }
804}
805
806object genUSSplitMask{
807  def apply(mask: UInt, index: UInt, addrOffset: UInt): UInt = {
808    val tmpMask = Cat(0.U(16.W),mask) << addrOffset // 32-bits
809    LookupTree(index, List(
810      0.U -> tmpMask(15, 0),
811      1.U -> tmpMask(31, 16),
812    ))
813  }
814}
815
816object genUSSplitData{
817  def apply(data: UInt, index: UInt, addrOffset: UInt): UInt = {
818    val tmpData = WireInit(0.U(256.W))
819    val lookupTable = (0 until 16).map{case i =>
820      if(i == 0){
821        i.U -> Cat(0.U(128.W), data)
822      }else{
823        i.U -> Cat(0.U(((16-i)*8).W), data, 0.U((i*8).W))
824      }
825    }
826    tmpData := LookupTree(addrOffset, lookupTable).asUInt
827
828    LookupTree(index, List(
829      0.U -> tmpData(127, 0),
830      1.U -> tmpData(255, 128)
831    ))
832  }
833}
834
835object genVSData extends VLSUConstants {
836  def apply(data: UInt, elemIdx: UInt, alignedType: UInt): UInt = {
837    LookupTree(alignedType, List(
838      "b000".U -> ZeroExt(LookupTree(elemIdx(3, 0), List.tabulate(VLEN/8)(i => i.U -> getByte(data, i))), VLEN),
839      "b001".U -> ZeroExt(LookupTree(elemIdx(2, 0), List.tabulate(VLEN/16)(i => i.U -> getHalfWord(data, i))), VLEN),
840      "b010".U -> ZeroExt(LookupTree(elemIdx(1, 0), List.tabulate(VLEN/32)(i => i.U -> getWord(data, i))), VLEN),
841      "b011".U -> ZeroExt(LookupTree(elemIdx(0), List.tabulate(VLEN/64)(i => i.U -> getDoubleWord(data, i))), VLEN),
842      "b100".U -> data // if have wider element, it will broken
843    ))
844  }
845}
846
847// TODO: more elegant
848object genVStride extends VLSUConstants {
849  def apply(uopIdx: UInt, stride: UInt): UInt = {
850    LookupTree(uopIdx, List(
851      0.U -> 0.U,
852      1.U -> stride,
853      2.U -> (stride << 1),
854      3.U -> ((stride << 1).asUInt + stride),
855      4.U -> (stride << 2),
856      5.U -> ((stride << 2).asUInt + stride),
857      6.U -> ((stride << 2).asUInt + (stride << 1)),
858      7.U -> ((stride << 2).asUInt + (stride << 1) + stride)
859    ))
860  }
861}
862/**
863 * generate uopOffset, not used in segment instruction
864 * */
865object genVUopOffset extends VLSUConstants {
866  def apply(instType: UInt, isfof: Bool, uopidx: UInt, nf: UInt, eew: UInt, stride: UInt, alignedType: UInt): UInt = {
867    val uopInsidefield = (uopidx >> nf).asUInt // when nf == 0, is uopidx
868
869    val fofVUopOffset = (LookupTree(instType,List(
870      "b000".U -> ( genVStride(uopInsidefield, stride) << (log2Up(VLENB).U - eew)   ) , // unit-stride fof
871      "b100".U -> ( genVStride(uopInsidefield, stride) << (log2Up(VLENB).U - eew)   ) , // segment unit-stride fof
872    ))).asUInt
873
874    val otherVUopOffset = (LookupTree(instType,List(
875      "b000".U -> ( uopInsidefield << alignedType                                   ) , // unit-stride
876      "b010".U -> ( genVStride(uopInsidefield, stride) << (log2Up(VLENB).U - eew)   ) , // strided
877      "b001".U -> ( 0.U                                                             ) , // indexed-unordered
878      "b011".U -> ( 0.U                                                             ) , // indexed-ordered
879      "b100".U -> ( uopInsidefield << alignedType                                   ) , // segment unit-stride
880      "b110".U -> ( genVStride(uopInsidefield, stride) << (log2Up(VLENB).U - eew)   ) , // segment strided
881      "b101".U -> ( 0.U                                                             ) , // segment indexed-unordered
882      "b111".U -> ( 0.U                                                             )   // segment indexed-ordered
883    ))).asUInt
884
885    Mux(isfof, fofVUopOffset, otherVUopOffset)
886  }
887}
888
889
890
891object searchVFirstUnMask extends VLSUConstants {
892  def apply(mask: UInt): UInt = {
893    require(mask.getWidth == 16, "The mask width must be 16")
894    val select = (0 until 16).zip(mask.asBools).map{case (i, v) =>
895      (v, i.U)
896    }
897    PriorityMuxDefault(select, 0.U)
898  }
899
900  def apply(mask: UInt, regOffset: UInt): UInt = {
901    require(mask.getWidth == 16, "The mask width must be 16")
902    val realMask = (mask >> regOffset).asUInt
903    val select = (0 until 16).zip(realMask.asBools).map{case (i, v) =>
904      (v, i.U)
905    }
906    PriorityMuxDefault(select, 0.U)
907  }
908}
909
910