xref: /XiangShan/src/main/scala/xiangshan/backend/fu/vector/Mgu.scala (revision 1a6cfb3dfdd5ad577761384bb9b026f24a9fd496)
1/****************************************************************************************
2 * Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3 * Copyright (c) 2020-2021 Peng Cheng Laboratory
4 *
5 * XiangShan is licensed under Mulan PSL v2.
6 * You can use this software according to the terms and conditions of the Mulan PSL v2.
7 * You may obtain a copy of Mulan PSL v2 at:
8 *          http://license.coscl.org.cn/MulanPSL2
9 *
10 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13 *
14 * See the Mulan PSL v2 for more details.
15 ****************************************************************************************
16 */
17
18
19package xiangshan.backend.fu.vector
20
21import chipsalliance.rocketchip.config.Parameters
22import chisel3._
23import chisel3.util._
24import chiseltest._
25import chiseltest.ChiselScalatestTester
26import org.scalatest.flatspec.AnyFlatSpec
27import org.scalatest.matchers.must.Matchers
28import top.{ArgParser, BaseConfig, DefaultConfig}
29import xiangshan._
30import xiangshan.backend.fu.vector.Bundles.{VSew, Vl}
31import xiangshan.backend.fu.vector.Utils.VecDataToMaskDataVec
32import yunsuan.vector._
33
34class Mgu(vlen: Int)(implicit p: Parameters) extends  Module {
35  private val numBytes = vlen / 8
36  private val byteWidth = log2Up(numBytes)
37
38  val io = IO(new MguIO(vlen))
39
40  val in = io.in
41  val out = io.out
42  val info = in.info
43  val vd = in.vd
44  val oldVd = in.oldVd
45  val narrow = io.in.info.narrow
46
47  private val vdIdx = Mux(narrow, info.vdIdx(2, 1), info.vdIdx)
48
49  private val maskTailGen = Module(new ByteMaskTailGen(vlen))
50
51  private val eewOH = SewOH(info.eew).oneHot
52
53  private val vstartMapVdIdx = elemIdxMapVdIdx(info.vstart)(2, 0) // 3bits 0~7
54  private val vlMapVdIdx = elemIdxMapVdIdx(info.vl)(3, 0)         // 4bits 0~8
55  private val uvlMax = numBytes.U >> info.eew
56  private val maskDataVec: Vec[UInt] = VecDataToMaskDataVec(in.mask, info.eew)
57  private val maskUsed = maskDataVec(vdIdx)
58
59  maskTailGen.io.in.begin := Mux1H(Seq(
60    (vstartMapVdIdx < vdIdx) -> 0.U,
61    (vstartMapVdIdx === vdIdx) -> elemIdxMapUElemIdx(info.vstart),
62    (vstartMapVdIdx > vdIdx) -> uvlMax,
63  ))
64  maskTailGen.io.in.end := Mux1H(Seq(
65    (vlMapVdIdx < vdIdx) -> 0.U,
66    (vlMapVdIdx === vdIdx) -> elemIdxMapUElemIdx(info.vl),
67    (vlMapVdIdx > vdIdx) -> uvlMax,
68  ))
69  maskTailGen.io.in.vma := info.ma
70  maskTailGen.io.in.vta := info.ta
71  maskTailGen.io.in.vsew := info.eew
72  maskTailGen.io.in.maskUsed := maskUsed
73
74  private val keepEn = maskTailGen.io.out.keepEn
75  private val agnosticEn = maskTailGen.io.out.agnosticEn
76
77  // the result of normal inst and narrow inst which does not need concat
78  private val byte1s: UInt = (~0.U(8.W)).asUInt
79
80  private val resVecByte = Wire(Vec(numBytes, UInt(8.W)))
81  private val vdVecByte = vd.asTypeOf(resVecByte)
82  private val oldVdVecByte = oldVd.asTypeOf(resVecByte)
83
84  for (i <- 0 until numBytes) {
85    resVecByte(i) := MuxCase(oldVdVecByte(i), Seq(
86      keepEn(i) -> vdVecByte(i),
87      agnosticEn(i) -> byte1s,
88    ))
89  }
90
91  // the result of narrow inst which needs concat
92  private val narrowNeedCat = info.vdIdx(0).asBool & narrow
93  private val narrowResCat = Cat(resVecByte.asUInt(vlen / 2 - 1, 0), oldVd(vlen / 2 - 1, 0))
94
95  // the result of mask-generating inst
96  private val maxVdIdx = 8
97  private val meaningfulBitsSeq = Seq(16, 8, 4, 2)
98  private val allPossibleResBit = Wire(Vec(4, Vec(maxVdIdx, UInt(vlen.W))))
99  private val catData = Mux(info.ta, ~0.U(vlen.W), oldVd)
100
101  for (sew <- 0 to 3) {
102    if (sew == 0) {
103      allPossibleResBit(sew)(maxVdIdx - 1) := Cat(vd(meaningfulBitsSeq(sew) - 1, 0),
104        oldVd(meaningfulBitsSeq(sew) * (maxVdIdx - 1) - 1, 0))
105    } else {
106      allPossibleResBit(sew)(maxVdIdx - 1) := Cat(catData(vlen - 1, meaningfulBitsSeq(sew) * maxVdIdx),
107        vd(meaningfulBitsSeq(sew) - 1, 0), oldVd(meaningfulBitsSeq(sew) * (maxVdIdx - 1) - 1, 0))
108    }
109    for (i <- 1 until maxVdIdx - 1) {
110      allPossibleResBit(sew)(i) := Cat(catData(vlen - 1, meaningfulBitsSeq(sew) * (i + 1)),
111        vd(meaningfulBitsSeq(sew) - 1, 0), oldVd(meaningfulBitsSeq(sew) * i - 1, 0))
112    }
113    allPossibleResBit(sew)(0) := Cat(catData(vlen - 1, meaningfulBitsSeq(sew)), vd(meaningfulBitsSeq(sew) - 1, 0))
114  }
115
116  private val resVecBit = allPossibleResBit(info.eew)(vdIdx)
117
118  io.out.vd := MuxCase(resVecByte.asUInt, Seq(
119    info.dstMask -> resVecBit.asUInt,
120    narrowNeedCat -> narrowResCat,
121  ))
122  io.out.keep := keepEn
123
124  io.debugOnly.vstartMapVdIdx := vstartMapVdIdx
125  io.debugOnly.vlMapVdIdx := vlMapVdIdx
126  io.debugOnly.begin := maskTailGen.io.in.begin
127  io.debugOnly.end := maskTailGen.io.in.end
128  io.debugOnly.keepEn := keepEn
129  io.debugOnly.agnosticEn := agnosticEn
130  def elemIdxMapVdIdx(elemIdx: UInt) = {
131    require(elemIdx.getWidth >= log2Up(vlen))
132    // 3 = log2(8)
133    Mux1H(eewOH, Seq.tabulate(eewOH.getWidth)(x => elemIdx(byteWidth - x + 3, byteWidth - x)))
134  }
135
136  def elemIdxMapUElemIdx(elemIdx: UInt) = {
137    Mux1H(eewOH, Seq.tabulate(eewOH.getWidth)(x => elemIdx(byteWidth - x - 1, 0)))
138  }
139}
140
141
142class MguIO(vlen: Int)(implicit p: Parameters) extends Bundle {
143  val in = new Bundle {
144    val vd = Input(UInt(vlen.W))
145    val oldVd = Input(UInt(vlen.W))
146    val mask = Input(UInt(vlen.W))
147    val info = Input(new VecInfo)
148  }
149  val out = new Bundle {
150    val vd = Output(UInt(vlen.W))
151    val keep = Output(UInt((vlen / 8).W))
152  }
153  val debugOnly = Output(new Bundle {
154    val vstartMapVdIdx = UInt()
155    val vlMapVdIdx = UInt()
156    val begin = UInt()
157    val end = UInt()
158    val keepEn = UInt()
159    val agnosticEn = UInt()
160  })
161}
162
163class VecInfo(implicit p: Parameters) extends Bundle {
164  val ta = Bool()
165  val ma = Bool()
166  val vl = Vl()
167  val vstart = Vl()
168  val eew = VSew()
169  val vdIdx = UInt(3.W) // 0~7
170  val narrow = Bool()
171  val dstMask = Bool()
172}
173
174object VerilogMgu extends App {
175  println("Generating the Mgu hardware")
176  val (config, firrtlOpts, firrtlComplier, firtoolOpts) = ArgParser.parse(args)
177  val p = config.alterPartial({case XSCoreParamsKey => config(XSTileKey).head})
178
179  emitVerilog(new Mgu(128)(p), Array("--target-dir", "build/vifu", "--full-stacktrace"))
180}
181
182class MguTest extends AnyFlatSpec with ChiselScalatestTester with Matchers {
183
184  val defaultConfig = (new DefaultConfig).alterPartial({
185    case XSCoreParamsKey => XSCoreParameters()
186  })
187
188  println("test start")
189
190  behavior of "Mgu"
191  it should "run" in {
192    test(new Mgu(128)(defaultConfig)).withAnnotations(Seq(VerilatorBackendAnnotation)) {
193      m: Mgu =>
194        m.io.in.vd.poke("h8765_4321_8765_4321_8765_4321_8765_4321".U)
195        m.io.in.oldVd.poke("h7777_7777_7777_7777_7777_7777_7777_7777".U)
196        m.io.in.mask.poke("h0000_0000_0000_0000_0000_0000_ffff_0000".U)
197        m.io.in.info.ta.poke(true.B)
198        m.io.in.info.ma.poke(false.B)
199        m.io.in.info.vl.poke((16 + 7).U)
200        m.io.in.info.vstart.poke((16 + 2).U)
201        m.io.in.info.eew.poke(VSew.e8)
202        m.io.in.info.vdIdx.poke(1.U)
203
204        println("out.vd: " + m.io.out.vd.peek().litValue.toString(16))
205        println("debugOnly.vstartMapVdIdx: " + m.io.debugOnly.vstartMapVdIdx.peek().litValue.toString(16))
206        println("debugOnly.vlMapVdIdx: "     + m.io.debugOnly.vlMapVdIdx.peek().litValue.toString(16))
207        println("debugOnly.begin: "          + m.io.debugOnly.begin.peek().litValue)
208        println("debugOnly.end: "            + m.io.debugOnly.end.peek().litValue)
209        println("debugOnly.keepEn: "         + m.io.debugOnly.keepEn.peek().litValue.toString(2))
210        println("debugOnly.agnosticEn: "     + m.io.debugOnly.agnosticEn.peek().litValue.toString(2))
211    }
212    println("test done")
213  }
214}