1*3feeca58Szfw/*************************************************************************************** 2*3feeca58Szfw* Copyright (c) 2020-2021 Institute of Computing Technology, Chinese Academy of Sciences 3*3feeca58Szfw* Copyright (c) 2020-2021 Peng Cheng Laboratory 4*3feeca58Szfw* 5*3feeca58Szfw* XiangShan is licensed under Mulan PSL v2. 6*3feeca58Szfw* You can use this software according to the terms and conditions of the Mulan PSL v2. 7*3feeca58Szfw* You may obtain a copy of Mulan PSL v2 at: 8*3feeca58Szfw* http://license.coscl.org.cn/MulanPSL2 9*3feeca58Szfw* 10*3feeca58Szfw* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, 11*3feeca58Szfw* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, 12*3feeca58Szfw* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. 13*3feeca58Szfw* 14*3feeca58Szfw* See the Mulan PSL v2 for more details. 15*3feeca58Szfw***************************************************************************************/ 16*3feeca58Szfw 17*3feeca58Szfwpackage xiangshan.backend.fu 18*3feeca58Szfw 19*3feeca58Szfwimport chipsalliance.rocketchip.config.Parameters 20*3feeca58Szfwimport chisel3._ 21*3feeca58Szfwimport chisel3.util._ 22*3feeca58Szfwimport utils.{LookupTreeDefault, ParallelMux, ParallelXOR, SignExt, XSDebug, ZeroExt} 23*3feeca58Szfwimport xiangshan._ 24*3feeca58Szfwimport xiangshan.backend.fu.util._ 25*3feeca58Szfw 26*3feeca58Szfw 27*3feeca58Szfw 28*3feeca58Szfw 29*3feeca58Szfwclass CountModule(implicit p: Parameters) extends XSModule { 30*3feeca58Szfw val io = IO(new Bundle() { 31*3feeca58Szfw val src = Input(UInt(XLEN.W)) 32*3feeca58Szfw val func = Input(UInt()) 33*3feeca58Szfw val out = Output(UInt(XLEN.W)) 34*3feeca58Szfw }) 35*3feeca58Szfw 36*3feeca58Szfw val funcReg = RegNext(io.func) 37*3feeca58Szfw 38*3feeca58Szfw def encode(bits: UInt): UInt = { 39*3feeca58Szfw LookupTreeDefault(bits, 0.U, List(0.U -> 2.U(2.W), 1.U -> 1.U(2.W))) 40*3feeca58Szfw } 41*3feeca58Szfw def clzi(msb: Int, left: UInt, right: UInt): UInt = { 42*3feeca58Szfw Mux(left(msb), 43*3feeca58Szfw Cat(left(msb) && right(msb), !right(msb), if(msb==1)right(0) else right(msb-1, 0)), 44*3feeca58Szfw left) 45*3feeca58Szfw } 46*3feeca58Szfw 47*3feeca58Szfw val c0 = Wire(Vec(32, UInt(2.W))) 48*3feeca58Szfw val c1 = Wire(Vec(16, UInt(3.W))) 49*3feeca58Szfw val c2 = Reg(Vec(8, UInt(4.W))) 50*3feeca58Szfw val c3 = Wire(Vec(4, UInt(5.W))) 51*3feeca58Szfw val c4 = Wire(Vec(2, UInt(6.W))) 52*3feeca58Szfw 53*3feeca58Szfw val countSrc = Mux(io.func(1), Reverse(io.src), io.src) 54*3feeca58Szfw 55*3feeca58Szfw for(i <- 0 until 32){ c0(i) := encode(countSrc(2*i+1, 2*i)) } 56*3feeca58Szfw for(i <- 0 until 16){ c1(i) := clzi(1, c0(i*2+1), c0(i*2)) } 57*3feeca58Szfw for(i <- 0 until 8){ c2(i) := clzi(2, c1(i*2+1), c1(i*2)) } 58*3feeca58Szfw for(i <- 0 until 4){ c3(i) := clzi(3, c2(i*2+1), c2(i*2)) } 59*3feeca58Szfw for(i <- 0 until 2){ c4(i) := clzi(4, c3(i*2+1), c3(i*2)) } 60*3feeca58Szfw val zeroRes = clzi(5, c4(1), c4(0)) 61*3feeca58Szfw val zeroWRes = Mux(funcReg(1), c4(1), c4(0)) 62*3feeca58Szfw 63*3feeca58Szfw val cpopTmp = Reg(Vec(4, UInt(5.W))) 64*3feeca58Szfw 65*3feeca58Szfw for(i <- 0 until 4){ 66*3feeca58Szfw cpopTmp(i) := PopCount(io.src(i*16+15, i*16)) 67*3feeca58Szfw } 68*3feeca58Szfw 69*3feeca58Szfw val cpopLo32 = cpopTmp(0) +& cpopTmp(1) 70*3feeca58Szfw val cpopHi32 = cpopTmp(2) +& cpopTmp(3) 71*3feeca58Szfw 72*3feeca58Szfw val cpopRes = cpopLo32 +& cpopHi32 73*3feeca58Szfw val cpopWRes = cpopLo32 74*3feeca58Szfw 75*3feeca58Szfw io.out := Mux(funcReg(2), Mux(funcReg(0), cpopWRes, cpopRes), Mux(funcReg(0), zeroWRes, zeroRes)) 76*3feeca58Szfw} 77*3feeca58Szfw 78*3feeca58Szfwclass ClmulModule(implicit p: Parameters) extends XSModule { 79*3feeca58Szfw val io = IO(new Bundle() { 80*3feeca58Szfw val src = Vec(2, Input(UInt(XLEN.W))) 81*3feeca58Szfw val func = Input(UInt()) 82*3feeca58Szfw val out = Output(UInt(XLEN.W)) 83*3feeca58Szfw }) 84*3feeca58Szfw 85*3feeca58Szfw val funcReg = RegNext(io.func) 86*3feeca58Szfw 87*3feeca58Szfw val (src1, src2) = (io.src(0), io.src(1)) 88*3feeca58Szfw 89*3feeca58Szfw val mul0 = Wire(Vec(64, UInt(128.W))) 90*3feeca58Szfw val mul1 = Wire(Vec(32, UInt(128.W))) 91*3feeca58Szfw val mul2 = Wire(Vec(16, UInt(128.W))) 92*3feeca58Szfw val mul3 = Reg(Vec(8, UInt(128.W))) 93*3feeca58Szfw 94*3feeca58Szfw (0 until XLEN) map { i => 95*3feeca58Szfw mul0(i) := Mux(src1(i), if(i==0) src2 else Cat(src2, 0.U(i.W)), 0.U) 96*3feeca58Szfw } 97*3feeca58Szfw 98*3feeca58Szfw (0 until 32) map { i => mul1(i) := mul0(i*2) ^ mul0(i*2+1)} 99*3feeca58Szfw (0 until 16) map { i => mul2(i) := mul1(i*2) ^ mul1(i*2+1)} 100*3feeca58Szfw (0 until 8) map { i => mul3(i) := mul2(i*2) ^ mul2(i*2+1)} 101*3feeca58Szfw 102*3feeca58Szfw val res = ParallelXOR(mul3) 103*3feeca58Szfw 104*3feeca58Szfw val clmul = res(63,0) 105*3feeca58Szfw val clmulh = res(127,64) 106*3feeca58Szfw val clmulr = res(126,63) 107*3feeca58Szfw 108*3feeca58Szfw io.out := LookupTreeDefault(funcReg, clmul, List( 109*3feeca58Szfw BKUOpType.clmul -> clmul, 110*3feeca58Szfw BKUOpType.clmulh -> clmulh, 111*3feeca58Szfw BKUOpType.clmulr -> clmulr 112*3feeca58Szfw )) 113*3feeca58Szfw} 114*3feeca58Szfw 115*3feeca58Szfwclass MiscModule(implicit p: Parameters) extends XSModule { 116*3feeca58Szfw val io = IO(new Bundle() { 117*3feeca58Szfw val src = Vec(2, Input(UInt(XLEN.W))) 118*3feeca58Szfw val func = Input(UInt()) 119*3feeca58Szfw val out = Output(UInt(XLEN.W)) 120*3feeca58Szfw }) 121*3feeca58Szfw 122*3feeca58Szfw val (src1, src2) = (io.src(0), io.src(1)) 123*3feeca58Szfw 124*3feeca58Szfw def xpermLUT(table: UInt, idx: UInt, width: Int) : UInt = { 125*3feeca58Szfw // ParallelMux((0 until XLEN/width).map( i => i.U -> table(i)).map( x => (x._1 === idx, x._2))) 126*3feeca58Szfw LookupTreeDefault(idx, 0.U(width.W), (0 until XLEN/width).map( i => i.U -> table(i*width+width-1, i*width))) 127*3feeca58Szfw } 128*3feeca58Szfw 129*3feeca58Szfw val xpermnVec = Wire(Vec(16, UInt(4.W))) 130*3feeca58Szfw (0 until 16).map( i => xpermnVec(i) := xpermLUT(src1, src2(i*4+3, i*4), 4)) 131*3feeca58Szfw val xpermn = Cat(xpermnVec.reverse) 132*3feeca58Szfw 133*3feeca58Szfw val xpermbVec = Wire(Vec(8, UInt(8.W))) 134*3feeca58Szfw (0 until 8).map( i => xpermbVec(i) := Mux(src2(i*8+7, i*8+3).orR, 0.U, xpermLUT(src1, src2(i*8+2, i*8), 8))) 135*3feeca58Szfw val xpermb = Cat(xpermbVec.reverse) 136*3feeca58Szfw 137*3feeca58Szfw io.out := RegNext(Mux(io.func(0), xpermb, xpermn)) 138*3feeca58Szfw} 139*3feeca58Szfw 140*3feeca58Szfwclass HashModule(implicit p: Parameters) extends XSModule { 141*3feeca58Szfw val io = IO(new Bundle() { 142*3feeca58Szfw val src = Input(UInt(XLEN.W)) 143*3feeca58Szfw val func = Input(UInt()) 144*3feeca58Szfw val out = Output(UInt(XLEN.W)) 145*3feeca58Szfw }) 146*3feeca58Szfw 147*3feeca58Szfw val src1 = io.src 148*3feeca58Szfw 149*3feeca58Szfw val sha256sum0 = ROR32(src1, 2) ^ ROR32(src1, 13) ^ ROR32(src1, 22) 150*3feeca58Szfw val sha256sum1 = ROR32(src1, 6) ^ ROR32(src1, 11) ^ ROR32(src1, 25) 151*3feeca58Szfw val sha256sig0 = ROR32(src1, 7) ^ ROR32(src1, 18) ^ SHR32(src1, 3) 152*3feeca58Szfw val sha256sig1 = ROR32(src1, 17) ^ ROR32(src1, 19) ^ SHR32(src1, 10) 153*3feeca58Szfw val sha512sum0 = ROR64(src1, 28) ^ ROR64(src1, 34) ^ ROR64(src1, 39) 154*3feeca58Szfw val sha512sum1 = ROR64(src1, 14) ^ ROR64(src1, 18) ^ ROR64(src1, 41) 155*3feeca58Szfw val sha512sig0 = ROR64(src1, 1) ^ ROR64(src1, 8) ^ SHR64(src1, 7) 156*3feeca58Szfw val sha512sig1 = ROR64(src1, 19) ^ ROR64(src1, 61) ^ SHR64(src1, 6) 157*3feeca58Szfw val sm3p0 = ROR32(src1, 23) ^ ROR32(src1, 15) ^ src1 158*3feeca58Szfw val sm3p1 = ROR32(src1, 9) ^ ROR32(src1, 17) ^ src1 159*3feeca58Szfw 160*3feeca58Szfw val shaSource = VecInit(Seq( 161*3feeca58Szfw SignExt(sha256sum0(31,0), XLEN), 162*3feeca58Szfw SignExt(sha256sum1(31,0), XLEN), 163*3feeca58Szfw SignExt(sha256sig0(31,0), XLEN), 164*3feeca58Szfw SignExt(sha256sig1(31,0), XLEN), 165*3feeca58Szfw sha512sum0, 166*3feeca58Szfw sha512sum1, 167*3feeca58Szfw sha512sig0, 168*3feeca58Szfw sha512sig1 169*3feeca58Szfw )) 170*3feeca58Szfw val sha = shaSource(io.func(2,0)) 171*3feeca58Szfw val sm3 = Mux(io.func(0), SignExt(sm3p1(31,0), XLEN), SignExt(sm3p0(31,0), XLEN)) 172*3feeca58Szfw 173*3feeca58Szfw io.out := RegNext(Mux(io.func(3), sm3, sha)) 174*3feeca58Szfw} 175*3feeca58Szfw 176*3feeca58Szfwclass BlockCipherModule(implicit p: Parameters) extends XSModule { 177*3feeca58Szfw val io = IO(new Bundle() { 178*3feeca58Szfw val src = Vec(2, Input(UInt(XLEN.W))) 179*3feeca58Szfw val func = Input(UInt()) 180*3feeca58Szfw val out = Output(UInt(XLEN.W)) 181*3feeca58Szfw }) 182*3feeca58Szfw 183*3feeca58Szfw val (src1, src2, func, funcReg) = (io.src(0), io.src(1), io.func, RegNext(io.func)) 184*3feeca58Szfw 185*3feeca58Szfw val src1Bytes = VecInit((0 until 8).map(i => src1(i*8+7, i*8))) 186*3feeca58Szfw val src2Bytes = VecInit((0 until 8).map(i => src2(i*8+7, i*8))) 187*3feeca58Szfw 188*3feeca58Szfw // AES 189*3feeca58Szfw val aesSboxIn = ForwardShiftRows(src1Bytes, src2Bytes) 190*3feeca58Szfw val aesSboxMid = Reg(Vec(8, Vec(18, Bool()))) 191*3feeca58Szfw val aesSboxOut = Wire(Vec(8, UInt(8.W))) 192*3feeca58Szfw 193*3feeca58Szfw val iaesSboxIn = InverseShiftRows(src1Bytes, src2Bytes) 194*3feeca58Szfw val iaesSboxMid = Reg(Vec(8, Vec(18, Bool()))) 195*3feeca58Szfw val iaesSboxOut = Wire(Vec(8, UInt(8.W))) 196*3feeca58Szfw 197*3feeca58Szfw aesSboxOut.zip(aesSboxMid).zip(aesSboxIn)foreach { case ((out, mid), in) => 198*3feeca58Szfw mid := SboxInv(SboxAesTop(in)) 199*3feeca58Szfw out := SboxAesOut(mid) 200*3feeca58Szfw } 201*3feeca58Szfw 202*3feeca58Szfw iaesSboxOut.zip(iaesSboxMid).zip(iaesSboxIn)foreach { case ((out, mid), in) => 203*3feeca58Szfw mid := SboxInv(SboxIaesTop(in)) 204*3feeca58Szfw out := SboxIaesOut(mid) 205*3feeca58Szfw } 206*3feeca58Szfw 207*3feeca58Szfw val aes64es = aesSboxOut.asUInt 208*3feeca58Szfw val aes64ds = iaesSboxOut.asUInt 209*3feeca58Szfw 210*3feeca58Szfw val imMinIn = RegNext(src1Bytes) 211*3feeca58Szfw 212*3feeca58Szfw val aes64esm = Cat(MixFwd(Seq(aesSboxOut(4), aesSboxOut(5), aesSboxOut(6), aesSboxOut(7))), 213*3feeca58Szfw MixFwd(Seq(aesSboxOut(0), aesSboxOut(1), aesSboxOut(2), aesSboxOut(3)))) 214*3feeca58Szfw val aes64dsm = Cat(MixInv(Seq(iaesSboxOut(4), iaesSboxOut(5), iaesSboxOut(6), iaesSboxOut(7))), 215*3feeca58Szfw MixInv(Seq(iaesSboxOut(0), iaesSboxOut(1), iaesSboxOut(2), iaesSboxOut(3)))) 216*3feeca58Szfw val aes64im = Cat(MixInv(Seq(imMinIn(4), imMinIn(5), imMinIn(6), imMinIn(7))), 217*3feeca58Szfw MixInv(Seq(imMinIn(0), imMinIn(1), imMinIn(2), imMinIn(3)))) 218*3feeca58Szfw 219*3feeca58Szfw 220*3feeca58Szfw val rcon = WireInit(VecInit(Seq("h01".U, "h02".U, "h04".U, "h08".U, 221*3feeca58Szfw "h10".U, "h20".U, "h40".U, "h80".U, 222*3feeca58Szfw "h1b".U, "h36".U, "h00".U))) 223*3feeca58Szfw 224*3feeca58Szfw val ksSboxIn = Wire(Vec(4, UInt(8.W))) 225*3feeca58Szfw val ksSboxTop = Reg(Vec(4, Vec(21, Bool()))) 226*3feeca58Szfw val ksSboxOut = Wire(Vec(4, UInt(8.W))) 227*3feeca58Szfw ksSboxIn(0) := Mux(src2(3,0) === "ha".U, src1Bytes(4), src1Bytes(5)) 228*3feeca58Szfw ksSboxIn(1) := Mux(src2(3,0) === "ha".U, src1Bytes(5), src1Bytes(6)) 229*3feeca58Szfw ksSboxIn(2) := Mux(src2(3,0) === "ha".U, src1Bytes(6), src1Bytes(7)) 230*3feeca58Szfw ksSboxIn(3) := Mux(src2(3,0) === "ha".U, src1Bytes(7), src1Bytes(4)) 231*3feeca58Szfw ksSboxOut.zip(ksSboxTop).zip(ksSboxIn).foreach{ case ((out, top), in) => 232*3feeca58Szfw top := SboxAesTop(in) 233*3feeca58Szfw out := SboxAesOut(SboxInv(top)) 234*3feeca58Szfw } 235*3feeca58Szfw 236*3feeca58Szfw val ks1Idx = RegNext(src2(3,0)) 237*3feeca58Szfw val aes64ks1i = Cat(ksSboxOut.asUInt ^ rcon(ks1Idx), ksSboxOut.asUInt ^ rcon(ks1Idx)) 238*3feeca58Szfw 239*3feeca58Szfw val aes64ks2Temp = src1(63,32) ^ src2(31,0) 240*3feeca58Szfw val aes64ks2 = RegNext(Cat(aes64ks2Temp ^ src2(63,32), aes64ks2Temp)) 241*3feeca58Szfw 242*3feeca58Szfw val aesResult = LookupTreeDefault(funcReg, aes64es, List( 243*3feeca58Szfw BKUOpType.aes64es -> aes64es, 244*3feeca58Szfw BKUOpType.aes64esm -> aes64esm, 245*3feeca58Szfw BKUOpType.aes64ds -> aes64ds, 246*3feeca58Szfw BKUOpType.aes64dsm -> aes64dsm, 247*3feeca58Szfw BKUOpType.aes64im -> aes64im, 248*3feeca58Szfw BKUOpType.aes64ks1i -> aes64ks1i, 249*3feeca58Szfw BKUOpType.aes64ks2 -> aes64ks2 250*3feeca58Szfw )) 251*3feeca58Szfw 252*3feeca58Szfw // SM4 253*3feeca58Szfw val sm4SboxIn = src2Bytes(func(1,0)) 254*3feeca58Szfw val sm4SboxTop = Reg(Vec(21, Bool())) 255*3feeca58Szfw sm4SboxTop := SboxSm4Top(sm4SboxIn) 256*3feeca58Szfw val sm4SboxOut = SboxSm4Out(SboxInv(sm4SboxTop)) 257*3feeca58Szfw 258*3feeca58Szfw val sm4ed = sm4SboxOut ^ (sm4SboxOut<<8) ^ (sm4SboxOut<<2) ^ (sm4SboxOut<<18) ^ (sm4SboxOut&"h3f".U<<26) ^ (sm4SboxOut&"hc0".U<<10) 259*3feeca58Szfw val sm4ks = sm4SboxOut ^ (sm4SboxOut&"h07".U<<29) ^ (sm4SboxOut&"hfe".U<<7) ^ (sm4SboxOut&"h01".U<<23) ^ (sm4SboxOut&"hf8".U<<13) 260*3feeca58Szfw val sm4Source = VecInit(Seq( 261*3feeca58Szfw sm4ed(31,0), 262*3feeca58Szfw Cat(sm4ed(23,0), sm4ed(31,24)), 263*3feeca58Szfw Cat(sm4ed(15,0), sm4ed(31,16)), 264*3feeca58Szfw Cat(sm4ed( 7,0), sm4ed(31,8)), 265*3feeca58Szfw sm4ks(31,0), 266*3feeca58Szfw Cat(sm4ks(23,0), sm4ks(31,24)), 267*3feeca58Szfw Cat(sm4ks(15,0), sm4ks(31,16)), 268*3feeca58Szfw Cat(sm4ks( 7,0), sm4ks(31,8)) 269*3feeca58Szfw )) 270*3feeca58Szfw val sm4Result = SignExt((sm4Source(funcReg(2,0)) ^ RegNext(src1(31,0)))(31,0), XLEN) 271*3feeca58Szfw 272*3feeca58Szfw io.out := Mux(funcReg(3), sm4Result, aesResult) 273*3feeca58Szfw} 274*3feeca58Szfw 275*3feeca58Szfwclass CryptoModule(implicit p: Parameters) extends XSModule { 276*3feeca58Szfw val io = IO(new Bundle() { 277*3feeca58Szfw val src = Vec(2, Input(UInt(XLEN.W))) 278*3feeca58Szfw val func = Input(UInt()) 279*3feeca58Szfw val out = Output(UInt(XLEN.W)) 280*3feeca58Szfw }) 281*3feeca58Szfw 282*3feeca58Szfw val (src1, src2, func) = (io.src(0), io.src(1), io.func) 283*3feeca58Szfw val funcReg = RegNext(func) 284*3feeca58Szfw 285*3feeca58Szfw val hashModule = Module(new HashModule) 286*3feeca58Szfw hashModule.io.src := src1 287*3feeca58Szfw hashModule.io.func := func 288*3feeca58Szfw 289*3feeca58Szfw val blockCipherModule = Module(new BlockCipherModule) 290*3feeca58Szfw blockCipherModule.io.src(0) := src1 291*3feeca58Szfw blockCipherModule.io.src(1) := src2 292*3feeca58Szfw blockCipherModule.io.func := func 293*3feeca58Szfw 294*3feeca58Szfw io.out := Mux(funcReg(4), hashModule.io.out, blockCipherModule.io.out) 295*3feeca58Szfw} 296*3feeca58Szfw 297*3feeca58Szfwclass Bku(implicit p: Parameters) extends FunctionUnit with HasPipelineReg { 298*3feeca58Szfw 299*3feeca58Szfw override def latency = 1 300*3feeca58Szfw 301*3feeca58Szfw val (src1, src2, func, funcReg) = ( 302*3feeca58Szfw io.in.bits.src(0), 303*3feeca58Szfw io.in.bits.src(1), 304*3feeca58Szfw io.in.bits.uop.ctrl.fuOpType, 305*3feeca58Szfw uopVec(latency).ctrl.fuOpType 306*3feeca58Szfw ) 307*3feeca58Szfw 308*3feeca58Szfw val countModule = Module(new CountModule) 309*3feeca58Szfw countModule.io.src := src1 310*3feeca58Szfw countModule.io.func := func 311*3feeca58Szfw 312*3feeca58Szfw val clmulModule = Module(new ClmulModule) 313*3feeca58Szfw clmulModule.io.src(0) := src1 314*3feeca58Szfw clmulModule.io.src(1) := src2 315*3feeca58Szfw clmulModule.io.func := func 316*3feeca58Szfw 317*3feeca58Szfw val miscModule = Module(new MiscModule) 318*3feeca58Szfw miscModule.io.src(0) := src1 319*3feeca58Szfw miscModule.io.src(1) := src2 320*3feeca58Szfw miscModule.io.func := func 321*3feeca58Szfw 322*3feeca58Szfw val cryptoModule = Module(new CryptoModule) 323*3feeca58Szfw cryptoModule.io.src(0) := src1 324*3feeca58Szfw cryptoModule.io.src(1) := src2 325*3feeca58Szfw cryptoModule.io.func := func 326*3feeca58Szfw 327*3feeca58Szfw 328*3feeca58Szfw val result = Mux(funcReg(5), cryptoModule.io.out, 329*3feeca58Szfw Mux(funcReg(3), countModule.io.out, 330*3feeca58Szfw Mux(funcReg(2),miscModule.io.out, clmulModule.io.out))) 331*3feeca58Szfw 332*3feeca58Szfw io.out.bits.data := result 333*3feeca58Szfw} 334