xref: /XiangShan/src/main/scala/xiangshan/mem/MaskedDataModule.scala (revision 039cdc35f5f3b68b6295ec5ace90f22a77322e02)
1/***************************************************************************************
2* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3* Copyright (c) 2020-2021 Peng Cheng Laboratory
4*
5* XiangShan is licensed under Mulan PSL v2.
6* You can use this software according to the terms and conditions of the Mulan PSL v2.
7* You may obtain a copy of Mulan PSL v2 at:
8*          http://license.coscl.org.cn/MulanPSL2
9*
10* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13*
14* See the Mulan PSL v2 for more details.
15***************************************************************************************/
16
17package xiangshan.mem
18
19import chisel3._
20import chisel3.util._
21import xiangshan._
22import utils._
23import utility._
24import xiangshan.cache._
25
26class MaskedSyncDataModuleTemplate[T <: Data](
27  gen: T,
28  numEntries: Int,
29  numRead: Int,
30  numWrite: Int,
31  numMRead: Int = 0,
32  numMWrite: Int = 0
33) extends Module {
34  val io = IO(new Bundle {
35    // address indexed sync read
36    val raddr = Input(Vec(numRead, UInt(log2Up(numEntries).W)))
37    val rdata = Output(Vec(numRead, gen))
38    // masked sync read (1H)
39    val mrmask = Input(Vec(numMRead, Vec(numEntries, Bool())))
40    val mrdata = Output(Vec(numMRead, gen))
41    // address indexed write
42    val wen   = Input(Vec(numWrite, Bool()))
43    val waddr = Input(Vec(numWrite, UInt(log2Up(numEntries).W)))
44    val wdata = Input(Vec(numWrite, gen))
45    // masked write
46    val mwmask = Input(Vec(numMWrite, Vec(numEntries, Bool())))
47    val mwdata = Input(Vec(numMWrite, gen))
48  })
49
50  val data = Reg(Vec(numEntries, gen))
51
52  // read ports
53  for (i <- 0 until numRead) {
54    io.rdata(i) := data(RegNext(io.raddr(i)))
55  }
56
57  // masked read ports
58  for (i <- 0 until numMRead) {
59    io.mrdata(i) := Mux1H(RegNext(io.mrmask(i)), data)
60  }
61
62  // write ports (with priorities)
63  for (i <- 0 until numWrite) {
64    when (io.wen(i)) {
65      data(io.waddr(i)) := io.wdata(i)
66    }
67  }
68
69  // masked write
70  for (j <- 0 until numEntries) {
71    val wen = VecInit((0 until numMWrite).map(i => io.mwmask(i)(j))).asUInt.orR
72    when (wen) {
73      data(j) := VecInit((0 until numMWrite).map(i => {
74        Mux(io.mwmask(i)(j), io.mwdata(i), 0.U).asUInt
75      })).reduce(_ | _)
76    }
77  }
78
79  // DataModuleTemplate should not be used when there're any write conflicts
80  for (i <- 0 until numWrite) {
81    for (j <- i+1 until numWrite) {
82      assert(!(io.wen(i) && io.wen(j) && io.waddr(i) === io.waddr(j)))
83    }
84  }
85}
86
87class MaskedBankedSyncDataModuleTemplate[T <: Data](
88  gen: T,
89  numEntries: Int,
90  numRead: Int,
91  numWrite: Int,
92  numMRead: Int = 0,
93  numMWrite: Int = 0,
94  numWBanks: Int = 2
95) extends Module {
96  val io = IO(new Bundle {
97    // address indexed sync read
98    val raddr = Input(Vec(numRead, UInt(log2Up(numEntries).W)))
99    val rdata = Output(Vec(numRead, gen))
100    // masked sync read (1H)
101    val mrmask = Input(Vec(numMRead, Vec(numEntries, Bool())))
102    val mrdata = Output(Vec(numMRead, gen))
103    // address indexed write
104    val wen   = Input(Vec(numWrite, Bool()))
105    val waddr = Input(Vec(numWrite, UInt(log2Up(numEntries).W)))
106    val wdata = Input(Vec(numWrite, gen))
107    // masked write
108    val mwmask = Input(Vec(numMWrite, Vec(numEntries, Bool())))
109    val mwdata = Input(Vec(numMWrite, gen))
110  })
111
112  require(isPow2(numWBanks))
113  require(numWBanks >= 2)
114
115  val numEntryPerBank = numEntries / numWBanks
116
117  val data = Reg(Vec(numEntries, gen))
118
119  // read ports
120  for (i <- 0 until numRead) {
121    val raddr_dec = RegNext(UIntToOH(io.raddr(i)))
122    io.rdata(i) := Mux1H(raddr_dec, data)
123  }
124
125  // masked read ports
126  for (i <- 0 until numMRead) {
127    io.mrdata(i) := Mux1H(RegNext(io.mrmask(i)), data)
128  }
129
130  val waddr_dec = io.waddr.map(a => UIntToOH(a))
131
132  def selectBankMask(in: UInt, bank: Int): UInt = {
133    in((bank + 1) * numEntryPerBank - 1, bank * numEntryPerBank)
134  }
135
136  for (bank <- 0 until numWBanks) {
137    // write ports
138    // s0: write to bank level buffer
139    val s0_bank_waddr_dec = waddr_dec.map(a => selectBankMask(a, bank))
140    val s0_bank_write_en = io.wen.zip(s0_bank_waddr_dec).map(w => w._1 && w._2.orR)
141    s0_bank_waddr_dec.zipWithIndex.map(a =>
142      a._1.suggestName("s0_bank_waddr_dec" + bank + "_" + a._2)
143    )
144    s0_bank_write_en.zipWithIndex.map(a =>
145      a._1.suggestName("s0_bank_write_en" + bank + "_" + a._2)
146    )
147    // s1: write data to entries
148    val s1_bank_waddr_dec = s0_bank_waddr_dec.zip(s0_bank_write_en).map(w => RegEnable(w._1, w._2))
149    val s1_bank_wen = RegNext(VecInit(s0_bank_write_en))
150    val s1_wdata = io.wdata.zip(s0_bank_write_en).map(w => RegEnable(w._1, w._2))
151    s1_bank_waddr_dec.zipWithIndex.map(a =>
152      a._1.suggestName("s1_bank_waddr_dec" + bank + "_" + a._2)
153    )
154    s1_bank_wen.zipWithIndex.map(a =>
155      a._1.suggestName("s1_bank_wen" + bank + "_" + a._2)
156    )
157    s1_wdata.zipWithIndex.map(a =>
158      a._1.suggestName("s1_wdata" + bank + "_" + a._2)
159    )
160    // masked write ports
161    // s0: write to bank level buffer
162    val s0_bank_mwmask = io.mwmask.map(a => selectBankMask(a.asUInt, bank))
163    val s0_bank_mwrite_en = s0_bank_mwmask.map(w => w.orR)
164    s0_bank_mwmask.zipWithIndex.map(a =>
165      a._1.suggestName("s0_bank_mwmask" + bank + "_" + a._2)
166    )
167    s0_bank_mwrite_en.zipWithIndex.map(a =>
168      a._1.suggestName("s0_bank_mwrite_en" + bank + "_" + a._2)
169    )
170    // s1: write data to entries
171    val s1_bank_mwmask = s0_bank_mwmask.map(a => RegNext(a))
172    val s1_mwdata = io.mwdata.zip(s0_bank_mwrite_en).map(w => RegEnable(w._1, w._2))
173    s1_bank_mwmask.zipWithIndex.map(a =>
174      a._1.suggestName("s1_bank_mwmask" + bank + "_" + a._2)
175    )
176    s1_mwdata.zipWithIndex.map(a =>
177      a._1.suggestName("s1_mwdata" + bank + "_" + a._2)
178    )
179
180    // entry write
181    for (entry <- 0 until numEntryPerBank) {
182      // write ports
183      val s1_entry_write_en_vec = s1_bank_wen.zip(s1_bank_waddr_dec).map(w => w._1 && w._2(entry))
184      val s1_entry_write_en = VecInit(s1_entry_write_en_vec).asUInt.orR
185      val s1_entry_write_data = Mux1H(s1_entry_write_en_vec, s1_wdata)
186      // masked write ports
187      val s1_bank_mwrite_en_vec = s1_bank_mwmask.map(_(entry))
188      val s1_bank_mwrite_en = VecInit(s1_bank_mwrite_en_vec).asUInt.orR
189      val s1_bank_mwrite_data = Mux1H(s1_bank_mwrite_en_vec, s1_mwdata)
190      when (s1_entry_write_en || s1_bank_mwrite_en) {
191        data(bank * numEntryPerBank + entry) := Mux1H(
192          Seq(s1_entry_write_en, s1_bank_mwrite_en),
193          Seq(s1_entry_write_data, s1_bank_mwrite_data)
194        )
195      }
196      s1_entry_write_en_vec.zipWithIndex.map(a =>
197        a._1.suggestName("s1_entry_write_en_vec" + bank + "_" + entry + "_" + a._2)
198      )
199      s1_bank_mwrite_en_vec.zipWithIndex.map(a =>
200        a._1.suggestName("s1_bank_mwrite_en_vec" + bank + "_" + entry + "_" + a._2)
201      )
202      s1_entry_write_en.suggestName("s1_entry_write_en" + bank + "_" + entry)
203      s1_entry_write_data.suggestName("s1_entry_write_data" + bank + "_" + entry)
204      s1_bank_mwrite_en.suggestName("s1_bank_mwrite_en" + bank + "_" + entry)
205      s1_bank_mwrite_data.suggestName("s1_bank_mwrite_data" + bank + "_" + entry)
206    }
207  }
208}
209