1/**************************************************************************************** 2 * Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences 3 * Copyright (c) 2020-2021 Peng Cheng Laboratory 4 * 5 * XiangShan is licensed under Mulan PSL v2. 6 * You can use this software according to the terms and conditions of the Mulan PSL v2. 7 * You may obtain a copy of Mulan PSL v2 at: 8 * http://license.coscl.org.cn/MulanPSL2 9 * 10 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, 11 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, 12 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 13 * 14 * See the Mulan PSL v2 for more details. 15 **************************************************************************************** 16 */ 17 18 19package xiangshan.backend.fu.vector 20 21import chipsalliance.rocketchip.config.Parameters 22import chisel3.{Mux, _} 23import chisel3.util._ 24import utils._ 25import utility._ 26import yunsuan.vector.{VectorFloatAdder,VectorFloatFMA,VectorFloatDivider} 27import yunsuan.VfpuType 28import xiangshan.{FuOpType, SrcType, XSBundle, XSCoreParamsKey, XSModule} 29import xiangshan.backend.fu.fpu.FPUSubModule 30 31class VFPU(implicit p: Parameters) extends FPUSubModule(p(XSCoreParamsKey).VLEN){ 32 XSError(io.in.valid && io.in.bits.uop.ctrl.fuOpType === VfpuType.dummy, "VFPU OpType not supported") 33 XSError(io.in.valid && (io.in.bits.uop.ctrl.vconfig.vtype.vsew === 0.U), "8 bits not supported in VFPU") 34 override val dataModule = null // Only use IO, not dataModule 35 36// rename signal 37 val in = io.in.bits 38 val ctrl = io.in.bits.uop.ctrl 39 val vtype = ctrl.vconfig.vtype 40 val src1Type = io.in.bits.uop.ctrl.srcType 41 val uopIdx = ctrl.uopIdx 42 43// def some signal 44 val fflagsReg = RegInit(0.U(5.W)) 45 val fflagsWire = WireInit(0.U(5.W)) 46 val dataReg = Reg(io.out.bits.data.cloneType) 47 val dataWire = Wire(dataReg.cloneType) 48 val s_idle :: s_compute :: s_finish :: Nil = Enum(3) 49 val state = RegInit(s_idle) 50 val vfalu = Module(new VfaluWrapper()(p)) 51 val vfmacc = Module(new VfmaccWrapper()(p)) 52 val vfdiv = Module(new VfdivWrapper()(p)) 53 val outValid = vfalu.io.out.valid || vfmacc.io.out.valid || vfdiv.io.out.valid 54 val outFire = vfalu.io.out.fire() || vfmacc.io.out.fire() || vfdiv.io.out.fire() 55 56// reg input signal 57 val s0_uopReg = Reg(io.in.bits.uop.cloneType) 58 val s0_maskReg = Reg(UInt(8.W)) 59 val inHs = io.in.fire() 60 when(inHs && state===s_idle){ 61 s0_uopReg := io.in.bits.uop 62 s0_maskReg := Fill(8, 1.U(1.W)) 63 } 64 65// fsm 66 switch (state) { 67 is (s_idle) { 68 state := Mux(inHs, s_compute, s_idle) 69 } 70 is (s_compute) { 71 state := Mux(outValid, Mux(outFire, s_idle, s_finish), 72 s_compute) 73 } 74 is (s_finish) { 75 state := Mux(io.out.fire(), s_idle, s_finish) 76 } 77 } 78 fflagsReg := Mux(outValid, fflagsWire, fflagsReg) 79 dataReg := Mux(outValid, dataWire, dataReg) 80 81// connect the input port of vfalu 82 vfalu.io.in.bits.src <> in.src 83 vfalu.io.in.bits.srcType <> in.uop.ctrl.srcType 84 vfalu.io.in.bits.round_mode := rm 85 vfalu.io.in.bits.fp_format := vtype.vsew(1,0) 86 vfalu.io.in.bits.uopIdx := uopIdx(0) //TODO 87 vfalu.io.in.bits.opb_widening := false.B // TODO 88 vfalu.io.in.bits.res_widening := false.B // TODO 89 vfalu.io.in.bits.op_code := ctrl.fuOpType 90 vfalu.io.ready_out.s0_mask := s0_maskReg 91 vfalu.io.ready_out.s0_sew := s0_uopReg.ctrl.vconfig.vtype.vsew(1, 0) 92 vfalu.io.ready_out.s0_vl := s0_uopReg.ctrl.vconfig.vl 93 94// connect the input port of vfmacc 95 vfmacc.io.in.bits.src <> in.src 96 vfmacc.io.in.bits.srcType <> in.uop.ctrl.srcType 97 vfmacc.io.in.bits.round_mode := rm 98 vfmacc.io.in.bits.fp_format := vtype.vsew(1, 0) 99 vfmacc.io.in.bits.uopIdx := uopIdx(0) //TODO 100 vfmacc.io.in.bits.opb_widening := DontCare // TODO 101 vfmacc.io.in.bits.res_widening := false.B // TODO 102 vfmacc.io.in.bits.op_code := DontCare 103 vfmacc.io.ready_out.s0_mask := s0_maskReg 104 vfmacc.io.ready_out.s0_sew := s0_uopReg.ctrl.vconfig.vtype.vsew(1, 0) 105 vfmacc.io.ready_out.s0_vl := s0_uopReg.ctrl.vconfig.vl 106 107 // connect the input port of vfdiv 108 vfdiv.io.in.bits.src <> in.src 109 vfdiv.io.in.bits.srcType <> in.uop.ctrl.srcType 110 vfdiv.io.in.bits.round_mode := rm 111 vfdiv.io.in.bits.fp_format := vtype.vsew(1, 0) 112 vfdiv.io.in.bits.uopIdx := uopIdx(0) // TODO 113 vfdiv.io.in.bits.opb_widening := DontCare // TODO 114 vfdiv.io.in.bits.res_widening := DontCare // TODO 115 vfdiv.io.in.bits.op_code := DontCare 116 vfdiv.io.ready_out.s0_mask := s0_maskReg 117 vfdiv.io.ready_out.s0_sew := s0_uopReg.ctrl.vconfig.vtype.vsew(1, 0) 118 vfdiv.io.ready_out.s0_vl := s0_uopReg.ctrl.vconfig.vl 119 120// connect the output port 121 fflagsWire := Mux1H( 122 Seq( 123 (s0_uopReg.ctrl.fuOpType === VfpuType.isVfalu ) -> vfalu.io.out.bits.fflags, 124 (s0_uopReg.ctrl.fuOpType === VfpuType.isVfmacc) -> vfmacc.io.out.bits.fflags, 125 (s0_uopReg.ctrl.fuOpType === VfpuType.isVfdiv ) -> vfdiv.io.out.bits.fflags 126 ) 127 ) 128 fflags := Mux(state === s_compute && outFire, fflagsWire, fflagsReg) 129 dataWire := Mux1H( 130 Seq( 131 (s0_uopReg.ctrl.fuOpType === VfpuType.isVfalu ) -> vfalu.io.out.bits.result, 132 (s0_uopReg.ctrl.fuOpType === VfpuType.isVfmacc) -> vfmacc.io.out.bits.result, 133 (s0_uopReg.ctrl.fuOpType === VfpuType.isVfdiv ) -> vfdiv.io.out.bits.result 134 ) 135 ) 136 io.out.bits.data := Mux(state === s_compute && outFire, dataWire, dataReg) 137 io.out.bits.uop := s0_uopReg 138 // valid/ready 139 vfalu.io.in.valid := io.in.valid && in.uop.ctrl.fuOpType === VfpuType.isVfalu && state === s_idle 140 vfmacc.io.in.valid := io.in.valid && in.uop.ctrl.fuOpType === VfpuType.isVfmacc && state === s_idle 141 vfdiv.io.in.valid := io.in.valid && in.uop.ctrl.fuOpType === VfpuType.isVfdiv && state === s_idle 142 io.out.valid := state === s_compute && outValid || state === s_finish 143 vfalu.io.out.ready := io.out.ready 144 vfmacc.io.out.ready := io.out.ready 145 vfdiv.io.out.ready := io.out.ready 146 io.in.ready := state === s_idle 147} 148 149class VFPUWraaperBundle (implicit p: Parameters) extends XSBundle{ 150 val in = Flipped(DecoupledIO(Output(new Bundle { 151 val src = Vec(4, Input(UInt(VLEN.W))) 152 val srcType = Vec(4, SrcType()) 153 154 val round_mode = UInt(3.W) 155 val fp_format = UInt(2.W) // vsew 156 val uopIdx = Bool() 157 val opb_widening = Bool() 158 val res_widening = Bool() 159 val op_code = FuOpType() 160 }))) 161 162 val ready_out = Input(new Bundle { 163 val s0_mask = UInt((VLEN / 16).W) 164 val s0_sew = UInt(2.W) 165 val s0_vl = UInt(8.W) 166 }) 167 168 val out = DecoupledIO(Output(new Bundle { 169 val result = UInt(128.W) 170 val fflags = UInt(5.W) 171 })) 172} 173 174class VfdivWrapper(implicit p: Parameters) extends XSModule{ 175 val Latency = List(5, 7, 12) 176 val AdderWidth = XLEN 177 val NumAdder = VLEN / XLEN 178 179 val io = IO(new VFPUWraaperBundle) 180 181 val in = io.in.bits 182 val out = io.out.bits 183 val inHs = io.in.fire() 184 185 val s0_mask = io.ready_out.s0_mask 186 val s0_sew = io.ready_out.s0_sew 187 val s0_vl = io.ready_out.s0_vl 188 189 val vfdiv = Seq.fill(NumAdder)(Module(new VectorFloatDivider())) 190 val src1 = Mux(in.srcType(0) === SrcType.vp, in.src(0), VecExtractor(in.fp_format, in.src(0))) 191 val src2 = Mux(in.srcType(1) === SrcType.vp, in.src(1), VecExtractor(in.fp_format, in.src(1))) 192 for (i <- 0 until NumAdder) { 193 vfdiv(i).io.opb_i := Mux(inHs, src1(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U) 194 vfdiv(i).io.opa_i := Mux(inHs, src2(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U) 195 vfdiv(i).io.is_vec_i := true.B // If you can enter, it must be vector 196 vfdiv(i).io.frs2_i := in.src(0)(63,0) // f[rs2] 197 vfdiv(i).io.frs1_i := in.src(1)(63,0) // f[rs1] 198 vfdiv(i).io.is_frs1_i := false.B // if true, vs2 / f[rs1] 199 vfdiv(i).io.is_frs2_i := false.B // if true, f[rs2] / vs1 200 vfdiv(i).io.is_sqrt_i := false.B // must false, not support sqrt now 201 vfdiv(i).io.rm_i := in.round_mode 202 vfdiv(i).io.fp_format_i := Mux(inHs, in.fp_format, 3.U(2.W)) 203 vfdiv(i).io.start_valid_i := io.in.valid 204 vfdiv(i).io.finish_ready_i := io.out.ready 205 vfdiv(i).io.flush_i := false.B // TODO 206 } 207 208 val s4_fflagsVec = VecInit(vfdiv.map(_.io.fflags_o)).asUInt() 209 val s4_fflags16vl = fflagsGen(s0_mask, s4_fflagsVec, List.range(0, 8)) 210 val s4_fflags32vl = fflagsGen(s0_mask, s4_fflagsVec, List(0, 1, 4, 5)) 211 val s4_fflags64vl = fflagsGen(s0_mask, s4_fflagsVec, List(0, 4)) 212 val s4_fflags = LookupTree(s0_sew(1, 0), List( 213 "b01".U -> Mux(s0_vl.orR, s4_fflags16vl(s0_vl - 1.U), 0.U(5.W)), 214 "b10".U -> Mux(s0_vl.orR, s4_fflags32vl(s0_vl - 1.U), 0.U(5.W)), 215 "b11".U -> Mux(s0_vl.orR, s4_fflags64vl(s0_vl - 1.U), 0.U(5.W)), 216 )) 217 out.fflags := s4_fflags 218 219 val s4_result = VecInit(vfdiv.map(_.io.fpdiv_res_o)).asUInt() 220 out.result := s4_result 221 222 io.in.ready := VecInit(vfdiv.map(_.io.start_ready_o)).asUInt().andR() 223 io.out.valid := VecInit(vfdiv.map(_.io.finish_valid_o)).asUInt().andR() 224} 225 226class VfmaccWrapper(implicit p: Parameters) extends XSModule{ 227 val Latency = 3 228 val AdderWidth = XLEN 229 val NumAdder = VLEN / XLEN 230 231 val io = IO(new VFPUWraaperBundle) 232 233 val in = io.in.bits 234 val out = io.out.bits 235 val inHs = io.in.fire() 236 237 val validPipe = Seq.fill(Latency)(RegInit(false.B)) 238 validPipe.zipWithIndex.foreach { 239 case (valid, idx) => 240 val _valid = if (idx == 0) Mux(inHs, true.B, false.B) else validPipe(idx - 1) 241 valid := _valid 242 } 243 val s0_mask = io.ready_out.s0_mask 244 val s0_sew = io.ready_out.s0_sew 245 val s0_vl = io.ready_out.s0_vl 246 247 val vfmacc = Seq.fill(NumAdder)(Module(new VectorFloatFMA())) 248 val src1 = Mux(in.srcType(0) === SrcType.vp, in.src(0), VecExtractor(in.fp_format, in.src(0))) 249 val src2 = Mux(in.srcType(1) === SrcType.vp, in.src(1), VecExtractor(in.fp_format, in.src(1))) 250 val src3 = Mux(in.srcType(2) === SrcType.vp, in.src(2), VecExtractor(in.fp_format, in.src(2))) 251 for (i <- 0 until NumAdder) { 252 vfmacc(i).io.fp_a := Mux(inHs, src1(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U) 253 vfmacc(i).io.fp_b := Mux(inHs, src2(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U) 254 vfmacc(i).io.fp_c := Mux(inHs, src3(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U) 255 vfmacc(i).io.widen_b := Mux(inHs, Cat(src1((AdderWidth / 2) * (i + 3) - 1, (AdderWidth / 2) * (i + 2)), src1((AdderWidth / 2) * (i + 1) - 1, (AdderWidth / 2) * i)), 0.U) 256 vfmacc(i).io.widen_a := Mux(inHs, Cat(src2((AdderWidth / 2) * (i + 3) - 1, (AdderWidth / 2) * (i + 2)), src2((AdderWidth / 2) * (i + 1) - 1, (AdderWidth / 2) * i)), 0.U) 257 vfmacc(i).io.frs1 := in.src(0)(63,0) 258 vfmacc(i).io.is_frs1 := false.B // TODO: support vf inst 259 vfmacc(i).io.uop_idx := in.uopIdx // TODO 260 vfmacc(i).io.op_code := in.op_code(3,0) 261 vfmacc(i).io.is_vec := true.B // If you can enter, it must be vector 262 vfmacc(i).io.round_mode := in.round_mode 263 vfmacc(i).io.fp_format := Mux(inHs, in.fp_format, 3.U(2.W)) 264 vfmacc(i).io.res_widening := in.res_widening // TODO 265 } 266 267 // output signal generation 268 val s2_fflagsVec = VecInit(vfmacc.map(_.io.fflags)).asUInt() 269 val s2_fflags16vl = fflagsGen(s0_mask, s2_fflagsVec, List.range(0, 8)) 270 val s2_fflags32vl = fflagsGen(s0_mask, s2_fflagsVec, List(0, 1, 4, 5)) 271 val s2_fflags64vl = fflagsGen(s0_mask, s2_fflagsVec, List(0, 4)) 272 val s2_fflags = LookupTree(s0_sew(1, 0), List( 273 "b01".U -> Mux(s0_vl.orR, s2_fflags16vl(s0_vl - 1.U), 0.U(5.W)), 274 "b10".U -> Mux(s0_vl.orR, s2_fflags32vl(s0_vl - 1.U), 0.U(5.W)), 275 "b11".U -> Mux(s0_vl.orR, s2_fflags64vl(s0_vl - 1.U), 0.U(5.W)), 276 )) 277 out.fflags := s2_fflags 278 279 val s2_result = VecInit(vfmacc.map(_.io.fp_result)).asUInt() 280 out.result := s2_result 281 282 io.in.ready := true.B 283 io.out.valid := validPipe(Latency - 1) 284} 285 286class VfaluWrapper(implicit p: Parameters) extends XSModule{ 287 val Latency = 2 288 val AdderWidth = XLEN 289 val NumAdder = VLEN / XLEN 290 291 val io = IO(new VFPUWraaperBundle) 292 293 val in = io.in.bits 294 val out = io.out.bits 295 val inHs = io.in.fire() 296 297 // reg input signal 298 val validPipe = Seq.fill(Latency)(RegInit(false.B)) 299 validPipe.zipWithIndex.foreach { 300 case (valid, idx) => 301 val _valid = if (idx == 0) Mux(inHs, true.B, false.B) else validPipe(idx - 1) 302 valid := _valid 303 } 304 val s0_mask = io.ready_out.s0_mask 305 val s0_sew = io.ready_out.s0_sew 306 val s0_vl = io.ready_out.s0_vl 307 308 // connect the input signal 309 val vfalu = Seq.fill(NumAdder)(Module(new VectorFloatAdder())) 310 val src1 = Mux(in.srcType(0) === SrcType.vp, in.src(0), VecExtractor(in.fp_format, in.src(0))) 311 val src2 = Mux(in.srcType(1) === SrcType.vp, in.src(1), VecExtractor(in.fp_format, in.src(1))) 312 for (i <- 0 until NumAdder) { 313 vfalu(i).io.fp_b := Mux(inHs, src1(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U) 314 vfalu(i).io.fp_a := Mux(inHs, src2(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U) 315 vfalu(i).io.widen_b := Mux(inHs, Cat(src1((AdderWidth / 2) * (i + 3) - 1, (AdderWidth / 2) * (i + 2)), src1((AdderWidth / 2) * (i + 1) - 1, (AdderWidth / 2) * i)), 0.U) 316 vfalu(i).io.widen_a := Mux(inHs, Cat(src2((AdderWidth / 2) * (i + 3) - 1, (AdderWidth / 2) * (i + 2)), src2((AdderWidth / 2) * (i + 1) - 1, (AdderWidth / 2) * i)), 0.U) 317 vfalu(i).io.frs1 := in.src(0)(63, 0) 318 vfalu(i).io.is_frs1 := false.B // TODO: support vf inst 319 vfalu(i).io.mask := 0.U //TODO 320 vfalu(i).io.uop_idx := in.uopIdx //TODO 321 vfalu(i).io.is_vec := true.B // If you can enter, it must be vector 322 vfalu(i).io.round_mode := in.round_mode 323 vfalu(i).io.fp_format := Mux(inHs, in.fp_format, 3.U(2.W)) 324 vfalu(i).io.opb_widening := in.opb_widening // TODO 325 vfalu(i).io.res_widening := in.res_widening // TODO 326 vfalu(i).io.op_code := in.op_code(4,0) 327 } 328 329 // output signal generation 330 val s0_fflagsVec = VecInit(vfalu.map(_.io.fflags)).asUInt() 331 val s0_fflags16vl = fflagsGen(s0_mask, s0_fflagsVec, List.range(0, 8)) 332 val s0_fflags32vl = fflagsGen(s0_mask, s0_fflagsVec, List(0, 1, 4, 5)) 333 val s0_fflags64vl = fflagsGen(s0_mask, s0_fflagsVec, List(0, 4)) 334 val s0_fflags = LookupTree(s0_sew(1, 0), List( 335 "b01".U -> Mux(s0_vl.orR, s0_fflags16vl(s0_vl - 1.U), 0.U(5.W)), 336 "b10".U -> Mux(s0_vl.orR, s0_fflags32vl(s0_vl - 1.U), 0.U(5.W)), 337 "b11".U -> Mux(s0_vl.orR, s0_fflags64vl(s0_vl - 1.U), 0.U(5.W)), 338 )) 339 val s1_fflags = RegEnable(s0_fflags, validPipe(Latency-2)) 340 out.fflags := s1_fflags 341 342 val s0_result = VecInit(vfalu.map(_.io.fp_result)).asUInt() 343 val s1_result = RegEnable(s0_result, validPipe(Latency-2)) 344 out.result := s1_result 345 346 io.in.ready := true.B 347 io.out.valid := validPipe(Latency-1) 348} 349 350object fflagsGen{ 351 def fflagsGen(vmask: UInt, fflagsResult:UInt, idx:List[Int] = List(0, 1, 4, 5)): Vec[UInt] = { 352 var num = idx.length 353 val fflags = Seq.fill(num)(Wire(UInt(5.W))) 354 fflags.zip(vmask(num-1, 0).asBools().reverse).zip(idx).foreach { 355 case ((fflags0, mask), id) => 356 fflags0 := Mux(mask, fflagsResult(id*5+4,id*5+0), 0.U) 357 } 358 val fflagsVl = Wire(Vec(num,UInt(5.W))) 359 for (i <- 0 until num) { 360 val _fflags = if (i == 0) fflags(i) else (fflagsVl(i - 1) | fflags(i)) 361 fflagsVl(i) := _fflags 362 } 363 fflagsVl 364 } 365 366 def apply(vmask: UInt, fflagsResult:UInt, idx:List[Int] = List(0, 1, 4, 5)): Vec[UInt] = { 367 fflagsGen(vmask, fflagsResult, idx) 368 } 369} 370 371object VecExtractor{ 372 def xf2v_sew(sew: UInt, xf:UInt): UInt = { 373 LookupTree(sew(1, 0), List( 374 "b00".U -> VecInit(Seq.fill(16)(xf(7, 0))).asUInt, 375 "b01".U -> VecInit(Seq.fill(8)(xf(15, 0))).asUInt, 376 "b10".U -> VecInit(Seq.fill(4)(xf(31, 0))).asUInt, 377 "b11".U -> VecInit(Seq.fill(2)(xf(63, 0))).asUInt, 378 )) 379 } 380 381 def apply(sew: UInt, xf: UInt): UInt = { 382 xf2v_sew(sew, xf) 383 } 384}