xref: /XiangShan/src/main/scala/xiangshan/mem/prefetch/SMSPrefetcher.scala (revision dcd58560d0a04f86ddbf73a98bd16c41d4a8e205)
1package xiangshan.mem.prefetch
2
3import chipsalliance.rocketchip.config.Parameters
4import chisel3._
5import chisel3.util._
6import xiangshan._
7import utils._
8import xiangshan.cache.HasDCacheParameters
9import xiangshan.cache.mmu._
10
11case class SMSParams
12(
13  region_size: Int = 1024,
14  vaddr_hash_width: Int = 5,
15  block_addr_raw_width: Int = 10,
16  stride_pc_bits: Int = 10,
17  max_stride: Int = 1024,
18  stride_entries: Int = 16,
19  active_gen_table_size: Int = 16,
20  pht_size: Int = 64,
21  pht_ways: Int = 2,
22  pht_hist_bits: Int = 2,
23  pht_tag_bits: Int = 13,
24  pht_lookup_queue_size: Int = 4,
25  pf_filter_size: Int = 16
26) extends PrefetcherParams
27
28trait HasSMSModuleHelper extends HasCircularQueuePtrHelper with HasDCacheParameters
29{ this: HasXSParameter =>
30  val smsParams = coreParams.prefetcher.get.asInstanceOf[SMSParams]
31  val BLK_ADDR_WIDTH = VAddrBits - log2Up(dcacheParameters.blockBytes)
32  val REGION_SIZE = smsParams.region_size
33  val REGION_BLKS = smsParams.region_size / dcacheParameters.blockBytes
34  val REGION_ADDR_BITS = VAddrBits - log2Up(REGION_SIZE)
35  val REGION_OFFSET = log2Up(REGION_BLKS)
36  val VADDR_HASH_WIDTH = smsParams.vaddr_hash_width
37  val BLK_ADDR_RAW_WIDTH = smsParams.block_addr_raw_width
38  val REGION_ADDR_RAW_WIDTH = BLK_ADDR_RAW_WIDTH - REGION_OFFSET
39  val BLK_TAG_WIDTH = BLK_ADDR_RAW_WIDTH + VADDR_HASH_WIDTH
40  val REGION_TAG_WIDTH = REGION_ADDR_RAW_WIDTH + VADDR_HASH_WIDTH
41  val PHT_INDEX_BITS = log2Up(smsParams.pht_size / smsParams.pht_ways)
42  val PHT_TAG_BITS = smsParams.pht_tag_bits
43  val PHT_HIST_BITS = smsParams.pht_hist_bits
44  // page bit index in block addr
45  val BLOCK_ADDR_PAGE_BIT = log2Up(dcacheParameters.pageSize / dcacheParameters.blockBytes)
46  val REGION_ADDR_PAGE_BIT = log2Up(dcacheParameters.pageSize / smsParams.region_size)
47  val STRIDE_PC_BITS = smsParams.stride_pc_bits
48  val STRIDE_BLK_ADDR_BITS = log2Up(smsParams.max_stride)
49
50  def block_addr(x: UInt): UInt = {
51    val offset = log2Up(dcacheParameters.blockBytes)
52    x(x.getWidth - 1, offset)
53  }
54
55  def region_addr(x: UInt): UInt = {
56    val offset = log2Up(REGION_SIZE)
57    x(x.getWidth - 1, offset)
58  }
59
60  def region_offset_to_bits(off: UInt): UInt = {
61    (1.U << off).asUInt
62  }
63
64  def region_hash_tag(rg_addr: UInt): UInt = {
65    val low = rg_addr(REGION_ADDR_RAW_WIDTH - 1, 0)
66    val high = rg_addr(REGION_ADDR_RAW_WIDTH + 3 * VADDR_HASH_WIDTH - 1, REGION_ADDR_RAW_WIDTH)
67    val high_hash = vaddr_hash(high)
68    Cat(high_hash, low)
69  }
70
71  def page_bit(region_addr: UInt): UInt = {
72    region_addr(log2Up(dcacheParameters.pageSize/REGION_SIZE))
73  }
74
75  def block_hash_tag(x: UInt): UInt = {
76    val blk_addr = block_addr(x)
77    val low = blk_addr(BLK_ADDR_RAW_WIDTH - 1, 0)
78    val high = blk_addr(BLK_ADDR_RAW_WIDTH - 1 + 3 * VADDR_HASH_WIDTH, BLK_ADDR_RAW_WIDTH)
79    val high_hash = vaddr_hash(high)
80    Cat(high_hash, low)
81  }
82
83  def vaddr_hash(x: UInt): UInt = {
84    val width = VADDR_HASH_WIDTH
85    val low = x(width - 1, 0)
86    val mid = x(2 * width - 1, width)
87    val high = x(3 * width - 1, 2 * width)
88    low ^ mid ^ high
89  }
90
91  def pht_index(pc: UInt): UInt = {
92    val low_bits = pc(PHT_INDEX_BITS, 2)
93    val hi_bit = pc(1) ^ pc(PHT_INDEX_BITS+1)
94    Cat(hi_bit, low_bits)
95  }
96
97  def pht_tag(pc: UInt): UInt = {
98    pc(PHT_INDEX_BITS + 2 + PHT_TAG_BITS - 1, PHT_INDEX_BITS + 2)
99  }
100}
101
102class StridePF()(implicit p: Parameters) extends XSModule with HasSMSModuleHelper {
103  val io = IO(new Bundle() {
104    val stride_en = Input(Bool())
105    val s0_lookup = Flipped(new ValidIO(new Bundle() {
106      val pc = UInt(STRIDE_PC_BITS.W)
107      val vaddr = UInt(VAddrBits.W)
108      val paddr = UInt(PAddrBits.W)
109    }))
110    val s1_valid = Input(Bool())
111    val s2_gen_req = ValidIO(new PfGenReq())
112  })
113
114  val prev_valid = RegNext(io.s0_lookup.valid, false.B)
115  val prev_pc = RegEnable(io.s0_lookup.bits.pc, io.s0_lookup.valid)
116
117  val s0_valid = io.s0_lookup.valid && !(prev_valid && prev_pc === io.s0_lookup.bits.pc)
118
119  def entry_map[T](fn: Int => T) = (0 until smsParams.stride_entries).map(fn)
120
121  val replacement = ReplacementPolicy.fromString("plru", smsParams.stride_entries)
122  val valids = entry_map(_ => RegInit(false.B))
123  val entries_pc = entry_map(_ => Reg(UInt(STRIDE_PC_BITS.W)) )
124  val entries_conf = entry_map(_ => RegInit(1.U(2.W)))
125  val entries_last_addr = entry_map(_ => Reg(UInt(STRIDE_BLK_ADDR_BITS.W)) )
126  val entries_stride = entry_map(_ => Reg(SInt((STRIDE_BLK_ADDR_BITS+1).W)))
127
128
129  val s0_match_vec = valids.zip(entries_pc).map({
130    case (v, pc) => v && pc === io.s0_lookup.bits.pc
131  })
132
133  val s0_hit = s0_valid && Cat(s0_match_vec).orR
134  val s0_miss = s0_valid && !s0_hit
135  val s0_matched_conf = Mux1H(s0_match_vec, entries_conf)
136  val s0_matched_last_addr = Mux1H(s0_match_vec, entries_last_addr)
137  val s0_matched_last_stride = Mux1H(s0_match_vec, entries_stride)
138
139
140  val s1_vaddr = RegEnable(io.s0_lookup.bits.vaddr, s0_valid)
141  val s1_paddr = RegEnable(io.s0_lookup.bits.paddr, s0_valid)
142  val s1_hit = RegNext(s0_hit) && io.s1_valid
143  val s1_alloc = RegNext(s0_miss) && io.s1_valid
144  val s1_conf = RegNext(s0_matched_conf)
145  val s1_last_addr = RegNext(s0_matched_last_addr)
146  val s1_last_stride = RegNext(s0_matched_last_stride)
147  val s1_match_vec = RegNext(VecInit(s0_match_vec))
148
149  val BLOCK_OFFSET = log2Up(dcacheParameters.blockBytes)
150  val s1_new_stride_vaddr = s1_vaddr(BLOCK_OFFSET + STRIDE_BLK_ADDR_BITS - 1, BLOCK_OFFSET)
151  val s1_new_stride = (0.U(1.W) ## s1_new_stride_vaddr).asSInt - (0.U(1.W) ## s1_last_addr).asSInt
152  val s1_stride_non_zero = s1_last_stride =/= 0.S
153  val s1_stride_match = s1_new_stride === s1_last_stride && s1_stride_non_zero
154  val s1_replace_idx = replacement.way
155
156  for(i <- 0 until smsParams.stride_entries){
157    val alloc = s1_alloc && i.U === s1_replace_idx
158    val update = s1_hit && s1_match_vec(i)
159    when(update){
160      assert(valids(i))
161      entries_conf(i) := Mux(s1_stride_match,
162        Mux(s1_conf === 3.U, 3.U, s1_conf + 1.U),
163        Mux(s1_conf === 0.U, 0.U, s1_conf - 1.U)
164      )
165      entries_last_addr(i) := s1_new_stride_vaddr
166      when(!s1_conf(1)){
167        entries_stride(i) := s1_new_stride
168      }
169    }
170    when(alloc){
171      valids(i) := true.B
172      entries_pc(i) := prev_pc
173      entries_conf(i) := 0.U
174      entries_last_addr(i) := s1_new_stride_vaddr
175      entries_stride(i) := 0.S
176    }
177    assert(!(update && alloc))
178  }
179  when(s1_hit){
180    replacement.access(OHToUInt(s1_match_vec.asUInt))
181  }.elsewhen(s1_alloc){
182    replacement.access(s1_replace_idx)
183  }
184
185  val s1_block_vaddr = block_addr(s1_vaddr)
186  val s1_pf_block_vaddr = (s1_block_vaddr.asSInt + s1_last_stride).asUInt
187  val s1_pf_cross_page = s1_pf_block_vaddr(BLOCK_ADDR_PAGE_BIT) =/= s1_block_vaddr(BLOCK_ADDR_PAGE_BIT)
188
189  val s2_pf_gen_valid = RegNext(s1_hit && s1_stride_match, false.B)
190  val s2_pf_gen_paddr_valid = RegEnable(!s1_pf_cross_page, s1_hit && s1_stride_match)
191  val s2_pf_block_vaddr = RegEnable(s1_pf_block_vaddr, s1_hit && s1_stride_match)
192  val s2_block_paddr = RegEnable(block_addr(s1_paddr), s1_hit && s1_stride_match)
193
194  val s2_pf_block_addr = Mux(s2_pf_gen_paddr_valid,
195    Cat(
196      s2_block_paddr(PAddrBits - BLOCK_OFFSET - 1, BLOCK_ADDR_PAGE_BIT),
197      s2_pf_block_vaddr(BLOCK_ADDR_PAGE_BIT - 1, 0)
198    ),
199    s2_pf_block_vaddr
200  )
201  val s2_pf_full_addr = Wire(UInt(VAddrBits.W))
202  s2_pf_full_addr := s2_pf_block_addr ## 0.U(BLOCK_OFFSET.W)
203
204  val s2_pf_region_addr = region_addr(s2_pf_full_addr)
205  val s2_pf_region_offset = s2_pf_block_addr(REGION_OFFSET - 1, 0)
206
207  val s2_full_vaddr = Wire(UInt(VAddrBits.W))
208  s2_full_vaddr := s2_pf_block_vaddr ## 0.U(BLOCK_OFFSET.W)
209
210  val s2_region_tag = region_hash_tag(region_addr(s2_full_vaddr))
211
212  io.s2_gen_req.valid := s2_pf_gen_valid && io.stride_en
213  io.s2_gen_req.bits.region_tag := s2_region_tag
214  io.s2_gen_req.bits.region_addr := s2_pf_region_addr
215  io.s2_gen_req.bits.region_bits := region_offset_to_bits(s2_pf_region_offset)
216  io.s2_gen_req.bits.paddr_valid := s2_pf_gen_paddr_valid
217  io.s2_gen_req.bits.decr_mode := false.B
218
219}
220
221class AGTEntry()(implicit p: Parameters) extends XSBundle with HasSMSModuleHelper {
222  val pht_index = UInt(PHT_INDEX_BITS.W)
223  val pht_tag = UInt(PHT_TAG_BITS.W)
224  val region_bits = UInt(REGION_BLKS.W)
225  val region_tag = UInt(REGION_TAG_WIDTH.W)
226  val region_offset = UInt(REGION_OFFSET.W)
227  val access_cnt = UInt((REGION_BLKS-1).U.getWidth.W)
228  val decr_mode = Bool()
229}
230
231class PfGenReq()(implicit p: Parameters) extends XSBundle with HasSMSModuleHelper {
232  val region_tag = UInt(REGION_TAG_WIDTH.W)
233  val region_addr = UInt(REGION_ADDR_BITS.W)
234  val region_bits = UInt(REGION_BLKS.W)
235  val paddr_valid = Bool()
236  val decr_mode = Bool()
237}
238
239class ActiveGenerationTable()(implicit p: Parameters) extends XSModule with HasSMSModuleHelper {
240  val io = IO(new Bundle() {
241    val agt_en = Input(Bool())
242    val s0_lookup = Flipped(ValidIO(new Bundle() {
243      val region_tag = UInt(REGION_TAG_WIDTH.W)
244      val region_p1_tag = UInt(REGION_TAG_WIDTH.W)
245      val region_m1_tag = UInt(REGION_TAG_WIDTH.W)
246      val region_offset = UInt(REGION_OFFSET.W)
247      val pht_index = UInt(PHT_INDEX_BITS.W)
248      val pht_tag = UInt(PHT_TAG_BITS.W)
249      val allow_cross_region_p1 = Bool()
250      val allow_cross_region_m1 = Bool()
251      val region_p1_cross_page = Bool()
252      val region_m1_cross_page = Bool()
253      val region_paddr = UInt(REGION_ADDR_BITS.W)
254      val region_vaddr = UInt(REGION_ADDR_BITS.W)
255    }))
256    val s1_sel_stride = Output(Bool())
257    val s2_stride_hit = Input(Bool())
258    // if agt/stride missed, try lookup pht
259    val s2_pht_lookup = ValidIO(new PhtLookup())
260    // evict entry to pht
261    val s2_evict = ValidIO(new AGTEntry())
262    val s2_pf_gen_req = ValidIO(new PfGenReq())
263    val act_threshold = Input(UInt(REGION_OFFSET.W))
264    val act_stride = Input(UInt(6.W))
265  })
266
267  val entries = Seq.fill(smsParams.active_gen_table_size){ Reg(new AGTEntry()) }
268  val valids = Seq.fill(smsParams.active_gen_table_size){ RegInit(false.B) }
269  val replacement = ReplacementPolicy.fromString("plru", smsParams.active_gen_table_size)
270
271  val s1_replace_mask_w = Wire(UInt(smsParams.active_gen_table_size.W))
272
273  val s0_lookup = io.s0_lookup.bits
274  val s0_lookup_valid = io.s0_lookup.valid
275
276  val prev_lookup = RegEnable(s0_lookup, s0_lookup_valid)
277  val prev_lookup_valid = RegNext(s0_lookup_valid, false.B)
278
279  val s0_match_prev = prev_lookup_valid && s0_lookup.region_tag === prev_lookup.region_tag
280
281  def gen_match_vec(region_tag: UInt): Seq[Bool] = {
282    entries.zip(valids).map({
283      case (ent, v) => v && ent.region_tag === region_tag
284    })
285  }
286
287  val region_match_vec_s0 = gen_match_vec(s0_lookup.region_tag)
288  val region_p1_match_vec_s0 = gen_match_vec(s0_lookup.region_p1_tag)
289  val region_m1_match_vec_s0 = gen_match_vec(s0_lookup.region_m1_tag)
290
291  val any_region_match = Cat(region_match_vec_s0).orR
292  val any_region_p1_match = Cat(region_p1_match_vec_s0).orR && s0_lookup.allow_cross_region_p1
293  val any_region_m1_match = Cat(region_m1_match_vec_s0).orR && s0_lookup.allow_cross_region_m1
294
295  val s0_region_hit = any_region_match
296  val s0_cross_region_hit = any_region_m1_match || any_region_p1_match
297  val s0_alloc = s0_lookup_valid && !s0_region_hit && !s0_match_prev
298  val s0_pf_gen_match_vec = valids.indices.map(i => {
299    Mux(any_region_match,
300      region_match_vec_s0(i),
301      Mux(any_region_m1_match,
302        region_m1_match_vec_s0(i), region_p1_match_vec_s0(i)
303      )
304    )
305  })
306  val s0_agt_entry = Wire(new AGTEntry())
307
308  s0_agt_entry.pht_index := s0_lookup.pht_index
309  s0_agt_entry.pht_tag := s0_lookup.pht_tag
310  s0_agt_entry.region_bits := region_offset_to_bits(s0_lookup.region_offset)
311  s0_agt_entry.region_tag := s0_lookup.region_tag
312  s0_agt_entry.region_offset := s0_lookup.region_offset
313  s0_agt_entry.access_cnt := 1.U
314  // lookup_region + 1 == entry_region
315  // lookup_region = entry_region - 1 => decr mode
316  s0_agt_entry.decr_mode := !s0_region_hit && !any_region_m1_match && any_region_p1_match
317  val s0_replace_way = replacement.way
318  val s0_replace_mask = UIntToOH(s0_replace_way)
319  // s0 hit a entry that may be replaced in s1
320  val s0_update_conflict = Cat(VecInit(region_match_vec_s0).asUInt & s1_replace_mask_w).orR
321  val s0_update = s0_lookup_valid && s0_region_hit && !s0_update_conflict
322
323  val s0_access_way = Mux1H(
324    Seq(s0_update, s0_alloc),
325    Seq(OHToUInt(region_match_vec_s0), s0_replace_way)
326  )
327  when(s0_update || s0_alloc) {
328    replacement.access(s0_access_way)
329  }
330
331  // stage1: update/alloc
332  // region hit, update entry
333  val s1_update = RegNext(s0_update, false.B)
334  val s1_update_mask = RegEnable(VecInit(region_match_vec_s0), s0_lookup_valid)
335  val s1_agt_entry = RegEnable(s0_agt_entry, s0_lookup_valid)
336  val s1_cross_region_match = RegNext(s0_lookup_valid && s0_cross_region_hit, false.B)
337  val s1_alloc = RegNext(s0_alloc, false.B)
338  val s1_alloc_entry = s1_agt_entry
339  val s1_replace_mask = RegEnable(s0_replace_mask, s0_lookup_valid)
340  s1_replace_mask_w := s1_replace_mask & Fill(smsParams.active_gen_table_size, s1_alloc)
341  val s1_evict_entry = Mux1H(s1_replace_mask, entries)
342  val s1_evict_valid = Mux1H(s1_replace_mask, valids)
343  // pf gen
344  val s1_pf_gen_match_vec = RegEnable(VecInit(s0_pf_gen_match_vec), s0_lookup_valid)
345  val s1_region_paddr = RegEnable(s0_lookup.region_paddr, s0_lookup_valid)
346  val s1_region_vaddr = RegEnable(s0_lookup.region_vaddr, s0_lookup_valid)
347  val s1_region_offset = RegEnable(s0_lookup.region_offset, s0_lookup_valid)
348  for(i <- entries.indices){
349    val alloc = s1_replace_mask(i) && s1_alloc
350    val update = s1_update_mask(i) && s1_update
351    val update_entry = WireInit(entries(i))
352    update_entry.region_bits := entries(i).region_bits | s1_agt_entry.region_bits
353    update_entry.access_cnt := Mux(entries(i).access_cnt === (REGION_BLKS - 1).U,
354      entries(i).access_cnt,
355      entries(i).access_cnt + (s1_agt_entry.region_bits & (~entries(i).region_bits).asUInt).orR
356    )
357    valids(i) := valids(i) || alloc
358    entries(i) := Mux(alloc, s1_alloc_entry, Mux(update, update_entry, entries(i)))
359  }
360
361  when(s1_update){
362    assert(PopCount(s1_update_mask) === 1.U, "multi-agt-update")
363  }
364  when(s1_alloc){
365    assert(PopCount(s1_replace_mask) === 1.U, "multi-agt-alloc")
366  }
367
368  // pf_addr
369  // 1.hit => pf_addr = lookup_addr + (decr ? -1 : 1)
370  // 2.lookup region - 1 hit => lookup_addr + 1 (incr mode)
371  // 3.lookup region + 1 hit => lookup_addr - 1 (decr mode)
372  val s1_hited_entry_decr = Mux1H(s1_update_mask, entries.map(_.decr_mode))
373  val s1_pf_gen_decr_mode = Mux(s1_update,
374    s1_hited_entry_decr,
375    s1_agt_entry.decr_mode
376  )
377
378  val s1_pf_gen_vaddr_inc = Cat(0.U, s1_region_vaddr(REGION_TAG_WIDTH - 1, 0), s1_region_offset) + io.act_stride
379  val s1_pf_gen_vaddr_dec = Cat(0.U, s1_region_vaddr(REGION_TAG_WIDTH - 1, 0), s1_region_offset) - io.act_stride
380  val s1_vaddr_inc_cross_page = s1_pf_gen_vaddr_inc(BLOCK_ADDR_PAGE_BIT) =/= s1_region_vaddr(REGION_ADDR_PAGE_BIT)
381  val s1_vaddr_dec_cross_page = s1_pf_gen_vaddr_dec(BLOCK_ADDR_PAGE_BIT) =/= s1_region_vaddr(REGION_ADDR_PAGE_BIT)
382  val s1_vaddr_inc_cross_max_lim = s1_pf_gen_vaddr_inc.head(1).asBool
383  val s1_vaddr_dec_cross_max_lim = s1_pf_gen_vaddr_dec.head(1).asBool
384
385  //val s1_pf_gen_vaddr_p1 = s1_region_vaddr(REGION_TAG_WIDTH - 1, 0) + 1.U
386  //val s1_pf_gen_vaddr_m1 = s1_region_vaddr(REGION_TAG_WIDTH - 1, 0) - 1.U
387  val s1_pf_gen_vaddr = Cat(
388    s1_region_vaddr(REGION_ADDR_BITS - 1, REGION_TAG_WIDTH),
389    Mux(s1_pf_gen_decr_mode,
390      s1_pf_gen_vaddr_dec.tail(1).head(REGION_TAG_WIDTH),
391      s1_pf_gen_vaddr_inc.tail(1).head(REGION_TAG_WIDTH)
392    )
393  )
394  val s1_pf_gen_offset = Mux(s1_pf_gen_decr_mode,
395    s1_pf_gen_vaddr_dec(REGION_OFFSET - 1, 0),
396    s1_pf_gen_vaddr_inc(REGION_OFFSET - 1, 0)
397  )
398  val s1_pf_gen_offset_mask = UIntToOH(s1_pf_gen_offset)
399  val s1_pf_gen_access_cnt = Mux1H(s1_pf_gen_match_vec, entries.map(_.access_cnt))
400  val s1_in_active_page = s1_pf_gen_access_cnt > io.act_threshold
401  val s1_pf_gen_valid = prev_lookup_valid && (s1_alloc && s1_cross_region_match || s1_update) && Mux(s1_pf_gen_decr_mode,
402    !s1_vaddr_dec_cross_max_lim,
403    !s1_vaddr_inc_cross_max_lim
404  ) && s1_in_active_page && io.agt_en
405  val s1_pf_gen_paddr_valid = Mux(s1_pf_gen_decr_mode, !s1_vaddr_dec_cross_page, !s1_vaddr_inc_cross_page)
406  val s1_pf_gen_region_addr = Mux(s1_pf_gen_paddr_valid,
407    Cat(s1_region_paddr(REGION_ADDR_BITS - 1, REGION_ADDR_PAGE_BIT), s1_pf_gen_vaddr(REGION_ADDR_PAGE_BIT - 1, 0)),
408    s1_pf_gen_vaddr
409  )
410  val s1_pf_gen_region_tag = region_hash_tag(s1_pf_gen_vaddr)
411  val s1_pf_gen_incr_region_bits = VecInit((0 until REGION_BLKS).map(i => {
412    if(i == 0) true.B else !s1_pf_gen_offset_mask(i - 1, 0).orR
413  })).asUInt
414  val s1_pf_gen_decr_region_bits = VecInit((0 until REGION_BLKS).map(i => {
415    if(i == REGION_BLKS - 1) true.B
416    else !s1_pf_gen_offset_mask(REGION_BLKS - 1, i + 1).orR
417  })).asUInt
418  val s1_pf_gen_region_bits = Mux(s1_pf_gen_decr_mode,
419    s1_pf_gen_decr_region_bits,
420    s1_pf_gen_incr_region_bits
421  )
422  val s1_pht_lookup_valid = Wire(Bool())
423  val s1_pht_lookup = Wire(new PhtLookup())
424
425  s1_pht_lookup_valid := !s1_pf_gen_valid && prev_lookup_valid
426  s1_pht_lookup.pht_index := s1_agt_entry.pht_index
427  s1_pht_lookup.pht_tag := s1_agt_entry.pht_tag
428  s1_pht_lookup.region_vaddr := s1_region_vaddr
429  s1_pht_lookup.region_paddr := s1_region_paddr
430  s1_pht_lookup.region_offset := s1_region_offset
431
432  io.s1_sel_stride := prev_lookup_valid && (s1_alloc && s1_cross_region_match || s1_update) && !s1_in_active_page
433
434  // stage2: gen pf reg / evict entry to pht
435  val s2_evict_entry = RegEnable(s1_evict_entry, s1_alloc)
436  val s2_evict_valid = RegNext(s1_alloc && s1_evict_valid, false.B)
437  val s2_paddr_valid = RegEnable(s1_pf_gen_paddr_valid, s1_pf_gen_valid)
438  val s2_pf_gen_region_tag = RegEnable(s1_pf_gen_region_tag, s1_pf_gen_valid)
439  val s2_pf_gen_decr_mode = RegEnable(s1_pf_gen_decr_mode, s1_pf_gen_valid)
440  val s2_pf_gen_region_paddr = RegEnable(s1_pf_gen_region_addr, s1_pf_gen_valid)
441  val s2_pf_gen_region_bits = RegEnable(s1_pf_gen_region_bits, s1_pf_gen_valid)
442  val s2_pf_gen_valid = RegNext(s1_pf_gen_valid, false.B)
443  val s2_pht_lookup_valid = RegNext(s1_pht_lookup_valid, false.B) && !io.s2_stride_hit
444  val s2_pht_lookup = RegEnable(s1_pht_lookup, s1_pht_lookup_valid)
445
446  io.s2_evict.valid := s2_evict_valid
447  io.s2_evict.bits := s2_evict_entry
448
449  io.s2_pf_gen_req.bits.region_tag := s2_pf_gen_region_tag
450  io.s2_pf_gen_req.bits.region_addr := s2_pf_gen_region_paddr
451  io.s2_pf_gen_req.bits.region_bits := s2_pf_gen_region_bits
452  io.s2_pf_gen_req.bits.paddr_valid := s2_paddr_valid
453  io.s2_pf_gen_req.bits.decr_mode := s2_pf_gen_decr_mode
454  io.s2_pf_gen_req.valid := s2_pf_gen_valid
455
456  io.s2_pht_lookup.valid := s2_pht_lookup_valid
457  io.s2_pht_lookup.bits := s2_pht_lookup
458
459  XSPerfAccumulate("sms_agt_in", io.s0_lookup.valid)
460  XSPerfAccumulate("sms_agt_alloc", s1_alloc) // cross region match or filter evict
461  XSPerfAccumulate("sms_agt_update", s1_update) // entry hit
462  XSPerfAccumulate("sms_agt_pf_gen", io.s2_pf_gen_req.valid)
463  XSPerfAccumulate("sms_agt_pf_gen_paddr_valid",
464    io.s2_pf_gen_req.valid && io.s2_pf_gen_req.bits.paddr_valid
465  )
466  XSPerfAccumulate("sms_agt_pf_gen_decr_mode",
467    io.s2_pf_gen_req.valid && io.s2_pf_gen_req.bits.decr_mode
468  )
469  for(i <- 0 until smsParams.active_gen_table_size){
470    XSPerfAccumulate(s"sms_agt_access_entry_$i",
471      s1_alloc && s1_replace_mask(i) || s1_update && s1_update_mask(i)
472    )
473  }
474
475}
476
477class PhtLookup()(implicit p: Parameters) extends XSBundle with HasSMSModuleHelper {
478  val pht_index = UInt(PHT_INDEX_BITS.W)
479  val pht_tag = UInt(PHT_TAG_BITS.W)
480  val region_paddr = UInt(REGION_ADDR_BITS.W)
481  val region_vaddr = UInt(REGION_ADDR_BITS.W)
482  val region_offset = UInt(REGION_OFFSET.W)
483}
484
485class PhtEntry()(implicit p: Parameters) extends XSBundle with HasSMSModuleHelper {
486  val hist = Vec(2 * (REGION_BLKS - 1), UInt(PHT_HIST_BITS.W))
487  val tag = UInt(PHT_TAG_BITS.W)
488  val decr_mode = Bool()
489}
490
491class PatternHistoryTable()(implicit p: Parameters) extends XSModule with HasSMSModuleHelper {
492  val io = IO(new Bundle() {
493    // receive agt evicted entry
494    val agt_update = Flipped(ValidIO(new AGTEntry()))
495    // at stage2, if we know agt missed, lookup pht
496    val s2_agt_lookup = Flipped(ValidIO(new PhtLookup()))
497    // pht-generated prefetch req
498    val pf_gen_req = ValidIO(new PfGenReq())
499  })
500
501  val pht_ram = Module(new SRAMTemplate[PhtEntry](new PhtEntry,
502    set = smsParams.pht_size / smsParams.pht_ways,
503    way =smsParams.pht_ways,
504    singlePort = true
505  ))
506  def PHT_SETS = smsParams.pht_size / smsParams.pht_ways
507  val pht_valids = Seq.fill(smsParams.pht_ways){
508    RegInit(VecInit(Seq.fill(PHT_SETS){false.B}))
509  }
510  val replacement = Seq.fill(PHT_SETS) { ReplacementPolicy.fromString("plru", smsParams.pht_ways) }
511
512  val lookup_queue = Module(new OverrideableQueue(new PhtLookup, smsParams.pht_lookup_queue_size))
513  lookup_queue.io.in := io.s2_agt_lookup
514  val lookup = lookup_queue.io.out
515
516  val evict_queue = Module(new OverrideableQueue(new AGTEntry, smsParams.pht_lookup_queue_size))
517  evict_queue.io.in := io.agt_update
518  val evict = evict_queue.io.out
519
520  XSPerfAccumulate("sms_pht_lookup_in", lookup_queue.io.in.fire)
521  XSPerfAccumulate("sms_pht_lookup_out", lookup_queue.io.out.fire)
522  XSPerfAccumulate("sms_pht_evict_in", evict_queue.io.in.fire)
523  XSPerfAccumulate("sms_pht_evict_out", evict_queue.io.out.fire)
524
525  val s3_ram_en = Wire(Bool())
526  val s1_valid = Wire(Bool())
527  // if s1.raddr == s2.waddr or s3 is using ram port, block s1
528  val s1_wait = Wire(Bool())
529  // pipe s0: select an op from [lookup, update], generate ram read addr
530  val s0_valid = lookup.valid || evict.valid
531
532  evict.ready := !s1_valid || !s1_wait
533  lookup.ready := evict.ready && !evict.valid
534
535  val s0_ram_raddr = Mux(evict.valid,
536    evict.bits.pht_index,
537    lookup.bits.pht_index
538  )
539  val s0_tag = Mux(evict.valid, evict.bits.pht_tag, lookup.bits.pht_tag)
540  val s0_region_offset = Mux(evict.valid, evict.bits.region_offset, lookup.bits.region_offset)
541  val s0_region_paddr = lookup.bits.region_paddr
542  val s0_region_vaddr = lookup.bits.region_vaddr
543  val s0_region_bits = evict.bits.region_bits
544  val s0_decr_mode = evict.bits.decr_mode
545  val s0_evict = evict.valid
546
547  // pipe s1: send addr to ram
548  val s1_valid_r = RegInit(false.B)
549  s1_valid_r := Mux(s1_valid && s1_wait, true.B, s0_valid)
550  s1_valid := s1_valid_r
551  val s1_reg_en = s0_valid && (!s1_wait || !s1_valid)
552  val s1_ram_raddr = RegEnable(s0_ram_raddr, s1_reg_en)
553  val s1_tag = RegEnable(s0_tag, s1_reg_en)
554  val s1_region_bits = RegEnable(s0_region_bits, s1_reg_en)
555  val s1_decr_mode = RegEnable(s0_decr_mode, s1_reg_en)
556  val s1_region_paddr = RegEnable(s0_region_paddr, s1_reg_en)
557  val s1_region_vaddr = RegEnable(s0_region_vaddr, s1_reg_en)
558  val s1_region_offset = RegEnable(s0_region_offset, s1_reg_en)
559  val s1_pht_valids = pht_valids.map(way => Mux1H(
560    (0 until PHT_SETS).map(i => i.U === s1_ram_raddr),
561    way
562  ))
563  val s1_evict = RegEnable(s0_evict, s1_reg_en)
564  val s1_replace_way = Mux1H(
565    (0 until PHT_SETS).map(i => i.U === s1_ram_raddr),
566    replacement.map(_.way)
567  )
568  val s1_hist_update_mask = Cat(
569    Fill(REGION_BLKS - 1, true.B), 0.U((REGION_BLKS - 1).W)
570  ) >> s1_region_offset
571  val s1_hist_bits = Cat(
572    s1_region_bits.head(REGION_BLKS - 1) >> s1_region_offset,
573    (Cat(
574      s1_region_bits.tail(1), 0.U((REGION_BLKS - 1).W)
575    ) >> s1_region_offset)(REGION_BLKS - 2, 0)
576  )
577
578  // pipe s2: generate ram write addr/data
579  val s2_valid = RegNext(s1_valid && !s1_wait, false.B)
580  val s2_reg_en = s1_valid && !s1_wait
581  val s2_hist_update_mask = RegEnable(s1_hist_update_mask, s2_reg_en)
582  val s2_hist_bits = RegEnable(s1_hist_bits, s2_reg_en)
583  val s2_tag = RegEnable(s1_tag, s2_reg_en)
584  val s2_region_bits = RegEnable(s1_region_bits, s2_reg_en)
585  val s2_decr_mode = RegEnable(s1_decr_mode, s2_reg_en)
586  val s2_region_paddr = RegEnable(s1_region_paddr, s2_reg_en)
587  val s2_region_vaddr = RegEnable(s1_region_vaddr, s2_reg_en)
588  val s2_region_offset = RegEnable(s1_region_offset, s2_reg_en)
589  val s2_region_offset_mask = region_offset_to_bits(s2_region_offset)
590  val s2_evict = RegEnable(s1_evict, s2_reg_en)
591  val s2_pht_valids = s1_pht_valids.map(v => RegEnable(v, s2_reg_en))
592  val s2_replace_way = RegEnable(s1_replace_way, s2_reg_en)
593  val s2_ram_waddr = RegEnable(s1_ram_raddr, s2_reg_en)
594  val s2_ram_rdata = pht_ram.io.r.resp.data
595  val s2_ram_rtags = s2_ram_rdata.map(_.tag)
596  val s2_tag_match_vec = s2_ram_rtags.map(t => t === s2_tag)
597  val s2_hit_vec = s2_tag_match_vec.zip(s2_pht_valids).map({
598    case (tag_match, v) => v && tag_match
599  })
600  val s2_hist_update = s2_ram_rdata.map(way => VecInit(way.hist.zipWithIndex.map({
601    case (h, i) =>
602      val do_update = s2_hist_update_mask(i)
603      val hist_updated = Mux(s2_hist_bits(i),
604        Mux(h.andR, h, h + 1.U),
605        Mux(h === 0.U, 0.U, h - 1.U)
606      )
607      Mux(do_update, hist_updated, h)
608  })))
609  val s2_hist_pf_gen = Mux1H(s2_hit_vec, s2_ram_rdata.map(way => VecInit(way.hist.map(_.head(1))).asUInt))
610  val s2_new_hist = VecInit(s2_hist_bits.asBools.map(b => Cat(0.U((PHT_HIST_BITS - 1).W), b)))
611  val s2_pht_hit = Cat(s2_hit_vec).orR
612  val s2_hist = Mux(s2_pht_hit, Mux1H(s2_hit_vec, s2_hist_update), s2_new_hist)
613  val s2_repl_way_mask = UIntToOH(s2_replace_way)
614
615  // pipe s3: send addr/data to ram, gen pf_req
616  val s3_valid = RegNext(s2_valid, false.B)
617  val s3_evict = RegEnable(s2_evict, s2_valid)
618  val s3_hist = RegEnable(s2_hist, s2_valid)
619  val s3_hist_pf_gen = RegEnable(s2_hist_pf_gen, s2_valid)
620  val s3_hist_update_mask = RegEnable(s2_hist_update_mask.asUInt, s2_valid)
621  val s3_region_offset = RegEnable(s2_region_offset, s2_valid)
622  val s3_region_offset_mask = RegEnable(s2_region_offset_mask, s2_valid)
623  val s3_decr_mode = RegEnable(s2_decr_mode, s2_valid)
624  val s3_region_paddr = RegEnable(s2_region_paddr, s2_valid)
625  val s3_region_vaddr = RegEnable(s2_region_vaddr, s2_valid)
626  val s3_pht_tag = RegEnable(s2_tag, s2_valid)
627  val s3_hit_vec = s2_hit_vec.map(h => RegEnable(h, s2_valid))
628  val s3_hit = Cat(s3_hit_vec).orR
629  val s3_hit_way = OHToUInt(s3_hit_vec)
630  val s3_repl_way = RegEnable(s2_replace_way, s2_valid)
631  val s3_repl_way_mask = RegEnable(s2_repl_way_mask, s2_valid)
632  val s3_repl_update_mask = RegEnable(VecInit((0 until PHT_SETS).map(i => i.U === s2_ram_waddr)), s2_valid)
633  val s3_ram_waddr = RegEnable(s2_ram_waddr, s2_valid)
634  s3_ram_en := s3_valid && s3_evict
635  val s3_ram_wdata = Wire(new PhtEntry())
636  s3_ram_wdata.hist := s3_hist
637  s3_ram_wdata.tag := s3_pht_tag
638  s3_ram_wdata.decr_mode := s3_decr_mode
639
640  s1_wait := (s2_valid && s2_evict && s2_ram_waddr === s1_ram_raddr) || s3_ram_en
641
642  for((valids, way_idx) <- pht_valids.zipWithIndex){
643    val update_way = s3_repl_way_mask(way_idx)
644    for((v, set_idx) <- valids.zipWithIndex){
645      val update_set = s3_repl_update_mask(set_idx)
646      when(s3_valid && s3_evict && !s3_hit && update_set && update_way){
647        v := true.B
648      }
649    }
650  }
651  for((r, i) <- replacement.zipWithIndex){
652    when(s3_valid && s3_repl_update_mask(i)){
653      when(s3_hit){
654        r.access(s3_hit_way)
655      }.elsewhen(s3_evict){
656        r.access(s3_repl_way)
657      }
658    }
659  }
660
661  val s3_way_mask = Mux(s3_hit,
662    VecInit(s3_hit_vec).asUInt,
663    s3_repl_way_mask,
664  ).asUInt
665
666  pht_ram.io.r(
667    s1_valid, s1_ram_raddr
668  )
669  pht_ram.io.w(
670    s3_ram_en, s3_ram_wdata, s3_ram_waddr, s3_way_mask
671  )
672
673  when(s3_valid && s3_hit){
674    assert(!Cat(s3_hit_vec).andR, "sms_pht: multi-hit!")
675  }
676
677  // generate pf req if hit
678  val s3_hist_hi = s3_hist_pf_gen.head(REGION_BLKS - 1)
679  val s3_hist_lo = s3_hist_pf_gen.tail(REGION_BLKS - 1)
680  val s3_hist_hi_shifted = (Cat(0.U((REGION_BLKS - 1).W), s3_hist_hi) << s3_region_offset)(2 * (REGION_BLKS - 1) - 1, 0)
681  val s3_hist_lo_shifted = (Cat(0.U((REGION_BLKS - 1).W), s3_hist_lo) << s3_region_offset)(2 * (REGION_BLKS - 1) - 1, 0)
682  val s3_cur_region_bits = Cat(s3_hist_hi_shifted.tail(REGION_BLKS - 1), 0.U(1.W)) |
683    Cat(0.U(1.W), s3_hist_lo_shifted.head(REGION_BLKS - 1))
684  val s3_incr_region_bits = Cat(0.U(1.W), s3_hist_hi_shifted.head(REGION_BLKS - 1))
685  val s3_decr_region_bits = Cat(s3_hist_lo_shifted.tail(REGION_BLKS - 1), 0.U(1.W))
686  val s3_pf_gen_valid = s3_valid && s3_hit && !s3_evict
687  val s3_cur_region_valid =  s3_pf_gen_valid && (s3_hist_pf_gen & s3_hist_update_mask).orR
688  val s3_incr_region_valid = s3_pf_gen_valid && (s3_hist_hi & (~s3_hist_update_mask.head(REGION_BLKS - 1)).asUInt).orR
689  val s3_decr_region_valid = s3_pf_gen_valid && (s3_hist_lo & (~s3_hist_update_mask.tail(REGION_BLKS - 1)).asUInt).orR
690  val s3_incr_region_vaddr = s3_region_vaddr + 1.U
691  val s3_decr_region_vaddr = s3_region_vaddr - 1.U
692  val s3_incr_region_paddr = Cat(
693    s3_region_paddr(REGION_ADDR_BITS - 1, REGION_ADDR_PAGE_BIT),
694    s3_incr_region_vaddr(REGION_ADDR_PAGE_BIT - 1, 0)
695  )
696  val s3_decr_region_paddr = Cat(
697    s3_region_paddr(REGION_ADDR_BITS - 1, REGION_ADDR_PAGE_BIT),
698    s3_decr_region_vaddr(REGION_ADDR_PAGE_BIT - 1, 0)
699  )
700  val s3_incr_crosspage = s3_incr_region_vaddr(REGION_ADDR_PAGE_BIT) =/= s3_region_vaddr(REGION_ADDR_PAGE_BIT)
701  val s3_decr_crosspage = s3_decr_region_vaddr(REGION_ADDR_PAGE_BIT) =/= s3_region_vaddr(REGION_ADDR_PAGE_BIT)
702  val s3_cur_region_tag = region_hash_tag(s3_region_vaddr)
703  val s3_incr_region_tag = region_hash_tag(s3_incr_region_vaddr)
704  val s3_decr_region_tag = region_hash_tag(s3_decr_region_vaddr)
705
706  val pf_gen_req_arb = Module(new Arbiter(new PfGenReq, 3))
707  val s4_pf_gen_cur_region_valid = RegInit(false.B)
708  val s4_pf_gen_cur_region = Reg(new PfGenReq)
709  val s4_pf_gen_incr_region_valid = RegInit(false.B)
710  val s4_pf_gen_incr_region = Reg(new PfGenReq)
711  val s4_pf_gen_decr_region_valid = RegInit(false.B)
712  val s4_pf_gen_decr_region = Reg(new PfGenReq)
713
714  s4_pf_gen_cur_region_valid := s3_cur_region_valid
715  when(s3_cur_region_valid){
716    s4_pf_gen_cur_region.region_addr := s3_region_paddr
717    s4_pf_gen_cur_region.region_tag := s3_cur_region_tag
718    s4_pf_gen_cur_region.region_bits := s3_cur_region_bits
719    s4_pf_gen_cur_region.paddr_valid := true.B
720    s4_pf_gen_cur_region.decr_mode := false.B
721  }
722  s4_pf_gen_incr_region_valid := s3_incr_region_valid ||
723    (!pf_gen_req_arb.io.in(1).ready && s4_pf_gen_incr_region_valid)
724  when(s3_incr_region_valid){
725    s4_pf_gen_incr_region.region_addr := Mux(s3_incr_crosspage, s3_incr_region_vaddr, s3_incr_region_paddr)
726    s4_pf_gen_incr_region.region_tag := s3_incr_region_tag
727    s4_pf_gen_incr_region.region_bits := s3_incr_region_bits
728    s4_pf_gen_incr_region.paddr_valid := !s3_incr_crosspage
729    s4_pf_gen_incr_region.decr_mode := false.B
730  }
731  s4_pf_gen_decr_region_valid := s3_decr_region_valid ||
732    (!pf_gen_req_arb.io.in(2).ready && s4_pf_gen_decr_region_valid)
733  when(s3_decr_region_valid){
734    s4_pf_gen_decr_region.region_addr := Mux(s3_decr_crosspage, s3_decr_region_vaddr, s3_decr_region_paddr)
735    s4_pf_gen_decr_region.region_tag := s3_decr_region_tag
736    s4_pf_gen_decr_region.region_bits := s3_decr_region_bits
737    s4_pf_gen_decr_region.paddr_valid := !s3_decr_crosspage
738    s4_pf_gen_decr_region.decr_mode := true.B
739  }
740
741  pf_gen_req_arb.io.in.head.valid := s4_pf_gen_cur_region_valid
742  pf_gen_req_arb.io.in.head.bits := s4_pf_gen_cur_region
743  pf_gen_req_arb.io.in(1).valid := s4_pf_gen_incr_region_valid
744  pf_gen_req_arb.io.in(1).bits := s4_pf_gen_incr_region
745  pf_gen_req_arb.io.in(2).valid := s4_pf_gen_decr_region_valid
746  pf_gen_req_arb.io.in(2).bits := s4_pf_gen_decr_region
747  pf_gen_req_arb.io.out.ready := true.B
748
749  io.pf_gen_req.valid := pf_gen_req_arb.io.out.valid
750  io.pf_gen_req.bits := pf_gen_req_arb.io.out.bits
751
752  XSPerfAccumulate("sms_pht_update", io.agt_update.valid)
753  XSPerfAccumulate("sms_pht_update_hit", s2_valid && s2_evict && s2_pht_hit)
754  XSPerfAccumulate("sms_pht_lookup", io.s2_agt_lookup.valid)
755  XSPerfAccumulate("sms_pht_lookup_hit", s2_valid && !s2_evict && s2_pht_hit)
756  for(i <- 0 until smsParams.pht_ways){
757    XSPerfAccumulate(s"sms_pht_write_way_$i", pht_ram.io.w.req.fire && pht_ram.io.w.req.bits.waymask.get(i))
758  }
759  for(i <- 0 until PHT_SETS){
760    XSPerfAccumulate(s"sms_pht_write_set_$i", pht_ram.io.w.req.fire && pht_ram.io.w.req.bits.setIdx === i.U)
761  }
762  XSPerfAccumulate(s"sms_pht_pf_gen", io.pf_gen_req.valid)
763}
764
765class PrefetchFilterEntry()(implicit p: Parameters) extends XSBundle with HasSMSModuleHelper {
766  val region_tag = UInt(REGION_TAG_WIDTH.W)
767  val region_addr = UInt(REGION_ADDR_BITS.W)
768  val region_bits = UInt(REGION_BLKS.W)
769  val filter_bits = UInt(REGION_BLKS.W)
770  val paddr_valid = Bool()
771  val decr_mode = Bool()
772}
773
774class PrefetchFilter()(implicit p: Parameters) extends XSModule with HasSMSModuleHelper {
775  val io = IO(new Bundle() {
776    val gen_req = Flipped(ValidIO(new PfGenReq()))
777    val tlb_req = new TlbRequestIO(2)
778    val l2_pf_addr = ValidIO(UInt(PAddrBits.W))
779  })
780  val entries = Seq.fill(smsParams.pf_filter_size){ Reg(new PrefetchFilterEntry()) }
781  val valids = Seq.fill(smsParams.pf_filter_size){ RegInit(false.B) }
782  val replacement = ReplacementPolicy.fromString("plru", smsParams.pf_filter_size)
783
784  val prev_valid = RegNext(io.gen_req.valid, false.B)
785  val prev_gen_req = RegEnable(io.gen_req.bits, io.gen_req.valid)
786
787  val tlb_req_arb = Module(new RRArbiterInit(new TlbReq, smsParams.pf_filter_size))
788  val pf_req_arb = Module(new RRArbiterInit(UInt(PAddrBits.W), smsParams.pf_filter_size))
789
790  io.tlb_req.req <> tlb_req_arb.io.out
791  io.tlb_req.resp.ready := true.B
792  io.tlb_req.req_kill := false.B
793  io.l2_pf_addr.valid := pf_req_arb.io.out.valid
794  io.l2_pf_addr.bits := pf_req_arb.io.out.bits
795  pf_req_arb.io.out.ready := true.B
796
797  val s1_valid = Wire(Bool())
798  val s1_hit = Wire(Bool())
799  val s1_replace_vec = Wire(UInt(smsParams.pf_filter_size.W))
800  val s1_tlb_fire_vec = Wire(UInt(smsParams.pf_filter_size.W))
801
802  // s0: entries lookup
803  val s0_gen_req = io.gen_req.bits
804  val s0_match_prev = prev_valid && (s0_gen_req.region_tag === prev_gen_req.region_tag)
805  val s0_gen_req_valid = io.gen_req.valid && !s0_match_prev
806  val s0_match_vec = valids.indices.map(i => {
807    valids(i) && entries(i).region_tag === s0_gen_req.region_tag && !(s1_valid && !s1_hit && s1_replace_vec(i))
808  })
809  val s0_any_matched = Cat(s0_match_vec).orR
810  val s0_replace_vec = UIntToOH(replacement.way)
811  val s0_hit = s0_gen_req_valid && s0_any_matched
812
813  for(((v, ent), i) <- valids.zip(entries).zipWithIndex){
814    val is_evicted = s1_valid && s1_replace_vec(i)
815    tlb_req_arb.io.in(i).valid := v && !s1_tlb_fire_vec(i) && !ent.paddr_valid && !is_evicted
816    tlb_req_arb.io.in(i).bits.vaddr := Cat(ent.region_addr, 0.U(log2Up(REGION_SIZE).W))
817    tlb_req_arb.io.in(i).bits.cmd := TlbCmd.read
818    tlb_req_arb.io.in(i).bits.size := 3.U
819    tlb_req_arb.io.in(i).bits.no_translate := false.B
820    tlb_req_arb.io.in(i).bits.debug := DontCare
821
822    val pending_req_vec = ent.region_bits & (~ent.filter_bits).asUInt
823    val first_one_offset = PriorityMux(
824      pending_req_vec.asBools,
825      (0 until smsParams.pf_filter_size).map(_.U(REGION_OFFSET.W))
826    )
827    val last_one_offset = PriorityMux(
828      pending_req_vec.asBools.reverse,
829      (0 until smsParams.pf_filter_size).reverse.map(_.U(REGION_OFFSET.W))
830    )
831    val pf_addr = Cat(
832      ent.region_addr,
833      Mux(ent.decr_mode, last_one_offset, first_one_offset),
834      0.U(log2Up(dcacheParameters.blockBytes).W)
835    )
836    pf_req_arb.io.in(i).valid := v && Cat(pending_req_vec).orR && ent.paddr_valid && !is_evicted
837    pf_req_arb.io.in(i).bits := pf_addr
838  }
839
840  val s0_tlb_fire_vec = VecInit(tlb_req_arb.io.in.map(_.fire))
841  val s0_pf_fire_vec = VecInit(pf_req_arb.io.in.map(_.fire))
842
843  val s0_update_way = OHToUInt(s0_match_vec)
844  val s0_replace_way = replacement.way
845  val s0_access_way = Mux(s0_any_matched, s0_update_way, s0_replace_way)
846  when(s0_gen_req_valid){
847    replacement.access(s0_access_way)
848  }
849
850  // s1: update or alloc
851  val s1_valid_r = RegNext(s0_gen_req_valid, false.B)
852  val s1_hit_r = RegEnable(s0_hit, false.B, s0_gen_req_valid)
853  val s1_gen_req = RegEnable(s0_gen_req, s0_gen_req_valid)
854  val s1_replace_vec_r = RegEnable(s0_replace_vec, s0_gen_req_valid && !s0_hit)
855  val s1_update_vec = RegEnable(VecInit(s0_match_vec).asUInt, s0_gen_req_valid && s0_hit)
856  val s1_tlb_fire_vec_r = RegNext(s0_tlb_fire_vec, 0.U.asTypeOf(s0_tlb_fire_vec))
857  val s1_alloc_entry = Wire(new PrefetchFilterEntry())
858  s1_valid := s1_valid_r
859  s1_hit := s1_hit_r
860  s1_replace_vec := s1_replace_vec_r
861  s1_tlb_fire_vec := s1_tlb_fire_vec_r.asUInt
862  s1_alloc_entry.region_tag := s1_gen_req.region_tag
863  s1_alloc_entry.region_addr := s1_gen_req.region_addr
864  s1_alloc_entry.region_bits := s1_gen_req.region_bits
865  s1_alloc_entry.paddr_valid := s1_gen_req.paddr_valid
866  s1_alloc_entry.decr_mode := s1_gen_req.decr_mode
867  s1_alloc_entry.filter_bits := 0.U
868  for(((v, ent), i) <- valids.zip(entries).zipWithIndex){
869    val alloc = s1_valid && !s1_hit && s1_replace_vec(i)
870    val update = s1_valid && s1_hit && s1_update_vec(i)
871    // for pf: use s0 data
872    val pf_fired = s0_pf_fire_vec(i)
873    val tlb_fired = s1_tlb_fire_vec(i) && !io.tlb_req.resp.bits.miss
874    when(tlb_fired){
875      ent.paddr_valid := !io.tlb_req.resp.bits.miss
876      ent.region_addr := region_addr(io.tlb_req.resp.bits.paddr.head)
877    }
878    when(update){
879      ent.region_bits := ent.region_bits | s1_gen_req.region_bits
880    }
881    when(pf_fired){
882      val curr_bit = UIntToOH(block_addr(pf_req_arb.io.in(i).bits)(REGION_OFFSET - 1, 0))
883      ent.filter_bits := ent.filter_bits | curr_bit
884    }
885    when(alloc){
886      ent := s1_alloc_entry
887      v := true.B
888    }
889  }
890  when(s1_valid && s1_hit){
891    assert(PopCount(s1_update_vec) === 1.U, "sms_pf_filter: multi-hit")
892  }
893
894  XSPerfAccumulate("sms_pf_filter_recv_req", io.gen_req.valid)
895  XSPerfAccumulate("sms_pf_filter_hit", s1_valid && s1_hit)
896  XSPerfAccumulate("sms_pf_filter_tlb_req", io.tlb_req.req.fire)
897  XSPerfAccumulate("sms_pf_filter_tlb_resp_miss", io.tlb_req.resp.fire && io.tlb_req.resp.bits.miss)
898  for(i <- 0 until smsParams.pf_filter_size){
899    XSPerfAccumulate(s"sms_pf_filter_access_way_$i", s0_gen_req_valid && s0_access_way === i.U)
900  }
901  XSPerfAccumulate("sms_pf_filter_l2_req", io.l2_pf_addr.valid)
902}
903
904class SMSPrefetcher()(implicit p: Parameters) extends BasePrefecher with HasSMSModuleHelper {
905
906  require(exuParameters.LduCnt == 2)
907
908  val io_agt_en = IO(Input(Bool()))
909  val io_stride_en = IO(Input(Bool()))
910  val io_pht_en = IO(Input(Bool()))
911  val io_act_threshold = IO(Input(UInt(REGION_OFFSET.W)))
912  val io_act_stride = IO(Input(UInt(6.W)))
913
914  val ld_curr = io.ld_in.map(_.bits)
915  val ld_curr_block_tag = ld_curr.map(x => block_hash_tag(x.vaddr))
916
917  // block filter
918  val ld_prev = io.ld_in.map(ld => RegEnable(ld.bits, ld.valid))
919  val ld_prev_block_tag = ld_curr_block_tag.zip(io.ld_in.map(_.valid)).map({
920    case (tag, v) => RegEnable(tag, v)
921  })
922  val ld_prev_vld = io.ld_in.map(ld => RegNext(ld.valid, false.B))
923
924  val ld_curr_match_prev = ld_curr_block_tag.map(cur_tag =>
925    Cat(ld_prev_block_tag.zip(ld_prev_vld).map({
926      case (prev_tag, prev_vld) => prev_vld && prev_tag === cur_tag
927    })).orR
928  )
929  val ld0_match_ld1 = io.ld_in.head.valid && io.ld_in.last.valid && ld_curr_block_tag.head === ld_curr_block_tag.last
930  val ld_curr_vld = Seq(
931    io.ld_in.head.valid && !ld_curr_match_prev.head,
932    io.ld_in.last.valid && !ld_curr_match_prev.last && !ld0_match_ld1
933  )
934  val ld0_older_than_ld1 = Cat(ld_curr_vld).andR && isBefore(ld_curr.head.uop.robIdx, ld_curr.last.uop.robIdx)
935  val pending_vld = RegNext(Cat(ld_curr_vld).andR, false.B)
936  val pending_sel_ld0 = RegNext(Mux(pending_vld, ld0_older_than_ld1, !ld0_older_than_ld1))
937  val pending_ld = Mux(pending_sel_ld0, ld_prev.head, ld_prev.last)
938  val pending_ld_block_tag = Mux(pending_sel_ld0, ld_prev_block_tag.head, ld_prev_block_tag.last)
939  val oldest_ld = Mux(pending_vld,
940    pending_ld,
941    Mux(ld0_older_than_ld1 || !ld_curr_vld.last, ld_curr.head, ld_curr.last)
942  )
943
944  val train_ld = RegEnable(oldest_ld, pending_vld || Cat(ld_curr_vld).orR)
945
946  val train_block_tag = block_hash_tag(train_ld.vaddr)
947  val train_region_tag = train_block_tag.head(REGION_TAG_WIDTH)
948
949  val train_region_addr_raw = region_addr(train_ld.vaddr)(REGION_TAG_WIDTH + 2 * VADDR_HASH_WIDTH - 1, 0)
950  val train_region_addr_p1 = Cat(0.U(1.W), train_region_addr_raw) + 1.U
951  val train_region_addr_m1 = Cat(0.U(1.W), train_region_addr_raw) - 1.U
952  // addr_p1 or addr_m1 is valid?
953  val train_allow_cross_region_p1 = !train_region_addr_p1.head(1).asBool
954  val train_allow_cross_region_m1 = !train_region_addr_m1.head(1).asBool
955
956  val train_region_p1_tag = region_hash_tag(train_region_addr_p1.tail(1))
957  val train_region_m1_tag = region_hash_tag(train_region_addr_m1.tail(1))
958
959  val train_region_p1_cross_page = page_bit(train_region_addr_p1) ^ page_bit(train_region_addr_raw)
960  val train_region_m1_cross_page = page_bit(train_region_addr_m1) ^ page_bit(train_region_addr_raw)
961
962  val train_region_paddr = region_addr(train_ld.paddr)
963  val train_region_vaddr = region_addr(train_ld.vaddr)
964  val train_region_offset = train_block_tag(REGION_OFFSET - 1, 0)
965  val train_vld = RegNext(pending_vld || Cat(ld_curr_vld).orR, false.B)
966
967
968  // prefetch stage0
969  val active_gen_table = Module(new ActiveGenerationTable())
970  val stride = Module(new StridePF())
971  val pht = Module(new PatternHistoryTable())
972  val pf_filter = Module(new PrefetchFilter())
973
974  val train_vld_s0 = RegNext(train_vld, false.B)
975  val train_s0 = RegEnable(train_ld, train_vld)
976  val train_region_tag_s0 = RegEnable(train_region_tag, train_vld)
977  val train_region_p1_tag_s0 = RegEnable(train_region_p1_tag, train_vld)
978  val train_region_m1_tag_s0 = RegEnable(train_region_m1_tag, train_vld)
979  val train_allow_cross_region_p1_s0 = RegEnable(train_allow_cross_region_p1, train_vld)
980  val train_allow_cross_region_m1_s0 = RegEnable(train_allow_cross_region_m1, train_vld)
981  val train_pht_tag_s0 = RegEnable(pht_tag(train_ld.uop.cf.pc), train_vld)
982  val train_pht_index_s0 = RegEnable(pht_index(train_ld.uop.cf.pc), train_vld)
983  val train_region_offset_s0 = RegEnable(train_region_offset, train_vld)
984  val train_region_p1_cross_page_s0 = RegEnable(train_region_p1_cross_page, train_vld)
985  val train_region_m1_cross_page_s0 = RegEnable(train_region_m1_cross_page, train_vld)
986  val train_region_paddr_s0 = RegEnable(train_region_paddr, train_vld)
987  val train_region_vaddr_s0 = RegEnable(train_region_vaddr, train_vld)
988
989  active_gen_table.io.agt_en := io_agt_en
990  active_gen_table.io.act_threshold := io_act_threshold
991  active_gen_table.io.act_stride := io_act_stride
992  active_gen_table.io.s0_lookup.valid := train_vld_s0
993  active_gen_table.io.s0_lookup.bits.region_tag := train_region_tag_s0
994  active_gen_table.io.s0_lookup.bits.region_p1_tag := train_region_p1_tag_s0
995  active_gen_table.io.s0_lookup.bits.region_m1_tag := train_region_m1_tag_s0
996  active_gen_table.io.s0_lookup.bits.region_offset := train_region_offset_s0
997  active_gen_table.io.s0_lookup.bits.pht_index := train_pht_index_s0
998  active_gen_table.io.s0_lookup.bits.pht_tag := train_pht_tag_s0
999  active_gen_table.io.s0_lookup.bits.allow_cross_region_p1 := train_allow_cross_region_p1_s0
1000  active_gen_table.io.s0_lookup.bits.allow_cross_region_m1 := train_allow_cross_region_m1_s0
1001  active_gen_table.io.s0_lookup.bits.region_p1_cross_page := train_region_p1_cross_page_s0
1002  active_gen_table.io.s0_lookup.bits.region_m1_cross_page := train_region_m1_cross_page_s0
1003  active_gen_table.io.s0_lookup.bits.region_paddr := train_region_paddr_s0
1004  active_gen_table.io.s0_lookup.bits.region_vaddr := train_region_vaddr_s0
1005  active_gen_table.io.s2_stride_hit := stride.io.s2_gen_req.valid
1006
1007  stride.io.stride_en := io_stride_en
1008  stride.io.s0_lookup.valid := train_vld_s0
1009  stride.io.s0_lookup.bits.pc := train_s0.uop.cf.pc(STRIDE_PC_BITS - 1, 0)
1010  stride.io.s0_lookup.bits.vaddr := Cat(
1011    train_region_vaddr_s0, train_region_offset_s0, 0.U(log2Up(dcacheParameters.blockBytes).W)
1012  )
1013  stride.io.s0_lookup.bits.paddr := Cat(
1014    train_region_paddr_s0, train_region_offset_s0, 0.U(log2Up(dcacheParameters.blockBytes).W)
1015  )
1016  stride.io.s1_valid := active_gen_table.io.s1_sel_stride
1017
1018  pht.io.s2_agt_lookup := active_gen_table.io.s2_pht_lookup
1019  pht.io.agt_update := active_gen_table.io.s2_evict
1020
1021  val pht_gen_valid = pht.io.pf_gen_req.valid && io_pht_en
1022  val agt_gen_valid = active_gen_table.io.s2_pf_gen_req.valid
1023  val stride_gen_valid = stride.io.s2_gen_req.valid
1024  val pf_gen_req = Mux(agt_gen_valid || stride_gen_valid,
1025    Mux1H(Seq(
1026      agt_gen_valid -> active_gen_table.io.s2_pf_gen_req.bits,
1027      stride_gen_valid -> stride.io.s2_gen_req.bits
1028    )),
1029    pht.io.pf_gen_req.bits
1030  )
1031  assert(!(agt_gen_valid && stride_gen_valid))
1032  pf_filter.io.gen_req.valid := pht_gen_valid || agt_gen_valid || stride_gen_valid
1033  pf_filter.io.gen_req.bits := pf_gen_req
1034  io.tlb_req <> pf_filter.io.tlb_req
1035  val is_valid_address = pf_filter.io.l2_pf_addr.bits > 0x80000000L.U
1036  io.pf_addr.valid := pf_filter.io.l2_pf_addr.valid && io.enable && is_valid_address
1037  io.pf_addr.bits := pf_filter.io.l2_pf_addr.bits
1038
1039  XSPerfAccumulate("sms_pf_gen_conflict",
1040    pht_gen_valid && agt_gen_valid
1041  )
1042  XSPerfAccumulate("sms_pht_disabled", pht.io.pf_gen_req.valid && !io_pht_en)
1043  XSPerfAccumulate("sms_agt_disabled", active_gen_table.io.s2_pf_gen_req.valid && !io_agt_en)
1044  XSPerfAccumulate("sms_pf_real_issued", io.pf_addr.valid)
1045}
1046