xref: /XiangShan/src/main/scala/xiangshan/cache/mmu/TLBStorage.scala (revision b56f947ea6e9fe50fd06047a225356a808f2a3b1)
1/***************************************************************************************
2  * Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3  * Copyright (c) 2020-2021 Peng Cheng Laboratory
4  *
5  * XiangShan is licensed under Mulan PSL v2.
6  * You can use this software according to the terms and conditions of the Mulan PSL v2.
7  * You may obtain a copy of Mulan PSL v2 at:
8  *          http://license.coscl.org.cn/MulanPSL2
9  *
10  * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11  * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12  * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13  *
14  * See the Mulan PSL v2 for more details.
15  ***************************************************************************************/
16
17package xiangshan.cache.mmu
18
19import chipsalliance.rocketchip.config.Parameters
20import chisel3._
21import chisel3.experimental.chiselName
22import chisel3.util._
23import utils._
24import freechips.rocketchip.formal.PropertyClass
25
26import scala.math.min
27
28@chiselName
29class TLBFA(
30  ports: Int,
31  nSets: Int,
32  nWays: Int,
33  saveLevel: Boolean = false,
34  normalPage: Boolean,
35  superPage: Boolean
36)(implicit p: Parameters) extends TlbModule with HasPerfEvents {
37
38  val io = IO(new TlbStorageIO(nSets, nWays, ports))
39  io.r.req.map(_.ready := true.B)
40
41  val v = RegInit(VecInit(Seq.fill(nWays)(false.B)))
42  val entries = Reg(Vec(nWays, new TlbEntry(normalPage, superPage)))
43  val g = entries.map(_.perm.g)
44
45  for (i <- 0 until ports) {
46    val req = io.r.req(i)
47    val resp = io.r.resp(i)
48    val access = io.access(i)
49
50    val vpn = req.bits.vpn
51    val vpn_reg = RegEnable(vpn, req.fire())
52    val vpn_gen_ppn = if(saveLevel) vpn else vpn_reg
53
54    val refill_mask = Mux(io.w.valid, UIntToOH(io.w.bits.wayIdx), 0.U(nWays.W))
55    val hitVec = VecInit((entries.zipWithIndex).zip(v zip refill_mask.asBools).map{case (e, m) => e._1.hit(vpn, io.csr.satp.asid) && m._1 && !m._2 })
56
57    hitVec.suggestName("hitVec")
58
59    val hitVecReg = RegEnable(hitVec, req.fire())
60
61    resp.valid := RegNext(req.valid)
62    resp.bits.hit := Cat(hitVecReg).orR
63    if (nWays == 1) {
64      resp.bits.ppn := entries(0).genPPN(saveLevel, req.valid)(vpn_gen_ppn)
65      resp.bits.perm := entries(0).perm
66    } else {
67      resp.bits.ppn := ParallelMux(hitVecReg zip entries.map(_.genPPN(saveLevel, req.valid)(vpn_gen_ppn)))
68      resp.bits.perm := ParallelMux(hitVecReg zip entries.map(_.perm))
69    }
70
71    access.sets := get_set_idx(vpn_reg, nSets) // no use
72    access.touch_ways.valid := resp.valid && Cat(hitVecReg).orR
73    access.touch_ways.bits := OHToUInt(hitVecReg)
74
75    resp.bits.hit.suggestName("hit")
76    resp.bits.ppn.suggestName("ppn")
77    resp.bits.perm.suggestName("perm")
78  }
79
80  when (io.w.valid) {
81    v(io.w.bits.wayIdx) := true.B
82    entries(io.w.bits.wayIdx).apply(io.w.bits.data, io.csr.satp.asid, io.w.bits.data_replenish)
83  }
84
85  val refill_vpn_reg = RegNext(io.w.bits.data.entry.tag)
86  val refill_wayIdx_reg = RegNext(io.w.bits.wayIdx)
87  when (RegNext(io.w.valid)) {
88    io.access.map { access =>
89      access.sets := get_set_idx(refill_vpn_reg, nSets)
90      access.touch_ways.valid := true.B
91      access.touch_ways.bits := refill_wayIdx_reg
92    }
93  }
94
95  val sfence = io.sfence
96  val sfence_vpn = sfence.bits.addr.asTypeOf(new VaBundle().cloneType).vpn
97  val sfenceHit = entries.map(_.hit(sfence_vpn, sfence.bits.asid))
98  val sfenceHit_noasid = entries.map(_.hit(sfence_vpn, sfence.bits.asid, ignoreAsid = true))
99  when (io.sfence.valid) {
100    when (sfence.bits.rs1) { // virtual address *.rs1 <- (rs1===0.U)
101      when (sfence.bits.rs2) { // asid, but i do not want to support asid, *.rs2 <- (rs2===0.U)
102        // all addr and all asid
103        v.map(_ := false.B)
104      }.otherwise {
105        // all addr but specific asid
106        v.zipWithIndex.map{ case (a,i) => a := a & (g(i) | !(entries(i).asid === sfence.bits.asid)) }
107      }
108    }.otherwise {
109      when (sfence.bits.rs2) {
110        // specific addr but all asid
111        v.zipWithIndex.map{ case (a,i) => a := a & !sfenceHit_noasid(i) }
112      }.otherwise {
113        // specific addr and specific asid
114        v.zipWithIndex.map{ case (a,i) => a := a & !(sfenceHit(i) && !g(i)) }
115      }
116    }
117  }
118
119  val victim_idx = io.w.bits.wayIdx
120  io.victim.out.valid := v(victim_idx) && io.w.valid && entries(victim_idx).is_normalentry()
121  io.victim.out.bits.entry := ns_to_n(entries(victim_idx))
122
123  def ns_to_n(ns: TlbEntry): TlbEntry = {
124    val n = Wire(new TlbEntry(pageNormal = true, pageSuper = false))
125    n.perm := ns.perm
126    n.ppn := ns.ppn
127    n.tag := ns.tag
128    n.asid := ns.asid
129    n
130  }
131
132  XSPerfAccumulate(s"access", io.r.resp.map(_.valid.asUInt()).fold(0.U)(_ + _))
133  XSPerfAccumulate(s"hit", io.r.resp.map(a => a.valid && a.bits.hit).fold(0.U)(_.asUInt() + _.asUInt()))
134
135  for (i <- 0 until nWays) {
136    XSPerfAccumulate(s"access${i}", io.r.resp.zip(io.access.map(acc => UIntToOH(acc.touch_ways.bits))).map{ case (a, b) =>
137      a.valid && a.bits.hit && b(i)}.fold(0.U)(_.asUInt() + _.asUInt()))
138  }
139  for (i <- 0 until nWays) {
140    XSPerfAccumulate(s"refill${i}", io.w.valid && io.w.bits.wayIdx === i.U)
141  }
142
143  val perfEvents = Seq(
144    ("tlbstore_access", io.r.resp.map(_.valid.asUInt()).fold(0.U)(_ + _)                            ),
145    ("tlbstore_hit   ", io.r.resp.map(a => a.valid && a.bits.hit).fold(0.U)(_.asUInt() + _.asUInt())),
146  )
147  generatePerfEvent()
148
149  println(s"tlb_fa: nSets${nSets} nWays:${nWays}")
150}
151
152@chiselName
153class TLBSA(
154  ports: Int,
155  nSets: Int,
156  nWays: Int,
157  normalPage: Boolean,
158  superPage: Boolean
159)(implicit p: Parameters) extends TlbModule {
160  require(!superPage, "super page should use reg/fa")
161  require(nWays == 1, "nWays larger than 1 causes bad timing")
162
163  // timing optimization to divide v select into two cycles.
164  val VPRE_SELECT = min(8, nSets)
165  val VPOST_SELECT = nSets / VPRE_SELECT
166
167  val io = IO(new TlbStorageIO(nSets, nWays, ports))
168
169  io.r.req.map(_.ready :=  true.B)
170  val v = RegInit(VecInit(Seq.fill(nSets)(VecInit(Seq.fill(nWays)(false.B)))))
171  val entries = Module(new SyncDataModuleTemplate(new TlbEntry(normalPage, superPage), nSets, ports, 1))
172
173  for (i <- 0 until ports) { // duplicate sram
174    val req = io.r.req(i)
175    val resp = io.r.resp(i)
176    val access = io.access(i)
177
178    val vpn = req.bits.vpn
179    val vpn_reg = RegEnable(vpn, req.fire())
180
181    val ridx = get_set_idx(vpn, nSets)
182    val v_resize = v.asTypeOf(Vec(VPRE_SELECT, Vec(VPOST_SELECT, UInt(nWays.W))))
183    val vidx_resize = RegNext(v_resize(get_set_idx(drop_set_idx(vpn, VPOST_SELECT), VPRE_SELECT)))
184    val vidx = vidx_resize(get_set_idx(vpn_reg, VPOST_SELECT)).asBools.map(_ && RegNext(req.fire()))
185    entries.io.raddr(i) := ridx
186
187    val data = entries.io.rdata(i)
188    val hit = data.hit(vpn_reg, io.csr.satp.asid, nSets) && vidx(0)
189    resp.bits.hit := hit
190    resp.bits.ppn := data.genPPN()(vpn_reg)
191    resp.bits.perm := data.perm
192
193    resp.valid := { RegNext(req.valid) }
194    resp.bits.hit.suggestName("hit")
195    resp.bits.ppn.suggestName("ppn")
196    resp.bits.perm.suggestName("perm")
197
198    access.sets := get_set_idx(vpn_reg, nSets) // no use
199    access.touch_ways.valid := resp.valid && hit
200    access.touch_ways.bits := 1.U // TODO: set-assoc need no replacer when nset is 1
201  }
202
203  entries.io.wen(0) := io.w.valid || io.victim.in.valid
204  entries.io.waddr(0) := Mux(io.w.valid,
205    get_set_idx(io.w.bits.data.entry.tag, nSets),
206    get_set_idx(io.victim.in.bits.entry.tag, nSets))
207  entries.io.wdata(0) := Mux(io.w.valid,
208    (Wire(new TlbEntry(normalPage, superPage)).apply(io.w.bits.data, io.csr.satp.asid, io.w.bits.data_replenish)),
209    io.victim.in.bits.entry)
210
211  when (io.victim.in.valid) {
212    v(get_set_idx(io.victim.in.bits.entry.tag, nSets))(io.w.bits.wayIdx) := true.B
213  }
214  // w has higher priority than victim
215  when (io.w.valid) {
216    v(get_set_idx(io.w.bits.data.entry.tag, nSets))(io.w.bits.wayIdx) := true.B
217  }
218
219  val refill_vpn_reg = RegNext(Mux(io.victim.in.valid, io.victim.in.bits.entry.tag, io.w.bits.data.entry.tag))
220  val refill_wayIdx_reg = RegNext(io.w.bits.wayIdx)
221  when (RegNext(io.w.valid || io.victim.in.valid)) {
222    io.access.map { access =>
223      access.sets := get_set_idx(refill_vpn_reg, nSets)
224      access.touch_ways.valid := true.B
225      access.touch_ways.bits := refill_wayIdx_reg
226    }
227  }
228
229  val sfence = io.sfence
230  val sfence_vpn = sfence.bits.addr.asTypeOf(new VaBundle().cloneType).vpn
231  when (io.sfence.valid) {
232    when (sfence.bits.rs1) { // virtual address *.rs1 <- (rs1===0.U)
233        v.map(a => a.map(b => b := false.B))
234    }.otherwise {
235        // specific addr but all asid
236        v(get_set_idx(sfence_vpn, nSets)).map(_ := false.B)
237    }
238  }
239
240  io.victim.out := DontCare
241  io.victim.out.valid := false.B
242
243  XSPerfAccumulate(s"access", io.r.req.map(_.valid.asUInt()).fold(0.U)(_ + _))
244  XSPerfAccumulate(s"hit", io.r.resp.map(a => a.valid && a.bits.hit).fold(0.U)(_.asUInt() + _.asUInt()))
245
246  for (i <- 0 until nSets) {
247    for (j <- 0 until nWays) {
248      XSPerfAccumulate(s"refill${i}_${j}", (io.w.valid || io.victim.in.valid) &&
249        (Mux(io.w.valid, get_set_idx(io.w.bits.data.entry.tag, nSets), get_set_idx(io.victim.in.bits.entry.tag, nSets)) === i.U) &&
250        (j.U === io.w.bits.wayIdx)
251      )
252    }
253  }
254
255  for (i <- 0 until nSets) {
256    for (j <- 0 until nWays) {
257      XSPerfAccumulate(s"hit${i}_${j}", io.r.resp.map(_.valid)
258        .zip(io.access.map(a => UIntToOH(a.touch_ways.bits)(j)))
259        .map{case(vi, hi) => vi && hi }
260        .zip(io.r.req.map(a => RegNext(get_set_idx(a.bits.vpn, nSets)) === i.U))
261        .map{a => (a._1 && a._2).asUInt()}
262        .fold(0.U)(_ + _)
263      )
264    }
265  }
266
267  for (i <- 0 until nSets) {
268    XSPerfAccumulate(s"access${i}", io.r.resp.map(_.valid)
269      .zip(io.r.req.map(a => RegNext(get_set_idx(a.bits.vpn, nSets)) === i.U))
270      .map{a => (a._1 && a._2).asUInt()}
271      .fold(0.U)(_ + _)
272    )
273  }
274
275  println(s"tlb_sa: nSets:${nSets} nWays:${nWays}")
276}
277
278object TlbStorage {
279  def apply
280  (
281    name: String,
282    associative: String,
283    ports: Int,
284    nSets: Int,
285    nWays: Int,
286    saveLevel: Boolean = false,
287    normalPage: Boolean,
288    superPage: Boolean
289  )(implicit p: Parameters) = {
290    if (associative == "fa") {
291       val storage = Module(new TLBFA(ports, nSets, nWays, saveLevel, normalPage, superPage))
292       storage.suggestName(s"tlb_${name}_fa")
293       storage.io
294    } else {
295       val storage = Module(new TLBSA(ports, nSets, nWays, normalPage, superPage))
296       storage.suggestName(s"tlb_${name}_sa")
297       storage.io
298    }
299  }
300}
301
302class TlbStorageWrapper(ports: Int, q: TLBParameters)(implicit p: Parameters) extends TlbModule {
303  val io = IO(new TlbStorageWrapperIO(ports, q))
304
305// TODO: wrap Normal page and super page together, wrap the declare & refill dirty codes
306  val normalPage = TlbStorage(
307    name = "normal",
308    associative = q.normalAssociative,
309    ports = ports,
310    nSets = q.normalNSets,
311    nWays = q.normalNWays,
312    saveLevel = q.saveLevel,
313    normalPage = true,
314    superPage = false
315  )
316  val superPage = TlbStorage(
317    name = "super",
318    associative = q.superAssociative,
319    ports = ports,
320    nSets = q.superNSets,
321    nWays = q.superNWays,
322    normalPage = q.normalAsVictim,
323    superPage = true,
324  )
325
326  for (i <- 0 until ports) {
327    normalPage.r_req_apply(
328      valid = io.r.req(i).valid,
329      vpn = io.r.req(i).bits.vpn,
330      i = i
331    )
332    superPage.r_req_apply(
333      valid = io.r.req(i).valid,
334      vpn = io.r.req(i).bits.vpn,
335      i = i
336    )
337  }
338
339  for (i <- 0 until ports) {
340    val nq = normalPage.r.req(i)
341    val np = normalPage.r.resp(i)
342    val sq = superPage.r.req(i)
343    val sp = superPage.r.resp(i)
344    val rq = io.r.req(i)
345    val rp = io.r.resp(i)
346    rq.ready := nq.ready && sq.ready // actually, not used
347    rp.valid := np.valid && sp.valid // actually, not used
348    rp.bits.hit := np.bits.hit || sp.bits.hit
349    rp.bits.ppn := Mux(sp.bits.hit, sp.bits.ppn, np.bits.ppn)
350    rp.bits.perm := Mux(sp.bits.hit, sp.bits.perm, np.bits.perm)
351    rp.bits.super_hit := sp.bits.hit
352    rp.bits.super_ppn := sp.bits.ppn
353    rp.bits.spm := np.bits.perm.pm
354  }
355
356  normalPage.victim.in <> superPage.victim.out
357  normalPage.victim.out <> superPage.victim.in
358  normalPage.sfence <> io.sfence
359  superPage.sfence <> io.sfence
360  normalPage.csr <> io.csr
361  superPage.csr <> io.csr
362
363  val normal_refill_idx = if (q.outReplace) {
364    io.replace.normalPage.access <> normalPage.access
365    io.replace.normalPage.chosen_set := get_set_idx(io.w.bits.data.entry.tag, q.normalNSets)
366    io.replace.normalPage.refillIdx
367  } else if (q.normalAssociative == "fa") {
368    val re = ReplacementPolicy.fromString(q.normalReplacer, q.normalNWays)
369    re.access(normalPage.access.map(_.touch_ways)) // normalhitVecVec.zipWithIndex.map{ case (hv, i) => get_access(hv, validRegVec(i))})
370    re.way
371  } else { // set-acco && plru
372    val re = ReplacementPolicy.fromString(q.normalReplacer, q.normalNSets, q.normalNWays)
373    re.access(normalPage.access.map(_.sets), normalPage.access.map(_.touch_ways))
374    re.way(get_set_idx(io.w.bits.data.entry.tag, q.normalNSets))
375  }
376
377  val super_refill_idx = if (q.outReplace) {
378    io.replace.superPage.access <> superPage.access
379    io.replace.superPage.chosen_set := DontCare
380    io.replace.superPage.refillIdx
381  } else {
382    val re = ReplacementPolicy.fromString(q.superReplacer, q.superNWays)
383    re.access(superPage.access.map(_.touch_ways))
384    re.way
385  }
386
387  normalPage.w_apply(
388    valid = { if (q.normalAsVictim) false.B
389    else io.w.valid && io.w.bits.data.entry.level.get === 2.U },
390    wayIdx = normal_refill_idx,
391    data = io.w.bits.data,
392    data_replenish = io.w.bits.data_replenish
393  )
394  superPage.w_apply(
395    valid = { if (q.normalAsVictim) io.w.valid
396    else io.w.valid && io.w.bits.data.entry.level.get =/= 2.U },
397    wayIdx = super_refill_idx,
398    data = io.w.bits.data,
399    data_replenish = io.w.bits.data_replenish
400  )
401
402    // replacement
403  def get_access(one_hot: UInt, valid: Bool): Valid[UInt] = {
404    val res = Wire(Valid(UInt(log2Up(one_hot.getWidth).W)))
405    res.valid := Cat(one_hot).orR && valid
406    res.bits := OHToUInt(one_hot)
407    res
408  }
409}
410