xref: /XiangShan/src/main/scala/xiangshan/mem/prefetch/L1StridePrefetcher.scala (revision 4ccb2e8b3629dc48d470568215e87ee66f85508b)
1package xiangshan.mem.prefetch
2
3import org.chipsalliance.cde.config.Parameters
4import chisel3._
5import chisel3.util._
6import xiangshan._
7import utils._
8import utility._
9import xiangshan.cache.HasDCacheParameters
10import xiangshan.cache.mmu._
11import xiangshan.mem.{L1PrefetchReq, LdPrefetchTrainBundle}
12import xiangshan.mem.trace._
13import scala.collection.SeqLike
14
15trait HasStridePrefetchHelper extends HasL1PrefetchHelper {
16  val STRIDE_FILTER_SIZE = 6
17  val STRIDE_ENTRY_NUM = 10
18  val STRIDE_BITS = 10 + BLOCK_OFFSET
19  val STRIDE_VADDR_BITS = 10 + BLOCK_OFFSET
20  val STRIDE_CONF_BITS = 2
21
22  // detail control
23  val ALWAYS_UPDATE_PRE_VADDR = true
24  val AGGRESIVE_POLICY = false // if true, prefetch degree is greater than 1, 1 otherwise
25  val STRIDE_LOOK_AHEAD_BLOCKS = 2 // aggressive degree
26  val LOOK_UP_STREAM = false // if true, avoid collision with stream
27
28  val STRIDE_WIDTH_BLOCKS = if(AGGRESIVE_POLICY) STRIDE_LOOK_AHEAD_BLOCKS else 1
29
30  def MAX_CONF = (1 << STRIDE_CONF_BITS) - 1
31}
32
33class StrideMetaBundle(implicit p: Parameters) extends XSBundle with HasStridePrefetchHelper {
34  val pre_vaddr = UInt(STRIDE_VADDR_BITS.W)
35  val stride = UInt(STRIDE_BITS.W)
36  val confidence = UInt(STRIDE_CONF_BITS.W)
37  val hash_pc = UInt(HASH_TAG_WIDTH.W)
38
39  def reset(index: Int) = {
40    pre_vaddr := 0.U
41    stride := 0.U
42    confidence := 0.U
43    hash_pc := index.U
44  }
45
46  def alloc(vaddr: UInt, alloc_hash_pc: UInt) = {
47    pre_vaddr := vaddr(STRIDE_VADDR_BITS - 1, 0)
48    stride := 0.U
49    confidence := 0.U
50    hash_pc := alloc_hash_pc
51  }
52
53  def update(vaddr: UInt, always_update_pre_vaddr: Bool) = {
54    val new_vaddr = vaddr(STRIDE_VADDR_BITS - 1, 0)
55    val new_stride = new_vaddr - pre_vaddr
56    val new_stride_blk = block_addr(new_stride)
57    // NOTE: for now, disable negtive stride
58    val stride_valid = new_stride_blk =/= 0.U && new_stride_blk =/= 1.U && new_stride(STRIDE_VADDR_BITS - 1) === 0.U
59    val stride_match = new_stride === stride
60    val low_confidence = confidence <= 1.U
61    val can_send_pf = stride_valid && stride_match && confidence === MAX_CONF.U
62
63    when(stride_valid) {
64      when(stride_match) {
65        confidence := Mux(confidence === MAX_CONF.U, confidence, confidence + 1.U)
66      }.otherwise {
67        confidence := Mux(confidence === 0.U, confidence, confidence - 1.U)
68        when(low_confidence) {
69          stride := new_stride
70        }
71      }
72      pre_vaddr := new_vaddr
73    }
74    when(always_update_pre_vaddr) {
75      pre_vaddr := new_vaddr
76    }
77
78    (can_send_pf, new_stride)
79  }
80
81}
82
83class StrideMetaArray(implicit p: Parameters) extends XSModule with HasStridePrefetchHelper {
84  val io = IO(new XSBundle {
85    val enable = Input(Bool())
86    // TODO: flush all entry when process changing happens, or disable stream prefetch for a while
87    val flush = Input(Bool())
88    val dynamic_depth = Input(UInt(32.W)) // TODO: enable dynamic stride depth
89    val train_req = Flipped(DecoupledIO(new PrefetchReqBundle))
90    val l1_prefetch_req = ValidIO(new StreamPrefetchReqBundle)
91    val l2_l3_prefetch_req = ValidIO(new StreamPrefetchReqBundle)
92    // query Stream component to see if a stream pattern has already been detected
93    val stream_lookup_req  = ValidIO(new PrefetchReqBundle)
94    val stream_lookup_resp = Input(Bool())
95  })
96
97  val array = Reg(Vec(STRIDE_ENTRY_NUM, new StrideMetaBundle))
98  val replacement = ReplacementPolicy.fromString("plru", STRIDE_ENTRY_NUM)
99
100  // s0: hash pc -> cam all entries
101  val s0_can_accept = Wire(Bool())
102  val s0_valid = io.train_req.fire
103  val s0_vaddr = io.train_req.bits.vaddr
104  val s0_pc = io.train_req.bits.pc
105  val s0_pc_hash = pc_hash_tag(s0_pc)
106  val s0_pc_match_vec = VecInit(array.map(_.hash_pc === s0_pc_hash)).asUInt
107  val s0_hit = s0_pc_match_vec.orR
108  val s0_index = Mux(s0_hit, OHToUInt(s0_pc_match_vec), replacement.way)
109  io.train_req.ready := s0_can_accept
110  io.stream_lookup_req.valid := s0_valid
111  io.stream_lookup_req.bits  := io.train_req.bits
112
113  when(s0_valid) {
114    replacement.access(s0_index)
115  }
116
117  assert(PopCount(s0_pc_match_vec) <= 1.U)
118  XSPerfAccumulate("s0_valid", s0_valid)
119  XSPerfAccumulate("s0_hit", s0_valid && s0_hit)
120  XSPerfAccumulate("s0_miss", s0_valid && !s0_hit)
121
122  // s1: alloc or update
123  val s1_valid = GatedValidRegNext(s0_valid)
124  val s1_index = RegEnable(s0_index, s0_valid)
125  val s1_pc_hash = RegEnable(s0_pc_hash, s0_valid)
126  val s1_vaddr = RegEnable(s0_vaddr, s0_valid)
127  val s1_hit = RegEnable(s0_hit, s0_valid)
128  val s1_alloc = s1_valid && !s1_hit
129  val s1_update = s1_valid && s1_hit
130  val s1_stride = array(s1_index).stride
131  val s1_new_stride = WireInit(0.U(STRIDE_BITS.W))
132  val s1_can_send_pf = WireInit(false.B)
133  s0_can_accept := !(s1_valid && s1_pc_hash === s0_pc_hash)
134
135  val always_update = Constantin.createRecord(s"always_update${p(XSCoreParamsKey).HartId}", initValue = ALWAYS_UPDATE_PRE_VADDR)
136
137  when(s1_alloc) {
138    array(s1_index).alloc(
139      vaddr = s1_vaddr,
140      alloc_hash_pc = s1_pc_hash
141    )
142  }.elsewhen(s1_update) {
143    val res = array(s1_index).update(s1_vaddr, always_update)
144    s1_can_send_pf := res._1
145    s1_new_stride := res._2
146  }
147
148  val l1_stride_ratio_const = Constantin.createRecord(s"l1_stride_ratio${p(XSCoreParamsKey).HartId}", initValue = 2)
149  val l1_stride_ratio = l1_stride_ratio_const(3, 0)
150  val l2_stride_ratio_const = Constantin.createRecord(s"l2_stride_ratio${p(XSCoreParamsKey).HartId}", initValue = 5)
151  val l2_stride_ratio = l2_stride_ratio_const(3, 0)
152  // s2: calculate L1 & L2 pf addr
153  val s2_valid = GatedValidRegNext(s1_valid && s1_can_send_pf)
154  val s2_vaddr = RegEnable(s1_vaddr, s1_valid && s1_can_send_pf)
155  val s2_stride = RegEnable(s1_stride, s1_valid && s1_can_send_pf)
156  val s2_l1_depth = s2_stride << l1_stride_ratio
157  val s2_l1_pf_vaddr = (s2_vaddr + s2_l1_depth)(VAddrBits - 1, 0)
158  val s2_l2_depth = s2_stride << l2_stride_ratio
159  val s2_l2_pf_vaddr = (s2_vaddr + s2_l2_depth)(VAddrBits - 1, 0)
160  val s2_l1_pf_req_bits = (new StreamPrefetchReqBundle).getStreamPrefetchReqBundle(
161    valid = s2_valid,
162    vaddr = s2_l1_pf_vaddr,
163    width = STRIDE_WIDTH_BLOCKS,
164    decr_mode = false.B,
165    sink = SINK_L1,
166    source = L1_HW_PREFETCH_STRIDE,
167    // TODO: add stride debug db, not useful for now
168    t_pc = 0xdeadbeefL.U,
169    t_va = 0xdeadbeefL.U
170    )
171  val s2_l2_pf_req_bits = (new StreamPrefetchReqBundle).getStreamPrefetchReqBundle(
172    valid = s2_valid,
173    vaddr = s2_l2_pf_vaddr,
174    width = STRIDE_WIDTH_BLOCKS,
175    decr_mode = false.B,
176    sink = SINK_L2,
177    source = L1_HW_PREFETCH_STRIDE,
178    // TODO: add stride debug db, not useful for now
179    t_pc = 0xdeadbeefL.U,
180    t_va = 0xdeadbeefL.U
181    )
182
183  // s3: send l1 pf out
184  val s3_valid = if (LOOK_UP_STREAM) GatedValidRegNext(s2_valid) && !io.stream_lookup_resp else GatedValidRegNext(s2_valid)
185  val s3_l1_pf_req_bits = RegEnable(s2_l1_pf_req_bits, s2_valid)
186  val s3_l2_pf_req_bits = RegEnable(s2_l2_pf_req_bits, s2_valid)
187
188  // s4: send l2 pf out
189  val s4_valid = GatedValidRegNext(s3_valid)
190  val s4_l2_pf_req_bits = RegEnable(s3_l2_pf_req_bits, s3_valid)
191
192  io.l1_prefetch_req.valid := s3_valid
193  io.l1_prefetch_req.bits := s3_l1_pf_req_bits
194  io.l2_l3_prefetch_req.valid := s4_valid
195  io.l2_l3_prefetch_req.bits := s4_l2_pf_req_bits
196
197  XSPerfAccumulate("pf_valid", PopCount(Seq(io.l1_prefetch_req.valid, io.l2_l3_prefetch_req.valid)))
198  XSPerfAccumulate("l1_pf_valid", s3_valid)
199  XSPerfAccumulate("l2_pf_valid", s4_valid)
200  XSPerfAccumulate("detect_stream", io.stream_lookup_resp)
201  XSPerfHistogram("high_conf_num", PopCount(VecInit(array.map(_.confidence === MAX_CONF.U))).asUInt, true.B, 0, STRIDE_ENTRY_NUM, 1)
202  for(i <- 0 until STRIDE_ENTRY_NUM) {
203    XSPerfAccumulate(s"entry_${i}_update", i.U === s1_index && s1_update)
204    for(j <- 0 until 4) {
205      XSPerfAccumulate(s"entry_${i}_disturb_${j}", i.U === s1_index && s1_update &&
206                                                   j.U === s1_new_stride &&
207                                                   array(s1_index).confidence === MAX_CONF.U &&
208                                                   array(s1_index).stride =/= s1_new_stride
209      )
210    }
211  }
212
213  for(i <- 0 until STRIDE_ENTRY_NUM) {
214    when(reset.asBool || GatedValidRegNext(io.flush)) {
215      array(i).reset(i)
216    }
217  }
218}