xref: /XiangShan/src/main/scala/xiangshan/mem/prefetch/SMSPrefetcher.scala (revision 5d13017ec21a682fa7c9fb193ba217ea13795ed3)
1package xiangshan.mem.prefetch
2
3import chipsalliance.rocketchip.config.Parameters
4import chisel3._
5import chisel3.util._
6import xiangshan._
7import utils._
8import xiangshan.cache.HasDCacheParameters
9import xiangshan.cache.mmu._
10
11case class SMSParams
12(
13  region_size: Int = 1024,
14  vaddr_hash_width: Int = 5,
15  block_addr_raw_width: Int = 10,
16  filter_table_size: Int = 16,
17  active_gen_table_size: Int = 16,
18  pht_size: Int = 64,
19  pht_ways: Int = 2,
20  pht_hist_bits: Int = 2,
21  pht_tag_bits: Int = 13,
22  pht_lookup_queue_size: Int = 4,
23  pf_filter_size: Int = 16
24) extends PrefetcherParams
25
26trait HasSMSModuleHelper extends HasCircularQueuePtrHelper with HasDCacheParameters
27{ this: HasXSParameter =>
28  val smsParams = coreParams.prefetcher.get.asInstanceOf[SMSParams]
29  val BLK_ADDR_WIDTH = VAddrBits - log2Up(dcacheParameters.blockBytes)
30  val REGION_SIZE = smsParams.region_size
31  val REGION_BLKS = smsParams.region_size / dcacheParameters.blockBytes
32  val REGION_ADDR_BITS = VAddrBits - log2Up(REGION_SIZE)
33  val REGION_OFFSET = log2Up(REGION_BLKS)
34  val VADDR_HASH_WIDTH = smsParams.vaddr_hash_width
35  val BLK_ADDR_RAW_WIDTH = smsParams.block_addr_raw_width
36  val REGION_ADDR_RAW_WIDTH = BLK_ADDR_RAW_WIDTH - REGION_OFFSET
37  val BLK_TAG_WIDTH = BLK_ADDR_RAW_WIDTH + VADDR_HASH_WIDTH
38  val REGION_TAG_WIDTH = REGION_ADDR_RAW_WIDTH + VADDR_HASH_WIDTH
39  val PHT_INDEX_BITS = log2Up(smsParams.pht_size / smsParams.pht_ways)
40  val PHT_TAG_BITS = smsParams.pht_tag_bits
41  val PHT_HIST_BITS = smsParams.pht_hist_bits
42  // page bit index in block addr
43  val BLOCK_ADDR_PAGE_BIT = log2Up(dcacheParameters.pageSize / dcacheParameters.blockBytes)
44  val REGION_ADDR_PAGE_BIT = log2Up(dcacheParameters.pageSize / smsParams.region_size)
45
46  def block_addr(x: UInt): UInt = {
47    val offset = log2Up(dcacheParameters.blockBytes)
48    x(x.getWidth - 1, offset)
49  }
50
51  def region_addr(x: UInt): UInt = {
52    val offset = log2Up(REGION_SIZE)
53    x(x.getWidth - 1, offset)
54  }
55
56  def region_offset_to_bits(off: UInt): UInt = {
57    (1.U << off).asUInt
58  }
59
60  def region_hash_tag(rg_addr: UInt): UInt = {
61    val low = rg_addr(REGION_ADDR_RAW_WIDTH - 1, 0)
62    val high = rg_addr(REGION_ADDR_RAW_WIDTH + 3 * VADDR_HASH_WIDTH - 1, REGION_ADDR_RAW_WIDTH)
63    val high_hash = vaddr_hash(high)
64    Cat(high_hash, low)
65  }
66
67  def page_bit(region_addr: UInt): UInt = {
68    region_addr(log2Up(dcacheParameters.pageSize/REGION_SIZE))
69  }
70
71  def block_hash_tag(x: UInt): UInt = {
72    val blk_addr = block_addr(x)
73    val low = blk_addr(BLK_ADDR_RAW_WIDTH - 1, 0)
74    val high = blk_addr(BLK_ADDR_RAW_WIDTH - 1 + 3 * VADDR_HASH_WIDTH, BLK_ADDR_RAW_WIDTH)
75    val high_hash = vaddr_hash(high)
76    Cat(high_hash, low)
77  }
78
79  def vaddr_hash(x: UInt): UInt = {
80    val width = VADDR_HASH_WIDTH
81    val low = x(width - 1, 0)
82    val mid = x(2 * width - 1, width)
83    val high = x(3 * width - 1, 2 * width)
84    low ^ mid ^ high
85  }
86
87  def pht_index(pc: UInt): UInt = {
88    val low_bits = pc(PHT_INDEX_BITS, 2)
89    val hi_bit = pc(1) ^ pc(PHT_INDEX_BITS+1)
90    Cat(hi_bit, low_bits)
91  }
92
93  def pht_tag(pc: UInt): UInt = {
94    pc(PHT_INDEX_BITS + 2 + PHT_TAG_BITS - 1, PHT_INDEX_BITS + 2)
95  }
96}
97
98class FilterTable()(implicit p: Parameters) extends XSModule with HasSMSModuleHelper {
99  val io = IO(new Bundle() {
100    val s0_lookup = Flipped(ValidIO(new FilterEntry()))
101    val s1_result = ValidIO(new FilterEntry())
102    val s1_update = Input(Bool())
103  })
104
105  val s0_lookup_entry = io.s0_lookup.bits
106  val s0_lookup_valid = io.s0_lookup.valid
107
108  val entries = Seq.fill(smsParams.filter_table_size){ Reg(new FilterEntry()) }
109  val valids = Seq.fill(smsParams.filter_table_size){ RegInit(false.B) }
110  val w_ptr = RegInit(0.U(log2Up(smsParams.filter_table_size).W))
111
112  val prev_entry = RegEnable(s0_lookup_entry, s0_lookup_valid)
113  val prev_lookup_valid = RegNext(s0_lookup_valid, false.B)
114
115  val s0_entry_match_vec = entries.zip(valids).map({
116    case (ent, v) => v && ent.region_tag === s0_lookup_entry.region_tag && ent.offset =/= s0_lookup_entry.offset
117  })
118  val s0_any_entry_match = Cat(s0_entry_match_vec).orR
119  val s0_matched_entry = Mux1H(s0_entry_match_vec, entries)
120  val s0_match_s1 = prev_lookup_valid &&
121    prev_entry.region_tag === s0_lookup_entry.region_tag && prev_entry.offset =/= s0_lookup_entry.offset
122
123  val s0_hit = s0_lookup_valid && (s0_any_entry_match || s0_match_s1)
124
125  val s0_lookup_result = Wire(new FilterEntry())
126  s0_lookup_result := Mux(s0_match_s1, prev_entry, s0_matched_entry)
127  io.s1_result.valid := RegNext(s0_hit, false.B)
128  io.s1_result.bits := RegEnable(s0_lookup_result, s0_hit)
129
130  val s0_invalid_mask = valids.map(!_)
131  val s0_has_invalid_entry = Cat(s0_invalid_mask).orR
132  val s0_invalid_index = PriorityEncoder(s0_invalid_mask)
133  // if match, invalidte entry
134  for((v, i) <- valids.zipWithIndex){
135    when(s0_lookup_valid && s0_entry_match_vec(i)){
136      v := false.B
137    }
138  }
139
140  // stage1
141  val s1_has_invalid_entry = RegEnable(s0_has_invalid_entry, s0_lookup_valid)
142  val s1_invalid_index = RegEnable(s0_invalid_index, s0_lookup_valid)
143  // alloc entry if (s0 miss && s1_update)
144  val s1_do_update = io.s1_update && prev_lookup_valid && !io.s1_result.valid
145  val update_ptr = Mux(s1_has_invalid_entry, s1_invalid_index, w_ptr)
146  when(s1_do_update && !s1_has_invalid_entry){ w_ptr := w_ptr + 1.U }
147  for((ent, i) <- entries.zipWithIndex){
148    val wen = s1_do_update && update_ptr === i.U
149    when(wen){
150      valids(i) := true.B
151      ent := prev_entry
152    }
153  }
154
155  XSPerfAccumulate("sms_filter_table_hit", io.s1_result.valid)
156  XSPerfAccumulate("sms_filter_table_update", s1_do_update)
157  for(i <- 0 until smsParams.filter_table_size){
158    XSPerfAccumulate(s"sms_filter_table_access_$i",
159      s1_do_update && update_ptr === i.U
160    )
161  }
162}
163
164class FilterEntry()(implicit p: Parameters) extends XSBundle with HasSMSModuleHelper {
165  val pht_index = UInt(PHT_INDEX_BITS.W)
166  val pht_tag = UInt(PHT_TAG_BITS.W)
167  val region_tag = UInt(REGION_TAG_WIDTH.W)
168  val offset = UInt(REGION_OFFSET.W)
169}
170
171class AGTEntry()(implicit p: Parameters) extends XSBundle with HasSMSModuleHelper {
172  val pht_index = UInt(PHT_INDEX_BITS.W)
173  val pht_tag = UInt(PHT_TAG_BITS.W)
174  val region_bits = UInt(REGION_BLKS.W)
175  val region_tag = UInt(REGION_TAG_WIDTH.W)
176  val access_cnt = UInt((REGION_BLKS-1).U.getWidth.W)
177  val decr_mode = Bool()
178}
179
180class PfGenReq()(implicit p: Parameters) extends XSBundle with HasSMSModuleHelper {
181  val region_tag = UInt(REGION_TAG_WIDTH.W)
182  val region_addr = UInt(REGION_ADDR_BITS.W)
183  val region_bits = UInt(REGION_BLKS.W)
184  val paddr_valid = Bool()
185  val decr_mode = Bool()
186}
187
188class ActiveGenerationTable()(implicit p: Parameters) extends XSModule with HasSMSModuleHelper {
189  val io = IO(new Bundle() {
190    val s0_lookup = Flipped(ValidIO(new Bundle() {
191      val region_tag = UInt(REGION_TAG_WIDTH.W)
192      val region_p1_tag = UInt(REGION_TAG_WIDTH.W)
193      val region_m1_tag = UInt(REGION_TAG_WIDTH.W)
194      val region_offset = UInt(REGION_OFFSET.W)
195      val pht_index = UInt(PHT_INDEX_BITS.W)
196      val pht_tag = UInt(PHT_TAG_BITS.W)
197      val allow_cross_region_p1 = Bool()
198      val allow_cross_region_m1 = Bool()
199      val region_p1_cross_page = Bool()
200      val region_m1_cross_page = Bool()
201      val region_paddr = UInt(REGION_ADDR_BITS.W)
202      val region_vaddr = UInt(REGION_ADDR_BITS.W)
203    }))
204    // do not alloc entry in filter table if agt hit
205    val s1_match_or_alloc = Output(Bool())
206    // if agt missed, try lookup pht
207    val s2_pht_lookup = ValidIO(new PhtLookup())
208    // receive second hit from filter table
209    val s1_recv_entry = Flipped(ValidIO(new AGTEntry()))
210    // evict entry to pht
211    val s2_evict = ValidIO(new AGTEntry())
212    val s2_pf_gen_req = ValidIO(new PfGenReq())
213    val act_threshold = Input(UInt(REGION_OFFSET.W))
214    val act_stride = Input(UInt(6.W))
215  })
216
217  val entries = Seq.fill(smsParams.active_gen_table_size){ Reg(new AGTEntry()) }
218  val valids = Seq.fill(smsParams.active_gen_table_size){ RegInit(false.B) }
219  val replacement = ReplacementPolicy.fromString("plru", smsParams.active_gen_table_size)
220
221  val s0_lookup = io.s0_lookup.bits
222  val s0_lookup_valid = io.s0_lookup.valid
223
224  val prev_lookup = RegEnable(s0_lookup, s0_lookup_valid)
225  val prev_lookup_valid = RegNext(s0_lookup_valid, false.B)
226
227  val s0_match_prev = prev_lookup_valid && s0_lookup.region_tag === prev_lookup.region_tag
228
229  def gen_match_vec(region_tag: UInt): Seq[Bool] = {
230    entries.zip(valids).map({
231      case (ent, v) => v && ent.region_tag === region_tag
232    })
233  }
234
235  val region_match_vec_s0 = gen_match_vec(s0_lookup.region_tag)
236  val region_p1_match_vec_s0 = gen_match_vec(s0_lookup.region_p1_tag)
237  val region_m1_match_vec_s0 = gen_match_vec(s0_lookup.region_m1_tag)
238
239  val any_region_match = Cat(region_match_vec_s0).orR
240  val any_region_p1_match = Cat(region_p1_match_vec_s0).orR && s0_lookup.allow_cross_region_p1
241  val any_region_m1_match = Cat(region_m1_match_vec_s0).orR && s0_lookup.allow_cross_region_m1
242
243  val s0_region_hit = any_region_match
244  // region miss, but cross region match
245  val s0_alloc = !s0_region_hit && (any_region_p1_match || any_region_m1_match) && !s0_match_prev
246  val s0_match_or_alloc = any_region_match || any_region_p1_match || any_region_m1_match
247  val s0_pf_gen_match_vec = valids.indices.map(i => {
248    Mux(any_region_match,
249      region_match_vec_s0(i),
250      Mux(any_region_m1_match,
251        region_m1_match_vec_s0(i), region_p1_match_vec_s0(i)
252      )
253    )
254  })
255  val s0_agt_entry = Wire(new AGTEntry())
256
257  s0_agt_entry.pht_index := s0_lookup.pht_index
258  s0_agt_entry.pht_tag := s0_lookup.pht_tag
259  s0_agt_entry.region_bits := region_offset_to_bits(s0_lookup.region_offset)
260  s0_agt_entry.region_tag := s0_lookup.region_tag
261  s0_agt_entry.access_cnt := 1.U
262  // lookup_region + 1 == entry_region
263  // lookup_region = entry_region - 1 => decr mode
264  s0_agt_entry.decr_mode := !s0_region_hit && !any_region_m1_match && any_region_p1_match
265  val s0_replace_mask = UIntToOH(replacement.way)
266  // s0 hit a entry that may be replaced in s1
267  val s0_update_conflict = Cat(VecInit(region_match_vec_s0).asUInt & s0_replace_mask).orR
268
269  // stage1: update/alloc
270  // region hit, update entry
271  val s1_update_conflict = RegEnable(s0_update_conflict, s0_lookup_valid && s0_region_hit)
272  val s1_update = RegNext(s0_lookup_valid && s0_region_hit, false.B) && !s1_update_conflict
273  val s1_update_mask = RegEnable(
274    VecInit(region_match_vec_s0),
275    VecInit(Seq.fill(smsParams.active_gen_table_size){ false.B }),
276    s0_lookup_valid
277  )
278  val s1_agt_entry = RegEnable(s0_agt_entry, s0_lookup_valid)
279  val s1_recv_entry = io.s1_recv_entry
280  val s1_drop = RegInit(false.B)
281  // cross region match or filter table second hit
282  val s1_cross_region_match = RegNext(s0_lookup_valid && s0_alloc, false.B)
283  val s1_alloc = s1_cross_region_match || (s1_recv_entry.valid && !s1_drop && !s1_update)
284  s1_drop := s0_lookup_valid && s0_match_prev && s1_alloc // TODO: use bypass update instead of drop
285  val s1_alloc_entry = Mux(s1_recv_entry.valid, s1_recv_entry.bits, s1_agt_entry)
286  val s1_replace_mask = RegEnable(s0_replace_mask, s0_lookup_valid)
287  val s1_evict_entry = Mux1H(s1_replace_mask, entries)
288  val s1_evict_valid = Mux1H(s1_replace_mask, valids)
289  // pf gen
290  val s1_pf_gen_match_vec = RegEnable(VecInit(s0_pf_gen_match_vec), s0_lookup_valid)
291  val s1_region_paddr = RegEnable(s0_lookup.region_paddr, s0_lookup_valid)
292  val s1_region_vaddr = RegEnable(s0_lookup.region_vaddr, s0_lookup_valid)
293  val s1_region_offset = RegEnable(s0_lookup.region_offset, s0_lookup_valid)
294  for(i <- entries.indices){
295    val alloc = s1_replace_mask(i) && s1_alloc
296    val update = s1_update_mask(i) && s1_update
297    val update_entry = WireInit(entries(i))
298    update_entry.region_bits := entries(i).region_bits | s1_agt_entry.region_bits
299    update_entry.access_cnt := Mux(entries(i).access_cnt === (REGION_BLKS - 1).U,
300      entries(i).access_cnt,
301      entries(i).access_cnt + (s1_agt_entry.region_bits & (~entries(i).region_bits).asUInt).orR
302    )
303    valids(i) := valids(i) || alloc
304    entries(i) := Mux(alloc, s1_alloc_entry, Mux(update, update_entry, entries(i)))
305  }
306  when(s1_update) {
307    replacement.access(OHToUInt(s1_update_mask))
308  }.elsewhen(s1_alloc){
309    replacement.access(OHToUInt(s1_replace_mask))
310  }
311
312  io.s1_match_or_alloc := s1_update || s1_alloc || s1_drop
313
314  when(s1_update){
315    assert(PopCount(s1_update_mask) === 1.U, "multi-agt-update")
316  }
317  when(s1_alloc){
318    assert(PopCount(s1_replace_mask) === 1.U, "multi-agt-alloc")
319  }
320
321  // pf_addr
322  // 1.hit => pf_addr = lookup_addr + (decr ? -1 : 1)
323  // 2.lookup region - 1 hit => lookup_addr + 1 (incr mode)
324  // 3.lookup region + 1 hit => lookup_addr - 1 (decr mode)
325  val s1_hited_entry_decr = Mux1H(s1_update_mask, entries.map(_.decr_mode))
326  val s1_pf_gen_decr_mode = Mux(s1_update,
327    s1_hited_entry_decr,
328    s1_agt_entry.decr_mode
329  )
330
331  val s1_pf_gen_vaddr_inc = Cat(0.U, s1_region_vaddr(REGION_TAG_WIDTH - 1, 0), s1_region_offset) + io.act_stride
332  val s1_pf_gen_vaddr_dec = Cat(0.U, s1_region_vaddr(REGION_TAG_WIDTH - 1, 0), s1_region_offset) - io.act_stride
333  val s1_vaddr_inc_cross_page = s1_pf_gen_vaddr_inc(BLOCK_ADDR_PAGE_BIT) =/= s1_region_vaddr(REGION_ADDR_PAGE_BIT)
334  val s1_vaddr_dec_cross_page = s1_pf_gen_vaddr_dec(BLOCK_ADDR_PAGE_BIT) =/= s1_region_vaddr(REGION_ADDR_PAGE_BIT)
335  val s1_vaddr_inc_cross_max_lim = s1_pf_gen_vaddr_inc.head(1).asBool
336  val s1_vaddr_dec_cross_max_lim = s1_pf_gen_vaddr_dec.head(1).asBool
337
338  //val s1_pf_gen_vaddr_p1 = s1_region_vaddr(REGION_TAG_WIDTH - 1, 0) + 1.U
339  //val s1_pf_gen_vaddr_m1 = s1_region_vaddr(REGION_TAG_WIDTH - 1, 0) - 1.U
340  val s1_pf_gen_vaddr = Cat(
341    s1_region_vaddr(REGION_ADDR_BITS - 1, REGION_TAG_WIDTH),
342    Mux(s1_pf_gen_decr_mode,
343      s1_pf_gen_vaddr_dec.tail(1).head(REGION_TAG_WIDTH),
344      s1_pf_gen_vaddr_inc.tail(1).head(REGION_TAG_WIDTH)
345    )
346  )
347  val s1_pf_gen_offset = Mux(s1_pf_gen_decr_mode,
348    s1_pf_gen_vaddr_dec(REGION_OFFSET - 1, 0),
349    s1_pf_gen_vaddr_inc(REGION_OFFSET - 1, 0)
350  )
351  val s1_pf_gen_offset_mask = UIntToOH(s1_pf_gen_offset)
352  val s1_pf_gen_access_cnt = Mux1H(s1_pf_gen_match_vec, entries.map(_.access_cnt))
353  val s1_pf_gen_valid = prev_lookup_valid && io.s1_match_or_alloc && Mux(s1_pf_gen_decr_mode,
354    !s1_vaddr_dec_cross_max_lim,
355    !s1_vaddr_inc_cross_max_lim
356  ) && (s1_pf_gen_access_cnt > io.act_threshold)
357  val s1_pf_gen_paddr_valid = Mux(s1_pf_gen_decr_mode, !s1_vaddr_dec_cross_page, !s1_vaddr_inc_cross_page)
358  val s1_pf_gen_region_addr = Mux(s1_pf_gen_paddr_valid,
359    Cat(s1_region_paddr(REGION_ADDR_BITS - 1, REGION_ADDR_PAGE_BIT), s1_pf_gen_vaddr(REGION_ADDR_PAGE_BIT - 1, 0)),
360    s1_pf_gen_vaddr
361  )
362  val s1_pf_gen_region_tag = region_hash_tag(s1_pf_gen_vaddr)
363  val s1_pf_gen_incr_region_bits = VecInit((0 until REGION_BLKS).map(i => {
364    if(i == 0) true.B else !s1_pf_gen_offset_mask(i - 1, 0).orR
365  })).asUInt
366  val s1_pf_gen_decr_region_bits = VecInit((0 until REGION_BLKS).map(i => {
367    if(i == REGION_BLKS - 1) true.B
368    else !s1_pf_gen_offset_mask(REGION_BLKS - 1, i + 1).orR
369  })).asUInt
370  val s1_pf_gen_region_bits = Mux(s1_pf_gen_decr_mode,
371    s1_pf_gen_decr_region_bits,
372    s1_pf_gen_incr_region_bits
373  )
374  val s1_pht_lookup_valid = Wire(Bool())
375  val s1_pht_lookup = Wire(new PhtLookup())
376
377  s1_pht_lookup_valid := !s1_pf_gen_valid && prev_lookup_valid
378  s1_pht_lookup.pht_index := s1_agt_entry.pht_index
379  s1_pht_lookup.pht_tag := s1_agt_entry.pht_tag
380  s1_pht_lookup.region_vaddr := s1_region_vaddr
381  s1_pht_lookup.region_paddr := s1_region_paddr
382  s1_pht_lookup.region_offset := s1_region_offset
383
384  // stage2: gen pf reg / evict entry to pht
385  val s2_evict_entry = RegEnable(s1_evict_entry, s1_alloc)
386  val s2_evict_valid = RegNext(s1_alloc && s1_evict_valid, false.B)
387  val s2_paddr_valid = RegEnable(s1_pf_gen_paddr_valid, s1_pf_gen_valid)
388  val s2_pf_gen_region_tag = RegEnable(s1_pf_gen_region_tag, s1_pf_gen_valid)
389  val s2_pf_gen_decr_mode = RegEnable(s1_pf_gen_decr_mode, s1_pf_gen_valid)
390  val s2_pf_gen_region_paddr = RegEnable(s1_pf_gen_region_addr, s1_pf_gen_valid)
391  val s2_pf_gen_region_bits = RegEnable(s1_pf_gen_region_bits, s1_pf_gen_valid)
392  val s2_pf_gen_valid = RegNext(s1_pf_gen_valid, false.B)
393  val s2_pht_lookup_valid = RegNext(s1_pht_lookup_valid, false.B)
394  val s2_pht_lookup = RegEnable(s1_pht_lookup, s1_pht_lookup_valid)
395
396  io.s2_evict.valid := s2_evict_valid
397  io.s2_evict.bits := s2_evict_entry
398
399  io.s2_pf_gen_req.bits.region_tag := s2_pf_gen_region_tag
400  io.s2_pf_gen_req.bits.region_addr := s2_pf_gen_region_paddr
401  io.s2_pf_gen_req.bits.region_bits := s2_pf_gen_region_bits
402  io.s2_pf_gen_req.bits.paddr_valid := s2_paddr_valid
403  io.s2_pf_gen_req.bits.decr_mode := s2_pf_gen_decr_mode
404  io.s2_pf_gen_req.valid := s2_pf_gen_valid
405
406  io.s2_pht_lookup.valid := s2_pht_lookup_valid
407  io.s2_pht_lookup.bits := s2_pht_lookup
408
409  XSPerfAccumulate("sms_agt_alloc", s1_alloc) // cross region match or filter evict
410  XSPerfAccumulate("sms_agt_update", s1_update) // entry hit
411  XSPerfAccumulate("sms_agt_pf_gen", io.s2_pf_gen_req.valid)
412  XSPerfAccumulate("sms_agt_pf_gen_paddr_valid",
413    io.s2_pf_gen_req.valid && io.s2_pf_gen_req.bits.paddr_valid
414  )
415  XSPerfAccumulate("sms_agt_pf_gen_decr_mode",
416    io.s2_pf_gen_req.valid && io.s2_pf_gen_req.bits.decr_mode
417  )
418  for(i <- 0 until smsParams.active_gen_table_size){
419    XSPerfAccumulate(s"sms_agt_access_entry_$i",
420      s1_alloc && s1_replace_mask(i) || s1_update && s1_update_mask(i)
421    )
422  }
423
424}
425
426class PhtLookup()(implicit p: Parameters) extends XSBundle with HasSMSModuleHelper {
427  val pht_index = UInt(PHT_INDEX_BITS.W)
428  val pht_tag = UInt(PHT_TAG_BITS.W)
429  val region_paddr = UInt(REGION_ADDR_BITS.W)
430  val region_vaddr = UInt(REGION_ADDR_BITS.W)
431  val region_offset = UInt(REGION_OFFSET.W)
432}
433
434class PhtEntry()(implicit p: Parameters) extends XSBundle with HasSMSModuleHelper {
435  val hist = Vec(2 * (REGION_BLKS - 1), UInt(PHT_HIST_BITS.W))
436  val tag = UInt(PHT_TAG_BITS.W)
437  val decr_mode = Bool()
438}
439
440class PatternHistoryTable()(implicit p: Parameters) extends XSModule with HasSMSModuleHelper {
441  val io = IO(new Bundle() {
442    // receive agt evicted entry
443    val agt_update = Flipped(ValidIO(new AGTEntry()))
444    // at stage2, if we know agt missed, lookup pht
445    val s2_agt_lookup = Flipped(ValidIO(new PhtLookup()))
446    // pht-generated prefetch req
447    val pf_gen_req = ValidIO(new PfGenReq())
448  })
449
450  val pht_ram = Module(new SRAMTemplate[PhtEntry](new PhtEntry,
451    set = smsParams.pht_size / smsParams.pht_ways,
452    way =smsParams.pht_ways,
453    singlePort = true
454  ))
455  def PHT_SETS = smsParams.pht_size / smsParams.pht_ways
456  val pht_valids = Seq.fill(smsParams.pht_ways){
457    RegInit(VecInit(Seq.fill(PHT_SETS){false.B}))
458  }
459  val replacement = Seq.fill(PHT_SETS) { ReplacementPolicy.fromString("plru", smsParams.pht_ways) }
460
461  val lookup_queue = Module(new OverrideableQueue(new PhtLookup, smsParams.pht_lookup_queue_size))
462  lookup_queue.io.in := io.s2_agt_lookup
463  val lookup = lookup_queue.io.out
464
465  val evict_queue = Module(new OverrideableQueue(new AGTEntry, smsParams.pht_lookup_queue_size))
466  evict_queue.io.in := io.agt_update
467  val evict = evict_queue.io.out
468
469  val s3_ram_en = Wire(Bool())
470  val s1_valid = Wire(Bool())
471  // if s1.raddr == s2.waddr or s3 is using ram port, block s1
472  val s1_wait = Wire(Bool())
473  // pipe s0: select an op from [lookup, update], generate ram read addr
474  val s0_valid = lookup.valid || evict.valid
475
476  evict.ready := !s1_valid || !s1_wait
477  lookup.ready := evict.ready && !evict.valid
478
479  val s0_ram_raddr = Mux(evict.valid,
480    evict.bits.pht_index,
481    lookup.bits.pht_index
482  )
483  val s0_tag = Mux(evict.valid, evict.bits.pht_tag, lookup.bits.pht_tag)
484  val s0_region_paddr = lookup.bits.region_paddr
485  val s0_region_vaddr = lookup.bits.region_vaddr
486  val s0_region_offset = lookup.bits.region_offset
487  val s0_region_bits = evict.bits.region_bits
488  val s0_decr_mode = evict.bits.decr_mode
489  val s0_evict = evict.valid
490
491  // pipe s1: send addr to ram
492  val s1_valid_r = RegInit(false.B)
493  s1_valid_r := Mux(s1_valid && s1_wait, true.B, s0_valid)
494  s1_valid := s1_valid_r
495  val s1_reg_en = s0_valid && (!s1_wait || !s1_valid)
496  val s1_ram_raddr = RegEnable(s0_ram_raddr, s1_reg_en)
497  val s1_tag = RegEnable(s0_tag, s1_reg_en)
498  val s1_region_bits = RegEnable(s0_region_bits, s1_reg_en)
499  val s1_decr_mode = RegEnable(s0_decr_mode, s1_reg_en)
500  val s1_region_paddr = RegEnable(s0_region_paddr, s1_reg_en)
501  val s1_region_vaddr = RegEnable(s0_region_vaddr, s1_reg_en)
502  val s1_region_offset = RegEnable(s0_region_offset, s1_reg_en)
503  val s1_pht_valids = pht_valids.map(way => Mux1H(
504    (0 until PHT_SETS).map(i => i.U === s1_ram_raddr),
505    way
506  ))
507  val s1_evict = RegEnable(s0_evict, s1_reg_en)
508  val s1_replace_way = Mux1H(
509    (0 until PHT_SETS).map(i => i.U === s1_ram_raddr),
510    replacement.map(_.way)
511  )
512  val s1_hist_update_mask = Cat(
513    Fill(REGION_BLKS - 1, true.B), 0.U((REGION_BLKS - 1).W)
514  ) >> s1_region_offset
515  val s1_hist_bits = Cat(
516    s1_region_bits.head(REGION_BLKS - 1) >> s1_region_offset,
517    (Cat(
518      s1_region_bits.tail(1), 0.U((REGION_BLKS - 1).W)
519    ) >> s1_region_offset)(REGION_BLKS - 2, 0)
520  )
521
522  // pipe s2: generate ram write addr/data
523  val s2_valid = RegNext(s1_valid && !s3_ram_en, false.B)
524  val s2_reg_en = s1_valid && !s3_ram_en
525  val s2_hist_update_mask = RegEnable(s1_hist_update_mask, s2_reg_en)
526  val s2_hist_bits = RegEnable(s1_hist_bits, s2_reg_en)
527  val s2_tag = RegEnable(s1_tag, s2_reg_en)
528  val s2_region_bits = RegEnable(s1_region_bits, s2_reg_en)
529  val s2_decr_mode = RegEnable(s1_decr_mode, s2_reg_en)
530  val s2_region_paddr = RegEnable(s1_region_paddr, s2_reg_en)
531  val s2_region_vaddr = RegEnable(s1_region_vaddr, s2_reg_en)
532  val s2_region_offset = RegEnable(s1_region_offset, s2_reg_en)
533  val s2_region_offset_mask = region_offset_to_bits(s2_region_offset)
534  val s2_evict = RegEnable(s1_evict, s2_reg_en)
535  val s2_pht_valids = s1_pht_valids.map(v => RegEnable(v, s2_reg_en))
536  val s2_replace_way = RegEnable(s1_replace_way, s2_reg_en)
537  val s2_ram_waddr = RegEnable(s1_ram_raddr, s2_reg_en)
538  val s2_ram_rdata = pht_ram.io.r.resp.data
539  val s2_ram_rtags = s2_ram_rdata.map(_.tag)
540  val s2_tag_match_vec = s2_ram_rtags.map(t => t === s2_tag)
541  val s2_hit_vec = s2_tag_match_vec.zip(s2_pht_valids).map({
542    case (tag_match, v) => v && tag_match
543  })
544  val s2_hist_update = s2_ram_rdata.map(way => VecInit(way.hist.zipWithIndex.map({
545    case (h, i) =>
546      val do_update = s2_hist_update_mask(i)
547      val hist_updated = Mux(s2_hist_bits(i),
548        Mux(h.andR, h, h + 1.U),
549        Mux(h === 0.U, 0.U, h - 1.U)
550      )
551      Mux(do_update, hist_updated, h)
552  })))
553  val s2_hist_pf_gen = Mux1H(s2_hit_vec, s2_ram_rdata.map(way => VecInit(way.hist.map(_.head(1))).asUInt))
554  val s2_new_hist = VecInit(s2_hist_bits.asBools.map(b => Cat(0.U((PHT_HIST_BITS - 1).W), b)))
555  val s2_pht_hit = Cat(s2_hit_vec).orR
556  val s2_hist = Mux(s2_pht_hit, Mux1H(s2_hit_vec, s2_hist_update), s2_new_hist)
557  val s2_repl_way_mask = UIntToOH(s2_replace_way)
558
559  // pipe s3: send addr/data to ram, gen pf_req
560  val s3_valid = RegNext(s2_valid, false.B)
561  val s3_evict = RegEnable(s2_evict, s2_valid)
562  val s3_hist = RegEnable(s2_hist, s2_valid)
563  val s3_hist_pf_gen = RegEnable(s2_hist_pf_gen, s2_valid)
564  val s3_hist_update_mask = RegEnable(s2_hist_update_mask.asUInt, s2_valid)
565  val s3_region_offset = RegEnable(s2_region_offset, s2_valid)
566  val s3_region_offset_mask = RegEnable(s2_region_offset_mask, s2_valid)
567  val s3_decr_mode = RegEnable(s2_decr_mode, s2_valid)
568  val s3_region_paddr = RegEnable(s2_region_paddr, s2_valid)
569  val s3_region_vaddr = RegEnable(s2_region_vaddr, s2_valid)
570  val s3_pht_tag = RegEnable(s2_tag, s2_valid)
571  val s3_hit_vec = s2_hit_vec.map(h => RegEnable(h, s2_valid))
572  val s3_hit = Cat(s3_hit_vec).orR
573  val s3_hit_way = OHToUInt(s3_hit_vec)
574  val s3_repl_way = RegEnable(s2_replace_way, s2_valid)
575  val s3_repl_way_mask = RegEnable(s2_repl_way_mask, s2_valid)
576  val s3_repl_update_mask = RegEnable(VecInit((0 until PHT_SETS).map(i => i.U === s2_ram_waddr)), s2_valid)
577  val s3_ram_waddr = RegEnable(s2_ram_waddr, s2_valid)
578  s3_ram_en := s3_valid && s3_evict
579  val s3_ram_wdata = Wire(new PhtEntry())
580  s3_ram_wdata.hist := s3_hist
581  s3_ram_wdata.tag := s3_pht_tag
582  s3_ram_wdata.decr_mode := s3_decr_mode
583
584  s1_wait := (s2_valid && s2_evict && s2_ram_waddr === s1_ram_raddr) || s3_ram_en
585
586  for((valids, way_idx) <- pht_valids.zipWithIndex){
587    val update_way = s3_repl_way_mask(way_idx)
588    for((v, set_idx) <- valids.zipWithIndex){
589      val update_set = s3_repl_update_mask(set_idx)
590      when(s3_valid && s3_evict && !s3_hit && update_set && update_way){
591        v := true.B
592      }
593    }
594  }
595  for((r, i) <- replacement.zipWithIndex){
596    when(s3_valid && s3_repl_update_mask(i)){
597      when(s3_hit){
598        r.access(s3_hit_way)
599      }.elsewhen(s3_evict){
600        r.access(s3_repl_way)
601      }
602    }
603  }
604
605  val s3_way_mask = Mux(s3_hit,
606    VecInit(s3_hit_vec).asUInt,
607    s3_repl_way_mask,
608  ).asUInt
609
610  pht_ram.io.r(
611    s1_valid, s1_ram_raddr
612  )
613  pht_ram.io.w(
614    s3_ram_en, s3_ram_wdata, s3_ram_waddr, s3_way_mask
615  )
616
617  when(s3_valid && s3_hit){
618    assert(!Cat(s3_hit_vec).andR, "sms_pht: multi-hit!")
619  }
620
621  // generate pf req if hit
622  val s3_hist_hi = s3_hist_pf_gen.head(REGION_BLKS - 1)
623  val s3_hist_lo = s3_hist_pf_gen.tail(REGION_BLKS - 1)
624  val s3_hist_hi_shifted = (Cat(0.U((REGION_BLKS - 1).W), s3_hist_hi) << s3_region_offset)(2 * (REGION_BLKS - 1) - 1, 0)
625  val s3_hist_lo_shifted = (Cat(0.U((REGION_BLKS - 1).W), s3_hist_lo) << s3_region_offset)(2 * (REGION_BLKS - 1) - 1, 0)
626  val s3_cur_region_bits = Cat(s3_hist_hi_shifted.tail(REGION_BLKS - 1), 0.U(1.W)) |
627    Cat(0.U(1.W), s3_hist_lo_shifted.head(REGION_BLKS - 1))
628  val s3_incr_region_bits = Cat(0.U(1.W), s3_hist_hi_shifted.head(REGION_BLKS - 1))
629  val s3_decr_region_bits = Cat(s3_hist_lo_shifted.tail(REGION_BLKS - 1), 0.U(1.W))
630  val s3_pf_gen_valid = s3_valid && s3_hit && !s3_evict
631  val s3_cur_region_valid =  s3_pf_gen_valid && (s3_hist_pf_gen & s3_hist_update_mask).orR
632  val s3_incr_region_valid = s3_pf_gen_valid && (s3_hist_hi & (~s3_hist_update_mask.head(REGION_BLKS - 1)).asUInt).orR
633  val s3_decr_region_valid = s3_pf_gen_valid && (s3_hist_lo & (~s3_hist_update_mask.tail(REGION_BLKS - 1)).asUInt).orR
634  val s3_incr_region_vaddr = s3_region_vaddr + 1.U
635  val s3_decr_region_vaddr = s3_region_vaddr - 1.U
636  val s3_incr_crosspage = s3_incr_region_vaddr(REGION_ADDR_PAGE_BIT) =/= s3_region_vaddr(REGION_ADDR_PAGE_BIT)
637  val s3_decr_crosspage = s3_decr_region_vaddr(REGION_ADDR_PAGE_BIT) =/= s3_region_vaddr(REGION_ADDR_PAGE_BIT)
638  val s3_cur_region_tag = region_hash_tag(s3_region_vaddr)
639  val s3_incr_region_tag = region_hash_tag(s3_incr_region_vaddr)
640  val s3_decr_region_tag = region_hash_tag(s3_decr_region_vaddr)
641
642  val pf_gen_req_arb = Module(new Arbiter(new PfGenReq, 3))
643  val s4_pf_gen_cur_region_valid = RegInit(false.B)
644  val s4_pf_gen_cur_region = Reg(new PfGenReq)
645  val s4_pf_gen_incr_region_valid = RegInit(false.B)
646  val s4_pf_gen_incr_region = Reg(new PfGenReq)
647  val s4_pf_gen_decr_region_valid = RegInit(false.B)
648  val s4_pf_gen_decr_region = Reg(new PfGenReq)
649
650  s4_pf_gen_cur_region_valid := s3_cur_region_valid
651  when(s3_cur_region_valid){
652    s4_pf_gen_cur_region.region_addr := s3_region_paddr
653    s4_pf_gen_cur_region.region_tag := s3_cur_region_tag
654    s4_pf_gen_cur_region.region_bits := s3_cur_region_bits
655    s4_pf_gen_cur_region.paddr_valid := true.B
656    s4_pf_gen_cur_region.decr_mode := false.B
657  }
658  s4_pf_gen_incr_region_valid := s3_incr_region_valid ||
659    (!pf_gen_req_arb.io.in(1).ready && s4_pf_gen_incr_region_valid)
660  when(s3_incr_region_valid){
661    s4_pf_gen_incr_region.region_addr := Mux(s3_incr_crosspage, s3_incr_region_vaddr, s3_region_paddr)
662    s4_pf_gen_incr_region.region_tag := s3_incr_region_tag
663    s4_pf_gen_incr_region.region_bits := s3_incr_region_bits
664    s4_pf_gen_incr_region.paddr_valid := !s3_incr_crosspage
665    s4_pf_gen_incr_region.decr_mode := false.B
666  }
667  s4_pf_gen_decr_region_valid := s3_decr_region_valid ||
668    (!pf_gen_req_arb.io.in(2).ready && s4_pf_gen_decr_region_valid)
669  when(s3_decr_region_valid){
670    s4_pf_gen_decr_region.region_addr := Mux(s3_decr_crosspage, s3_decr_region_vaddr, s3_region_paddr)
671    s4_pf_gen_decr_region.region_tag := s3_decr_region_tag
672    s4_pf_gen_decr_region.region_bits := s3_decr_region_bits
673    s4_pf_gen_decr_region.paddr_valid := !s3_decr_crosspage
674    s4_pf_gen_decr_region.decr_mode := true.B
675  }
676
677  pf_gen_req_arb.io.in.head.valid := s4_pf_gen_cur_region_valid
678  pf_gen_req_arb.io.in.head.bits := s4_pf_gen_cur_region
679  pf_gen_req_arb.io.in(1).valid := s4_pf_gen_incr_region_valid
680  pf_gen_req_arb.io.in(1).bits := s4_pf_gen_incr_region
681  pf_gen_req_arb.io.in(2).valid := s4_pf_gen_decr_region_valid
682  pf_gen_req_arb.io.in(2).bits := s4_pf_gen_decr_region
683  pf_gen_req_arb.io.out.ready := true.B
684
685  io.pf_gen_req.valid := pf_gen_req_arb.io.out.valid
686  io.pf_gen_req.bits := pf_gen_req_arb.io.out.bits
687
688  XSPerfAccumulate("sms_pht_update", io.agt_update.valid)
689  XSPerfAccumulate("sms_pht_update_hit", s2_valid && s2_evict && s2_pht_hit)
690  XSPerfAccumulate("sms_pht_lookup", io.s2_agt_lookup.valid)
691  XSPerfAccumulate("sms_pht_lookup_hit", s2_valid && !s2_evict && s2_pht_hit)
692  for(i <- 0 until smsParams.pht_ways){
693    XSPerfAccumulate(s"sms_pht_write_way_$i", pht_ram.io.w.req.fire && pht_ram.io.w.req.bits.waymask.get(i))
694  }
695  for(i <- 0 until PHT_SETS){
696    XSPerfAccumulate(s"sms_pht_write_set_$i", pht_ram.io.w.req.fire && pht_ram.io.w.req.bits.setIdx === i.U)
697  }
698  XSPerfAccumulate(s"sms_pht_pf_gen", io.pf_gen_req.valid)
699}
700
701class PrefetchFilterEntry()(implicit p: Parameters) extends XSBundle with HasSMSModuleHelper {
702  val region_tag = UInt(REGION_TAG_WIDTH.W)
703  val region_addr = UInt(REGION_ADDR_BITS.W)
704  val region_bits = UInt(REGION_BLKS.W)
705  val filter_bits = UInt(REGION_BLKS.W)
706  val paddr_valid = Bool()
707  val decr_mode = Bool()
708}
709
710class PrefetchFilter()(implicit p: Parameters) extends XSModule with HasSMSModuleHelper {
711  val io = IO(new Bundle() {
712    val gen_req = Flipped(ValidIO(new PfGenReq()))
713    val tlb_req = new TlbRequestIO(2)
714    val l2_pf_addr = ValidIO(UInt(PAddrBits.W))
715  })
716  val entries = Seq.fill(smsParams.pf_filter_size){ Reg(new PrefetchFilterEntry()) }
717  val valids = Seq.fill(smsParams.pf_filter_size){ RegInit(false.B) }
718  val replacement = ReplacementPolicy.fromString("plru", smsParams.pf_filter_size)
719
720  val prev_valid = RegNext(io.gen_req.valid, false.B)
721  val prev_gen_req = RegEnable(io.gen_req.bits, io.gen_req.valid)
722
723  val tlb_req_arb = Module(new RRArbiter(new TlbReq, smsParams.pf_filter_size))
724  val pf_req_arb = Module(new RRArbiter(UInt(PAddrBits.W), smsParams.pf_filter_size))
725
726  io.tlb_req.req <> tlb_req_arb.io.out
727  io.tlb_req.resp.ready := true.B
728  io.tlb_req.req_kill := false.B
729  io.l2_pf_addr.valid := pf_req_arb.io.out.valid
730  io.l2_pf_addr.bits := pf_req_arb.io.out.bits
731  pf_req_arb.io.out.ready := true.B
732
733  val s1_valid = Wire(Bool())
734  val s1_hit = Wire(Bool())
735  val s1_replace_vec = Wire(UInt(smsParams.pf_filter_size.W))
736  val s1_tlb_fire_vec = Wire(UInt(smsParams.pf_filter_size.W))
737
738  // s0: entries lookup
739  val s0_gen_req = io.gen_req.bits
740  val s0_match_prev = prev_valid && (s0_gen_req.region_tag === prev_gen_req.region_tag)
741  val s0_gen_req_valid = io.gen_req.valid && !s0_match_prev
742  val s0_match_vec = valids.indices.map(i => {
743    valids(i) && entries(i).region_tag === s0_gen_req.region_tag && !(s1_valid && !s1_hit && s1_replace_vec(i))
744  })
745  val s0_any_matched = Cat(s0_match_vec).orR
746  val s0_replace_vec = UIntToOH(replacement.way)
747  val s0_hit = s0_gen_req_valid && s0_any_matched
748
749  for(((v, ent), i) <- valids.zip(entries).zipWithIndex){
750    val is_evicted = s1_valid && s1_replace_vec(i)
751    tlb_req_arb.io.in(i).valid := v && !s1_tlb_fire_vec(i) && !ent.paddr_valid && !is_evicted
752    tlb_req_arb.io.in(i).bits.vaddr := Cat(ent.region_addr, 0.U(log2Up(REGION_SIZE).W))
753    tlb_req_arb.io.in(i).bits.cmd := TlbCmd.read
754    tlb_req_arb.io.in(i).bits.size := 3.U
755    tlb_req_arb.io.in(i).bits.robIdx := DontCare
756    tlb_req_arb.io.in(i).bits.debug := DontCare
757
758    val pending_req_vec = ent.region_bits & (~ent.filter_bits).asUInt
759    val first_one_offset = PriorityMux(
760      pending_req_vec.asBools,
761      (0 until smsParams.filter_table_size).map(_.U(REGION_OFFSET.W))
762    )
763    val last_one_offset = PriorityMux(
764      pending_req_vec.asBools.reverse,
765      (0 until smsParams.filter_table_size).reverse.map(_.U(REGION_OFFSET.W))
766    )
767    val pf_addr = Cat(
768      ent.region_addr,
769      Mux(ent.decr_mode, last_one_offset, first_one_offset),
770      0.U(log2Up(dcacheParameters.blockBytes).W)
771    )
772    pf_req_arb.io.in(i).valid := v && Cat(pending_req_vec).orR && ent.paddr_valid && !is_evicted
773    pf_req_arb.io.in(i).bits := pf_addr
774  }
775
776  val s0_tlb_fire_vec = VecInit(tlb_req_arb.io.in.map(_.fire))
777  val s0_pf_fire_vec = VecInit(pf_req_arb.io.in.map(_.fire))
778
779  // s1: update or alloc
780  val s1_valid_r = RegNext(s0_gen_req_valid, false.B)
781  val s1_hit_r = RegEnable(s0_hit, s0_gen_req_valid)
782  val s1_gen_req = RegEnable(s0_gen_req, s0_gen_req_valid)
783  val s1_replace_vec_r = RegEnable(s0_replace_vec, s0_gen_req_valid && !s0_hit)
784  val s1_update_vec = RegEnable(VecInit(s0_match_vec).asUInt, s0_gen_req_valid && s0_hit)
785  val s1_tlb_fire_vec_r = RegNext(s0_tlb_fire_vec, 0.U.asTypeOf(s0_tlb_fire_vec))
786  val s1_alloc_entry = Wire(new PrefetchFilterEntry())
787  s1_valid := s1_valid_r
788  s1_hit := s1_hit_r
789  s1_replace_vec := s1_replace_vec_r
790  s1_tlb_fire_vec := s1_tlb_fire_vec_r.asUInt
791  s1_alloc_entry.region_tag := s1_gen_req.region_tag
792  s1_alloc_entry.region_addr := s1_gen_req.region_addr
793  s1_alloc_entry.region_bits := s1_gen_req.region_bits
794  s1_alloc_entry.paddr_valid := s1_gen_req.paddr_valid
795  s1_alloc_entry.decr_mode := s1_gen_req.decr_mode
796  s1_alloc_entry.filter_bits := 0.U
797  for(((v, ent), i) <- valids.zip(entries).zipWithIndex){
798    val alloc = s1_valid && !s1_hit && s1_replace_vec(i)
799    val update = s1_valid && s1_hit && s1_update_vec(i)
800    // for pf: use s0 data
801    val pf_fired = s0_pf_fire_vec(i)
802    val tlb_fired = s1_tlb_fire_vec(i) && !io.tlb_req.resp.bits.miss
803    when(tlb_fired){
804      ent.paddr_valid := !io.tlb_req.resp.bits.miss
805      ent.region_addr := region_addr(io.tlb_req.resp.bits.paddr.head)
806    }
807    when(update){
808      ent.region_bits := ent.region_bits | s1_gen_req.region_bits
809    }
810    when(pf_fired){
811      val curr_bit = UIntToOH(block_addr(pf_req_arb.io.in(i).bits)(REGION_OFFSET - 1, 0))
812      ent.filter_bits := ent.filter_bits | curr_bit
813    }
814    when(alloc){
815      ent := s1_alloc_entry
816      v := true.B
817    }
818  }
819  val s1_access_mask = Mux(s1_hit, s1_update_vec, s1_replace_vec)
820  val s1_access_way = OHToUInt(s1_access_mask.asUInt)
821  when(s1_valid){
822    replacement.access(s1_access_way)
823  }
824  when(s1_valid && s1_hit){
825    assert(PopCount(s1_update_vec) === 1.U, "sms_pf_filter: multi-hit")
826  }
827
828  XSPerfAccumulate("sms_pf_filter_recv_req", io.gen_req.valid)
829  XSPerfAccumulate("sms_pf_filter_hit", s1_valid && s1_hit)
830  XSPerfAccumulate("sms_pf_filter_tlb_req", io.tlb_req.req.fire)
831  XSPerfAccumulate("sms_pf_filter_tlb_resp_miss", io.tlb_req.resp.fire && io.tlb_req.resp.bits.miss)
832  for(i <- 0 until smsParams.pf_filter_size){
833    XSPerfAccumulate(s"sms_pf_filter_access_way_$i", s1_valid && s1_access_way === i.U)
834  }
835  XSPerfAccumulate("sms_pf_filter_l2_req", io.l2_pf_addr.valid)
836}
837
838class SMSPrefetcher()(implicit p: Parameters) extends BasePrefecher with HasSMSModuleHelper {
839
840  require(exuParameters.LduCnt == 2)
841
842  val io_agt_en = IO(Input(Bool()))
843  val io_pht_en = IO(Input(Bool()))
844  val io_act_threshold = IO(Input(UInt(REGION_OFFSET.W)))
845  val io_act_stride = IO(Input(UInt(6.W)))
846
847  val ld_curr = io.ld_in.map(_.bits)
848  val ld_curr_block_tag = ld_curr.map(x => block_hash_tag(x.vaddr))
849
850  // block filter
851  val ld_prev = io.ld_in.map(ld => RegEnable(ld.bits, ld.valid))
852  val ld_prev_block_tag = ld_curr_block_tag.zip(io.ld_in.map(_.valid)).map({
853    case (tag, v) => RegEnable(tag, v)
854  })
855  val ld_prev_vld = io.ld_in.map(ld => RegNext(ld.valid, false.B))
856
857  val ld_curr_match_prev = ld_curr_block_tag.map(cur_tag =>
858    Cat(ld_prev_block_tag.zip(ld_prev_vld).map({
859      case (prev_tag, prev_vld) => prev_vld && prev_tag === cur_tag
860    })).orR
861  )
862  val ld0_match_ld1 = io.ld_in.head.valid && io.ld_in.last.valid && ld_curr_block_tag.head === ld_curr_block_tag.last
863  val ld_curr_vld = Seq(
864    io.ld_in.head.valid && !ld_curr_match_prev.head,
865    io.ld_in.last.valid && !ld_curr_match_prev.last && !ld0_match_ld1
866  )
867  val ld0_older_than_ld1 = Cat(ld_curr_vld).andR && isBefore(ld_curr.head.uop.robIdx, ld_curr.last.uop.robIdx)
868  val pending_vld = RegNext(Cat(ld_curr_vld).andR, false.B)
869  val pending_sel_ld0 = RegNext(Mux(pending_vld, ld0_older_than_ld1, !ld0_older_than_ld1))
870  val pending_ld = Mux(pending_sel_ld0, ld_prev.head, ld_prev.last)
871  val pending_ld_block_tag = Mux(pending_sel_ld0, ld_prev_block_tag.head, ld_prev_block_tag.last)
872
873  // prepare training data
874  val train_ld = RegEnable(
875    Mux(pending_vld, pending_ld, Mux(ld0_older_than_ld1 || !ld_curr_vld.last, ld_curr.head, ld_curr.last)),
876    pending_vld || Cat(ld_curr_vld).orR
877  )
878
879  val train_block_tag = RegEnable(
880    Mux(pending_vld, pending_ld_block_tag,
881      Mux(ld0_older_than_ld1 || !ld_curr_vld.last, ld_curr_block_tag.head, ld_curr_block_tag.last)
882    ),
883    pending_vld || Cat(ld_curr_vld).orR
884  )
885  val train_region_tag = train_block_tag.head(REGION_TAG_WIDTH)
886
887  val train_region_addr_raw = region_addr(train_ld.vaddr)(REGION_TAG_WIDTH + 2 * VADDR_HASH_WIDTH - 1, 0)
888  val train_region_addr_p1 = Cat(0.U(1.W), train_region_addr_raw) + 1.U
889  val train_region_addr_m1 = Cat(0.U(1.W), train_region_addr_raw) - 1.U
890  // addr_p1 or addr_m1 is valid?
891  val train_allow_cross_region_p1 = !train_region_addr_p1.head(1).asBool
892  val train_allow_cross_region_m1 = !train_region_addr_m1.head(1).asBool
893
894  val train_region_p1_tag = region_hash_tag(train_region_addr_p1.tail(1))
895  val train_region_m1_tag = region_hash_tag(train_region_addr_m1.tail(1))
896
897  val train_region_p1_cross_page = page_bit(train_region_addr_p1) ^ page_bit(train_region_addr_raw)
898  val train_region_m1_cross_page = page_bit(train_region_addr_m1) ^ page_bit(train_region_addr_raw)
899
900  val train_region_paddr = region_addr(train_ld.paddr)
901  val train_region_vaddr = region_addr(train_ld.vaddr)
902  val train_region_offset = train_block_tag(REGION_OFFSET - 1, 0)
903  val train_vld = RegNext(pending_vld || Cat(ld_curr_vld).orR, false.B)
904
905
906  // prefetch stage0
907  val filter_table = Module(new FilterTable())
908  val active_gen_table = Module(new ActiveGenerationTable())
909  val pht = Module(new PatternHistoryTable())
910  val pf_filter = Module(new PrefetchFilter())
911
912  val train_vld_s0 = RegNext(train_vld, false.B)
913  val train_s0 = RegEnable(train_ld, train_vld)
914  val train_region_tag_s0 = RegEnable(train_region_tag, train_vld)
915  val train_region_p1_tag_s0 = RegEnable(train_region_p1_tag, train_vld)
916  val train_region_m1_tag_s0 = RegEnable(train_region_m1_tag, train_vld)
917  val train_allow_cross_region_p1_s0 = RegEnable(train_allow_cross_region_p1, train_vld)
918  val train_allow_cross_region_m1_s0 = RegEnable(train_allow_cross_region_m1, train_vld)
919  val train_pht_tag_s0 = RegEnable(pht_tag(train_ld.uop.cf.pc), train_vld)
920  val train_pht_index_s0 = RegEnable(pht_index(train_ld.uop.cf.pc), train_vld)
921  val train_region_offset_s0 = RegEnable(train_region_offset, train_vld)
922  val train_region_p1_cross_page_s0 = RegEnable(train_region_p1_cross_page, train_vld)
923  val train_region_m1_cross_page_s0 = RegEnable(train_region_m1_cross_page, train_vld)
924  val train_region_paddr_s0 = RegEnable(train_region_paddr, train_vld)
925  val train_region_vaddr_s0 = RegEnable(train_region_vaddr, train_vld)
926
927  filter_table.io.s0_lookup.valid := train_vld_s0
928  filter_table.io.s0_lookup.bits.pht_tag := train_pht_tag_s0
929  filter_table.io.s0_lookup.bits.pht_index := train_pht_index_s0
930  filter_table.io.s0_lookup.bits.region_tag := train_region_tag_s0
931  filter_table.io.s0_lookup.bits.offset := train_region_offset_s0
932  filter_table.io.s1_update := !active_gen_table.io.s1_match_or_alloc
933
934  active_gen_table.io.act_threshold := io_act_threshold
935  active_gen_table.io.act_stride := io_act_stride
936  active_gen_table.io.s0_lookup.valid := train_vld_s0
937  active_gen_table.io.s0_lookup.bits.region_tag := train_region_tag_s0
938  active_gen_table.io.s0_lookup.bits.region_p1_tag := train_region_p1_tag_s0
939  active_gen_table.io.s0_lookup.bits.region_m1_tag := train_region_m1_tag_s0
940  active_gen_table.io.s0_lookup.bits.region_offset := train_region_offset_s0
941  active_gen_table.io.s0_lookup.bits.pht_index := train_pht_index_s0
942  active_gen_table.io.s0_lookup.bits.pht_tag := train_pht_tag_s0
943  active_gen_table.io.s0_lookup.bits.allow_cross_region_p1 := train_allow_cross_region_p1_s0
944  active_gen_table.io.s0_lookup.bits.allow_cross_region_m1 := train_allow_cross_region_m1_s0
945  active_gen_table.io.s0_lookup.bits.region_p1_cross_page := train_region_p1_cross_page_s0
946  active_gen_table.io.s0_lookup.bits.region_m1_cross_page := train_region_m1_cross_page_s0
947  active_gen_table.io.s0_lookup.bits.region_paddr := train_region_paddr_s0
948  active_gen_table.io.s0_lookup.bits.region_vaddr := train_region_vaddr_s0
949
950  val train_region_offset_s1 = RegEnable(train_region_offset_s0, train_vld_s0)
951  val agt_region_bits_s1 = region_offset_to_bits(train_region_offset_s1) |
952    region_offset_to_bits(filter_table.io.s1_result.bits.offset)
953
954  active_gen_table.io.s1_recv_entry.valid := filter_table.io.s1_result.valid
955  active_gen_table.io.s1_recv_entry.bits.pht_index := filter_table.io.s1_result.bits.pht_index
956  active_gen_table.io.s1_recv_entry.bits.pht_tag := filter_table.io.s1_result.bits.pht_tag
957  active_gen_table.io.s1_recv_entry.bits.region_bits := agt_region_bits_s1
958  active_gen_table.io.s1_recv_entry.bits.region_tag := filter_table.io.s1_result.bits.region_tag
959  active_gen_table.io.s1_recv_entry.bits.access_cnt := 2.U
960  active_gen_table.io.s1_recv_entry.bits.decr_mode := false.B
961
962  pht.io.s2_agt_lookup := active_gen_table.io.s2_pht_lookup
963  pht.io.agt_update := active_gen_table.io.s2_evict
964
965  val pht_gen_valid = pht.io.pf_gen_req.valid && io_pht_en
966  val agt_gen_valid = active_gen_table.io.s2_pf_gen_req.valid && io_agt_en
967  val pf_gen_req = Mux(agt_gen_valid,
968    active_gen_table.io.s2_pf_gen_req.bits,
969    pht.io.pf_gen_req.bits
970  )
971  pf_filter.io.gen_req.valid := pht_gen_valid || agt_gen_valid
972  pf_filter.io.gen_req.bits := pf_gen_req
973  io.tlb_req <> pf_filter.io.tlb_req
974  io.pf_addr.valid := pf_filter.io.l2_pf_addr.valid && io.enable
975  io.pf_addr.bits := pf_filter.io.l2_pf_addr.bits
976
977  XSPerfAccumulate("sms_pf_gen_conflict",
978    pht_gen_valid && agt_gen_valid
979  )
980  XSPerfAccumulate("sms_pht_disabled", pht.io.pf_gen_req.valid && !io_pht_en)
981  XSPerfAccumulate("sms_agt_disabled", active_gen_table.io.s2_pf_gen_req.valid && !io_agt_en)
982  XSPerfAccumulate("sms_pf_real_issued", io.pf_addr.valid)
983}
984