xref: /XiangShan/src/main/scala/xiangshan/backend/fu/vector/VFPU.scala (revision db72af19c28fcb9c75ed7bf8385b050efd6f0b23)
1/****************************************************************************************
2  * Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3  * Copyright (c) 2020-2021 Peng Cheng Laboratory
4  *
5  * XiangShan is licensed under Mulan PSL v2.
6  * You can use this software according to the terms and conditions of the Mulan PSL v2.
7  * You may obtain a copy of Mulan PSL v2 at:
8  *          http://license.coscl.org.cn/MulanPSL2
9  *
10  * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11  * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12  * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13  *
14  * See the Mulan PSL v2 for more details.
15  ****************************************************************************************
16  */
17
18
19package xiangshan.backend.fu.vector
20
21import chipsalliance.rocketchip.config.Parameters
22import chisel3.{Mux, _}
23import chisel3.util._
24import utils._
25import utility._
26import yunsuan.vector.{VectorFloatAdder,VectorFloatFMA,VectorFloatDivider}
27import yunsuan.VfpuType
28import xiangshan.{FuOpType, SrcType, XSBundle, XSCoreParamsKey, XSModule}
29import xiangshan.backend.fu.fpu.FPUSubModule
30
31class VFPU(implicit p: Parameters) extends FPUSubModule(p(XSCoreParamsKey).VLEN){
32  XSError(io.in.valid && io.in.bits.uop.ctrl.fuOpType === VfpuType.dummy, "VFPU OpType not supported")
33  XSError(io.in.valid && (io.in.bits.uop.ctrl.vconfig.vtype.vsew === 0.U), "8 bits not supported in VFPU")
34  override val dataModule = null // Only use IO, not dataModule
35
36// rename signal
37  val in = io.in.bits
38  val ctrl = io.in.bits.uop.ctrl
39  val vtype = ctrl.vconfig.vtype
40  val src1Type = io.in.bits.uop.ctrl.srcType
41
42// reg input signal
43  val s0_uopReg = Reg(io.in.bits.uop.cloneType)
44  val s0_maskReg = Reg(UInt(8.W))
45  val inHs = io.in.fire()
46  when(inHs){
47    s0_uopReg := io.in.bits.uop
48    s0_maskReg := Fill(8, 1.U(1.W))
49  }
50
51// connect the input port of vfalu
52  val vfalu = Module(new VfaluWrapper()(p))
53  vfalu.io.in.bits.src <> in.src
54  vfalu.io.in.bits.srcType <> in.uop.ctrl.srcType
55  vfalu.io.in.bits.round_mode := rm
56  vfalu.io.in.bits.fp_format := vtype.vsew(1,0)
57  vfalu.io.in.bits.opb_widening := false.B // TODO
58  vfalu.io.in.bits.res_widening := false.B // TODO
59  vfalu.io.in.bits.op_code := ctrl.fuOpType
60  vfalu.io.ready_out.s0_mask := s0_maskReg
61  vfalu.io.ready_out.s0_sew := s0_uopReg.ctrl.vconfig.vtype.vsew(1, 0)
62  vfalu.io.ready_out.s0_vl := s0_uopReg.ctrl.vconfig.vl
63
64//  connect the input port of vfmacc
65  val vfmacc = Module(new VfmaccWrapper()(p))
66  vfmacc.io.in.bits.src <> in.src
67  vfmacc.io.in.bits.srcType <> in.uop.ctrl.srcType
68  vfmacc.io.in.bits.round_mode := rm
69  vfmacc.io.in.bits.fp_format := vtype.vsew(1, 0)
70  vfmacc.io.in.bits.opb_widening := DontCare // TODO
71  vfmacc.io.in.bits.res_widening := false.B // TODO
72  vfmacc.io.in.bits.op_code := DontCare
73  vfmacc.io.ready_out.s0_mask := s0_maskReg
74  vfmacc.io.ready_out.s0_sew := s0_uopReg.ctrl.vconfig.vtype.vsew(1, 0)
75  vfmacc.io.ready_out.s0_vl := s0_uopReg.ctrl.vconfig.vl
76
77  //  connect the input port of vfdiv
78  val vfdiv = Module(new VfdivWrapper()(p))
79  vfdiv.io.in.bits.src <> in.src
80  vfdiv.io.in.bits.srcType <> in.uop.ctrl.srcType
81  vfdiv.io.in.bits.round_mode := rm
82  vfdiv.io.in.bits.fp_format := vtype.vsew(1, 0)
83  vfdiv.io.in.bits.opb_widening := DontCare // TODO
84  vfdiv.io.in.bits.res_widening := DontCare // TODO
85  vfdiv.io.in.bits.op_code := DontCare
86  vfdiv.io.ready_out.s0_mask := s0_maskReg
87  vfdiv.io.ready_out.s0_sew := s0_uopReg.ctrl.vconfig.vtype.vsew(1, 0)
88  vfdiv.io.ready_out.s0_vl := s0_uopReg.ctrl.vconfig.vl
89
90// connect the output port
91  fflags := LookupTree(s0_uopReg.ctrl.fuOpType, List(
92    VfpuType.fadd  -> vfalu.io.out.bits.fflags,
93    VfpuType.fsub  -> vfalu.io.out.bits.fflags,
94    VfpuType.fmacc -> vfmacc.io.out.bits.fflags,
95    VfpuType.fdiv  -> vfdiv.io.out.bits.fflags,
96  ))
97  io.out.bits.data := LookupTree(s0_uopReg.ctrl.fuOpType, List(
98    VfpuType.fadd -> vfalu.io.out.bits.result,
99    VfpuType.fsub -> vfalu.io.out.bits.result,
100    VfpuType.fmacc -> vfmacc.io.out.bits.result,
101    VfpuType.fdiv -> vfdiv.io.out.bits.result,
102  ))
103  io.out.bits.uop := s0_uopReg
104  // valid/ready
105  vfalu.io.in.valid := io.in.valid && VfpuType.isVfalu(in.uop.ctrl.fuOpType)
106  vfmacc.io.in.valid := io.in.valid && in.uop.ctrl.fuOpType === VfpuType.fmacc
107  vfdiv.io.in.valid := io.in.valid && in.uop.ctrl.fuOpType === VfpuType.fdiv
108  io.out.valid := vfalu.io.out.valid || vfmacc.io.out.valid || vfdiv.io.out.valid
109  vfalu.io.out.ready := io.out.ready
110  vfmacc.io.out.ready := io.out.ready
111  vfdiv.io.out.ready := io.out.ready
112  io.in.ready := vfalu.io.in.ready && vfmacc.io.in.ready && vfdiv.io.in.ready
113}
114
115class VFPUWraaperBundle (implicit p: Parameters)  extends XSBundle{
116  val in = Flipped(DecoupledIO(Output(new Bundle {
117    val src = Vec(3, Input(UInt(VLEN.W)))
118    val srcType = Vec(4, SrcType())
119
120    val round_mode = UInt(3.W)
121    val fp_format = UInt(2.W) // vsew
122    val opb_widening = Bool()
123    val res_widening = Bool()
124    val op_code = FuOpType()
125  })))
126
127  val ready_out = Input(new Bundle {
128    val s0_mask = UInt((VLEN / 16).W)
129    val s0_sew = UInt(2.W)
130    val s0_vl = UInt(8.W)
131  })
132
133  val out = DecoupledIO(Output(new Bundle {
134    val result = UInt(128.W)
135    val fflags = UInt(5.W)
136  }))
137}
138
139class VfdivWrapper(implicit p: Parameters)  extends XSModule{
140  val Latency = List(5, 7, 12)
141  val AdderWidth = XLEN
142  val NumAdder = VLEN / XLEN
143
144  val io = IO(new VFPUWraaperBundle)
145
146  val in = io.in.bits
147  val out = io.out.bits
148  val inHs = io.in.fire()
149
150  val s0_mask = io.ready_out.s0_mask
151  val s0_sew = io.ready_out.s0_sew
152  val s0_vl = io.ready_out.s0_vl
153
154  val vfdiv = Seq.fill(NumAdder)(Module(new VectorFloatDivider()))
155  val src1 = Mux(in.srcType(0) === SrcType.vp, in.src(0), VecExtractor(in.fp_format, in.src(0)))
156  val src2 = Mux(in.srcType(1) === SrcType.vp, in.src(1), VecExtractor(in.fp_format, in.src(1)))
157  for (i <- 0 until NumAdder) {
158    vfdiv(i).io.opa_i := Mux(inHs, src2(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U)
159    vfdiv(i).io.opb_i := Mux(inHs, src1(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U)
160    vfdiv(i).io.is_vec_i := true.B // If you can enter, it must be vector
161    vfdiv(i).io.rm_i := in.round_mode
162    vfdiv(i).io.fp_format_i := Mux(inHs, in.fp_format, 3.U(2.W))
163    vfdiv(i).io.start_valid_i := io.in.valid
164    vfdiv(i).io.finish_ready_i := io.out.ready
165    vfdiv(i).io.flush_i := false.B  // TODO
166  }
167
168  val s4_fflagsVec = VecInit(vfdiv.map(_.io.fflags_o)).asUInt()
169  val s4_fflags16vl = fflagsGen(s0_mask, s4_fflagsVec, List.range(0, 8))
170  val s4_fflags32vl = fflagsGen(s0_mask, s4_fflagsVec, List(0, 1, 4, 5))
171  val s4_fflags64vl = fflagsGen(s0_mask, s4_fflagsVec, List(0, 4))
172  val s4_fflags = LookupTree(s0_sew(1, 0), List(
173    "b01".U -> Mux(s0_vl.orR, s4_fflags16vl(s0_vl - 1.U), 0.U(5.W)),
174    "b10".U -> Mux(s0_vl.orR, s4_fflags32vl(s0_vl - 1.U), 0.U(5.W)),
175    "b11".U -> Mux(s0_vl.orR, s4_fflags64vl(s0_vl - 1.U), 0.U(5.W)),
176  ))
177  out.fflags := s4_fflags
178
179  val s4_result = VecInit(vfdiv.map(_.io.fpdiv_res_o)).asUInt()
180  out.result := s4_result
181
182  io.in.ready := VecInit(vfdiv.map(_.io.start_ready_o)).asUInt().andR()
183  io.out.valid := VecInit(vfdiv.map(_.io.finish_valid_o)).asUInt().andR()
184}
185
186class VfmaccWrapper(implicit p: Parameters)  extends XSModule{
187  val Latency = 3
188  val AdderWidth = XLEN
189  val NumAdder = VLEN / XLEN
190
191  val io = IO(new VFPUWraaperBundle)
192
193  val in = io.in.bits
194  val out = io.out.bits
195  val inHs = io.in.fire()
196
197  val validPipe = Seq.fill(Latency)(RegInit(false.B))
198  validPipe.zipWithIndex.foreach {
199    case (valid, idx) =>
200      val _valid = if (idx == 0) Mux(inHs, true.B, false.B) else validPipe(idx - 1)
201      valid := _valid
202  }
203  val s0_mask = io.ready_out.s0_mask
204  val s0_sew = io.ready_out.s0_sew
205  val s0_vl = io.ready_out.s0_vl
206
207  val vfmacc = Seq.fill(NumAdder)(Module(new VectorFloatFMA()))
208  val src1 = Mux(in.srcType(0) === SrcType.vp, in.src(0), VecExtractor(in.fp_format, in.src(0)))
209  val src2 = Mux(in.srcType(1) === SrcType.vp, in.src(1), VecExtractor(in.fp_format, in.src(1)))
210  val src3 = Mux(in.srcType(2) === SrcType.vp, in.src(2), VecExtractor(in.fp_format, in.src(2)))
211  for (i <- 0 until NumAdder) {
212    vfmacc(i).io.fp_a := Mux(inHs, src1(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U)
213    vfmacc(i).io.fp_b := Mux(inHs, src2(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U)
214    vfmacc(i).io.fp_c := Mux(inHs, src3(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U)
215    vfmacc(i).io.is_vec := true.B // If you can enter, it must be vector
216    vfmacc(i).io.round_mode := in.round_mode
217    vfmacc(i).io.fp_format := Mux(inHs, in.fp_format, 3.U(2.W))
218    vfmacc(i).io.res_widening := in.res_widening // TODO
219  }
220
221  // output signal generation
222  val s2_fflagsVec = VecInit(vfmacc.map(_.io.fflags)).asUInt()
223  val s2_fflags16vl = fflagsGen(s0_mask, s2_fflagsVec, List.range(0, 8))
224  val s2_fflags32vl = fflagsGen(s0_mask, s2_fflagsVec, List(0, 1, 4, 5))
225  val s2_fflags64vl = fflagsGen(s0_mask, s2_fflagsVec, List(0, 4))
226  val s2_fflags = LookupTree(s0_sew(1, 0), List(
227    "b01".U -> Mux(s0_vl.orR, s2_fflags16vl(s0_vl - 1.U), 0.U(5.W)),
228    "b10".U -> Mux(s0_vl.orR, s2_fflags32vl(s0_vl - 1.U), 0.U(5.W)),
229    "b11".U -> Mux(s0_vl.orR, s2_fflags64vl(s0_vl - 1.U), 0.U(5.W)),
230  ))
231  out.fflags := s2_fflags
232
233  val s2_result = VecInit(vfmacc.map(_.io.fp_result)).asUInt()
234  out.result := s2_result
235
236  io.in.ready := !(validPipe.foldLeft(false.B)(_ | _)) && io.out.ready
237  io.out.valid := validPipe(Latency - 1)
238}
239
240class VfaluWrapper(implicit p: Parameters)  extends XSModule{
241  val Latency = 2
242  val AdderWidth = XLEN
243  val NumAdder = VLEN / XLEN
244
245  val io = IO(new VFPUWraaperBundle)
246
247  val in = io.in.bits
248  val out = io.out.bits
249  val inHs = io.in.fire()
250
251  // reg input signal
252  val validPipe = Seq.fill(Latency)(RegInit(false.B))
253  validPipe.zipWithIndex.foreach {
254    case (valid, idx) =>
255      val _valid = if (idx == 0) Mux(inHs, true.B, false.B) else validPipe(idx - 1)
256      valid := _valid
257  }
258  val s0_mask = io.ready_out.s0_mask
259  val s0_sew = io.ready_out.s0_sew
260  val s0_vl = io.ready_out.s0_vl
261
262  // connect the input signal
263  val vfalu = Seq.fill(NumAdder)(Module(new VectorFloatAdder()))
264  val src1 = Mux(in.srcType(0) === SrcType.vp, in.src(0), VecExtractor(in.fp_format, in.src(0)))
265  val src2 = Mux(in.srcType(1) === SrcType.vp, in.src(1), VecExtractor(in.fp_format, in.src(1)))
266  for (i <- 0 until NumAdder) {
267    vfalu(i).io.fp_a := Mux(inHs, src1(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U)
268    vfalu(i).io.fp_b := Mux(inHs, src2(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U)
269    vfalu(i).io.is_vec := true.B // If you can enter, it must be vector
270    vfalu(i).io.round_mode := in.round_mode
271    vfalu(i).io.fp_format := Mux(inHs, in.fp_format, 3.U(2.W))
272    vfalu(i).io.opb_widening := in.opb_widening // TODO
273    vfalu(i).io.res_widening := in.res_widening // TODO
274    vfalu(i).io.op_code := in.op_code
275  }
276
277  // output signal generation
278  val s0_fflagsVec = VecInit(vfalu.map(_.io.fflags)).asUInt()
279  val s0_fflags16vl = fflagsGen(s0_mask, s0_fflagsVec, List.range(0, 8))
280  val s0_fflags32vl = fflagsGen(s0_mask, s0_fflagsVec, List(0, 1, 4, 5))
281  val s0_fflags64vl = fflagsGen(s0_mask, s0_fflagsVec, List(0, 4))
282  val s0_fflags = LookupTree(s0_sew(1, 0), List(
283    "b01".U -> Mux(s0_vl.orR, s0_fflags16vl(s0_vl - 1.U), 0.U(5.W)),
284    "b10".U -> Mux(s0_vl.orR, s0_fflags32vl(s0_vl - 1.U), 0.U(5.W)),
285    "b11".U -> Mux(s0_vl.orR, s0_fflags64vl(s0_vl - 1.U), 0.U(5.W)),
286  ))
287  val s1_fflags = RegEnable(s0_fflags, validPipe(Latency-2))
288  out.fflags := s1_fflags
289
290  val s0_result = LookupTree(s0_sew(1, 0), List(
291    "b01".U -> VecInit(vfalu.map(_.io.fp_f16_result)).asUInt(),
292    "b10".U -> VecInit(vfalu.map(_.io.fp_f32_result)).asUInt(),
293    "b11".U -> VecInit(vfalu.map(_.io.fp_f64_result)).asUInt(),
294  ))
295  val s1_result = RegEnable(s0_result, validPipe(Latency-2))
296  out.result := s1_result
297
298  io.in.ready := !(validPipe.foldLeft(false.B)(_|_)) && io.out.ready
299  io.out.valid := validPipe(Latency-1)
300}
301
302object fflagsGen{
303  def fflagsGen(vmask: UInt, fflagsResult:UInt, idx:List[Int] = List(0, 1, 4, 5)): Vec[UInt] = {
304    var num = idx.length
305    val fflags = Seq.fill(num)(Wire(UInt(5.W)))
306    fflags.zip(vmask(num-1, 0).asBools().reverse).zip(idx).foreach {
307      case ((fflags0, mask), id) =>
308        fflags0 := Mux(mask, fflagsResult(id*5+4,id*5+0), 0.U)
309    }
310    val fflagsVl = Wire(Vec(num,UInt(5.W)))
311    for (i <- 0 until num) {
312      val _fflags = if (i == 0) fflags(i) else (fflagsVl(i - 1) | fflags(i))
313      fflagsVl(i) := _fflags
314    }
315    fflagsVl
316  }
317
318  def apply(vmask: UInt, fflagsResult:UInt, idx:List[Int] = List(0, 1, 4, 5)): Vec[UInt] = {
319    fflagsGen(vmask, fflagsResult, idx)
320  }
321}
322
323object VecExtractor{
324  def xf2v_sew(sew: UInt, xf:UInt): UInt = {
325    LookupTree(sew(1, 0), List(
326      "b00".U -> VecInit(Seq.fill(16)(xf(7, 0))).asUInt,
327      "b01".U -> VecInit(Seq.fill(8)(xf(15, 0))).asUInt,
328      "b10".U -> VecInit(Seq.fill(4)(xf(31, 0))).asUInt,
329      "b11".U -> VecInit(Seq.fill(2)(xf(63, 0))).asUInt,
330    ))
331  }
332
333  def apply(sew: UInt, xf: UInt): UInt = {
334    xf2v_sew(sew, xf)
335  }
336}