xref: /XiangShan/src/main/scala/xiangshan/backend/fu/vector/ByteMaskTailGen.scala (revision bb2f3f51dd67f6e16e0cc1ffe43368c9fc7e4aef)
1package xiangshan.backend.fu.vector
2
3import org.chipsalliance.cde.config.Parameters
4import chisel3._
5import chisel3.util._
6import xiangshan.backend.fu.vector.Bundles.{VSew, Vl}
7import xiangshan.backend.fu.vector.utils.{MaskExtractor, UIntToContLow0s, UIntToContLow1s}
8import utility.XSDebug
9import yunsuan.vector.SewOH
10import yunsuan.util.LookupTree
11
12
13class ByteMaskTailGenIO(vlen: Int)(implicit p: Parameters) extends Bundle {
14  private val numBytes = vlen / 8
15  private val maxVLMUL = 8
16  private val maxVLMAX = 8 * 16 // TODO: parameterize this
17  private val elemIdxWidth = log2Up(maxVLMAX + 1)
18  println(s"elemIdxWidth: $elemIdxWidth")
19
20  val in = Input(new Bundle {
21    val begin = UInt(elemIdxWidth.W)
22    val end = UInt(elemIdxWidth.W)
23    val vma = Bool()
24    val vta = Bool()
25    val vsew = VSew()
26    val maskUsed = UInt(numBytes.W)
27    val vdIdx = UInt(3.W)
28  })
29  val out = Output(new Bundle {
30    val activeEn   = UInt(numBytes.W)
31    val agnosticEn = UInt(numBytes.W)
32  })
33  val debugOnly = Output(new Bundle {
34    val startBytes = UInt()
35    val vlBytes = UInt()
36    val prestartEn = UInt()
37    val bodyEn = UInt()
38    val tailEn = UInt()
39    val maskEn = UInt()
40    val maskAgnosticEn = UInt()
41    val tailAgnosticEn = UInt()
42    val agnosticEn = UInt()
43  })
44}
45
46class ByteMaskTailGen(vlen: Int)(implicit p: Parameters) extends Module {
47  require(isPow2(vlen))
48
49  private val numBytes = vlen / 8
50  private val byteWidth = log2Up(numBytes) // vlen=128, numBytes=16, byteWidth=log2(16)=4
51  private val maxVLMUL = 8
52  private val maxVLMAX = 8 * 16 // TODO: parameterize this
53  private val elemIdxWidth = log2Up(maxVLMAX + 1)
54
55  println(s"numBytes: ${numBytes}, byteWidth: ${byteWidth}")
56
57  val io = IO(new ByteMaskTailGenIO(vlen))
58
59  private val eewOH = SewOH(io.in.vsew).oneHot
60
61  private val startBytes = Mux1H(eewOH, Seq.tabulate(4)(x => io.in.begin(elemIdxWidth - 1 - x, 0) << x)).asUInt
62  private val vlBytes    = Mux1H(eewOH, Seq.tabulate(4)(x => io.in.end(elemIdxWidth - 1 - x, 0) << x)).asUInt
63  private val vdIdx      = io.in.vdIdx
64
65  private val prestartEn = UIntToContLow1s(startBytes, maxVLMAX)
66  private val bodyEn = UIntToContLow0s(startBytes, maxVLMAX) & UIntToContLow1s(vlBytes, maxVLMAX)
67  private val tailEn = UIntToContLow0s(vlBytes, maxVLMAX)
68  private val prestartEnInVd = LookupTree(vdIdx, (0 until maxVLMUL).map(i => i.U -> prestartEn((i+1)*numBytes - 1, i*numBytes)))
69  private val bodyEnInVd = LookupTree(vdIdx, (0 until maxVLMUL).map(i => i.U -> bodyEn((i+1)*numBytes - 1, i*numBytes)))
70  private val tailEnInVd = LookupTree(vdIdx, (0 until maxVLMUL).map(i => i.U -> tailEn((i+1)*numBytes - 1, i*numBytes)))
71
72  private val maskEn = MaskExtractor(vlen)(io.in.maskUsed, io.in.vsew)
73  private val maskOffEn = (~maskEn).asUInt
74  private val maskAgnosticEn = Mux(io.in.vma, maskOffEn, 0.U) & bodyEnInVd
75
76  private val tailAgnosticEn = Mux(io.in.vta, tailEnInVd, 0.U)
77
78  private val activeEn = Mux(io.in.begin >= io.in.end, 0.U(numBytes.W), bodyEnInVd & maskEn)
79  private val agnosticEn = Mux(io.in.begin >= io.in.end, 0.U(numBytes.W), maskAgnosticEn | tailAgnosticEn)
80
81  // TODO: delete me later
82  dontTouch(eewOH)
83  dontTouch(startBytes)
84  dontTouch(vlBytes)
85  dontTouch(vdIdx)
86  dontTouch(prestartEn)
87  dontTouch(bodyEn)
88  dontTouch(tailEn)
89  dontTouch(prestartEnInVd)
90  dontTouch(bodyEnInVd)
91  dontTouch(tailEnInVd)
92  dontTouch(maskEn)
93  dontTouch(maskOffEn)
94  dontTouch(maskAgnosticEn)
95  dontTouch(tailAgnosticEn)
96
97  io.out.activeEn := activeEn
98  io.out.agnosticEn := agnosticEn
99
100  io.debugOnly.startBytes := startBytes
101  io.debugOnly.vlBytes := vlBytes
102  io.debugOnly.prestartEn := prestartEnInVd
103  io.debugOnly.bodyEn := bodyEn
104  io.debugOnly.tailEn := tailEnInVd
105  io.debugOnly.maskEn := maskEn
106  io.debugOnly.maskAgnosticEn := maskAgnosticEn
107  io.debugOnly.tailAgnosticEn := tailAgnosticEn
108  io.debugOnly.agnosticEn := agnosticEn
109}
110
111