xref: /XiangShan/src/main/scala/xiangshan/backend/decode/FPDecoder.scala (revision 887862dbb8debde8ab099befc426493834a69ee7)
1/***************************************************************************************
2* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences
3* Copyright (c) 2020-2021 Peng Cheng Laboratory
4*
5* XiangShan is licensed under Mulan PSL v2.
6* You can use this software according to the terms and conditions of the Mulan PSL v2.
7* You may obtain a copy of Mulan PSL v2 at:
8*          http://license.coscl.org.cn/MulanPSL2
9*
10* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
11* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
12* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
13*
14* See the Mulan PSL v2 for more details.
15***************************************************************************************/
16
17package xiangshan.backend.decode
18
19import org.chipsalliance.cde.config.Parameters
20import chisel3._
21import chisel3.util._
22import freechips.rocketchip.rocket.DecodeLogic
23import freechips.rocketchip.rocket.Instructions._
24import xiangshan.backend.decode.isa.bitfield.XSInstBitFields
25import xiangshan.backend.fu.fpu.FPU
26import xiangshan.backend.fu.vector.Bundles.{VSew, VLmul}
27import xiangshan.backend.Bundles.VPUCtrlSignals
28import xiangshan.{FPUCtrlSignals, XSModule}
29
30class FPToVecDecoder(implicit p: Parameters) extends XSModule {
31  val io = IO(new Bundle() {
32    val instr = Input(UInt(32.W))
33    val vpuCtrl = Output(new VPUCtrlSignals)
34  })
35
36  val inst = io.instr.asTypeOf(new XSInstBitFields)
37  val fpToVecInsts = Seq(
38    FADD_S, FSUB_S, FADD_D, FSUB_D,
39    FEQ_S, FLT_S, FLE_S, FEQ_D, FLT_D, FLE_D,
40    FMIN_S, FMAX_S, FMIN_D, FMAX_D,
41    FMUL_S, FMUL_D,
42    FDIV_S, FDIV_D, FSQRT_S, FSQRT_D,
43    FMADD_S, FMSUB_S, FNMADD_S, FNMSUB_S, FMADD_D, FMSUB_D, FNMADD_D, FNMSUB_D,
44    FCLASS_S, FCLASS_D, FSGNJ_S, FSGNJ_D, FSGNJX_S, FSGNJX_D, FSGNJN_S, FSGNJN_D,
45
46    // scalar cvt inst
47    FCVT_W_S, FCVT_WU_S, FCVT_L_S, FCVT_LU_S,
48    FCVT_W_D, FCVT_WU_D, FCVT_L_D, FCVT_LU_D, FCVT_S_D, FCVT_D_S,
49    FCVT_S_H, FCVT_H_S, FCVT_H_D, FCVT_D_H,
50    FMV_X_W, FMV_X_D, FMV_X_H,
51    // zfa inst
52    FLEQ_H, FLEQ_S, FLEQ_D, FLTQ_H, FLTQ_S, FLTQ_D, FMINM_H, FMINM_S, FMINM_D, FMAXM_H, FMAXM_S, FMAXM_D,
53    FROUND_H, FROUND_S, FROUND_D, FROUNDNX_H, FROUNDNX_S, FROUNDNX_D, FCVTMOD_W_D,
54  )
55  val isFpToVecInst = fpToVecInsts.map(io.instr === _).reduce(_ || _)
56  val isFP16Instrs = Seq(
57    // zfa inst
58    FLEQ_H, FLTQ_H, FMINM_H, FMAXM_H,
59    FROUND_H, FROUNDNX_H,
60  )
61  val isFP16Instr = isFP16Instrs.map(io.instr === _).reduce(_ || _)
62  val isFP32Instrs = Seq(
63    FADD_S, FSUB_S, FEQ_S, FLT_S, FLE_S, FMIN_S, FMAX_S,
64    FMUL_S, FDIV_S, FSQRT_S,
65    FMADD_S, FMSUB_S, FNMADD_S, FNMSUB_S,
66    FCLASS_S, FSGNJ_S, FSGNJX_S, FSGNJN_S,
67    // zfa inst
68    FLEQ_S, FLTQ_S, FMINM_S, FMAXM_S,
69    FROUND_S, FROUNDNX_S,
70  )
71  val isFP32Instr = isFP32Instrs.map(io.instr === _).reduce(_ || _)
72  val isFP64Instrs = Seq(
73    FADD_D, FSUB_D, FEQ_D, FLT_D, FLE_D, FMIN_D, FMAX_D,
74    FMUL_D, FDIV_D, FSQRT_D,
75    FMADD_D, FMSUB_D, FNMADD_D, FNMSUB_D,
76    FCLASS_D, FSGNJ_D, FSGNJX_D, FSGNJN_D,
77  )
78  val isFP64Instr = isFP64Instrs.map(io.instr === _).reduce(_ || _)
79  // scalar cvt inst
80  val isSew2Cvts = Seq(
81    FCVT_W_S, FCVT_WU_S, FCVT_L_S, FCVT_LU_S,
82    FCVT_W_D, FCVT_WU_D, FCVT_S_D, FCVT_D_S,
83    FMV_X_W,
84    // zfa inst
85    FCVTMOD_W_D,
86  )
87  /*
88  The optype for FCVT_D_H and FCVT_H_D is the same,
89  so the two instructions are distinguished by sew.
90  FCVT_H_D:VSew.e64
91  FCVT_D_H:VSew.e16
92   */
93  val isSew2Cvth = Seq(
94    FCVT_S_H, FCVT_H_S, FCVT_D_H,
95    FMV_X_H,
96  )
97  val isSew2Cvt32 = isSew2Cvts.map(io.instr === _).reduce(_ || _)
98  val isSew2Cvt16 = isSew2Cvth.map(io.instr === _).reduce(_ || _)
99  val isLmulMf4Cvts = Seq(
100    FCVT_W_S, FCVT_WU_S,
101    FMV_X_W,
102  )
103  val isLmulMf4Cvt = isLmulMf4Cvts.map(io.instr === _).reduce(_ || _)
104  val needReverseInsts = Seq(
105    FADD_S, FSUB_S, FADD_D, FSUB_D,
106    FEQ_S, FLT_S, FLE_S, FEQ_D, FLT_D, FLE_D,
107    FMIN_S, FMAX_S, FMIN_D, FMAX_D,
108    FMUL_S, FMUL_D,
109    FDIV_S, FDIV_D, FSQRT_S, FSQRT_D,
110    FMADD_S, FMSUB_S, FNMADD_S, FNMSUB_S, FMADD_D, FMSUB_D, FNMADD_D, FNMSUB_D,
111    FCLASS_S, FCLASS_D, FSGNJ_S, FSGNJ_D, FSGNJX_S, FSGNJX_D, FSGNJN_S, FSGNJN_D,
112    // zfa inst
113    FLEQ_H, FLEQ_S, FLEQ_D, FLTQ_H, FLTQ_S, FLTQ_D, FMINM_H, FMINM_S, FMINM_D, FMAXM_H, FMAXM_S, FMAXM_D,
114  )
115  val needReverseInst = needReverseInsts.map(_ === inst.ALL).reduce(_ || _)
116  io.vpuCtrl := 0.U.asTypeOf(io.vpuCtrl)
117  io.vpuCtrl.fpu.isFpToVecInst := isFpToVecInst
118  io.vpuCtrl.fpu.isFP32Instr   := isFP32Instr
119  io.vpuCtrl.fpu.isFP64Instr   := isFP64Instr
120  io.vpuCtrl.vill  := false.B
121  io.vpuCtrl.vma   := true.B
122  io.vpuCtrl.vta   := true.B
123  io.vpuCtrl.vsew  := Mux(isFP32Instr || isSew2Cvt32, VSew.e32, Mux(isFP16Instr || isSew2Cvt16, VSew.e16, VSew.e64))
124  io.vpuCtrl.vlmul := Mux(isFP32Instr || isLmulMf4Cvt, VLmul.mf4, VLmul.mf2)
125  io.vpuCtrl.vm    := inst.VM
126  io.vpuCtrl.nf    := inst.NF
127  io.vpuCtrl.veew := inst.WIDTH
128  io.vpuCtrl.isReverse := needReverseInst
129  io.vpuCtrl.isExt     := false.B
130  io.vpuCtrl.isNarrow  := false.B
131  io.vpuCtrl.isDstMask := false.B
132  io.vpuCtrl.isOpMask  := false.B
133  io.vpuCtrl.isDependOldvd := false.B
134  io.vpuCtrl.isWritePartVd := false.B
135}
136
137
138class FPDecoder(implicit p: Parameters) extends XSModule{
139  val io = IO(new Bundle() {
140    val instr = Input(UInt(32.W))
141    val fpCtrl = Output(new FPUCtrlSignals)
142  })
143
144  private val inst: XSInstBitFields = io.instr.asTypeOf(new XSInstBitFields)
145
146  def X = BitPat("b?")
147  def N = BitPat("b0")
148  def Y = BitPat("b1")
149  val s = BitPat(FPU.S(0))
150  val d = BitPat(FPU.D(0))
151  val i = BitPat(FPU.D(0))
152
153  val default = List(X,X,X,N,N,N,X,X,X)
154
155  // isAddSub tagIn tagOut fromInt wflags fpWen div sqrt fcvt
156  val single: Array[(BitPat, List[BitPat])] = Array(
157    // IntToFP
158    FMV_W_X  -> List(N,i,s,Y,N,Y,N,N,N),
159    FCVT_S_W -> List(N,i,s,Y,Y,Y,N,N,Y),
160    FCVT_S_WU-> List(N,i,s,Y,Y,Y,N,N,Y),
161    FCVT_S_L -> List(N,i,s,Y,Y,Y,N,N,Y),
162    FCVT_S_LU-> List(N,i,s,Y,Y,Y,N,N,Y),
163    // FPToInt
164    FMV_X_W  -> List(N,d,i,N,N,N,N,N,N), // dont box result of fmv.fp.int
165    FCLASS_S -> List(N,s,i,N,N,N,N,N,N),
166    FCVT_W_S -> List(N,s,i,N,Y,N,N,N,Y),
167    FCVT_WU_S-> List(N,s,i,N,Y,N,N,N,Y),
168    FCVT_L_S -> List(N,s,i,N,Y,N,N,N,Y),
169    FCVT_LU_S-> List(N,s,i,N,Y,N,N,N,Y),
170    FEQ_S    -> List(N,s,i,N,Y,N,N,N,N),
171    FLT_S    -> List(N,s,i,N,Y,N,N,N,N),
172    FLE_S    -> List(N,s,i,N,Y,N,N,N,N),
173    // FPToFP
174    FSGNJ_S  -> List(N,s,s,N,N,Y,N,N,N),
175    FSGNJN_S -> List(N,s,s,N,N,Y,N,N,N),
176    FSGNJX_S -> List(N,s,s,N,N,Y,N,N,N),
177    FMIN_S   -> List(N,s,s,N,Y,Y,N,N,N),
178    FMAX_S   -> List(N,s,s,N,Y,Y,N,N,N),
179    FADD_S   -> List(Y,s,s,N,Y,Y,N,N,N),
180    FSUB_S   -> List(Y,s,s,N,Y,Y,N,N,N),
181    FMUL_S   -> List(N,s,s,N,Y,Y,N,N,N),
182    FMADD_S  -> List(N,s,s,N,Y,Y,N,N,N),
183    FMSUB_S  -> List(N,s,s,N,Y,Y,N,N,N),
184    FNMADD_S -> List(N,s,s,N,Y,Y,N,N,N),
185    FNMSUB_S -> List(N,s,s,N,Y,Y,N,N,N),
186    FDIV_S   -> List(N,s,s,N,Y,Y,Y,N,N),
187    FSQRT_S  -> List(N,s,s,N,Y,Y,N,Y,N)
188  )
189
190
191  // isAddSub tagIn tagOut fromInt wflags fpWen div sqrt fcvt
192  val double: Array[(BitPat, List[BitPat])] = Array(
193    FMV_D_X  -> List(N,i,d,Y,N,Y,N,N,N),
194    FCVT_D_W -> List(N,i,d,Y,Y,Y,N,N,Y),
195    FCVT_D_WU-> List(N,i,d,Y,Y,Y,N,N,Y),
196    FCVT_D_L -> List(N,i,d,Y,Y,Y,N,N,Y),
197    FCVT_D_LU-> List(N,i,d,Y,Y,Y,N,N,Y),
198    FMV_X_D  -> List(N,d,i,N,N,N,N,N,N),
199    FCLASS_D -> List(N,d,i,N,N,N,N,N,N),
200    FCVT_W_D -> List(N,d,i,N,Y,N,N,N,Y),
201    FCVT_WU_D-> List(N,d,i,N,Y,N,N,N,Y),
202    FCVT_L_D -> List(N,d,i,N,Y,N,N,N,Y),
203    FCVT_LU_D-> List(N,d,i,N,Y,N,N,N,Y),
204    FCVT_S_D -> List(N,d,s,N,Y,Y,N,N,Y),
205    FCVT_D_S -> List(N,s,d,N,Y,Y,N,N,Y),
206    FEQ_D    -> List(N,d,i,N,Y,N,N,N,N),
207    FLT_D    -> List(N,d,i,N,Y,N,N,N,N),
208    FLE_D    -> List(N,d,i,N,Y,N,N,N,N),
209    FSGNJ_D  -> List(N,d,d,N,N,Y,N,N,N),
210    FSGNJN_D -> List(N,d,d,N,N,Y,N,N,N),
211    FSGNJX_D -> List(N,d,d,N,N,Y,N,N,N),
212    FMIN_D   -> List(N,d,d,N,Y,Y,N,N,N),
213    FMAX_D   -> List(N,d,d,N,Y,Y,N,N,N),
214    FADD_D   -> List(Y,d,d,N,Y,Y,N,N,N),
215    FSUB_D   -> List(Y,d,d,N,Y,Y,N,N,N),
216    FMUL_D   -> List(N,d,d,N,Y,Y,N,N,N),
217    FMADD_D  -> List(N,d,d,N,Y,Y,N,N,N),
218    FMSUB_D  -> List(N,d,d,N,Y,Y,N,N,N),
219    FNMADD_D -> List(N,d,d,N,Y,Y,N,N,N),
220    FNMSUB_D -> List(N,d,d,N,Y,Y,N,N,N),
221    FDIV_D   -> List(N,d,d,N,Y,Y,Y,N,N),
222    FSQRT_D  -> List(N,d,d,N,Y,Y,N,Y,N)
223  )
224
225  val table = single ++ double
226
227  val decoder = DecodeLogic(io.instr, default, table)
228
229  val ctrl = io.fpCtrl
230  val sigs = Seq(
231    ctrl.isAddSub, ctrl.typeTagIn, ctrl.typeTagOut,
232    ctrl.fromInt, ctrl.wflags, ctrl.fpWen,
233    ctrl.div, ctrl.sqrt, ctrl.fcvt
234  )
235  sigs.zip(decoder).foreach({case (s, d) => s := d})
236  ctrl.typ := inst.TYP
237  ctrl.fmt := inst.FMT
238  ctrl.rm := inst.RM
239
240  val fmaTable: Array[(BitPat, List[BitPat])] = Array(
241    FADD_S  -> List(BitPat("b00"),N),
242    FADD_D  -> List(BitPat("b00"),N),
243    FSUB_S  -> List(BitPat("b01"),N),
244    FSUB_D  -> List(BitPat("b01"),N),
245    FMUL_S  -> List(BitPat("b00"),N),
246    FMUL_D  -> List(BitPat("b00"),N),
247    FMADD_S -> List(BitPat("b00"),Y),
248    FMADD_D -> List(BitPat("b00"),Y),
249    FMSUB_S -> List(BitPat("b01"),Y),
250    FMSUB_D -> List(BitPat("b01"),Y),
251    FNMADD_S-> List(BitPat("b11"),Y),
252    FNMADD_D-> List(BitPat("b11"),Y),
253    FNMSUB_S-> List(BitPat("b10"),Y),
254    FNMSUB_D-> List(BitPat("b10"),Y)
255  )
256  val fmaDefault = List(BitPat("b??"), N)
257  Seq(ctrl.fmaCmd, ctrl.ren3).zip(
258    DecodeLogic(io.instr, fmaDefault, fmaTable)
259  ).foreach({
260    case (s, d) => s := d
261  })
262}
263