xref: /XiangShan/src/main/scala/xiangshan/mem/MaskedDataModule.scala (revision deb6421e9ab9b7980dc6c429456fc7bd2161357b)
1/***************************************************************************************
2* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3* Copyright (c) 2020-2021 Peng Cheng Laboratory
4*
5* XiangShan is licensed under Mulan PSL v2.
6* You can use this software according to the terms and conditions of the Mulan PSL v2.
7* You may obtain a copy of Mulan PSL v2 at:
8*          http://license.coscl.org.cn/MulanPSL2
9*
10* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13*
14* See the Mulan PSL v2 for more details.
15***************************************************************************************/
16
17package xiangshan.mem
18
19import chisel3._
20import chisel3.util._
21import xiangshan._
22import utils._
23import xiangshan.cache._
24
25class MaskedSyncDataModuleTemplate[T <: Data](
26  gen: T,
27  numEntries: Int,
28  numRead: Int,
29  numWrite: Int,
30  numMRead: Int = 0,
31  numMWrite: Int = 0
32) extends Module {
33  val io = IO(new Bundle {
34    // address indexed sync read
35    val raddr = Input(Vec(numRead, UInt(log2Up(numEntries).W)))
36    val rdata = Output(Vec(numRead, gen))
37    // masked sync read (1H)
38    val mrmask = Input(Vec(numMRead, Vec(numEntries, Bool())))
39    val mrdata = Output(Vec(numMRead, gen))
40    // address indexed write
41    val wen   = Input(Vec(numWrite, Bool()))
42    val waddr = Input(Vec(numWrite, UInt(log2Up(numEntries).W)))
43    val wdata = Input(Vec(numWrite, gen))
44    // masked write
45    val mwmask = Input(Vec(numMWrite, Vec(numEntries, Bool())))
46    val mwdata = Input(Vec(numMWrite, gen))
47  })
48
49  val data = Reg(Vec(numEntries, gen))
50
51  // read ports
52  for (i <- 0 until numRead) {
53    io.rdata(i) := data(RegNext(io.raddr(i)))
54  }
55
56  // masked read ports
57  for (i <- 0 until numMRead) {
58    io.mrdata(i) := Mux1H(RegNext(io.mrmask(i)), data)
59  }
60
61  // write ports (with priorities)
62  for (i <- 0 until numWrite) {
63    when (io.wen(i)) {
64      data(io.waddr(i)) := io.wdata(i)
65    }
66  }
67
68  // masked write
69  for (j <- 0 until numEntries) {
70    val wen = VecInit((0 until numMWrite).map(i => io.mwmask(i)(j))).asUInt.orR
71    when (wen) {
72      data(j) := VecInit((0 until numMWrite).map(i => {
73        Mux(io.mwmask(i)(j), io.mwdata(i), 0.U).asUInt
74      })).reduce(_ | _)
75    }
76  }
77
78  // DataModuleTemplate should not be used when there're any write conflicts
79  for (i <- 0 until numWrite) {
80    for (j <- i+1 until numWrite) {
81      assert(!(io.wen(i) && io.wen(j) && io.waddr(i) === io.waddr(j)))
82    }
83  }
84}
85
86class MaskedBankedSyncDataModuleTemplate[T <: Data](
87  gen: T,
88  numEntries: Int,
89  numRead: Int,
90  numWrite: Int,
91  numMRead: Int = 0,
92  numMWrite: Int = 0,
93  numWBanks: Int = 2
94) extends Module {
95  val io = IO(new Bundle {
96    // address indexed sync read
97    val raddr = Input(Vec(numRead, UInt(log2Up(numEntries).W)))
98    val rdata = Output(Vec(numRead, gen))
99    // masked sync read (1H)
100    val mrmask = Input(Vec(numMRead, Vec(numEntries, Bool())))
101    val mrdata = Output(Vec(numMRead, gen))
102    // address indexed write
103    val wen   = Input(Vec(numWrite, Bool()))
104    val waddr = Input(Vec(numWrite, UInt(log2Up(numEntries).W)))
105    val wdata = Input(Vec(numWrite, gen))
106    // masked write
107    val mwmask = Input(Vec(numMWrite, Vec(numEntries, Bool())))
108    val mwdata = Input(Vec(numMWrite, gen))
109  })
110
111  require(isPow2(numWBanks))
112  require(numWBanks >= 2)
113
114  val numEntryPerBank = numEntries / numWBanks
115
116  val data = Reg(Vec(numEntries, gen))
117
118  // read ports
119  for (i <- 0 until numRead) {
120    val raddr_dec = RegNext(UIntToOH(io.raddr(i)))
121    io.rdata(i) := Mux1H(raddr_dec, data)
122  }
123
124  // masked read ports
125  for (i <- 0 until numMRead) {
126    io.mrdata(i) := Mux1H(RegNext(io.mrmask(i)), data)
127  }
128
129  val waddr_dec = io.waddr.map(a => UIntToOH(a))
130
131  def selectBankMask(in: UInt, bank: Int): UInt = {
132    in((bank + 1) * numEntryPerBank - 1, bank * numEntryPerBank)
133  }
134
135  for (bank <- 0 until numWBanks) {
136    // write ports
137    // s0: write to bank level buffer
138    val s0_bank_waddr_dec = waddr_dec.map(a => selectBankMask(a, bank))
139    val s0_bank_write_en = io.wen.zip(s0_bank_waddr_dec).map(w => w._1 && w._2.orR)
140    s0_bank_waddr_dec.zipWithIndex.map(a =>
141      a._1.suggestName("s0_bank_waddr_dec" + bank + "_" + a._2)
142    )
143    s0_bank_write_en.zipWithIndex.map(a =>
144      a._1.suggestName("s0_bank_write_en" + bank + "_" + a._2)
145    )
146    // s1: write data to entries
147    val s1_bank_waddr_dec = s0_bank_waddr_dec.zip(s0_bank_write_en).map(w => RegEnable(w._1, w._2))
148    val s1_bank_wen = RegNext(VecInit(s0_bank_write_en))
149    val s1_wdata = io.wdata.zip(s0_bank_write_en).map(w => RegEnable(w._1, w._2))
150    s1_bank_waddr_dec.zipWithIndex.map(a =>
151      a._1.suggestName("s1_bank_waddr_dec" + bank + "_" + a._2)
152    )
153    s1_bank_wen.zipWithIndex.map(a =>
154      a._1.suggestName("s1_bank_wen" + bank + "_" + a._2)
155    )
156    s1_wdata.zipWithIndex.map(a =>
157      a._1.suggestName("s1_wdata" + bank + "_" + a._2)
158    )
159    // masked write ports
160    // s0: write to bank level buffer
161    val s0_bank_mwmask = io.mwmask.map(a => selectBankMask(a.asUInt, bank))
162    val s0_bank_mwrite_en = s0_bank_mwmask.map(w => w.orR)
163    s0_bank_mwmask.zipWithIndex.map(a =>
164      a._1.suggestName("s0_bank_mwmask" + bank + "_" + a._2)
165    )
166    s0_bank_mwrite_en.zipWithIndex.map(a =>
167      a._1.suggestName("s0_bank_mwrite_en" + bank + "_" + a._2)
168    )
169    // s1: write data to entries
170    val s1_bank_mwmask = s0_bank_mwmask.map(a => RegNext(a))
171    val s1_mwdata = io.mwdata.zip(s0_bank_mwrite_en).map(w => RegEnable(w._1, w._2))
172    s1_bank_mwmask.zipWithIndex.map(a =>
173      a._1.suggestName("s1_bank_mwmask" + bank + "_" + a._2)
174    )
175    s1_mwdata.zipWithIndex.map(a =>
176      a._1.suggestName("s1_mwdata" + bank + "_" + a._2)
177    )
178
179    // entry write
180    for (entry <- 0 until numEntryPerBank) {
181      // write ports
182      val s1_entry_write_en_vec = s1_bank_wen.zip(s1_bank_waddr_dec).map(w => w._1 && w._2(entry))
183      val s1_entry_write_en = VecInit(s1_entry_write_en_vec).asUInt.orR
184      val s1_entry_write_data = Mux1H(s1_entry_write_en_vec, s1_wdata)
185      // masked write ports
186      val s1_bank_mwrite_en_vec = s1_bank_mwmask.map(_(entry))
187      val s1_bank_mwrite_en = VecInit(s1_bank_mwrite_en_vec).asUInt.orR
188      val s1_bank_mwrite_data = Mux1H(s1_bank_mwrite_en_vec, s1_mwdata)
189      when (s1_entry_write_en || s1_bank_mwrite_en) {
190        data(bank * numEntryPerBank + entry) := Mux1H(
191          Seq(s1_entry_write_en, s1_bank_mwrite_en),
192          Seq(s1_entry_write_data, s1_bank_mwrite_data)
193        )
194      }
195      s1_entry_write_en_vec.zipWithIndex.map(a =>
196        a._1.suggestName("s1_entry_write_en_vec" + bank + "_" + entry + "_" + a._2)
197      )
198      s1_bank_mwrite_en_vec.zipWithIndex.map(a =>
199        a._1.suggestName("s1_bank_mwrite_en_vec" + bank + "_" + entry + "_" + a._2)
200      )
201      s1_entry_write_en.suggestName("s1_entry_write_en" + bank + "_" + entry)
202      s1_entry_write_data.suggestName("s1_entry_write_data" + bank + "_" + entry)
203      s1_bank_mwrite_en.suggestName("s1_bank_mwrite_en" + bank + "_" + entry)
204      s1_bank_mwrite_data.suggestName("s1_bank_mwrite_data" + bank + "_" + entry)
205    }
206  }
207}
208