1package xiangshan.mem.prefetch 2 3import org.chipsalliance.cde.config.Parameters 4import chisel3._ 5import chisel3.util._ 6import xiangshan._ 7import utils._ 8import utility._ 9import xiangshan.cache.HasDCacheParameters 10import xiangshan.cache.mmu._ 11import xiangshan.mem.{L1PrefetchReq, LdPrefetchTrainBundle} 12import xiangshan.mem.trace._ 13import scala.collection.SeqLike 14 15trait HasStridePrefetchHelper extends HasL1PrefetchHelper { 16 val STRIDE_FILTER_SIZE = 6 17 val STRIDE_ENTRY_NUM = 10 18 val STRIDE_BITS = 10 + BLOCK_OFFSET 19 val STRIDE_VADDR_BITS = 10 + BLOCK_OFFSET 20 val STRIDE_CONF_BITS = 2 21 22 // detail control 23 val ALWAYS_UPDATE_PRE_VADDR = 1 // 1 for true, 0 for false 24 val AGGRESIVE_POLICY = false // if true, prefetch degree is greater than 1, 1 otherwise 25 val STRIDE_LOOK_AHEAD_BLOCKS = 2 // aggressive degree 26 val LOOK_UP_STREAM = false // if true, avoid collision with stream 27 28 val STRIDE_WIDTH_BLOCKS = if(AGGRESIVE_POLICY) STRIDE_LOOK_AHEAD_BLOCKS else 1 29 30 def MAX_CONF = (1 << STRIDE_CONF_BITS) - 1 31} 32 33class StrideMetaBundle(implicit p: Parameters) extends XSBundle with HasStridePrefetchHelper { 34 val pre_vaddr = UInt(STRIDE_VADDR_BITS.W) 35 val stride = UInt(STRIDE_BITS.W) 36 val confidence = UInt(STRIDE_CONF_BITS.W) 37 val hash_pc = UInt(HASH_TAG_WIDTH.W) 38 39 def reset(index: Int) = { 40 pre_vaddr := 0.U 41 stride := 0.U 42 confidence := 0.U 43 hash_pc := index.U 44 } 45 46 def alloc(vaddr: UInt, alloc_hash_pc: UInt) = { 47 pre_vaddr := vaddr(STRIDE_VADDR_BITS - 1, 0) 48 stride := 0.U 49 confidence := 0.U 50 hash_pc := alloc_hash_pc 51 } 52 53 def update(vaddr: UInt, always_update_pre_vaddr: Bool) = { 54 val new_vaddr = vaddr(STRIDE_VADDR_BITS - 1, 0) 55 val new_stride = new_vaddr - pre_vaddr 56 val new_stride_blk = block_addr(new_stride) 57 // NOTE: for now, disable negtive stride 58 val stride_valid = new_stride_blk =/= 0.U && new_stride_blk =/= 1.U && new_stride(STRIDE_VADDR_BITS - 1) === 0.U 59 val stride_match = new_stride === stride 60 val low_confidence = confidence <= 1.U 61 val can_send_pf = stride_valid && stride_match && confidence === MAX_CONF.U 62 63 when(stride_valid) { 64 when(stride_match) { 65 confidence := Mux(confidence === MAX_CONF.U, confidence, confidence + 1.U) 66 }.otherwise { 67 confidence := Mux(confidence === 0.U, confidence, confidence - 1.U) 68 when(low_confidence) { 69 stride := new_stride 70 } 71 } 72 pre_vaddr := new_vaddr 73 } 74 when(always_update_pre_vaddr) { 75 pre_vaddr := new_vaddr 76 } 77 78 (can_send_pf, new_stride) 79 } 80 81} 82 83class StrideMetaArray(implicit p: Parameters) extends XSModule with HasStridePrefetchHelper { 84 val io = IO(new XSBundle { 85 val enable = Input(Bool()) 86 // TODO: flush all entry when process changing happens, or disable stream prefetch for a while 87 val flush = Input(Bool()) 88 val dynamic_depth = Input(UInt(32.W)) // TODO: enable dynamic stride depth 89 val train_req = Flipped(DecoupledIO(new PrefetchReqBundle)) 90 val l1_prefetch_req = ValidIO(new StreamPrefetchReqBundle) 91 val l2_l3_prefetch_req = ValidIO(new StreamPrefetchReqBundle) 92 // query Stream component to see if a stream pattern has already been detected 93 val stream_lookup_req = ValidIO(new PrefetchReqBundle) 94 val stream_lookup_resp = Input(Bool()) 95 }) 96 97 val array = Reg(Vec(STRIDE_ENTRY_NUM, new StrideMetaBundle)) 98 val replacement = ReplacementPolicy.fromString("plru", STRIDE_ENTRY_NUM) 99 100 // s0: hash pc -> cam all entries 101 val s0_can_accept = Wire(Bool()) 102 val s0_valid = io.train_req.fire 103 val s0_vaddr = io.train_req.bits.vaddr 104 val s0_pc = io.train_req.bits.pc 105 val s0_pc_hash = pc_hash_tag(s0_pc) 106 val s0_pc_match_vec = VecInit(array.map(_.hash_pc === s0_pc_hash)).asUInt 107 val s0_hit = s0_pc_match_vec.orR 108 val s0_index = Mux(s0_hit, OHToUInt(s0_pc_match_vec), replacement.way) 109 io.train_req.ready := s0_can_accept 110 io.stream_lookup_req.valid := s0_valid 111 io.stream_lookup_req.bits := io.train_req.bits 112 113 when(s0_valid) { 114 replacement.access(s0_index) 115 } 116 117 assert(PopCount(s0_pc_match_vec) <= 1.U) 118 XSPerfAccumulate("s0_valid", s0_valid) 119 XSPerfAccumulate("s0_hit", s0_valid && s0_hit) 120 XSPerfAccumulate("s0_miss", s0_valid && !s0_hit) 121 122 // s1: alloc or update 123 val s1_valid = RegNext(s0_valid) 124 val s1_index = RegEnable(s0_index, s0_valid) 125 val s1_pc_hash = RegEnable(s0_pc_hash, s0_valid) 126 val s1_vaddr = RegEnable(s0_vaddr, s0_valid) 127 val s1_hit = RegEnable(s0_hit, s0_valid) 128 val s1_alloc = s1_valid && !s1_hit 129 val s1_update = s1_valid && s1_hit 130 val s1_stride = array(s1_index).stride 131 val s1_new_stride = WireInit(0.U(STRIDE_BITS.W)) 132 val s1_can_send_pf = WireInit(false.B) 133 s0_can_accept := !(s1_valid && s1_pc_hash === s0_pc_hash) 134 135 val always_update = WireInit(Constantin.createRecord("always_update" + p(XSCoreParamsKey).HartId.toString, initValue = ALWAYS_UPDATE_PRE_VADDR.U)) === 1.U 136 137 when(s1_alloc) { 138 array(s1_index).alloc( 139 vaddr = s1_vaddr, 140 alloc_hash_pc = s1_pc_hash 141 ) 142 }.elsewhen(s1_update) { 143 val res = array(s1_index).update(s1_vaddr, always_update) 144 s1_can_send_pf := res._1 145 s1_new_stride := res._2 146 } 147 148 val l1_stride_ratio_const = WireInit(Constantin.createRecord("l1_stride_ratio" + p(XSCoreParamsKey).HartId.toString, initValue = 2.U)) 149 val l1_stride_ratio = l1_stride_ratio_const(3, 0) 150 val l2_stride_ratio_const = WireInit(Constantin.createRecord("l2_stride_ratio" + p(XSCoreParamsKey).HartId.toString, initValue = 5.U)) 151 val l2_stride_ratio = l2_stride_ratio_const(3, 0) 152 // s2: calculate L1 & L2 pf addr 153 val s2_valid = RegNext(s1_valid && s1_can_send_pf) 154 val s2_vaddr = RegEnable(s1_vaddr, s1_valid && s1_can_send_pf) 155 val s2_stride = RegEnable(s1_stride, s1_valid && s1_can_send_pf) 156 val s2_l1_depth = s2_stride << l1_stride_ratio 157 val s2_l1_pf_vaddr = (s2_vaddr + s2_l1_depth)(VAddrBits - 1, 0) 158 val s2_l2_depth = s2_stride << l2_stride_ratio 159 val s2_l2_pf_vaddr = (s2_vaddr + s2_l2_depth)(VAddrBits - 1, 0) 160 val s2_l1_pf_req_bits = (new StreamPrefetchReqBundle).getStreamPrefetchReqBundle( 161 valid = s2_valid, 162 vaddr = s2_l1_pf_vaddr, 163 width = STRIDE_WIDTH_BLOCKS, 164 decr_mode = false.B, 165 sink = SINK_L1, 166 source = L1_HW_PREFETCH_STRIDE, 167 // TODO: add stride debug db, not useful for now 168 t_pc = 0xdeadbeefL.U, 169 t_va = 0xdeadbeefL.U 170 ) 171 val s2_l2_pf_req_bits = (new StreamPrefetchReqBundle).getStreamPrefetchReqBundle( 172 valid = s2_valid, 173 vaddr = s2_l2_pf_vaddr, 174 width = STRIDE_WIDTH_BLOCKS, 175 decr_mode = false.B, 176 sink = SINK_L2, 177 source = L1_HW_PREFETCH_STRIDE, 178 // TODO: add stride debug db, not useful for now 179 t_pc = 0xdeadbeefL.U, 180 t_va = 0xdeadbeefL.U 181 ) 182 183 // s3: send l1 pf out 184 val s3_valid = if (LOOK_UP_STREAM) RegNext(s2_valid) && !io.stream_lookup_resp else RegNext(s2_valid) 185 val s3_l1_pf_req_bits = RegEnable(s2_l1_pf_req_bits, s2_valid) 186 val s3_l2_pf_req_bits = RegEnable(s2_l2_pf_req_bits, s2_valid) 187 188 // s4: send l2 pf out 189 val s4_valid = RegNext(s3_valid) 190 val s4_l2_pf_req_bits = RegEnable(s3_l2_pf_req_bits, s3_valid) 191 192 io.l1_prefetch_req.valid := s3_valid 193 io.l1_prefetch_req.bits := s3_l1_pf_req_bits 194 io.l2_l3_prefetch_req.valid := s4_valid 195 io.l2_l3_prefetch_req.bits := s4_l2_pf_req_bits 196 197 XSPerfAccumulate("pf_valid", PopCount(Seq(io.l1_prefetch_req.valid, io.l2_l3_prefetch_req.valid))) 198 XSPerfAccumulate("l1_pf_valid", s3_valid) 199 XSPerfAccumulate("l2_pf_valid", s4_valid) 200 XSPerfAccumulate("detect_stream", io.stream_lookup_resp) 201 XSPerfHistogram("high_conf_num", PopCount(VecInit(array.map(_.confidence === MAX_CONF.U))).asUInt, true.B, 0, STRIDE_ENTRY_NUM, 1) 202 for(i <- 0 until STRIDE_ENTRY_NUM) { 203 XSPerfAccumulate(s"entry_${i}_update", i.U === s1_index && s1_update) 204 for(j <- 0 until 4) { 205 XSPerfAccumulate(s"entry_${i}_disturb_${j}", i.U === s1_index && s1_update && 206 j.U === s1_new_stride && 207 array(s1_index).confidence === MAX_CONF.U && 208 array(s1_index).stride =/= s1_new_stride 209 ) 210 } 211 } 212 213 for(i <- 0 until STRIDE_ENTRY_NUM) { 214 when(reset.asBool || RegNext(io.flush)) { 215 array(i).reset(i) 216 } 217 } 218}