1/**************************************************************************************** 2 * Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences 3 * Copyright (c) 2020-2021 Peng Cheng Laboratory 4 * 5 * XiangShan is licensed under Mulan PSL v2. 6 * You can use this software according to the terms and conditions of the Mulan PSL v2. 7 * You may obtain a copy of Mulan PSL v2 at: 8 * http://license.coscl.org.cn/MulanPSL2 9 * 10 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, 11 * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, 12 * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 13 * 14 * See the Mulan PSL v2 for more details. 15 **************************************************************************************** 16 */ 17 18 19package xiangshan.backend.fu.vector 20 21import chipsalliance.rocketchip.config.Parameters 22import chisel3.{Mux, _} 23import chisel3.util._ 24import utils._ 25import utility._ 26import yunsuan.vector.{VectorFloatAdder,VectorFloatFMA,VectorFloatDivider} 27import yunsuan.VfpuType 28import xiangshan.{FuOpType, SrcType, XSBundle, XSCoreParamsKey, XSModule} 29import xiangshan.backend.fu.fpu.FPUSubModule 30 31class VFPU(implicit p: Parameters) extends FPUSubModule(p(XSCoreParamsKey).VLEN){ 32 XSError(io.in.valid && io.in.bits.uop.ctrl.fuOpType === VfpuType.dummy, "VFPU OpType not supported") 33 XSError(io.in.valid && (io.in.bits.uop.ctrl.vconfig.vtype.vsew === 0.U), "8 bits not supported in VFPU") 34 override val dataModule = null // Only use IO, not dataModule 35 36// rename signal 37 val in = io.in.bits 38 val ctrl = io.in.bits.uop.ctrl 39 val vtype = ctrl.vconfig.vtype 40 val src1Type = io.in.bits.uop.ctrl.srcType 41 42// def some signal 43 val fflagsReg = RegInit(0.U(5.W)) 44 val fflagsWire = WireInit(0.U(5.W)) 45 val dataReg = Reg(io.out.bits.data.cloneType) 46 val dataWire = Wire(dataReg.cloneType) 47 val s_idle :: s_compute :: s_finish :: Nil = Enum(3) 48 val state = RegInit(s_idle) 49 val vfalu = Module(new VfaluWrapper()(p)) 50 val vfmacc = Module(new VfmaccWrapper()(p)) 51 val vfdiv = Module(new VfdivWrapper()(p)) 52 val outValid = vfalu.io.out.valid || vfmacc.io.out.valid || vfdiv.io.out.valid 53 val outFire = vfalu.io.out.fire() || vfmacc.io.out.fire() || vfdiv.io.out.fire() 54 55// reg input signal 56 val s0_uopReg = Reg(io.in.bits.uop.cloneType) 57 val s0_maskReg = Reg(UInt(8.W)) 58 val inHs = io.in.fire() 59 when(inHs && state===s_idle){ 60 s0_uopReg := io.in.bits.uop 61 s0_maskReg := Fill(8, 1.U(1.W)) 62 } 63 64// fsm 65 switch (state) { 66 is (s_idle) { 67 state := Mux(inHs, s_compute, s_idle) 68 } 69 is (s_compute) { 70 state := Mux(outValid, Mux(outFire, s_idle, s_finish), 71 s_compute) 72 } 73 is (s_finish) { 74 state := Mux(io.out.fire(), s_idle, s_finish) 75 } 76 } 77 fflagsReg := Mux(outValid, fflagsWire, fflagsReg) 78 dataReg := Mux(outValid, dataWire, dataReg) 79 80// connect the input port of vfalu 81 vfalu.io.in.bits.src <> in.src 82 vfalu.io.in.bits.srcType <> in.uop.ctrl.srcType 83 vfalu.io.in.bits.round_mode := rm 84 vfalu.io.in.bits.fp_format := vtype.vsew(1,0) 85 vfalu.io.in.bits.opb_widening := false.B // TODO 86 vfalu.io.in.bits.res_widening := false.B // TODO 87 vfalu.io.in.bits.op_code := ctrl.fuOpType 88 vfalu.io.ready_out.s0_mask := s0_maskReg 89 vfalu.io.ready_out.s0_sew := s0_uopReg.ctrl.vconfig.vtype.vsew(1, 0) 90 vfalu.io.ready_out.s0_vl := s0_uopReg.ctrl.vconfig.vl 91 92// connect the input port of vfmacc 93 vfmacc.io.in.bits.src <> in.src 94 vfmacc.io.in.bits.srcType <> in.uop.ctrl.srcType 95 vfmacc.io.in.bits.round_mode := rm 96 vfmacc.io.in.bits.fp_format := vtype.vsew(1, 0) 97 vfmacc.io.in.bits.opb_widening := DontCare // TODO 98 vfmacc.io.in.bits.res_widening := false.B // TODO 99 vfmacc.io.in.bits.op_code := DontCare 100 vfmacc.io.ready_out.s0_mask := s0_maskReg 101 vfmacc.io.ready_out.s0_sew := s0_uopReg.ctrl.vconfig.vtype.vsew(1, 0) 102 vfmacc.io.ready_out.s0_vl := s0_uopReg.ctrl.vconfig.vl 103 104 // connect the input port of vfdiv 105 vfdiv.io.in.bits.src <> in.src 106 vfdiv.io.in.bits.srcType <> in.uop.ctrl.srcType 107 vfdiv.io.in.bits.round_mode := rm 108 vfdiv.io.in.bits.fp_format := vtype.vsew(1, 0) 109 vfdiv.io.in.bits.opb_widening := DontCare // TODO 110 vfdiv.io.in.bits.res_widening := DontCare // TODO 111 vfdiv.io.in.bits.op_code := DontCare 112 vfdiv.io.ready_out.s0_mask := s0_maskReg 113 vfdiv.io.ready_out.s0_sew := s0_uopReg.ctrl.vconfig.vtype.vsew(1, 0) 114 vfdiv.io.ready_out.s0_vl := s0_uopReg.ctrl.vconfig.vl 115 116// connect the output port 117 fflagsWire := LookupTree(s0_uopReg.ctrl.fuOpType, List( 118 VfpuType.fadd -> vfalu.io.out.bits.fflags, 119 VfpuType.fsub -> vfalu.io.out.bits.fflags, 120 VfpuType.fmacc -> vfmacc.io.out.bits.fflags, 121 VfpuType.fdiv -> vfdiv.io.out.bits.fflags, 122 )) 123 fflags := Mux(state === s_compute && outFire, fflagsWire, fflagsReg) 124 dataWire := LookupTree(s0_uopReg.ctrl.fuOpType, List( 125 VfpuType.fadd -> vfalu.io.out.bits.result, 126 VfpuType.fsub -> vfalu.io.out.bits.result, 127 VfpuType.fmacc -> vfmacc.io.out.bits.result, 128 VfpuType.fdiv -> vfdiv.io.out.bits.result, 129 )) 130 io.out.bits.data := Mux(state === s_compute && outFire, dataWire, dataReg) 131 io.out.bits.uop := s0_uopReg 132 // valid/ready 133 vfalu.io.in.valid := io.in.valid && VfpuType.isVfalu(in.uop.ctrl.fuOpType) && state === s_idle 134 vfmacc.io.in.valid := io.in.valid && in.uop.ctrl.fuOpType === VfpuType.fmacc && state === s_idle 135 vfdiv.io.in.valid := io.in.valid && in.uop.ctrl.fuOpType === VfpuType.fdiv && state === s_idle 136 io.out.valid := state === s_compute && outValid || state === s_finish 137 vfalu.io.out.ready := io.out.ready 138 vfmacc.io.out.ready := io.out.ready 139 vfdiv.io.out.ready := io.out.ready 140 io.in.ready := state === s_idle 141} 142 143class VFPUWraaperBundle (implicit p: Parameters) extends XSBundle{ 144 val in = Flipped(DecoupledIO(Output(new Bundle { 145 val src = Vec(4, Input(UInt(VLEN.W))) 146 val srcType = Vec(4, SrcType()) 147 148 val round_mode = UInt(3.W) 149 val fp_format = UInt(2.W) // vsew 150 val opb_widening = Bool() 151 val res_widening = Bool() 152 val op_code = FuOpType() 153 }))) 154 155 val ready_out = Input(new Bundle { 156 val s0_mask = UInt((VLEN / 16).W) 157 val s0_sew = UInt(2.W) 158 val s0_vl = UInt(8.W) 159 }) 160 161 val out = DecoupledIO(Output(new Bundle { 162 val result = UInt(128.W) 163 val fflags = UInt(5.W) 164 })) 165} 166 167class VfdivWrapper(implicit p: Parameters) extends XSModule{ 168 val Latency = List(5, 7, 12) 169 val AdderWidth = XLEN 170 val NumAdder = VLEN / XLEN 171 172 val io = IO(new VFPUWraaperBundle) 173 174 val in = io.in.bits 175 val out = io.out.bits 176 val inHs = io.in.fire() 177 178 val s0_mask = io.ready_out.s0_mask 179 val s0_sew = io.ready_out.s0_sew 180 val s0_vl = io.ready_out.s0_vl 181 182 val vfdiv = Seq.fill(NumAdder)(Module(new VectorFloatDivider())) 183 val src1 = Mux(in.srcType(0) === SrcType.vp, in.src(0), VecExtractor(in.fp_format, in.src(0))) 184 val src2 = Mux(in.srcType(1) === SrcType.vp, in.src(1), VecExtractor(in.fp_format, in.src(1))) 185 for (i <- 0 until NumAdder) { 186 vfdiv(i).io.opa_i := Mux(inHs, src2(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U) 187 vfdiv(i).io.opb_i := Mux(inHs, src1(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U) 188 vfdiv(i).io.is_vec_i := true.B // If you can enter, it must be vector 189 vfdiv(i).io.rm_i := in.round_mode 190 vfdiv(i).io.fp_format_i := Mux(inHs, in.fp_format, 3.U(2.W)) 191 vfdiv(i).io.start_valid_i := io.in.valid 192 vfdiv(i).io.finish_ready_i := io.out.ready 193 vfdiv(i).io.flush_i := false.B // TODO 194 } 195 196 val s4_fflagsVec = VecInit(vfdiv.map(_.io.fflags_o)).asUInt() 197 val s4_fflags16vl = fflagsGen(s0_mask, s4_fflagsVec, List.range(0, 8)) 198 val s4_fflags32vl = fflagsGen(s0_mask, s4_fflagsVec, List(0, 1, 4, 5)) 199 val s4_fflags64vl = fflagsGen(s0_mask, s4_fflagsVec, List(0, 4)) 200 val s4_fflags = LookupTree(s0_sew(1, 0), List( 201 "b01".U -> Mux(s0_vl.orR, s4_fflags16vl(s0_vl - 1.U), 0.U(5.W)), 202 "b10".U -> Mux(s0_vl.orR, s4_fflags32vl(s0_vl - 1.U), 0.U(5.W)), 203 "b11".U -> Mux(s0_vl.orR, s4_fflags64vl(s0_vl - 1.U), 0.U(5.W)), 204 )) 205 out.fflags := s4_fflags 206 207 val s4_result = VecInit(vfdiv.map(_.io.fpdiv_res_o)).asUInt() 208 out.result := s4_result 209 210 io.in.ready := VecInit(vfdiv.map(_.io.start_ready_o)).asUInt().andR() 211 io.out.valid := VecInit(vfdiv.map(_.io.finish_valid_o)).asUInt().andR() 212} 213 214class VfmaccWrapper(implicit p: Parameters) extends XSModule{ 215 val Latency = 3 216 val AdderWidth = XLEN 217 val NumAdder = VLEN / XLEN 218 219 val io = IO(new VFPUWraaperBundle) 220 221 val in = io.in.bits 222 val out = io.out.bits 223 val inHs = io.in.fire() 224 225 val validPipe = Seq.fill(Latency)(RegInit(false.B)) 226 validPipe.zipWithIndex.foreach { 227 case (valid, idx) => 228 val _valid = if (idx == 0) Mux(inHs, true.B, false.B) else validPipe(idx - 1) 229 valid := _valid 230 } 231 val s0_mask = io.ready_out.s0_mask 232 val s0_sew = io.ready_out.s0_sew 233 val s0_vl = io.ready_out.s0_vl 234 235 val vfmacc = Seq.fill(NumAdder)(Module(new VectorFloatFMA())) 236 val src1 = Mux(in.srcType(0) === SrcType.vp, in.src(0), VecExtractor(in.fp_format, in.src(0))) 237 val src2 = Mux(in.srcType(1) === SrcType.vp, in.src(1), VecExtractor(in.fp_format, in.src(1))) 238 val src3 = Mux(in.srcType(2) === SrcType.vp, in.src(2), VecExtractor(in.fp_format, in.src(2))) 239 for (i <- 0 until NumAdder) { 240 vfmacc(i).io.fp_a := Mux(inHs, src1(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U) 241 vfmacc(i).io.fp_b := Mux(inHs, src2(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U) 242 vfmacc(i).io.fp_c := Mux(inHs, src3(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U) 243 vfmacc(i).io.is_vec := true.B // If you can enter, it must be vector 244 vfmacc(i).io.round_mode := in.round_mode 245 vfmacc(i).io.fp_format := Mux(inHs, in.fp_format, 3.U(2.W)) 246 vfmacc(i).io.res_widening := in.res_widening // TODO 247 } 248 249 // output signal generation 250 val s2_fflagsVec = VecInit(vfmacc.map(_.io.fflags)).asUInt() 251 val s2_fflags16vl = fflagsGen(s0_mask, s2_fflagsVec, List.range(0, 8)) 252 val s2_fflags32vl = fflagsGen(s0_mask, s2_fflagsVec, List(0, 1, 4, 5)) 253 val s2_fflags64vl = fflagsGen(s0_mask, s2_fflagsVec, List(0, 4)) 254 val s2_fflags = LookupTree(s0_sew(1, 0), List( 255 "b01".U -> Mux(s0_vl.orR, s2_fflags16vl(s0_vl - 1.U), 0.U(5.W)), 256 "b10".U -> Mux(s0_vl.orR, s2_fflags32vl(s0_vl - 1.U), 0.U(5.W)), 257 "b11".U -> Mux(s0_vl.orR, s2_fflags64vl(s0_vl - 1.U), 0.U(5.W)), 258 )) 259 out.fflags := s2_fflags 260 261 val s2_result = VecInit(vfmacc.map(_.io.fp_result)).asUInt() 262 out.result := s2_result 263 264 io.in.ready := !(validPipe.foldLeft(false.B)(_ | _)) && io.out.ready 265 io.out.valid := validPipe(Latency - 1) 266} 267 268class VfaluWrapper(implicit p: Parameters) extends XSModule{ 269 val Latency = 2 270 val AdderWidth = XLEN 271 val NumAdder = VLEN / XLEN 272 273 val io = IO(new VFPUWraaperBundle) 274 275 val in = io.in.bits 276 val out = io.out.bits 277 val inHs = io.in.fire() 278 279 // reg input signal 280 val validPipe = Seq.fill(Latency)(RegInit(false.B)) 281 validPipe.zipWithIndex.foreach { 282 case (valid, idx) => 283 val _valid = if (idx == 0) Mux(inHs, true.B, false.B) else validPipe(idx - 1) 284 valid := _valid 285 } 286 val s0_mask = io.ready_out.s0_mask 287 val s0_sew = io.ready_out.s0_sew 288 val s0_vl = io.ready_out.s0_vl 289 290 // connect the input signal 291 val vfalu = Seq.fill(NumAdder)(Module(new VectorFloatAdder())) 292 val src1 = Mux(in.srcType(0) === SrcType.vp, in.src(0), VecExtractor(in.fp_format, in.src(0))) 293 val src2 = Mux(in.srcType(1) === SrcType.vp, in.src(1), VecExtractor(in.fp_format, in.src(1))) 294 for (i <- 0 until NumAdder) { 295 vfalu(i).io.fp_a := Mux(inHs, src1(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U) 296 vfalu(i).io.fp_b := Mux(inHs, src2(AdderWidth * (i + 1) - 1, AdderWidth * i), 0.U) 297 vfalu(i).io.is_vec := true.B // If you can enter, it must be vector 298 vfalu(i).io.round_mode := in.round_mode 299 vfalu(i).io.fp_format := Mux(inHs, in.fp_format, 3.U(2.W)) 300 vfalu(i).io.opb_widening := in.opb_widening // TODO 301 vfalu(i).io.res_widening := in.res_widening // TODO 302 vfalu(i).io.op_code := in.op_code 303 } 304 305 // output signal generation 306 val s0_fflagsVec = VecInit(vfalu.map(_.io.fflags)).asUInt() 307 val s0_fflags16vl = fflagsGen(s0_mask, s0_fflagsVec, List.range(0, 8)) 308 val s0_fflags32vl = fflagsGen(s0_mask, s0_fflagsVec, List(0, 1, 4, 5)) 309 val s0_fflags64vl = fflagsGen(s0_mask, s0_fflagsVec, List(0, 4)) 310 val s0_fflags = LookupTree(s0_sew(1, 0), List( 311 "b01".U -> Mux(s0_vl.orR, s0_fflags16vl(s0_vl - 1.U), 0.U(5.W)), 312 "b10".U -> Mux(s0_vl.orR, s0_fflags32vl(s0_vl - 1.U), 0.U(5.W)), 313 "b11".U -> Mux(s0_vl.orR, s0_fflags64vl(s0_vl - 1.U), 0.U(5.W)), 314 )) 315 val s1_fflags = RegEnable(s0_fflags, validPipe(Latency-2)) 316 out.fflags := s1_fflags 317 318 val s0_result = LookupTree(s0_sew(1, 0), List( 319 "b01".U -> VecInit(vfalu.map(_.io.fp_f16_result)).asUInt(), 320 "b10".U -> VecInit(vfalu.map(_.io.fp_f32_result)).asUInt(), 321 "b11".U -> VecInit(vfalu.map(_.io.fp_f64_result)).asUInt(), 322 )) 323 val s1_result = RegEnable(s0_result, validPipe(Latency-2)) 324 out.result := s1_result 325 326 io.in.ready := !(validPipe.foldLeft(false.B)(_|_)) && io.out.ready 327 io.out.valid := validPipe(Latency-1) 328} 329 330object fflagsGen{ 331 def fflagsGen(vmask: UInt, fflagsResult:UInt, idx:List[Int] = List(0, 1, 4, 5)): Vec[UInt] = { 332 var num = idx.length 333 val fflags = Seq.fill(num)(Wire(UInt(5.W))) 334 fflags.zip(vmask(num-1, 0).asBools().reverse).zip(idx).foreach { 335 case ((fflags0, mask), id) => 336 fflags0 := Mux(mask, fflagsResult(id*5+4,id*5+0), 0.U) 337 } 338 val fflagsVl = Wire(Vec(num,UInt(5.W))) 339 for (i <- 0 until num) { 340 val _fflags = if (i == 0) fflags(i) else (fflagsVl(i - 1) | fflags(i)) 341 fflagsVl(i) := _fflags 342 } 343 fflagsVl 344 } 345 346 def apply(vmask: UInt, fflagsResult:UInt, idx:List[Int] = List(0, 1, 4, 5)): Vec[UInt] = { 347 fflagsGen(vmask, fflagsResult, idx) 348 } 349} 350 351object VecExtractor{ 352 def xf2v_sew(sew: UInt, xf:UInt): UInt = { 353 LookupTree(sew(1, 0), List( 354 "b00".U -> VecInit(Seq.fill(16)(xf(7, 0))).asUInt, 355 "b01".U -> VecInit(Seq.fill(8)(xf(15, 0))).asUInt, 356 "b10".U -> VecInit(Seq.fill(4)(xf(31, 0))).asUInt, 357 "b11".U -> VecInit(Seq.fill(2)(xf(63, 0))).asUInt, 358 )) 359 } 360 361 def apply(sew: UInt, xf: UInt): UInt = { 362 xf2v_sew(sew, xf) 363 } 364}