xref: /XiangShan/src/main/scala/xiangshan/cache/mmu/TLBStorage.scala (revision 7f2b7720ff1889edcfc44902e64e1a082b775d9b)
1/***************************************************************************************
2  * Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3  * Copyright (c) 2020-2021 Peng Cheng Laboratory
4  *
5  * XiangShan is licensed under Mulan PSL v2.
6  * You can use this software according to the terms and conditions of the Mulan PSL v2.
7  * You may obtain a copy of Mulan PSL v2 at:
8  *          http://license.coscl.org.cn/MulanPSL2
9  *
10  * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11  * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12  * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13  *
14  * See the Mulan PSL v2 for more details.
15  ***************************************************************************************/
16
17package xiangshan.cache.mmu
18
19import chipsalliance.rocketchip.config.Parameters
20import chisel3._
21import chisel3.experimental.chiselName
22import chisel3.util._
23import utils._
24import freechips.rocketchip.formal.PropertyClass
25
26import scala.math.min
27
28class BankedAsyncDataModuleTemplateWithDup[T <: Data](
29  gen: T,
30  numEntries: Int,
31  numRead: Int,
32  numDup: Int,
33  numBanks: Int
34) extends Module {
35  val io = IO(new Bundle {
36    val raddr = Vec(numRead, Input(UInt(log2Ceil(numEntries).W)))
37    val rdata = Vec(numRead, Vec(numDup, Output(gen)))
38    val wen   = Input(Bool())
39    val waddr = Input(UInt(log2Ceil(numEntries).W))
40    val wdata = Input(gen)
41  })
42  require(numBanks > 1)
43  require(numEntries > numBanks)
44
45  val numBankEntries = numEntries / numBanks
46  def bankOffset(address: UInt): UInt = {
47    address(log2Ceil(numBankEntries) - 1, 0)
48  }
49
50  def bankIndex(address: UInt): UInt = {
51    address(log2Ceil(numEntries) - 1, log2Ceil(numBankEntries))
52  }
53
54  val dataBanks = Seq.tabulate(numBanks)(i => {
55    val bankEntries = if (i < numBanks - 1) numBankEntries else (numEntries - (i * numBankEntries))
56    Mem(bankEntries, gen)
57  })
58
59  // async read, but regnext
60  for (i <- 0 until numRead) {
61    val data_read = Reg(Vec(numDup, Vec(numBanks, gen)))
62    val bank_index = Reg(Vec(numDup, UInt(numBanks.W)))
63    for (j <- 0 until numDup) {
64      bank_index(j) := UIntToOH(bankIndex(io.raddr(i)))
65      for (k <- 0 until numBanks) {
66        data_read(j)(k) := Mux(io.wen && (io.waddr === io.raddr(i)),
67          io.wdata, dataBanks(k)(bankOffset(io.raddr(i))))
68      }
69    }
70    // next cycle
71    for (j <- 0 until numDup) {
72      io.rdata(i)(j) := Mux1H(bank_index(j), data_read(j))
73    }
74  }
75
76  // write
77  for (i <- 0 until numBanks) {
78    when (io.wen && (bankIndex(io.waddr) === i.U)) {
79      dataBanks(i)(bankOffset(io.waddr)) := io.wdata
80    }
81  }
82}
83
84@chiselName
85class TLBFA(
86  parentName: String,
87  ports: Int,
88  nSets: Int,
89  nWays: Int,
90  saveLevel: Boolean = false,
91  normalPage: Boolean,
92  superPage: Boolean
93)(implicit p: Parameters) extends TlbModule with HasPerfEvents {
94
95  val io = IO(new TlbStorageIO(nSets, nWays, ports))
96  io.r.req.map(_.ready := true.B)
97
98  val v = RegInit(VecInit(Seq.fill(nWays)(false.B)))
99  val entries = Reg(Vec(nWays, new TlbEntry(normalPage, superPage)))
100  val g = entries.map(_.perm.g)
101
102  for (i <- 0 until ports) {
103    val req = io.r.req(i)
104    val resp = io.r.resp(i)
105    val access = io.access(i)
106
107    val vpn = req.bits.vpn
108    val vpn_reg = RegEnable(vpn, req.fire())
109    val vpn_gen_ppn = if(saveLevel) vpn else vpn_reg
110
111    val refill_mask = Mux(io.w.valid, UIntToOH(io.w.bits.wayIdx), 0.U(nWays.W))
112    val hitVec = VecInit((entries.zipWithIndex).zip(v zip refill_mask.asBools).map{case (e, m) => e._1.hit(vpn, io.csr.satp.asid) && m._1 && !m._2 })
113
114    hitVec.suggestName("hitVec")
115
116    val hitVecReg = RegEnable(hitVec, req.fire())
117    assert(!resp.valid || (PopCount(hitVecReg) === 0.U || PopCount(hitVecReg) === 1.U), s"${parentName} fa port${i} multi-hit")
118
119    resp.valid := RegNext(req.valid)
120    resp.bits.hit := Cat(hitVecReg).orR
121    if (nWays == 1) {
122      resp.bits.ppn(0) := entries(0).genPPN(saveLevel, req.valid)(vpn_gen_ppn)
123      resp.bits.perm(0) := entries(0).perm
124    } else {
125      resp.bits.ppn(0) := ParallelMux(hitVecReg zip entries.map(_.genPPN(saveLevel, req.valid)(vpn_gen_ppn)))
126      resp.bits.perm(0) := ParallelMux(hitVecReg zip entries.map(_.perm))
127    }
128
129    access.sets := get_set_idx(vpn_reg, nSets) // no use
130    access.touch_ways.valid := resp.valid && Cat(hitVecReg).orR
131    access.touch_ways.bits := OHToUInt(hitVecReg)
132
133    resp.bits.hit.suggestName("hit")
134    resp.bits.ppn.suggestName("ppn")
135    resp.bits.perm.suggestName("perm")
136  }
137
138  when (io.w.valid) {
139    v(io.w.bits.wayIdx) := true.B
140    entries(io.w.bits.wayIdx).apply(io.w.bits.data, io.csr.satp.asid, io.w.bits.data_replenish)
141  }
142  // write assert, shoulg not duplicate with the existing entries
143  val w_hit_vec = VecInit(entries.zip(v).map{case (e, vi) => e.hit(io.w.bits.data.entry.tag, io.csr.satp.asid) && vi })
144  XSError(io.w.valid && Cat(w_hit_vec).orR, s"${parentName} refill, duplicate with existing entries")
145
146  val refill_vpn_reg = RegNext(io.w.bits.data.entry.tag)
147  val refill_wayIdx_reg = RegNext(io.w.bits.wayIdx)
148  when (RegNext(io.w.valid)) {
149    io.access.map { access =>
150      access.sets := get_set_idx(refill_vpn_reg, nSets)
151      access.touch_ways.valid := true.B
152      access.touch_ways.bits := refill_wayIdx_reg
153    }
154  }
155
156  val sfence = io.sfence
157  val sfence_vpn = sfence.bits.addr.asTypeOf(new VaBundle().cloneType).vpn
158  val sfenceHit = entries.map(_.hit(sfence_vpn, sfence.bits.asid))
159  val sfenceHit_noasid = entries.map(_.hit(sfence_vpn, sfence.bits.asid, ignoreAsid = true))
160  when (io.sfence.valid) {
161    when (sfence.bits.rs1) { // virtual address *.rs1 <- (rs1===0.U)
162      when (sfence.bits.rs2) { // asid, but i do not want to support asid, *.rs2 <- (rs2===0.U)
163        // all addr and all asid
164        v.map(_ := false.B)
165      }.otherwise {
166        // all addr but specific asid
167        v.zipWithIndex.map{ case (a,i) => a := a & (g(i) | !(entries(i).asid === sfence.bits.asid)) }
168      }
169    }.otherwise {
170      when (sfence.bits.rs2) {
171        // specific addr but all asid
172        v.zipWithIndex.map{ case (a,i) => a := a & !sfenceHit_noasid(i) }
173      }.otherwise {
174        // specific addr and specific asid
175        v.zipWithIndex.map{ case (a,i) => a := a & !(sfenceHit(i) && !g(i)) }
176      }
177    }
178  }
179
180  val victim_idx = io.w.bits.wayIdx
181  io.victim.out.valid := v(victim_idx) && io.w.valid && entries(victim_idx).is_normalentry()
182  io.victim.out.bits.entry := ns_to_n(entries(victim_idx))
183
184  def ns_to_n(ns: TlbEntry): TlbEntry = {
185    val n = Wire(new TlbEntry(pageNormal = true, pageSuper = false))
186    n.perm := ns.perm
187    n.ppn := ns.ppn
188    n.tag := ns.tag
189    n.asid := ns.asid
190    n
191  }
192
193  XSPerfAccumulate(s"access", io.r.resp.map(_.valid.asUInt()).fold(0.U)(_ + _))
194  XSPerfAccumulate(s"hit", io.r.resp.map(a => a.valid && a.bits.hit).fold(0.U)(_.asUInt() + _.asUInt()))
195
196  for (i <- 0 until nWays) {
197    XSPerfAccumulate(s"access${i}", io.r.resp.zip(io.access.map(acc => UIntToOH(acc.touch_ways.bits))).map{ case (a, b) =>
198      a.valid && a.bits.hit && b(i)}.fold(0.U)(_.asUInt() + _.asUInt()))
199  }
200  for (i <- 0 until nWays) {
201    XSPerfAccumulate(s"refill${i}", io.w.valid && io.w.bits.wayIdx === i.U)
202  }
203
204  val perfEvents = Seq(
205    ("tlbstore_access", io.r.resp.map(_.valid.asUInt()).fold(0.U)(_ + _)                            ),
206    ("tlbstore_hit   ", io.r.resp.map(a => a.valid && a.bits.hit).fold(0.U)(_.asUInt() + _.asUInt())),
207  )
208  generatePerfEvent()
209
210  println(s"${parentName} tlb_fa: nSets${nSets} nWays:${nWays}")
211}
212
213@chiselName
214class TLBSA(
215  parentName: String,
216  ports: Int,
217  nDups: Int,
218  nSets: Int,
219  nWays: Int,
220  normalPage: Boolean,
221  superPage: Boolean
222)(implicit p: Parameters) extends TlbModule {
223  require(!superPage, "super page should use reg/fa")
224  require(nWays == 1, "nWays larger than 1 causes bad timing")
225
226  // timing optimization to divide v select into two cycles.
227  val VPRE_SELECT = min(8, nSets)
228  val VPOST_SELECT = nSets / VPRE_SELECT
229  val nBanks = 8
230
231  val io = IO(new TlbStorageIO(nSets, nWays, ports, nDups))
232
233  io.r.req.map(_.ready :=  true.B)
234  val v = RegInit(VecInit(Seq.fill(nSets)(VecInit(Seq.fill(nWays)(false.B)))))
235  val entries = Module(new BankedAsyncDataModuleTemplateWithDup(new TlbEntry(normalPage, superPage), nSets, ports, nDups, nBanks))
236
237  for (i <- 0 until ports) { // duplicate sram
238    val req = io.r.req(i)
239    val resp = io.r.resp(i)
240    val access = io.access(i)
241
242    val vpn = req.bits.vpn
243    val vpn_reg = RegEnable(vpn, req.fire())
244
245    val ridx = get_set_idx(vpn, nSets)
246    val v_resize = v.asTypeOf(Vec(VPRE_SELECT, Vec(VPOST_SELECT, UInt(nWays.W))))
247    val vidx_resize = RegNext(v_resize(get_set_idx(drop_set_idx(vpn, VPOST_SELECT), VPRE_SELECT)))
248    val vidx = vidx_resize(get_set_idx(vpn_reg, VPOST_SELECT)).asBools.map(_ && RegNext(req.fire()))
249    val vidx_bypass = RegNext((entries.io.waddr === ridx) && entries.io.wen)
250    entries.io.raddr(i) := ridx
251
252    val data = entries.io.rdata(i)
253    val hit = data(0).hit(vpn_reg, io.csr.satp.asid, nSets) && (vidx(0) || vidx_bypass)
254    resp.bits.hit := hit
255    for (d <- 0 until nDups) {
256      resp.bits.ppn(d) := data(d).genPPN()(vpn_reg)
257      resp.bits.perm(d) := data(d).perm
258    }
259
260    resp.valid := { RegNext(req.valid) }
261    resp.bits.hit.suggestName("hit")
262    resp.bits.ppn.suggestName("ppn")
263    resp.bits.perm.suggestName("perm")
264
265    access.sets := get_set_idx(vpn_reg, nSets) // no use
266    access.touch_ways.valid := resp.valid && hit
267    access.touch_ways.bits := 1.U // TODO: set-assoc need no replacer when nset is 1
268  }
269
270  // W ports should be 1, or, check at above will be wrong.
271  entries.io.wen := io.w.valid || io.victim.in.valid
272  entries.io.waddr := Mux(io.w.valid,
273    get_set_idx(io.w.bits.data.entry.tag, nSets),
274    get_set_idx(io.victim.in.bits.entry.tag, nSets))
275  entries.io.wdata := Mux(io.w.valid,
276    (Wire(new TlbEntry(normalPage, superPage)).apply(io.w.bits.data, io.csr.satp.asid, io.w.bits.data_replenish)),
277    io.victim.in.bits.entry)
278
279  when (io.victim.in.valid) {
280    v(get_set_idx(io.victim.in.bits.entry.tag, nSets))(io.w.bits.wayIdx) := true.B
281  }
282  // w has higher priority than victim
283  when (io.w.valid) {
284    v(get_set_idx(io.w.bits.data.entry.tag, nSets))(io.w.bits.wayIdx) := true.B
285  }
286
287  val refill_vpn_reg = RegNext(Mux(io.victim.in.valid, io.victim.in.bits.entry.tag, io.w.bits.data.entry.tag))
288  val refill_wayIdx_reg = RegNext(io.w.bits.wayIdx)
289  when (RegNext(io.w.valid || io.victim.in.valid)) {
290    io.access.map { access =>
291      access.sets := get_set_idx(refill_vpn_reg, nSets)
292      access.touch_ways.valid := true.B
293      access.touch_ways.bits := refill_wayIdx_reg
294    }
295  }
296
297  val sfence = io.sfence
298  val sfence_vpn = sfence.bits.addr.asTypeOf(new VaBundle().cloneType).vpn
299  when (io.sfence.valid) {
300    when (sfence.bits.rs1) { // virtual address *.rs1 <- (rs1===0.U)
301        v.map(a => a.map(b => b := false.B))
302    }.otherwise {
303        // specific addr but all asid
304        v(get_set_idx(sfence_vpn, nSets)).map(_ := false.B)
305    }
306  }
307
308  io.victim.out := DontCare
309  io.victim.out.valid := false.B
310
311  XSPerfAccumulate(s"access", io.r.req.map(_.valid.asUInt()).fold(0.U)(_ + _))
312  XSPerfAccumulate(s"hit", io.r.resp.map(a => a.valid && a.bits.hit).fold(0.U)(_.asUInt() + _.asUInt()))
313
314  for (i <- 0 until nSets) {
315    XSPerfAccumulate(s"refill${i}", (io.w.valid || io.victim.in.valid) &&
316        (Mux(io.w.valid, get_set_idx(io.w.bits.data.entry.tag, nSets), get_set_idx(io.victim.in.bits.entry.tag, nSets)) === i.U)
317      )
318  }
319
320  for (i <- 0 until nSets) {
321    XSPerfAccumulate(s"hit${i}", io.r.resp.map(a => a.valid & a.bits.hit)
322      .zip(io.r.req.map(a => RegNext(get_set_idx(a.bits.vpn, nSets)) === i.U))
323      .map{a => (a._1 && a._2).asUInt()}
324      .fold(0.U)(_ + _)
325    )
326  }
327
328  for (i <- 0 until nSets) {
329    XSPerfAccumulate(s"access${i}", io.r.resp.map(_.valid)
330      .zip(io.r.req.map(a => RegNext(get_set_idx(a.bits.vpn, nSets)) === i.U))
331      .map{a => (a._1 && a._2).asUInt()}
332      .fold(0.U)(_ + _)
333    )
334  }
335
336  println(s"${parentName} tlb_sa: nSets:${nSets} nWays:${nWays}")
337}
338
339object TlbStorage {
340  def apply
341  (
342    parentName: String,
343    associative: String,
344    ports: Int,
345    nDups: Int = 1,
346    nSets: Int,
347    nWays: Int,
348    saveLevel: Boolean = false,
349    normalPage: Boolean,
350    superPage: Boolean
351  )(implicit p: Parameters) = {
352    if (associative == "fa") {
353       val storage = Module(new TLBFA(parentName, ports, nSets, nWays, saveLevel, normalPage, superPage))
354       storage.suggestName(s"${parentName}_fa")
355       storage.io
356    } else {
357       val storage = Module(new TLBSA(parentName, ports, nDups, nSets, nWays, normalPage, superPage))
358       storage.suggestName(s"${parentName}_sa")
359       storage.io
360    }
361  }
362}
363
364class TlbStorageWrapper(ports: Int, q: TLBParameters, nDups: Int = 1)(implicit p: Parameters) extends TlbModule {
365  val io = IO(new TlbStorageWrapperIO(ports, q, nDups))
366
367// TODO: wrap Normal page and super page together, wrap the declare & refill dirty codes
368  val normalPage = TlbStorage(
369    parentName = q.name + "_np_storage",
370    associative = q.normalAssociative,
371    ports = ports,
372    nDups = nDups,
373    nSets = q.normalNSets,
374    nWays = q.normalNWays,
375    saveLevel = q.saveLevel,
376    normalPage = true,
377    superPage = false
378  )
379  val superPage = TlbStorage(
380    parentName = q.name + "_sp_storage",
381    associative = q.superAssociative,
382    ports = ports,
383    nSets = q.superNSets,
384    nWays = q.superNWays,
385    normalPage = q.normalAsVictim,
386    superPage = true,
387  )
388
389  for (i <- 0 until ports) {
390    normalPage.r_req_apply(
391      valid = io.r.req(i).valid,
392      vpn = io.r.req(i).bits.vpn,
393      i = i
394    )
395    superPage.r_req_apply(
396      valid = io.r.req(i).valid,
397      vpn = io.r.req(i).bits.vpn,
398      i = i
399    )
400  }
401
402  for (i <- 0 until ports) {
403    val nq = normalPage.r.req(i)
404    val np = normalPage.r.resp(i)
405    val sq = superPage.r.req(i)
406    val sp = superPage.r.resp(i)
407    val rq = io.r.req(i)
408    val rp = io.r.resp(i)
409    rq.ready := nq.ready && sq.ready // actually, not used
410    rp.valid := np.valid && sp.valid // actually, not used
411    rp.bits.hit := np.bits.hit || sp.bits.hit
412    for (d <- 0 until nDups) {
413      rp.bits.ppn(d) := Mux(sp.bits.hit, sp.bits.ppn(0), np.bits.ppn(d))
414      rp.bits.perm(d) := Mux(sp.bits.hit, sp.bits.perm(0), np.bits.perm(d))
415    }
416    rp.bits.super_hit := sp.bits.hit
417    rp.bits.super_ppn := sp.bits.ppn(0)
418    rp.bits.spm := np.bits.perm(0).pm
419    assert(!np.bits.hit || !sp.bits.hit || !rp.valid, s"${q.name} storage ports${i} normal and super multi-hit")
420  }
421
422  normalPage.victim.in <> superPage.victim.out
423  normalPage.victim.out <> superPage.victim.in
424  normalPage.sfence <> io.sfence
425  superPage.sfence <> io.sfence
426  normalPage.csr <> io.csr
427  superPage.csr <> io.csr
428
429  val normal_refill_idx = if (q.outReplace) {
430    io.replace.normalPage.access <> normalPage.access
431    io.replace.normalPage.chosen_set := get_set_idx(io.w.bits.data.entry.tag, q.normalNSets)
432    io.replace.normalPage.refillIdx
433  } else if (q.normalAssociative == "fa") {
434    val re = ReplacementPolicy.fromString(q.normalReplacer, q.normalNWays)
435    re.access(normalPage.access.map(_.touch_ways)) // normalhitVecVec.zipWithIndex.map{ case (hv, i) => get_access(hv, validRegVec(i))})
436    re.way
437  } else { // set-acco && plru
438    val re = ReplacementPolicy.fromString(q.normalReplacer, q.normalNSets, q.normalNWays)
439    re.access(normalPage.access.map(_.sets), normalPage.access.map(_.touch_ways))
440    re.way(get_set_idx(io.w.bits.data.entry.tag, q.normalNSets))
441  }
442
443  val super_refill_idx = if (q.outReplace) {
444    io.replace.superPage.access <> superPage.access
445    io.replace.superPage.chosen_set := DontCare
446    io.replace.superPage.refillIdx
447  } else {
448    val re = ReplacementPolicy.fromString(q.superReplacer, q.superNWays)
449    re.access(superPage.access.map(_.touch_ways))
450    re.way
451  }
452
453  normalPage.w_apply(
454    valid = { if (q.normalAsVictim) false.B
455    else io.w.valid && io.w.bits.data.entry.level.get === 2.U },
456    wayIdx = normal_refill_idx,
457    data = io.w.bits.data,
458    data_replenish = io.w.bits.data_replenish
459  )
460  superPage.w_apply(
461    valid = { if (q.normalAsVictim) io.w.valid
462    else io.w.valid && io.w.bits.data.entry.level.get =/= 2.U },
463    wayIdx = super_refill_idx,
464    data = io.w.bits.data,
465    data_replenish = io.w.bits.data_replenish
466  )
467
468    // replacement
469  def get_access(one_hot: UInt, valid: Bool): Valid[UInt] = {
470    val res = Wire(Valid(UInt(log2Up(one_hot.getWidth).W)))
471    res.valid := Cat(one_hot).orR && valid
472    res.bits := OHToUInt(one_hot)
473    res
474  }
475}
476