xref: /XiangShan/src/main/scala/xiangshan/cache/mmu/TLBStorage.scala (revision dff7ca56cd20b781b172ce6aa5464295e0ac5e41)
1/***************************************************************************************
2  * Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3  * Copyright (c) 2020-2021 Peng Cheng Laboratory
4  *
5  * XiangShan is licensed under Mulan PSL v2.
6  * You can use this software according to the terms and conditions of the Mulan PSL v2.
7  * You may obtain a copy of Mulan PSL v2 at:
8  *          http://license.coscl.org.cn/MulanPSL2
9  *
10  * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11  * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12  * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13  *
14  * See the Mulan PSL v2 for more details.
15  ***************************************************************************************/
16
17package xiangshan.cache.mmu
18
19import chipsalliance.rocketchip.config.Parameters
20import chisel3._
21import chisel3.experimental.chiselName
22import chisel3.util._
23import utils._
24import freechips.rocketchip.formal.PropertyClass
25
26import scala.math.min
27
28@chiselName
29class TLBFA(
30  parentName: String,
31  ports: Int,
32  nSets: Int,
33  nWays: Int,
34  saveLevel: Boolean = false,
35  normalPage: Boolean,
36  superPage: Boolean
37)(implicit p: Parameters) extends TlbModule with HasPerfEvents {
38
39  val io = IO(new TlbStorageIO(nSets, nWays, ports))
40  io.r.req.map(_.ready := true.B)
41
42  val v = RegInit(VecInit(Seq.fill(nWays)(false.B)))
43  val entries = Reg(Vec(nWays, new TlbEntry(normalPage, superPage)))
44  val g = entries.map(_.perm.g)
45
46  for (i <- 0 until ports) {
47    val req = io.r.req(i)
48    val resp = io.r.resp(i)
49    val access = io.access(i)
50
51    val vpn = req.bits.vpn
52    val vpn_reg = RegEnable(vpn, req.fire())
53    val vpn_gen_ppn = if(saveLevel) vpn else vpn_reg
54
55    val refill_mask = Mux(io.w.valid, UIntToOH(io.w.bits.wayIdx), 0.U(nWays.W))
56    val hitVec = VecInit((entries.zipWithIndex).zip(v zip refill_mask.asBools).map{case (e, m) => e._1.hit(vpn, io.csr.satp.asid) && m._1 && !m._2 })
57
58    hitVec.suggestName("hitVec")
59
60    val hitVecReg = RegEnable(hitVec, req.fire())
61    assert(!resp.valid || (PopCount(hitVecReg) === 0.U || PopCount(hitVecReg) === 1.U), s"${parentName} fa port${i} multi-hit")
62
63    resp.valid := RegNext(req.valid)
64    resp.bits.hit := Cat(hitVecReg).orR
65    if (nWays == 1) {
66      resp.bits.ppn := entries(0).genPPN(saveLevel, req.valid)(vpn_gen_ppn)
67      resp.bits.perm := entries(0).perm
68    } else {
69      resp.bits.ppn := ParallelMux(hitVecReg zip entries.map(_.genPPN(saveLevel, req.valid)(vpn_gen_ppn)))
70      resp.bits.perm := ParallelMux(hitVecReg zip entries.map(_.perm))
71    }
72
73    access.sets := get_set_idx(vpn_reg, nSets) // no use
74    access.touch_ways.valid := resp.valid && Cat(hitVecReg).orR
75    access.touch_ways.bits := OHToUInt(hitVecReg)
76
77    resp.bits.hit.suggestName("hit")
78    resp.bits.ppn.suggestName("ppn")
79    resp.bits.perm.suggestName("perm")
80  }
81
82  when (io.w.valid) {
83    v(io.w.bits.wayIdx) := true.B
84    entries(io.w.bits.wayIdx).apply(io.w.bits.data, io.csr.satp.asid, io.w.bits.data_replenish)
85  }
86  // write assert, shoulg not duplicate with the existing entries
87  val w_hit_vec = VecInit(entries.zip(v).map{case (e, vi) => e.hit(io.w.bits.data.entry.tag, io.csr.satp.asid) && vi })
88  XSError(io.w.valid && Cat(w_hit_vec).orR, s"${parentName} refill, duplicate with existing entries")
89
90  val refill_vpn_reg = RegNext(io.w.bits.data.entry.tag)
91  val refill_wayIdx_reg = RegNext(io.w.bits.wayIdx)
92  when (RegNext(io.w.valid)) {
93    io.access.map { access =>
94      access.sets := get_set_idx(refill_vpn_reg, nSets)
95      access.touch_ways.valid := true.B
96      access.touch_ways.bits := refill_wayIdx_reg
97    }
98  }
99
100  val sfence = io.sfence
101  val sfence_vpn = sfence.bits.addr.asTypeOf(new VaBundle().cloneType).vpn
102  val sfenceHit = entries.map(_.hit(sfence_vpn, sfence.bits.asid))
103  val sfenceHit_noasid = entries.map(_.hit(sfence_vpn, sfence.bits.asid, ignoreAsid = true))
104  when (io.sfence.valid) {
105    when (sfence.bits.rs1) { // virtual address *.rs1 <- (rs1===0.U)
106      when (sfence.bits.rs2) { // asid, but i do not want to support asid, *.rs2 <- (rs2===0.U)
107        // all addr and all asid
108        v.map(_ := false.B)
109      }.otherwise {
110        // all addr but specific asid
111        v.zipWithIndex.map{ case (a,i) => a := a & (g(i) | !(entries(i).asid === sfence.bits.asid)) }
112      }
113    }.otherwise {
114      when (sfence.bits.rs2) {
115        // specific addr but all asid
116        v.zipWithIndex.map{ case (a,i) => a := a & !sfenceHit_noasid(i) }
117      }.otherwise {
118        // specific addr and specific asid
119        v.zipWithIndex.map{ case (a,i) => a := a & !(sfenceHit(i) && !g(i)) }
120      }
121    }
122  }
123
124  val victim_idx = io.w.bits.wayIdx
125  io.victim.out.valid := v(victim_idx) && io.w.valid && entries(victim_idx).is_normalentry()
126  io.victim.out.bits.entry := ns_to_n(entries(victim_idx))
127
128  def ns_to_n(ns: TlbEntry): TlbEntry = {
129    val n = Wire(new TlbEntry(pageNormal = true, pageSuper = false))
130    n.perm := ns.perm
131    n.ppn := ns.ppn
132    n.tag := ns.tag
133    n.asid := ns.asid
134    n
135  }
136
137  XSPerfAccumulate(s"access", io.r.resp.map(_.valid.asUInt()).fold(0.U)(_ + _))
138  XSPerfAccumulate(s"hit", io.r.resp.map(a => a.valid && a.bits.hit).fold(0.U)(_.asUInt() + _.asUInt()))
139
140  for (i <- 0 until nWays) {
141    XSPerfAccumulate(s"access${i}", io.r.resp.zip(io.access.map(acc => UIntToOH(acc.touch_ways.bits))).map{ case (a, b) =>
142      a.valid && a.bits.hit && b(i)}.fold(0.U)(_.asUInt() + _.asUInt()))
143  }
144  for (i <- 0 until nWays) {
145    XSPerfAccumulate(s"refill${i}", io.w.valid && io.w.bits.wayIdx === i.U)
146  }
147
148  val perfEvents = Seq(
149    ("tlbstore_access", io.r.resp.map(_.valid.asUInt()).fold(0.U)(_ + _)                            ),
150    ("tlbstore_hit   ", io.r.resp.map(a => a.valid && a.bits.hit).fold(0.U)(_.asUInt() + _.asUInt())),
151  )
152  generatePerfEvent()
153
154  println(s"${parentName} tlb_fa: nSets${nSets} nWays:${nWays}")
155}
156
157@chiselName
158class TLBSA(
159  parentName: String,
160  ports: Int,
161  nSets: Int,
162  nWays: Int,
163  normalPage: Boolean,
164  superPage: Boolean
165)(implicit p: Parameters) extends TlbModule {
166  require(!superPage, "super page should use reg/fa")
167  require(nWays == 1, "nWays larger than 1 causes bad timing")
168
169  // timing optimization to divide v select into two cycles.
170  val VPRE_SELECT = min(8, nSets)
171  val VPOST_SELECT = nSets / VPRE_SELECT
172
173  val io = IO(new TlbStorageIO(nSets, nWays, ports))
174
175  io.r.req.map(_.ready :=  true.B)
176  val v = RegInit(VecInit(Seq.fill(nSets)(VecInit(Seq.fill(nWays)(false.B)))))
177  val entries = Module(new SyncDataModuleTemplate(new TlbEntry(normalPage, superPage), nSets, ports, 1))
178
179  for (i <- 0 until ports) { // duplicate sram
180    val req = io.r.req(i)
181    val resp = io.r.resp(i)
182    val access = io.access(i)
183
184    val vpn = req.bits.vpn
185    val vpn_reg = RegEnable(vpn, req.fire())
186
187    val ridx = get_set_idx(vpn, nSets)
188    val v_resize = v.asTypeOf(Vec(VPRE_SELECT, Vec(VPOST_SELECT, UInt(nWays.W))))
189    val vidx_resize = RegNext(v_resize(get_set_idx(drop_set_idx(vpn, VPOST_SELECT), VPRE_SELECT)))
190    val vidx = vidx_resize(get_set_idx(vpn_reg, VPOST_SELECT)).asBools.map(_ && RegNext(req.fire()))
191    val vidx_bypass = RegNext((entries.io.waddr(0) === ridx) && entries.io.wen(0))
192    entries.io.raddr(i) := ridx
193
194    val data = entries.io.rdata(i)
195    val hit = data.hit(vpn_reg, io.csr.satp.asid, nSets) && (vidx(0) || vidx_bypass)
196    resp.bits.hit := hit
197    resp.bits.ppn := data.genPPN()(vpn_reg)
198    resp.bits.perm := data.perm
199
200    resp.valid := { RegNext(req.valid) }
201    resp.bits.hit.suggestName("hit")
202    resp.bits.ppn.suggestName("ppn")
203    resp.bits.perm.suggestName("perm")
204
205    access.sets := get_set_idx(vpn_reg, nSets) // no use
206    access.touch_ways.valid := resp.valid && hit
207    access.touch_ways.bits := 1.U // TODO: set-assoc need no replacer when nset is 1
208  }
209
210  // W ports should be 1, or, check at above will be wrong.
211  entries.io.wen(0) := io.w.valid || io.victim.in.valid
212  entries.io.waddr(0) := Mux(io.w.valid,
213    get_set_idx(io.w.bits.data.entry.tag, nSets),
214    get_set_idx(io.victim.in.bits.entry.tag, nSets))
215  entries.io.wdata(0) := Mux(io.w.valid,
216    (Wire(new TlbEntry(normalPage, superPage)).apply(io.w.bits.data, io.csr.satp.asid, io.w.bits.data_replenish)),
217    io.victim.in.bits.entry)
218
219  when (io.victim.in.valid) {
220    v(get_set_idx(io.victim.in.bits.entry.tag, nSets))(io.w.bits.wayIdx) := true.B
221  }
222  // w has higher priority than victim
223  when (io.w.valid) {
224    v(get_set_idx(io.w.bits.data.entry.tag, nSets))(io.w.bits.wayIdx) := true.B
225  }
226
227  val refill_vpn_reg = RegNext(Mux(io.victim.in.valid, io.victim.in.bits.entry.tag, io.w.bits.data.entry.tag))
228  val refill_wayIdx_reg = RegNext(io.w.bits.wayIdx)
229  when (RegNext(io.w.valid || io.victim.in.valid)) {
230    io.access.map { access =>
231      access.sets := get_set_idx(refill_vpn_reg, nSets)
232      access.touch_ways.valid := true.B
233      access.touch_ways.bits := refill_wayIdx_reg
234    }
235  }
236
237  val sfence = io.sfence
238  val sfence_vpn = sfence.bits.addr.asTypeOf(new VaBundle().cloneType).vpn
239  when (io.sfence.valid) {
240    when (sfence.bits.rs1) { // virtual address *.rs1 <- (rs1===0.U)
241        v.map(a => a.map(b => b := false.B))
242    }.otherwise {
243        // specific addr but all asid
244        v(get_set_idx(sfence_vpn, nSets)).map(_ := false.B)
245    }
246  }
247
248  io.victim.out := DontCare
249  io.victim.out.valid := false.B
250
251  XSPerfAccumulate(s"access", io.r.req.map(_.valid.asUInt()).fold(0.U)(_ + _))
252  XSPerfAccumulate(s"hit", io.r.resp.map(a => a.valid && a.bits.hit).fold(0.U)(_.asUInt() + _.asUInt()))
253
254  for (i <- 0 until nSets) {
255    XSPerfAccumulate(s"refill${i}", (io.w.valid || io.victim.in.valid) &&
256        (Mux(io.w.valid, get_set_idx(io.w.bits.data.entry.tag, nSets), get_set_idx(io.victim.in.bits.entry.tag, nSets)) === i.U)
257      )
258  }
259
260  for (i <- 0 until nSets) {
261    XSPerfAccumulate(s"hit${i}", io.r.resp.map(a => a.valid & a.bits.hit)
262      .zip(io.r.req.map(a => RegNext(get_set_idx(a.bits.vpn, nSets)) === i.U))
263      .map{a => (a._1 && a._2).asUInt()}
264      .fold(0.U)(_ + _)
265    )
266  }
267
268  for (i <- 0 until nSets) {
269    XSPerfAccumulate(s"access${i}", io.r.resp.map(_.valid)
270      .zip(io.r.req.map(a => RegNext(get_set_idx(a.bits.vpn, nSets)) === i.U))
271      .map{a => (a._1 && a._2).asUInt()}
272      .fold(0.U)(_ + _)
273    )
274  }
275
276  println(s"${parentName} tlb_sa: nSets:${nSets} nWays:${nWays}")
277}
278
279object TlbStorage {
280  def apply
281  (
282    parentName: String,
283    associative: String,
284    ports: Int,
285    nSets: Int,
286    nWays: Int,
287    saveLevel: Boolean = false,
288    normalPage: Boolean,
289    superPage: Boolean
290  )(implicit p: Parameters) = {
291    if (associative == "fa") {
292       val storage = Module(new TLBFA(parentName, ports, nSets, nWays, saveLevel, normalPage, superPage))
293       storage.suggestName(s"${parentName}_fa")
294       storage.io
295    } else {
296       val storage = Module(new TLBSA(parentName, ports, nSets, nWays, normalPage, superPage))
297       storage.suggestName(s"${parentName}_sa")
298       storage.io
299    }
300  }
301}
302
303class TlbStorageWrapper(ports: Int, q: TLBParameters)(implicit p: Parameters) extends TlbModule {
304  val io = IO(new TlbStorageWrapperIO(ports, q))
305
306// TODO: wrap Normal page and super page together, wrap the declare & refill dirty codes
307  val normalPage = TlbStorage(
308    parentName = q.name + "_np_storage",
309    associative = q.normalAssociative,
310    ports = ports,
311    nSets = q.normalNSets,
312    nWays = q.normalNWays,
313    saveLevel = q.saveLevel,
314    normalPage = true,
315    superPage = false
316  )
317  val superPage = TlbStorage(
318    parentName = q.name + "_sp_storage",
319    associative = q.superAssociative,
320    ports = ports,
321    nSets = q.superNSets,
322    nWays = q.superNWays,
323    normalPage = q.normalAsVictim,
324    superPage = true,
325  )
326
327  for (i <- 0 until ports) {
328    normalPage.r_req_apply(
329      valid = io.r.req(i).valid,
330      vpn = io.r.req(i).bits.vpn,
331      i = i
332    )
333    superPage.r_req_apply(
334      valid = io.r.req(i).valid,
335      vpn = io.r.req(i).bits.vpn,
336      i = i
337    )
338  }
339
340  for (i <- 0 until ports) {
341    val nq = normalPage.r.req(i)
342    val np = normalPage.r.resp(i)
343    val sq = superPage.r.req(i)
344    val sp = superPage.r.resp(i)
345    val rq = io.r.req(i)
346    val rp = io.r.resp(i)
347    rq.ready := nq.ready && sq.ready // actually, not used
348    rp.valid := np.valid && sp.valid // actually, not used
349    rp.bits.hit := np.bits.hit || sp.bits.hit
350    rp.bits.ppn := Mux(sp.bits.hit, sp.bits.ppn, np.bits.ppn)
351    rp.bits.perm := Mux(sp.bits.hit, sp.bits.perm, np.bits.perm)
352    rp.bits.super_hit := sp.bits.hit
353    rp.bits.super_ppn := sp.bits.ppn
354    rp.bits.spm := np.bits.perm.pm
355    assert(!np.bits.hit || !sp.bits.hit || !rp.valid, s"${q.name} storage ports${i} normal and super multi-hit")
356  }
357
358  normalPage.victim.in <> superPage.victim.out
359  normalPage.victim.out <> superPage.victim.in
360  normalPage.sfence <> io.sfence
361  superPage.sfence <> io.sfence
362  normalPage.csr <> io.csr
363  superPage.csr <> io.csr
364
365  val normal_refill_idx = if (q.outReplace) {
366    io.replace.normalPage.access <> normalPage.access
367    io.replace.normalPage.chosen_set := get_set_idx(io.w.bits.data.entry.tag, q.normalNSets)
368    io.replace.normalPage.refillIdx
369  } else if (q.normalAssociative == "fa") {
370    val re = ReplacementPolicy.fromString(q.normalReplacer, q.normalNWays)
371    re.access(normalPage.access.map(_.touch_ways)) // normalhitVecVec.zipWithIndex.map{ case (hv, i) => get_access(hv, validRegVec(i))})
372    re.way
373  } else { // set-acco && plru
374    val re = ReplacementPolicy.fromString(q.normalReplacer, q.normalNSets, q.normalNWays)
375    re.access(normalPage.access.map(_.sets), normalPage.access.map(_.touch_ways))
376    re.way(get_set_idx(io.w.bits.data.entry.tag, q.normalNSets))
377  }
378
379  val super_refill_idx = if (q.outReplace) {
380    io.replace.superPage.access <> superPage.access
381    io.replace.superPage.chosen_set := DontCare
382    io.replace.superPage.refillIdx
383  } else {
384    val re = ReplacementPolicy.fromString(q.superReplacer, q.superNWays)
385    re.access(superPage.access.map(_.touch_ways))
386    re.way
387  }
388
389  normalPage.w_apply(
390    valid = { if (q.normalAsVictim) false.B
391    else io.w.valid && io.w.bits.data.entry.level.get === 2.U },
392    wayIdx = normal_refill_idx,
393    data = io.w.bits.data,
394    data_replenish = io.w.bits.data_replenish
395  )
396  superPage.w_apply(
397    valid = { if (q.normalAsVictim) io.w.valid
398    else io.w.valid && io.w.bits.data.entry.level.get =/= 2.U },
399    wayIdx = super_refill_idx,
400    data = io.w.bits.data,
401    data_replenish = io.w.bits.data_replenish
402  )
403
404    // replacement
405  def get_access(one_hot: UInt, valid: Bool): Valid[UInt] = {
406    val res = Wire(Valid(UInt(log2Up(one_hot.getWidth).W)))
407    res.valid := Cat(one_hot).orR && valid
408    res.bits := OHToUInt(one_hot)
409    res
410  }
411}
412