xref: /XiangShan/src/main/scala/xiangshan/backend/datapath/WbArbiter.scala (revision 92b88f30156d46e844042eea94f7121557fd09a1)
1package xiangshan.backend.datapath
2
3import chipsalliance.rocketchip.config.Parameters
4import chisel3._
5import chisel3.util._
6import difftest.{DifftestFpWriteback, DifftestIntWriteback}
7import utils.XSError
8import xiangshan.backend.BackendParams
9import xiangshan.backend.Bundles.{ExuOutput, WriteBackBundle}
10import xiangshan.backend.regfile.RfWritePortWithConfig
11import xiangshan.{Redirect, XSBundle, XSModule}
12
13class WbArbiterDispatcherIO[T <: Data](private val gen: T, n: Int) extends Bundle {
14  val in = Flipped(DecoupledIO(gen))
15
16  val out = Vec(n, DecoupledIO(gen))
17}
18
19class WbArbiterDispatcher[T <: Data](private val gen: T, n: Int, acceptCond: T => Seq[Bool])
20                           (implicit p: Parameters)
21  extends Module {
22
23  val io = IO(new WbArbiterDispatcherIO(gen, n))
24
25  private val acceptVec: Vec[Bool] = VecInit(acceptCond(io.in.bits))
26
27  XSError(io.in.valid && PopCount(acceptVec) > 1.U, s"s[ExeUnit] accept vec should no more than 1, ${Binary(acceptVec.asUInt)} ")
28
29  io.out.zipWithIndex.foreach { case (out, i) =>
30    out.valid := acceptVec(i) && io.in.valid
31    out.bits := io.in.bits
32  }
33
34  io.in.ready := Cat(io.out.zip(acceptVec).map{ case(out, canAccept) => out.ready && canAccept}).orR
35}
36
37class WbArbiterIO()(implicit p: Parameters, params: WbArbiterParams) extends XSBundle {
38  val flush = Flipped(ValidIO(new Redirect))
39  val in: MixedVec[DecoupledIO[WriteBackBundle]] = Flipped(params.genInput)
40  val out: MixedVec[ValidIO[WriteBackBundle]] = params.genOutput
41
42  def inGroup: Map[Int, IndexedSeq[DecoupledIO[WriteBackBundle]]] = in.groupBy(_.bits.params.port)
43}
44
45class WbArbiter(params: WbArbiterParams)(implicit p: Parameters) extends XSModule {
46  val io = IO(new WbArbiterIO()(p, params))
47  // Todo: Sorted by priority
48  private val inGroup: Map[Int, IndexedSeq[DecoupledIO[WriteBackBundle]]] = io.inGroup
49
50  private val arbiters: Seq[Option[Arbiter[WriteBackBundle]]] = Seq.tabulate(params.numOut) { x => {
51    if (inGroup.contains(x)) {
52      Some(Module(new Arbiter(new WriteBackBundle(inGroup.values.head.head.bits.params), inGroup(x).length)))
53    } else {
54      None
55    }
56  }}
57
58  arbiters.zipWithIndex.foreach { case (arb, i) =>
59    if (arb.nonEmpty) {
60      arb.get.io.in.zip(inGroup(i)).foreach { case (arbIn, wbIn) =>
61        arbIn <> wbIn
62      }
63    }
64  }
65
66  io.out.zip(arbiters).foreach { case (wbOut, arb) =>
67    if (arb.nonEmpty) {
68      val arbOut = arb.get.io.out
69      arbOut.ready := true.B
70      wbOut.valid := arbOut.valid
71      wbOut.bits := arbOut.bits
72    } else {
73      wbOut := 0.U.asTypeOf(wbOut)
74    }
75  }
76
77  def getInOutMap: Map[Int, Int] = {
78    (params.wbCfgs.indices zip params.wbCfgs.map(_.port)).toMap
79  }
80}
81
82class WbDataPathIO()(implicit p: Parameters, params: BackendParams) extends XSBundle {
83  val flush = Flipped(ValidIO(new Redirect()))
84
85  val fromTop = new Bundle {
86    val hartId = Input(UInt(8.W))
87  }
88
89  val fromIntExu: MixedVec[MixedVec[DecoupledIO[ExuOutput]]] = Flipped(params.intSchdParams.get.genExuOutputDecoupledBundle)
90
91  val fromVfExu: MixedVec[MixedVec[DecoupledIO[ExuOutput]]] = Flipped(params.vfSchdParams.get.genExuOutputDecoupledBundle)
92
93  val fromMemExu: MixedVec[MixedVec[DecoupledIO[ExuOutput]]] = Flipped(params.memSchdParams.get.genExuOutputDecoupledBundle)
94
95  val toIntPreg = Flipped(MixedVec(Vec(params.intPregParams.numWrite,
96    new RfWritePortWithConfig(params.intPregParams.dataCfg, params.intPregParams.addrWidth))))
97
98  val toVfPreg = Flipped(MixedVec(Vec(params.vfPregParams.numWrite,
99    new RfWritePortWithConfig(params.vfPregParams.dataCfg, params.vfPregParams.addrWidth))))
100
101  val toCtrlBlock = new Bundle {
102    val writeback: MixedVec[ValidIO[ExuOutput]] = params.genWrite2CtrlBundles
103  }
104}
105
106class WbDataPath(params: BackendParams)(implicit p: Parameters) extends XSModule {
107  val io = IO(new WbDataPathIO()(p, params))
108
109  // alias
110  val fromExu = (io.fromIntExu ++ io.fromVfExu ++ io.fromMemExu).flatten
111  val intArbiterInputsWire = WireInit(MixedVecInit(fromExu))
112  val intArbiterInputsWireY = intArbiterInputsWire.filter(_.bits.params.writeIntRf)
113  val intArbiterInputsWireN = intArbiterInputsWire.filterNot(_.bits.params.writeIntRf)
114  val vfArbiterInputsWire = WireInit(MixedVecInit(fromExu))
115  val vfArbiterInputsWireY = vfArbiterInputsWire.filter(_.bits.params.writeVfRf)
116  val vfArbiterInputsWireN = vfArbiterInputsWire.filterNot(_.bits.params.writeVfRf)
117
118  def acceptCond(exuOutput: ExuOutput): Seq[Bool] = {
119    val intWen = if(exuOutput.intWen.isDefined) exuOutput.intWen.get else false.B
120    val fpwen  = if(exuOutput.fpWen.isDefined) exuOutput.fpWen.get else false.B
121    val vecWen = if(exuOutput.vecWen.isDefined) exuOutput.vecWen.get else false.B
122    Seq(intWen, fpwen || vecWen)
123  }
124
125  fromExu.zip(intArbiterInputsWire.zip(vfArbiterInputsWire))map{
126    case (exuOut, (intArbiterInput, vfArbiterInput)) =>
127      val regfilesTypeNum = params.pregParams.size
128      val in1ToN = Module(new WbArbiterDispatcher(new ExuOutput(exuOut.bits.params), regfilesTypeNum, acceptCond))
129      in1ToN.io.in.valid := exuOut.valid
130      in1ToN.io.in.bits := exuOut.bits
131      exuOut.ready := in1ToN.io.in.ready
132      in1ToN.io.out.zip(MixedVecInit(intArbiterInput, vfArbiterInput)).foreach { case (source, sink) =>
133        sink.valid := source.valid
134        sink.bits := source.bits
135        source.ready := sink.ready
136      }
137  }
138  intArbiterInputsWireN.foreach(_.ready := false.B)
139  vfArbiterInputsWireN.foreach(_.ready := false.B)
140
141  println(s"[WbDataPath] write int preg: " +
142    s"IntExu(${io.fromIntExu.flatten.count(_.bits.params.writeIntRf)}) " +
143    s"VfExu(${io.fromVfExu.flatten.count(_.bits.params.writeIntRf)}) " +
144    s"MemExu(${io.fromMemExu.flatten.count(_.bits.params.writeIntRf)})"
145  )
146  println(s"[WbDataPath] write vf preg: " +
147    s"IntExu(${io.fromIntExu.flatten.count(_.bits.params.writeVfRf)}) " +
148    s"VfExu(${io.fromVfExu.flatten.count(_.bits.params.writeVfRf)}) " +
149    s"MemExu(${io.fromMemExu.flatten.count(_.bits.params.writeVfRf)})"
150  )
151
152  // modules
153  private val intWbArbiter = Module(new WbArbiter(params.getIntWbArbiterParams))
154  private val vfWbArbiter = Module(new WbArbiter(params.getVfWbArbiterParams))
155  println(s"[WbDataPath] int preg write back port num: ${intWbArbiter.io.out.size}, active port: ${intWbArbiter.io.inGroup.keys.toSeq.sorted}")
156  println(s"[WbDataPath] vf preg write back port num: ${vfWbArbiter.io.out.size}, active port: ${vfWbArbiter.io.inGroup.keys.toSeq.sorted}")
157
158  // module assign
159  intWbArbiter.io.flush <> io.flush
160  require(intWbArbiter.io.in.size == intArbiterInputsWireY.size, s"intWbArbiter input size: ${intWbArbiter.io.in.size}, all vf wb size: ${intArbiterInputsWireY.size}")
161  intWbArbiter.io.in.zip(intArbiterInputsWireY).foreach { case (arbiterIn, in) =>
162    arbiterIn.valid := in.valid && in.bits.intWen.get
163    in.ready := arbiterIn.ready
164    arbiterIn.bits.fromExuOutput(in.bits)
165  }
166  private val intWbArbiterOut = intWbArbiter.io.out
167
168  vfWbArbiter.io.flush <> io.flush
169  require(vfWbArbiter.io.in.size == vfArbiterInputsWireY.size, s"vfWbArbiter input size: ${vfWbArbiter.io.in.size}, all vf wb size: ${vfArbiterInputsWireY.size}")
170  vfWbArbiter.io.in.zip(vfArbiterInputsWireY).foreach { case (arbiterIn, in) =>
171    arbiterIn.valid := in.valid && (in.bits.fpWen.getOrElse(false.B) || in.bits.vecWen.getOrElse(false.B))
172    in.ready := arbiterIn.ready
173    arbiterIn.bits.fromExuOutput(in.bits)
174  }
175
176  private val vfWbArbiterOut = vfWbArbiter.io.out
177
178  private val intExuInputs = io.fromIntExu.flatten
179  private val intExuWBs = WireInit(MixedVecInit(io.fromIntExu.flatten))
180  private val vfExuInputs = io.fromVfExu.flatten
181  private val vfExuWBs = WireInit(MixedVecInit(io.fromVfExu.flatten))
182  private val memExuInputs = io.fromMemExu.flatten
183  private val memExuWBs = WireInit(MixedVecInit(io.fromMemExu.flatten))
184
185  // only fired port can write back to ctrl block
186  (intExuWBs zip intExuInputs).foreach { case (wb, input) => wb.valid := input.fire }
187  (vfExuWBs zip vfExuInputs).foreach { case (wb, input) => wb.valid := input.fire }
188  (memExuWBs zip memExuInputs).foreach { case (wb, input) => wb.valid := input.fire }
189
190  // the ports not writting back pregs are always ready
191  (intExuInputs ++ vfExuInputs ++ memExuInputs).foreach( x =>
192    if (x.bits.params.hasNoDataWB) x.ready := true.B
193  )
194
195  // io assign
196  private val toIntPreg: MixedVec[RfWritePortWithConfig] = MixedVecInit(intWbArbiterOut.map(x => x.bits.asIntRfWriteBundle(x.fire)))
197  private val toVfPreg: MixedVec[RfWritePortWithConfig] = MixedVecInit(vfWbArbiterOut.map(x => x.bits.asVfRfWriteBundle(x.fire)))
198
199  private val wb2Ctrl = intExuWBs ++ vfExuWBs ++ memExuWBs
200
201  io.toIntPreg := toIntPreg
202  io.toVfPreg := toVfPreg
203  io.toCtrlBlock.writeback.zip(wb2Ctrl).foreach { case (sink, source) =>
204    sink.valid := source.valid
205    sink.bits := source.bits
206    source.ready := true.B
207  }
208
209  if (env.EnableDifftest || env.AlwaysBasicDiff) {
210    intWbArbiterOut.foreach(out => {
211      val difftest = Module(new DifftestIntWriteback)
212      difftest.io.clock := clock
213      difftest.io.coreid := io.fromTop.hartId
214      difftest.io.valid := out.fire && out.bits.rfWen
215      difftest.io.dest := out.bits.pdest
216      difftest.io.data := out.bits.data
217    })
218  }
219
220  if (env.EnableDifftest || env.AlwaysBasicDiff) {
221    vfWbArbiterOut.foreach(out => {
222      val difftest = Module(new DifftestFpWriteback)
223      difftest.io.clock := clock
224      difftest.io.coreid := io.fromTop.hartId
225      difftest.io.valid := out.fire // all fp instr will write fp rf
226      difftest.io.dest := out.bits.pdest
227      difftest.io.data := out.bits.data
228    })
229  }
230
231}
232
233
234
235
236