xref: /XiangShan/src/main/scala/device/AXI4RAM.scala (revision bd53bc375131ff94680cab81a6f71d853243a1d5)
1package device
2
3import chipsalliance.rocketchip.config.Parameters
4import chisel3._
5import chisel3.util._
6import freechips.rocketchip.diplomacy.{AddressSet, LazyModule, LazyModuleImp, RegionType}
7import xiangshan.HasXSParameter
8import utils.{MaskExpand}
9
10class RAMHelper(memByte: BigInt) extends BlackBox with HasXSParameter {
11  val io = IO(new Bundle {
12    val clk   = Input(Clock())
13    val en    = Input(Bool())
14    val rIdx  = Input(UInt(DataBits.W))
15    val rdata = Output(UInt(DataBits.W))
16    val wIdx  = Input(UInt(DataBits.W))
17    val wdata = Input(UInt(DataBits.W))
18    val wmask = Input(UInt(DataBits.W))
19    val wen   = Input(Bool())
20  })
21}
22
23class AXI4RAM
24(
25  address: Seq[AddressSet],
26  memByte: Long,
27  useBlackBox: Boolean = false,
28  executable: Boolean = true,
29  beatBytes: Int = 8,
30  burstLen: Int = 16
31)(implicit p: Parameters)
32  extends AXI4SlaveModule(address, executable, beatBytes, burstLen)
33{
34
35  override lazy val module = new AXI4SlaveModuleImp(this){
36
37    val split = beatBytes / 8
38    val bankByte = memByte / split
39    val offsetBits = log2Up(memByte)
40    val offsetMask = (1 << offsetBits) - 1
41
42    def index(addr: UInt) = (((addr - 0x80000000L.U) & offsetMask.U) >> log2Ceil(beatBytes)).asUInt()
43
44    def inRange(idx: UInt) = idx < (memByte / beatBytes).U
45
46    val wIdx = index(waddr) + writeBeatCnt
47    val rIdx = index(raddr) + readBeatCnt
48    val wen = in.w.fire() && inRange(wIdx)
49    require(beatBytes >= 8)
50
51    val rdata = if (useBlackBox) {
52      val mems = (0 until split).map {_ => Module(new RAMHelper(bankByte))}
53      mems.zipWithIndex map { case (mem, i) =>
54        mem.io.clk   := clock
55        mem.io.en    := !reset.asBool()
56        mem.io.rIdx  := (rIdx << log2Up(split)) + i.U
57        mem.io.wIdx  := (wIdx << log2Up(split)) + i.U
58        mem.io.wdata := in.w.bits.data((i + 1) * 64 - 1, i * 64)
59        mem.io.wmask := MaskExpand(in.w.bits.strb((i + 1) * 8 - 1, i * 8))
60        mem.io.wen   := wen
61      }
62      val rdata = mems.map {mem => mem.io.rdata}
63      Cat(rdata.reverse)
64    } else {
65      val mem = Mem(memByte / beatBytes, Vec(beatBytes, UInt(8.W)))
66
67      val wdata = VecInit.tabulate(beatBytes) { i => in.w.bits.data(8 * (i + 1) - 1, 8 * i) }
68      when(wen) {
69        mem.write(wIdx, wdata, in.w.bits.strb.asBools())
70      }
71
72      Cat(mem.read(rIdx).reverse)
73    }
74    in.r.bits.data := rdata
75  }
76}
77