xref: /XiangShan/src/main/scala/device/MemEncrypt.scala (revision 11269ca741bcbed259cf718605d4720728016f90)
1/***************************************************************************************
2* Copyright (c) 2024-2025 Institute of Information Engineering, Chinese Academy of Sciences
3*
4* XiangShan is licensed under Mulan PSL v2.
5* You can use this software according to the terms and conditions of the Mulan PSL v2.
6* You may obtain a copy of Mulan PSL v2 at:
7*          http://license.coscl.org.cn/MulanPSL2
8*
9* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
10* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
11* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
12*
13* See the Mulan PSL v2 for more details.
14***************************************************************************************/
15
16package device
17
18import chisel3._
19import chisel3.util._
20import chisel3.util.HasBlackBoxResource
21import org.chipsalliance.cde.config.Field
22import org.chipsalliance.cde.config.Parameters
23import freechips.rocketchip.amba.axi4._
24import freechips.rocketchip.diplomacy._
25import freechips.rocketchip.util._
26import freechips.rocketchip.amba.apb._
27import freechips.rocketchip.tilelink.AXI4TLState
28import javax.xml.crypto.dsig.keyinfo.KeyInfo
29import system._
30
31case object MemcEdgeInKey extends Field[AXI4EdgeParameters]
32case object MemcEdgeOutKey extends Field[AXI4EdgeParameters]
33
34trait Memconsts {
35  val p: Parameters
36  val cvm = p(CVMParamskey)
37  val soc = p(SoCParamsKey)
38  val PAddrBits= soc.PAddrBits
39  val KeyIDBits= cvm.KeyIDBits
40  val MemencPipes = cvm.MemencPipes
41  lazy val MemcedgeIn = p(MemcEdgeInKey)
42  lazy val MemcedgeOut = p(MemcEdgeOutKey)
43  require (isPow2(MemencPipes), s"AXI4MemEncrypt: MemencPipes must be a power of two, not $MemencPipes")
44  require (PAddrBits > KeyIDBits, s"AXI4MemEncrypt: PAddrBits must be greater than KeyIDBits")
45  def HasDelayNoencryption = cvm.HasDelayNoencryption
46}
47
48
49abstract class MemEncryptModule(implicit val p: Parameters) extends Module with Memconsts
50
51class TweakEncrptyQueue(implicit p: Parameters) extends MemEncryptModule
52{
53  val io = IO(new Bundle {
54    val enq = Flipped(DecoupledIO(new Bundle {
55      val addr = UInt(PAddrBits.W)
56      val len  = UInt(MemcedgeIn.bundle.lenBits.W)  // number of beats - 1
57    }))
58    val deq = DecoupledIO(new Bundle {
59      val keyid = UInt(KeyIDBits.W)
60      val tweak = UInt(MemcedgeIn.bundle.dataBits.W)
61      val addr = UInt(MemcedgeIn.bundle.addrBits.W)
62    })
63    val tweak_round_keys = Input(Vec(32, UInt(32.W)))
64  })
65    val tweak_in = Cat(0.U((128 - PAddrBits).W), Cat(io.enq.bits.addr(PAddrBits - 1, 6), 0.U(6.W)))
66
67    val tweak_enc_module = Module(new TweakEncrypt(opt = true))
68    val tweakgf128_module = Module(new TweakGF128())
69
70    tweak_enc_module.io.tweak_enc_req.valid           := io.enq.valid
71    tweak_enc_module.io.tweak_enc_resp.ready          := tweakgf128_module.io.req.ready
72    tweak_enc_module.io.tweak_enc_req.bits.tweak             := tweak_in
73    tweak_enc_module.io.tweak_enc_req.bits.addr_in           := io.enq.bits.addr
74    tweak_enc_module.io.tweak_enc_req.bits.len_in            := io.enq.bits.len
75    tweak_enc_module.io.tweak_enc_req.bits.id_in             := 0.U
76    tweak_enc_module.io.tweak_enc_req.bits.tweak_round_keys  := io.tweak_round_keys
77
78    io.enq.ready := tweak_enc_module.io.tweak_enc_req.ready
79
80    tweakgf128_module.io.req.bits.len      := tweak_enc_module.io.tweak_enc_resp.bits.len_out
81    tweakgf128_module.io.req.bits.addr     := tweak_enc_module.io.tweak_enc_resp.bits.addr_out
82    tweakgf128_module.io.req.bits.tweak_in := tweak_enc_module.io.tweak_enc_resp.bits.tweak_encrpty
83    tweakgf128_module.io.req.valid         := tweak_enc_module.io.tweak_enc_resp.valid
84    tweakgf128_module.io.resp.ready        := io.deq.ready
85
86    io.deq.bits.keyid := tweakgf128_module.io.resp.bits.keyid_out
87    io.deq.bits.tweak := tweakgf128_module.io.resp.bits.tweak_out
88    io.deq.bits.addr  := tweakgf128_module.io.resp.bits.addr_out
89    io.deq.valid      := tweakgf128_module.io.resp.valid
90}
91
92class AXI4W_KT(opt:Boolean)(implicit val p: Parameters) extends Bundle with Memconsts
93{
94  val edgeUse = if (opt) MemcedgeIn else MemcedgeOut
95  val axi4 = new AXI4BundleW(edgeUse.bundle)
96  val keyid = UInt(KeyIDBits.W)
97  val tweak = UInt(edgeUse.bundle.dataBits.W)
98}
99
100// Used to indicate the source of the req (L1I/L1D/PTW)
101case object ReqSourceKey extends ControlKey[UInt]("reqSource")
102
103class AXI4WriteMachine(implicit p: Parameters) extends MemEncryptModule
104{
105  val io = IO(new Bundle {
106    val in_w = Flipped(Irrevocable(new AXI4BundleW(MemcedgeOut.bundle)))
107    val in_kt = Flipped(DecoupledIO(new Bundle {
108      val keyid = UInt(KeyIDBits.W)
109      val tweak = UInt(MemcedgeOut.bundle.dataBits.W)
110      val addr = UInt(MemcedgeOut.bundle.addrBits.W)
111    }))
112    val out_ar = Irrevocable(new AXI4BundleAR(MemcedgeOut.bundle))
113    val in_r = Flipped(Irrevocable(new AXI4BundleR(MemcedgeOut.bundle)))
114    val out_w = DecoupledIO(new AXI4W_KT(true))
115    val uncache_en = Output(Bool())
116    val uncache_commit = Input(Bool())
117  })
118  // ----------------
119  // s0 stage
120  // ----------------
121  val w_cacheable = io.in_w.bits.strb.andR
122
123  // ----------------
124  // s1 stage
125  // ----------------
126  val in_w_v   = RegInit(false.B)
127  val in_kt_v  = RegInit(false.B)
128
129  val in_w_req  = RegEnable(io.in_w.bits, io.in_w.fire)
130  val in_kt_req = RegEnable(io.in_kt.bits, io.in_kt.fire)
131  io.in_w.ready := !in_w_v || io.out_w.fire
132  io.in_kt.ready := !in_kt_v || io.out_w.fire
133
134  when(io.in_w.fire) {
135    in_w_v := true.B
136  }.elsewhen(io.out_w.fire) {
137    in_w_v := false.B
138  }.otherwise {
139    in_w_v := in_w_v
140  }
141
142  when(io.in_kt.fire) {
143    in_kt_v := true.B
144  }.elsewhen(io.out_w.fire) {
145    in_kt_v := false.B
146  }.otherwise {
147    in_kt_v := in_kt_v
148  }
149
150  // -----------------------------
151  // s2 stage only uncacheable use
152  // -----------------------------
153  val out_ar_v = RegInit(false.B)
154  val out_ar_mask = RegInit(false.B)
155  val in_r_v   = RegInit(false.B)
156  val r_uncache_en = RegInit(false.B)
157  when(io.in_r.fire) {
158    in_r_v := true.B
159  }.elsewhen(io.out_w.fire) {
160    in_r_v := false.B
161  }.otherwise {
162    in_r_v := in_r_v
163  }
164
165  when(io.in_r.fire) {
166    r_uncache_en := true.B
167  }.elsewhen(io.uncache_commit) {
168    r_uncache_en := false.B
169  }.otherwise {
170    r_uncache_en := r_uncache_en
171  }
172
173  io.in_r.ready := !r_uncache_en || io.uncache_commit
174  io.uncache_en := r_uncache_en
175
176  val s1_w_cacheable = RegEnable(w_cacheable, io.in_w.fire)
177
178  when(in_w_v && in_kt_v && !s1_w_cacheable && !out_ar_mask) {
179    out_ar_v := true.B
180  }.elsewhen(io.out_ar.fire) {
181    out_ar_v := false.B
182  }.otherwise {
183    out_ar_v := out_ar_v
184  }
185
186  when(in_w_v && in_kt_v && !s1_w_cacheable && !out_ar_mask) {
187    out_ar_mask := true.B
188  }.elsewhen(io.out_w.fire) {
189    out_ar_mask := false.B
190  }.otherwise {
191    out_ar_mask := out_ar_mask
192  }
193
194  io.out_ar.valid := out_ar_v
195  val ar = io.out_ar.bits
196  ar.id    := 1.U << (MemcedgeOut.bundle.idBits - 1)
197  ar.addr  := (in_kt_req.addr >> log2Ceil(MemcedgeOut.bundle.dataBits/8)) << log2Ceil(MemcedgeOut.bundle.dataBits/8)
198  ar.len   := 0.U
199  ar.size  := log2Ceil(MemcedgeOut.bundle.dataBits/8).U
200  ar.burst := AXI4Parameters.BURST_INCR
201  ar.lock  := 0.U // not exclusive (LR/SC unsupported b/c no forward progress guarantee)
202  ar.cache := 0.U // do not allow AXI to modify our transactions
203  ar.prot  := AXI4Parameters.PROT_PRIVILEGED
204  ar.qos   := 0.U // no QoS
205  if (MemcedgeOut.bundle.echoFields != Nil) {
206    val ar_extra = ar.echo(AXI4TLState)
207      ar_extra.source := 0.U
208      ar_extra.size   := 0.U
209  }
210  if (MemcedgeOut.bundle.requestFields != Nil) {
211    val ar_user = ar.user(ReqSourceKey)
212      ar_user := 0.U
213  }
214
215  def gen_wmask(strb: UInt): UInt = {
216    val extendedBits = VecInit((0 until MemcedgeOut.bundle.dataBits/8).map(i => Cat(Fill(7, strb((MemcedgeOut.bundle.dataBits/8)-1-i)), strb((MemcedgeOut.bundle.dataBits/8)-1-i))))
217    extendedBits.reduce(_ ## _)
218  }
219
220  val new_data = Reg(UInt(MemcedgeOut.bundle.dataBits.W))
221  val new_strb = ~0.U((MemcedgeOut.bundle.dataBits/8).W)
222  val wmask = gen_wmask(in_w_req.strb)
223
224  when(io.in_r.fire) {
225    new_data := (io.in_r.bits.data & ~wmask) | (in_w_req.data & wmask)
226  }
227
228  when(s1_w_cacheable) {
229    io.out_w.valid := in_w_v && in_kt_v
230    io.out_w.bits.axi4 := in_w_req
231    io.out_w.bits.keyid := in_kt_req.keyid
232    io.out_w.bits.tweak := in_kt_req.tweak
233  }.otherwise {
234    io.out_w.valid := in_w_v && in_kt_v && in_r_v
235    io.out_w.bits.axi4 := in_w_req
236    io.out_w.bits.axi4.data := new_data
237    io.out_w.bits.axi4.strb := new_strb
238    io.out_w.bits.keyid := in_kt_req.keyid
239    io.out_w.bits.tweak := in_kt_req.tweak
240  }
241
242}
243
244class WdataEncrptyPipe(implicit p: Parameters) extends MemEncryptModule
245{
246  val io = IO(new Bundle {
247    val in_w = Flipped(DecoupledIO(new AXI4W_KT(true)))
248    val out_w = Irrevocable(new AXI4BundleW(MemcedgeIn.bundle))
249    val enc_keyids = Output(Vec(MemencPipes, UInt(KeyIDBits.W)))
250    val enc_round_keys = Input(Vec(MemencPipes, UInt((32*32/MemencPipes).W)))
251  })
252  val reg_encdec_result_0 = Reg(Vec(MemencPipes, UInt(128.W)))
253  val reg_encdec_result_1 = Reg(Vec(MemencPipes, UInt(128.W)))
254  val reg_axi4_other_result = Reg(Vec(MemencPipes, new AXI4BundleWWithoutData(MemcedgeIn.bundle)))
255  val reg_tweak_result_0 = Reg(Vec(MemencPipes, UInt(128.W)))
256  val reg_tweak_result_1 = Reg(Vec(MemencPipes, UInt(128.W)))
257  val reg_keyid = Reg(Vec(MemencPipes, UInt(KeyIDBits.W)))
258  val reg_encdec_valid = RegInit(VecInit(Seq.fill(MemencPipes)(false.B)))
259  val wire_ready_result = WireInit(VecInit(Seq.fill(MemencPipes)(false.B)))
260
261  val wire_axi4_other = Wire(new AXI4BundleWWithoutData(MemcedgeIn.bundle))
262  wire_axi4_other.strb := io.in_w.bits.axi4.strb
263  wire_axi4_other.last := io.in_w.bits.axi4.last
264  wire_axi4_other.user := io.in_w.bits.axi4.user
265
266
267  val pipes_first_data_0 = Wire(UInt(128.W))
268  val pipes_first_data_1 = Wire(UInt(128.W))
269  if (HasDelayNoencryption) {
270    pipes_first_data_0 := io.in_w.bits.axi4.data(127,0)
271    pipes_first_data_1 := io.in_w.bits.axi4.data(255,128)
272  } else {
273    pipes_first_data_0 := io.in_w.bits.axi4.data(127,0) ^ io.in_w.bits.tweak(127, 0)
274    pipes_first_data_1 := io.in_w.bits.axi4.data(255,128) ^ io.in_w.bits.tweak(255,128)
275  }
276
277  def configureModule(flag: Boolean, i: Int, keyId: UInt, dataIn: UInt, tweakIn: UInt, axi4In: AXI4BundleWWithoutData, roundKeys: UInt): OnePipeEncBase = {
278      when(wire_ready_result(i) && (if (i == 0) io.in_w.valid else reg_encdec_valid(i-1))) {
279        reg_encdec_valid(i) := true.B
280      }.elsewhen(reg_encdec_valid(i) && (if (i == MemencPipes - 1) io.out_w.ready else wire_ready_result(i + 1))) {
281        reg_encdec_valid(i) := false.B
282      }.otherwise {
283        reg_encdec_valid(i) := reg_encdec_valid(i)
284      }
285
286      wire_ready_result(i) := !reg_encdec_valid(i) || (reg_encdec_valid(i) && (if (i == MemencPipes - 1) io.out_w.ready else wire_ready_result(i+1)))
287
288      val module: OnePipeEncBase = if (HasDelayNoencryption) Module(new OnePipeForEncNoEnc()) else Module(new OnePipeForEnc())
289      module.io.onepipe_in.keyid := keyId
290      module.io.onepipe_in.data_in := dataIn
291      module.io.onepipe_in.tweak_in := tweakIn
292      module.io.onepipe_in.axi4_other := axi4In
293      for (i <- 0 until 32/MemencPipes) {
294        module.io.onepipe_in.round_key_in(i) := roundKeys(i * 32 + 31, i * 32)
295      }
296      when((if (i == 0) io.in_w.valid else reg_encdec_valid(i-1)) && wire_ready_result(i)) {
297        if (flag) {
298          reg_encdec_result_0(i) := module.io.onepipe_out.result_out
299          reg_tweak_result_0(i) := module.io.onepipe_out.tweak_out
300          reg_axi4_other_result(i) := module.io.onepipe_out.axi4_other_out
301          reg_keyid(i) := module.io.onepipe_out.keyid_out
302        } else {
303          reg_encdec_result_1(i) := module.io.onepipe_out.result_out
304          reg_tweak_result_1(i) := module.io.onepipe_out.tweak_out
305        }
306      }
307      io.enc_keyids(i) := module.io.onepipe_out.keyid_out
308      module
309  }
310  val modules_0 = (0 until MemencPipes).map { i =>
311    if (i == 0) {
312      configureModule(true, i, io.in_w.bits.keyid, pipes_first_data_0, io.in_w.bits.tweak(127, 0), wire_axi4_other, io.enc_round_keys(i))
313    } else {
314      configureModule(true, i, reg_keyid(i-1), reg_encdec_result_0(i-1), reg_tweak_result_0(i-1), reg_axi4_other_result(i-1), io.enc_round_keys(i))
315    }
316  }
317  val modules_1 = (0 until MemencPipes).map { i =>
318    if (i == 0) {
319      configureModule(false, i, io.in_w.bits.keyid, pipes_first_data_1, io.in_w.bits.tweak(255,128), wire_axi4_other, io.enc_round_keys(i))
320    } else {
321      configureModule(false, i, reg_keyid(i-1), reg_encdec_result_1(i-1), reg_tweak_result_1(i-1), reg_axi4_other_result(i-1), io.enc_round_keys(i))
322    }
323  }
324  if (HasDelayNoencryption) {
325    io.out_w.bits.data := Cat(reg_encdec_result_1.last, reg_encdec_result_0.last)
326  } else {
327    val enc_0_out = Cat(
328      reg_encdec_result_0.last(31, 0),
329      reg_encdec_result_0.last(63, 32),
330      reg_encdec_result_0.last(95, 64),
331      reg_encdec_result_0.last(127, 96)
332    )
333    val enc_1_out = Cat(
334      reg_encdec_result_1.last(31, 0),
335      reg_encdec_result_1.last(63, 32),
336      reg_encdec_result_1.last(95, 64),
337      reg_encdec_result_1.last(127, 96)
338    )
339    io.out_w.bits.data := Cat(enc_1_out ^ reg_tweak_result_1.last, enc_0_out ^ reg_tweak_result_0.last)
340  }
341
342  io.out_w.bits.strb := reg_axi4_other_result.last.strb
343  io.out_w.bits.last := reg_axi4_other_result.last.last
344  io.out_w.bits.user := reg_axi4_other_result.last.user
345  io.out_w.valid := reg_encdec_valid.last
346  io.in_w.ready := wire_ready_result(0)
347}
348
349class TweakEncrptyTable(implicit p: Parameters) extends MemEncryptModule
350{
351  val io = IO(new Bundle {
352    val enq = Flipped(DecoupledIO(new Bundle {
353      val addr = UInt(PAddrBits.W)
354      val len  = UInt(MemcedgeOut.bundle.lenBits.W)  // number of beats - 1
355      val id  = UInt(MemcedgeOut.bundle.idBits.W)    // 7 bits
356    }))
357    val req = Flipped(DecoupledIO(new Bundle {
358      val id  = UInt(MemcedgeOut.bundle.idBits.W)
359    }))
360    val resp = DecoupledIO(new Bundle {
361      val keyid = UInt(KeyIDBits.W)
362      val tweak = UInt(MemcedgeOut.bundle.dataBits.W)
363    })
364    val dec_r = new Bundle {
365      val id = Input(UInt(MemcedgeOut.bundle.idBits.W))
366      val mode = Output(Bool())
367    }
368    val dec_keyid = Output(UInt(KeyIDBits.W))
369    val dec_mode = Input(Bool())
370    val tweak_round_keys = Input(Vec(32, UInt(32.W)))
371    val memenc_enable = Input(Bool())
372  })
373
374  val tweak_in = Cat(0.U((128 - PAddrBits).W), Cat(io.enq.bits.addr(PAddrBits-1, 6), 0.U(6.W)))
375  // query the dec_mode from the round key
376  io.dec_keyid := io.enq.bits.addr(PAddrBits - 1, PAddrBits - KeyIDBits)
377
378  val tweak_enc_module = Module(new TweakEncrypt(opt = false))
379  val tweak_table = Module(new TweakTable())
380  val tweak_gf128 = Module(new GF128())
381
382  // updata mode table
383  tweak_table.io.w_mode.bits.id       := io.enq.bits.id
384  tweak_table.io.w_mode.bits.dec_mode := io.dec_mode && io.memenc_enable
385  tweak_table.io.w_mode.valid         := io.enq.fire
386
387  tweak_enc_module.io.tweak_enc_resp.ready          := tweak_table.io.write.ready   // always true
388  tweak_enc_module.io.tweak_enc_req.bits.tweak             := tweak_in
389  tweak_enc_module.io.tweak_enc_req.bits.addr_in           := io.enq.bits.addr
390  tweak_enc_module.io.tweak_enc_req.bits.len_in            := io.enq.bits.len
391  tweak_enc_module.io.tweak_enc_req.bits.id_in             := io.enq.bits.id
392  tweak_enc_module.io.tweak_enc_req.bits.tweak_round_keys  := io.tweak_round_keys
393  tweak_enc_module.io.tweak_enc_req.valid                  := io.enq.valid && io.dec_mode && io.memenc_enable
394
395  io.enq.ready := tweak_enc_module.io.tweak_enc_req.ready
396
397  // write signal in tweak table
398  tweak_table.io.write.valid       := tweak_enc_module.io.tweak_enc_resp.valid
399  tweak_table.io.write.bits.id            := tweak_enc_module.io.tweak_enc_resp.bits.id_out
400  tweak_table.io.write.bits.addr          := tweak_enc_module.io.tweak_enc_resp.bits.addr_out
401  tweak_table.io.write.bits.len           := tweak_enc_module.io.tweak_enc_resp.bits.len_out
402  tweak_table.io.write.bits.tweak_encrpty := tweak_enc_module.io.tweak_enc_resp.bits.tweak_encrpty
403
404  // read signal in tweak table
405  tweak_table.io.req.valid   := io.req.valid
406  tweak_table.io.resp.ready  := io.resp.ready
407
408  tweak_table.io.req.bits.read_id := io.req.bits.id
409
410  val tweak_encrpty = tweak_table.io.resp.bits.read_tweak
411  val tweak_counter = tweak_table.io.resp.bits.read_sel_counter
412  val keyid         = tweak_table.io.resp.bits.read_keyid
413
414  tweak_table.io.r_mode.id := io.dec_r.id
415  val mode = tweak_table.io.r_mode.dec_mode
416  io.dec_r.mode        := mode
417
418  tweak_gf128.io.tweak_in := tweak_encrpty
419  io.resp.bits.tweak := Mux(tweak_counter, tweak_gf128.io.tweak_out(511, 256), tweak_gf128.io.tweak_out(255, 0))
420  io.resp.bits.keyid := keyid
421  io.resp.valid      := tweak_table.io.resp.valid
422  io.req.ready       := tweak_table.io.req.ready
423
424}
425
426class AXI4R_KT(opt:Boolean)(implicit val p: Parameters) extends Bundle with Memconsts
427{
428  val edgeUse = if (opt) MemcedgeIn else MemcedgeOut
429  val axi4 = new AXI4BundleR(edgeUse.bundle)
430  val keyid = UInt(KeyIDBits.W)
431  val tweak = UInt(edgeUse.bundle.dataBits.W)
432}
433
434class AXI4ReadMachine(implicit p: Parameters) extends MemEncryptModule
435{
436  val io = IO(new Bundle {
437    val in_r = Flipped(Irrevocable(new AXI4BundleR(MemcedgeOut.bundle)))
438    val kt_req = DecoupledIO(new Bundle {
439      val id  = UInt(MemcedgeOut.bundle.idBits.W)
440    })
441    val in_kt = Flipped(DecoupledIO(new Bundle {
442      val keyid = UInt(KeyIDBits.W)
443      val tweak = UInt(MemcedgeOut.bundle.dataBits.W)
444    }))
445    val out_r = DecoupledIO(new AXI4R_KT(false))
446  })
447  val s1_r_val = RegInit(false.B)
448  val s1_r_req = RegEnable(io.in_r.bits, io.in_r.fire)
449  val s1_r_out_rdy = Wire(Bool())
450
451  val s2_r_val = RegInit(false.B)
452  val s2_r_in_rdy = Wire(Bool())
453  val s2_r_req = RegEnable(s1_r_req, s1_r_val && s2_r_in_rdy)
454
455  // ----------------
456  // s0 stage
457  // ----------------
458  io.in_r.ready := !s1_r_val || (s1_r_val && s1_r_out_rdy)
459
460  // ----------------
461  // s1 stage
462  // ----------------
463  when(io.in_r.fire) {
464    s1_r_val := true.B
465  }.elsewhen(s1_r_val && s1_r_out_rdy) {
466    s1_r_val := false.B
467  }.otherwise {
468    s1_r_val := s1_r_val
469  }
470
471  s1_r_out_rdy := s2_r_in_rdy && io.kt_req.ready
472  io.kt_req.valid := s1_r_val && s2_r_in_rdy
473  io.kt_req.bits.id := s1_r_req.id
474
475  // ----------------
476  // s2 stage
477  // ----------------
478  when(s1_r_val && s1_r_out_rdy) {
479    s2_r_val := true.B
480  }.elsewhen(s2_r_val && io.out_r.fire) {
481    s2_r_val := false.B
482  }.otherwise {
483    s2_r_val := s2_r_val
484  }
485  s2_r_in_rdy := !s2_r_val || io.out_r.fire
486
487  io.in_kt.ready := io.out_r.fire
488
489  io.out_r.valid := s2_r_val && io.in_kt.valid
490  io.out_r.bits.axi4 := s2_r_req
491  io.out_r.bits.keyid := io.in_kt.bits.keyid
492  io.out_r.bits.tweak := io.in_kt.bits.tweak
493}
494
495class RdataDecrptyPipe(implicit p: Parameters) extends MemEncryptModule
496{
497  val io = IO(new Bundle {
498    val in_r = Flipped(DecoupledIO(new AXI4R_KT(false)))
499    val out_r = Irrevocable(new AXI4BundleR(MemcedgeOut.bundle))
500    val dec_keyids = Output(Vec(MemencPipes, UInt(KeyIDBits.W)))
501    val dec_round_keys = Input(Vec(MemencPipes, UInt((32*32/MemencPipes).W)))
502  })
503
504  val reg_encdec_result_0 = Reg(Vec(MemencPipes, UInt(128.W)))
505  val reg_encdec_result_1 = Reg(Vec(MemencPipes, UInt(128.W)))
506  val reg_axi4_other_result = Reg(Vec(MemencPipes, new AXI4BundleRWithoutData(MemcedgeOut.bundle)))
507  val reg_tweak_result_0 = Reg(Vec(MemencPipes, UInt(128.W)))
508  val reg_tweak_result_1 = Reg(Vec(MemencPipes, UInt(128.W)))
509  val reg_keyid = Reg(Vec(MemencPipes, UInt(KeyIDBits.W)))
510  val reg_encdec_valid = RegInit(VecInit(Seq.fill(MemencPipes)(false.B)))
511  val wire_ready_result = WireInit(VecInit(Seq.fill(MemencPipes)(false.B)))
512
513
514  val wire_axi4_other = Wire(new AXI4BundleRWithoutData(MemcedgeOut.bundle))
515  wire_axi4_other.id := io.in_r.bits.axi4.id
516  wire_axi4_other.resp := io.in_r.bits.axi4.resp
517  wire_axi4_other.user := io.in_r.bits.axi4.user
518  wire_axi4_other.echo := io.in_r.bits.axi4.echo
519  wire_axi4_other.last := io.in_r.bits.axi4.last
520
521  val pipes_first_data_0 = Wire(UInt(128.W))
522  val pipes_first_data_1 = Wire(UInt(128.W))
523
524  if (HasDelayNoencryption) {
525    pipes_first_data_0 := io.in_r.bits.axi4.data(127,0)
526    pipes_first_data_1 := io.in_r.bits.axi4.data(255,128)
527  } else {
528    pipes_first_data_0 := io.in_r.bits.axi4.data(127,0) ^ io.in_r.bits.tweak(127, 0)
529    pipes_first_data_1 := io.in_r.bits.axi4.data(255,128) ^ io.in_r.bits.tweak(255,128)
530  }
531  def configureModule(flag: Boolean, i: Int, keyId: UInt, dataIn: UInt, tweakIn: UInt, axi4In: AXI4BundleRWithoutData, roundKeys: UInt): OnePipeDecBase = {
532
533    when(wire_ready_result(i) && (if (i == 0) io.in_r.valid else reg_encdec_valid(i-1))) {
534      reg_encdec_valid(i) := true.B
535    }.elsewhen(reg_encdec_valid(i) && (if (i == MemencPipes - 1) io.out_r.ready else wire_ready_result(i+1))) {
536      reg_encdec_valid(i) := false.B
537    }.otherwise {
538      reg_encdec_valid(i) := reg_encdec_valid(i)
539    }
540
541    wire_ready_result(i) := !reg_encdec_valid(i) || (reg_encdec_valid(i) && (if (i == MemencPipes - 1) io.out_r.ready else wire_ready_result(i+1)))
542
543    val module: OnePipeDecBase = if (HasDelayNoencryption) Module(new OnePipeForDecNoDec()) else Module(new OnePipeForDec())
544    module.io.onepipe_in.keyid := keyId
545    module.io.onepipe_in.data_in := dataIn
546    module.io.onepipe_in.tweak_in := tweakIn
547    module.io.onepipe_in.axi4_other := axi4In
548    for (i <- 0 until 32/MemencPipes) {
549      module.io.onepipe_in.round_key_in(i) := roundKeys(i * 32 + 31, i * 32)
550    }
551    when((if (i == 0) io.in_r.valid else reg_encdec_valid(i-1)) && wire_ready_result(i)) {
552      if (flag) {
553        reg_encdec_result_0(i) := module.io.onepipe_out.result_out
554        reg_tweak_result_0(i) := module.io.onepipe_out.tweak_out
555        reg_axi4_other_result(i) := module.io.onepipe_out.axi4_other_out
556        reg_keyid(i) := module.io.onepipe_out.keyid_out
557      } else {
558        reg_encdec_result_1(i) := module.io.onepipe_out.result_out
559        reg_tweak_result_1(i) := module.io.onepipe_out.tweak_out
560      }
561    }
562    io.dec_keyids(i)  := module.io.onepipe_out.keyid_out
563    module
564  }
565  val modules_0 = (0 until MemencPipes).map { i =>
566    if (i == 0) {
567      configureModule(true, i, io.in_r.bits.keyid, pipes_first_data_0, io.in_r.bits.tweak(127, 0), wire_axi4_other, io.dec_round_keys(i))
568    } else {
569      configureModule(true, i, reg_keyid(i-1), reg_encdec_result_0(i-1), reg_tweak_result_0(i-1), reg_axi4_other_result(i-1), io.dec_round_keys(i))
570    }
571  }
572
573  val modules_1 = (0 until MemencPipes).map { i =>
574    if (i == 0) {
575      configureModule(false, i, io.in_r.bits.keyid, pipes_first_data_1, io.in_r.bits.tweak(255,128), wire_axi4_other,  io.dec_round_keys(i))
576    } else {
577      configureModule(false, i, reg_keyid(i-1),reg_encdec_result_1(i-1), reg_tweak_result_1(i-1), reg_axi4_other_result(i-1), io.dec_round_keys(i))
578    }
579  }
580  if (HasDelayNoencryption) {
581    io.out_r.bits.data := Cat(reg_encdec_result_1.last, reg_encdec_result_0.last)
582  } else {
583    val enc_0_out = Cat(
584      reg_encdec_result_0.last(31, 0),
585      reg_encdec_result_0.last(63, 32),
586      reg_encdec_result_0.last(95, 64),
587      reg_encdec_result_0.last(127, 96)
588    )
589    val enc_1_out = Cat(
590      reg_encdec_result_1.last(31, 0),
591      reg_encdec_result_1.last(63, 32),
592      reg_encdec_result_1.last(95, 64),
593      reg_encdec_result_1.last(127, 96)
594    )
595    io.out_r.bits.data := Cat(enc_1_out ^ reg_tweak_result_1.last, enc_0_out ^ reg_tweak_result_0.last)
596  }
597
598  io.out_r.bits.id := reg_axi4_other_result.last.id
599  io.out_r.bits.resp := reg_axi4_other_result.last.resp
600  io.out_r.bits.user := reg_axi4_other_result.last.user
601  io.out_r.bits.echo := reg_axi4_other_result.last.echo
602  io.out_r.bits.last := reg_axi4_other_result.last.last
603  io.out_r.valid := reg_encdec_valid.last
604  io.in_r.ready := wire_ready_result(0)
605
606}
607
608class RdataRoute(implicit p: Parameters) extends MemEncryptModule
609{
610  val io = IO(new Bundle {
611    val in_r = Flipped(Irrevocable(new AXI4BundleR(MemcedgeOut.bundle)))
612    val out_r0 = Irrevocable(new AXI4BundleR(MemcedgeIn.bundle))
613    val out_r1 = Irrevocable(new AXI4BundleR(MemcedgeIn.bundle))
614  })
615
616  val r_sel = io.in_r.bits.id(MemcedgeOut.bundle.idBits - 1).asBool
617
618  io.out_r0.bits <> io.in_r.bits
619  io.out_r1.bits <> io.in_r.bits
620
621  io.out_r0.valid := io.in_r.valid && !r_sel
622  io.out_r1.valid := io.in_r.valid &&  r_sel
623  io.in_r.ready := Mux(r_sel, io.out_r1.ready, io.out_r0.ready)
624}
625
626class MemEncryptCSR(implicit p: Parameters) extends MemEncryptModule
627{
628  val io = IO(new Bundle {
629      val en    = Input(Bool())
630      val wmode = Input(Bool())
631      val addr  = Input(UInt(12.W))
632      val wdata = Input(UInt(64.W))
633      val wmask = Input(UInt(8.W))
634      val rdata = Output(UInt(64.W))  // get rdata next cycle after en
635      val memenc_enable = Output(Bool())
636      val keyextend_req = DecoupledIO(new Bundle {
637        val key = UInt(128.W)
638        val keyid = UInt(KeyIDBits.W)
639        val enc_mode = Bool()         // 1:this keyid open enc  0:this keyid close enc
640        val tweak_flage = Bool()      // 1:extend tweak key  0:extend keyid key
641      })
642      val randomio = new Bundle {
643        val random_req = Output(Bool())
644        val random_val = Input(Bool())
645        val random_data = Input(Bool())
646      }
647  })
648  // CSR
649  val key_id             = RegInit(0.U(5.W)) // [4:0]
650  val mode               = RegInit(0.U(2.W)) // [6:5]
651  val tweak_flage        = RegInit(0.U(1.W)) // [7]
652  val memenc_enable      = if (HasDelayNoencryption) RegInit(true.B) else RegInit(false.B) // [8]
653  val memenc_enable_lock = RegInit(false.B)
654  val random_ready_flag  = Wire(Bool()) // [32]
655  val key_expansion_idle = Wire(Bool()) // [33]
656  val last_req_accepted  = RegInit(false.B) // [34]
657  val cfg_succesd        = Wire(Bool()) // [35]
658  val key_init_req       = RegInit(false.B) // [63]
659  // KEY0&1
660  val key0               = RegInit(0.U(64.W))
661  val key1               = RegInit(0.U(64.W))
662  // RelPaddrBitsMap
663  val relpaddrbitsmap    = ~0.U((PAddrBits - KeyIDBits).W)
664  // KeyIDBitsMap
665  val keyidbitsmap       = ~0.U(PAddrBits.W) - ~0.U((PAddrBits - KeyIDBits).W)
666  // Version
667  val memenc_version_p0  = (0x0001).U(16.W)
668  val memenc_version_p1  = (0x0001).U(16.W)
669  val memenc_version_p2  = (0x00000002).U(32.W)
670  val memenc_version     = Cat(memenc_version_p0, memenc_version_p1, memenc_version_p2)
671
672  // READ
673  val rdata_reg = RegInit(0.U(64.W))
674  when(io.en && !io.wmode && (io.addr(11,3) === 0.U)) {
675    rdata_reg := Cat(0.U(28.W), cfg_succesd, last_req_accepted, key_expansion_idle, random_ready_flag, 0.U(23.W), memenc_enable, tweak_flage, mode, key_id)
676  }.elsewhen(io.en && !io.wmode && (io.addr(11,3) === 3.U)) {
677    rdata_reg := relpaddrbitsmap
678  }.elsewhen(io.en && !io.wmode && (io.addr(11,3) === 4.U)) {
679    rdata_reg := keyidbitsmap
680  }.elsewhen(io.en && !io.wmode && (io.addr(11,3) === 5.U)) {
681    rdata_reg := memenc_version
682  }.otherwise {
683    rdata_reg := 0.U
684  }
685
686  io.rdata := rdata_reg
687
688  // WRITE
689  val wmask_legal = (io.wmask === (0xff).U)
690
691  when(io.en && io.wmode && wmask_legal && (io.addr(11,3) === 0.U)) {
692    key_id := io.wdata(4,0)
693    mode := io.wdata(6,5)
694    tweak_flage := io.wdata(7)
695    key_init_req := io.wdata(63).asBool
696  }.otherwise {
697    key_init_req := false.B
698  }
699  when(io.en && io.wmode && wmask_legal && (io.addr(11,3) === 0.U) && (!memenc_enable_lock)) {
700    memenc_enable := io.wdata(8)
701    memenc_enable_lock := true.B
702  }
703  when(io.en && io.wmode && wmask_legal && (io.addr(11,3) === 1.U)) {
704    key0 := io.wdata
705  }
706  when(io.en && io.wmode && wmask_legal && (io.addr(11,3) === 2.U)) {
707    key1 := io.wdata
708  }
709  io.memenc_enable := memenc_enable
710
711  // RANDOM COLLECT
712  val random_vec_data = RegInit(0.U(128.W))
713  val random_cnt = RegInit(0.U(8.W))
714  val random_key_init_done = Wire(Bool())
715  io.randomio.random_req := random_cnt =/= 128.U(8.W)
716  random_ready_flag := random_cnt === 128.U(8.W)
717
718  when(io.randomio.random_req && io.randomio.random_val) {
719    random_vec_data := Cat(random_vec_data(127,1), io.randomio.random_data)
720  }
721
722  when(random_ready_flag && random_key_init_done) {
723    random_cnt := 0.U
724  }.elsewhen(io.randomio.random_req && io.randomio.random_val) {
725    random_cnt := random_cnt + 1.U
726  }
727
728  // KEY Extend Req
729  key_expansion_idle := io.keyextend_req.ready
730  cfg_succesd := io.keyextend_req.ready
731
732  val keyextend_req_valid = RegInit(false.B)
733  val req_leagl = Wire(Bool())
734  req_leagl := (mode =/= 3.U(2.W)) && key_expansion_idle && ((mode =/= 2.U(2.W)) || random_ready_flag)
735
736  when(key_init_req && req_leagl) {
737    keyextend_req_valid := true.B
738  }.elsewhen(io.keyextend_req.fire) {
739    keyextend_req_valid := false.B
740  }.otherwise {
741    keyextend_req_valid := keyextend_req_valid
742  }
743
744  when(key_init_req && req_leagl) {
745    last_req_accepted := true.B
746  }.elsewhen(key_init_req) {
747    last_req_accepted := false.B
748  }.otherwise {
749    last_req_accepted := last_req_accepted
750  }
751
752  random_key_init_done := io.keyextend_req.fire && (mode === 2.U(2.W))
753
754  io.keyextend_req.valid := keyextend_req_valid
755  io.keyextend_req.bits.key := Mux(mode === 1.U(2.W), Cat(key1, key0), random_vec_data)
756  io.keyextend_req.bits.keyid := key_id
757  io.keyextend_req.bits.enc_mode := mode =/= 0.U(2.W)
758  io.keyextend_req.bits.tweak_flage := tweak_flage.asBool
759}
760
761class KeyTableEntry extends Bundle {
762  val round_key_data = Vec(32, UInt(32.W))
763  val encdec_mode = Bool()
764}
765class KeyTable(implicit p: Parameters) extends MemEncryptModule {
766  val io = IO(new Bundle {
767    val write_req = Input(new Bundle {
768      val keyid = UInt(KeyIDBits.W)
769      val keyid_valid = Input(Bool())
770      val enc_mode = Input(Bool())        // 1: this keyid open enc, 0: this keyid close enc
771      val round_id = UInt(5.W)
772      val data = Input(UInt(32.W))
773    })
774
775  val enc_keyids = Input(Vec(MemencPipes, UInt(KeyIDBits.W)))
776  val enc_round_keys = Output(Vec(MemencPipes, UInt((32*32/MemencPipes).W)))
777  val dec_keyids = Input(Vec(MemencPipes, UInt(KeyIDBits.W)))
778  val dec_round_keys = Output(Vec(MemencPipes, UInt((32*32/MemencPipes).W)))
779  val dec = new Bundle {
780    val keyid = Input(UInt(KeyIDBits.W))           // query dec_mode in advance in the AR channel
781    val mode = Output(Bool())
782  }
783  val enc = new Bundle {
784    val keyid = Input(UInt(KeyIDBits.W))           // query enc_mode in advance in the AW channel
785    val mode = Output(Bool())
786  }
787})
788
789  val init_entry = Wire(new KeyTableEntry)
790  init_entry.round_key_data := DontCare // Keep round_key_data as default (uninitialized)
791  if (HasDelayNoencryption) {
792    init_entry.encdec_mode := true.B
793  } else {
794    init_entry.encdec_mode := false.B
795  }
796  val table = RegInit(VecInit(Seq.fill(1 << KeyIDBits)(init_entry)))
797  val wire_enc_round_keys = Wire(Vec(MemencPipes, UInt((32*32/MemencPipes).W)))
798  val wire_dec_round_keys = Wire(Vec(MemencPipes, UInt((32*32/MemencPipes).W)))
799
800  // write and updata mode
801  when(io.write_req.keyid_valid && io.write_req.enc_mode) {
802    val entry = table(io.write_req.keyid)
803    entry.encdec_mode := io.write_req.enc_mode
804    entry.round_key_data(io.write_req.round_id) := io.write_req.data
805  }
806  when(io.write_req.keyid_valid && !io.write_req.enc_mode) {
807    val entry = table(io.write_req.keyid)
808    entry.encdec_mode := io.write_req.enc_mode
809  }
810
811// read logic
812  for (i <- 0 until MemencPipes) {
813    val enc_entry = table(io.enc_keyids(i))
814    val enc_round_key_parts = VecInit(Seq.fill(32 / MemencPipes)(0.U(32.W)))
815    for (j <- 0 until (32 / MemencPipes)) {
816      enc_round_key_parts((32 / MemencPipes) - 1 - j) := enc_entry.round_key_data(i.U * (32 / MemencPipes).U + j.U)
817    }
818    wire_enc_round_keys(i) := enc_round_key_parts.reduce(Cat(_, _))
819
820    val dec_entry = table(io.dec_keyids(i))
821    val dec_round_key_parts = VecInit(Seq.fill(32 / MemencPipes)(0.U(32.W)))
822    for (j <- 0 until (32 / MemencPipes)) {
823      dec_round_key_parts((32 / MemencPipes) - 1 - j) := dec_entry.round_key_data(31.U - (i.U * (32 / MemencPipes).U + j.U))
824    }
825    wire_dec_round_keys(i) := dec_round_key_parts.reduce(Cat(_, _))
826  }
827  // output read data(round keys, enc/dec_mode, ar_mode, aw_mode)
828  val dec_mode_entry = table(io.dec.keyid)
829  io.dec.mode := dec_mode_entry.encdec_mode
830
831  val enc_mode_entry = table(io.enc.keyid)
832  io.enc.mode := enc_mode_entry.encdec_mode
833
834  io.enc_round_keys := wire_enc_round_keys
835  io.dec_round_keys := wire_dec_round_keys
836
837}
838
839class KeyExtender(implicit p: Parameters) extends MemEncryptModule{
840  val io = IO(new Bundle {
841      val keyextend_req = Flipped(DecoupledIO(new Bundle {
842        val key = UInt(128.W)
843        val keyid = UInt(KeyIDBits.W)
844        val enc_mode = Bool()         // 1:this keyid open enc  0:this keyid close enc
845        val tweak_flage = Bool()      // 1:extend tweak key  0:extend keyid key
846      }))
847      val tweak_round_keys = Output(Vec(32, UInt(32.W)))
848      val enc_keyids = Input(Vec(MemencPipes, UInt(KeyIDBits.W)))
849      val enc_round_keys = Output(Vec(MemencPipes, UInt((32*32/MemencPipes).W)))
850      val dec_keyids = Input(Vec(MemencPipes, UInt(KeyIDBits.W)))
851      val dec_round_keys = Output(Vec(MemencPipes, UInt((32*32/MemencPipes).W)))
852      val dec = new Bundle {
853        val keyid = Input(UInt(KeyIDBits.W))           // query dec_mode in advance in the AR channel
854        val mode = Output(Bool())
855      }
856      val enc = new Bundle {
857        val keyid = Input(UInt(KeyIDBits.W))           // query enc_mode in advance in the AW channel
858        val mode = Output(Bool())
859      }
860  })
861
862  val idle :: keyExpansion :: Nil = Enum(2)
863  val current = RegInit(idle)
864  val next = WireDefault(idle)
865  current := next
866
867  val count_round = RegInit(0.U(5.W))
868  val reg_count_round = RegNext(count_round)
869  val reg_user_key = RegInit(0.U(128.W))
870  val data_for_round = Wire(UInt(128.W))
871  val data_after_round = Wire(UInt(128.W))
872  val reg_data_after_round = RegInit(0.U(128.W))
873  val key_exp_finished_out = RegInit(1.U)
874  val reg_key_valid = RegNext(io.keyextend_req.valid, false.B)
875  val reg_tweak_round_keys = Reg(Vec(32, UInt(32.W)))
876
877
878  switch(current) {
879    is(idle) {
880      when(!reg_key_valid && io.keyextend_req.valid && io.keyextend_req.bits.enc_mode) {
881        next := keyExpansion
882      }
883    }
884    is(keyExpansion) {
885      when(reg_count_round === 31.U) {
886        next := idle
887      }.otherwise {
888        next := keyExpansion
889      }
890    }
891  }
892
893  when(next === keyExpansion) {
894    count_round := count_round + 1.U
895  }.otherwise {
896    count_round := 0.U
897  }
898
899  when(!reg_key_valid && io.keyextend_req.valid && io.keyextend_req.bits.enc_mode) {
900    reg_user_key := io.keyextend_req.bits.key
901  }
902
903  when(current === keyExpansion && next === idle) {
904    key_exp_finished_out := true.B
905  }.elsewhen(io.keyextend_req.valid && io.keyextend_req.bits.enc_mode) {
906    key_exp_finished_out := false.B
907  }
908  io.keyextend_req.ready := key_exp_finished_out
909
910  // Data for round calculation
911  data_for_round := Mux(reg_count_round =/= 0.U, reg_data_after_round, reg_user_key)
912  val cki = Module(new GetCKI)
913  cki.io.countRoundIn := count_round
914  val one_round = Module(new OneRoundForKeyExp)
915  one_round.io.countRoundIn := reg_count_round
916  one_round.io.dataIn := data_for_round
917  one_round.io.ckParameterIn := cki.io.ckiOut
918  data_after_round := one_round.io.resultOut
919
920  when(current === keyExpansion) {
921    reg_data_after_round := data_after_round
922  }
923
924  val keyTable = Module(new KeyTable())
925    keyTable.io.write_req.keyid := io.keyextend_req.bits.keyid
926    keyTable.io.write_req.enc_mode := io.keyextend_req.bits.enc_mode
927    keyTable.io.write_req.round_id := reg_count_round
928    keyTable.io.write_req.data := data_after_round(31, 0)
929
930    keyTable.io.enc_keyids := io.enc_keyids
931    keyTable.io.dec_keyids := io.dec_keyids
932    keyTable.io.dec.keyid := io.dec.keyid
933    keyTable.io.enc.keyid := io.enc.keyid
934    io.dec.mode := keyTable.io.dec.mode
935    io.enc.mode := keyTable.io.enc.mode
936    io.enc_round_keys := keyTable.io.enc_round_keys
937    io.dec_round_keys := keyTable.io.dec_round_keys
938
939
940  when(io.keyextend_req.bits.tweak_flage) {
941    reg_tweak_round_keys(reg_count_round) := data_after_round(31, 0)
942    keyTable.io.write_req.keyid_valid := false.B
943  }.otherwise {
944    keyTable.io.write_req.keyid_valid := current
945  }
946  io.tweak_round_keys := reg_tweak_round_keys
947}
948
949class AXI4MemEncrypt(address: AddressSet)(implicit p: Parameters) extends LazyModule with Memconsts
950{
951  require (isPow2(MemencPipes), s"AXI4MemEncrypt: MemencPipes must be a power of two, not $MemencPipes")
952  require (PAddrBits > KeyIDBits, s"AXI4MemEncrypt: PAddrBits must be greater than KeyIDBits")
953
954  val node = AXI4AdapterNode(
955    masterFn = { mp =>
956      val new_idbits = log2Ceil(mp.endId) + 1
957      // Create one new "master" per ID
958      val masters = Array.tabulate(1 << new_idbits) { i => AXI4MasterParameters(
959         name      = "",
960         id        = IdRange(i, i+1),
961         aligned   = true,
962         maxFlight = Some(0))
963      }
964      // Accumulate the names of masters we squish
965      val names = Array.fill(1 << new_idbits) { new scala.collection.mutable.HashSet[String]() }
966      // Squash the information from original masters into new ID masters
967      mp.masters.foreach { m =>
968        for (i <- 0 until (1 << new_idbits)) {
969          val accumulated = masters(i)
970          names(i) += m.name
971          masters(i) = accumulated.copy(
972            aligned   = accumulated.aligned && m.aligned,
973            maxFlight = accumulated.maxFlight.flatMap { o => m.maxFlight.map { n => o+n } })
974        }
975      }
976      val finalNameStrings = names.map { n => if (n.isEmpty) "(unused)" else n.toList.mkString(", ") }
977      mp.copy(masters = masters.toIndexedSeq.zip(finalNameStrings.toIndexedSeq).map { case (m, n) => m.copy(name = n) })
978    },
979    slaveFn  = { sp => sp })
980
981  val device = new SimpleDevice("mem-encrypt-unit", Seq("iie,memencrypt0"))
982  val ctrl_node = APBSlaveNode(Seq(APBSlavePortParameters(
983    Seq(APBSlaveParameters(
984      address       = List(address),
985      resources     = device.reg,
986      device        = Some(device),
987      regionType    = RegionType.IDEMPOTENT)),
988    beatBytes     = 8)))
989
990  lazy val module = new Impl
991  class Impl extends LazyModuleImp(this) {
992    val io = IO(new Bundle {
993      val random_req = Output(Bool())
994      val random_val = Input(Bool())
995      val random_data = Input(Bool())
996    })
997
998    val en    = Wire(Bool())
999    val wmode = Wire(Bool())
1000    val addr  = Wire(UInt(12.W))
1001    val wdata = Wire(UInt(64.W))
1002    val wmask = Wire(UInt(8.W))
1003    val rdata = Wire(UInt(64.W))  // get rdata next cycle after en
1004
1005    (ctrl_node.in) foreach { case (ctrl_in, _) =>
1006      en    := ctrl_in.psel && !ctrl_in.penable
1007      wmode := ctrl_in.pwrite
1008      addr  := ctrl_in.paddr(11, 0)
1009      wdata := ctrl_in.pwdata
1010      wmask := ctrl_in.pstrb
1011      ctrl_in.pready  := true.B
1012      ctrl_in.pslverr := false.B
1013      ctrl_in.prdata  := rdata
1014    }
1015
1016    (node.in zip node.out) foreach { case ((in, edgeIn), (out, edgeOut)) =>
1017      require (edgeIn.bundle.dataBits == 256, s"AXI4MemEncrypt: edgeIn dataBits must be 256")
1018      require (edgeOut.bundle.dataBits == 256, s"AXI4MemEncrypt: edgeOut dataBits must be 256")
1019
1020      val memencParams: Parameters = p.alterPartial {
1021        case MemcEdgeInKey => edgeIn
1022        case MemcEdgeOutKey => edgeOut
1023      }
1024      // -------------------------------------
1025      // MemEncrypt Config and State Registers
1026      // -------------------------------------
1027      val memenc_enable = Wire(Bool())
1028      val memencrypt_csr = Module(new MemEncryptCSR()(memencParams))
1029      memencrypt_csr.io.en    := en
1030      memencrypt_csr.io.wmode := wmode
1031      memencrypt_csr.io.addr  := addr
1032      memencrypt_csr.io.wdata := wdata
1033      memencrypt_csr.io.wmask := wmask
1034      memenc_enable := memencrypt_csr.io.memenc_enable
1035      rdata := memencrypt_csr.io.rdata
1036
1037      io.random_req := memencrypt_csr.io.randomio.random_req
1038      memencrypt_csr.io.randomio.random_val := io.random_val
1039      memencrypt_csr.io.randomio.random_data := io.random_data
1040
1041      // -------------------------------------
1042      // Key Extender & Round Key Lookup Table
1043      // -------------------------------------
1044      val key_extender = Module(new KeyExtender()(memencParams))
1045      key_extender.io.keyextend_req :<>= memencrypt_csr.io.keyextend_req
1046
1047      // -------------------
1048      // AXI4 chanel B
1049      // -------------------
1050      Connectable.waiveUnmatched(in.b, out.b) match {
1051        case (lhs, rhs) => lhs.squeezeAll :<>= rhs.squeezeAll
1052      }
1053
1054      val write_route = Module(new WriteChanelRoute()(memencParams))
1055      val aw_tweakenc = Module(new TweakEncrptyQueue()(memencParams))
1056      val waddr_q = Module(new IrrevocableQueue(chiselTypeOf(in.aw.bits), entries = MemencPipes+1))
1057      val wdata_q = Module(new IrrevocableQueue(chiselTypeOf(in.w.bits), entries = MemencPipes+1))
1058      val write_machine = Module(new AXI4WriteMachine()(memencParams))
1059      val axi4w_kt_q = Module(new Queue(new AXI4W_KT(false)(memencParams), entries = 2, flow = true))
1060      val wdata_encpipe = Module(new WdataEncrptyPipe()(memencParams))
1061      val write_arb = Module(new WriteChanelArbiter()(memencParams))
1062
1063      // -------------------
1064      // AXI4 Write Route
1065      // Unencrypt & Encrypt
1066      // -------------------
1067      write_route.io.memenc_enable := memenc_enable
1068      key_extender.io.enc.keyid := write_route.io.enc_keyid
1069      write_route.io.enc_mode := key_extender.io.enc.mode
1070
1071      write_route.io.in.aw :<>= in.aw
1072      write_route.io.in.w :<>= in.w
1073
1074      val unenc_aw = write_route.io.out0.aw
1075      val unenc_w = write_route.io.out0.w
1076      val pre_enc_aw = write_route.io.out1.aw
1077      val pre_enc_w = write_route.io.out1.w
1078
1079      // -------------------
1080      // AXI4 chanel AW
1081      // -------------------
1082      pre_enc_aw.ready := waddr_q.io.enq.ready && aw_tweakenc.io.enq.ready
1083      waddr_q.io.enq.valid := pre_enc_aw.valid && aw_tweakenc.io.enq.ready
1084      aw_tweakenc.io.enq.valid := pre_enc_aw.valid && waddr_q.io.enq.ready
1085
1086      waddr_q.io.enq.bits := pre_enc_aw.bits
1087      waddr_q.io.enq.bits.addr := pre_enc_aw.bits.addr(PAddrBits-KeyIDBits-1, 0)
1088      aw_tweakenc.io.enq.bits.addr := pre_enc_aw.bits.addr
1089      aw_tweakenc.io.enq.bits.len := pre_enc_aw.bits.len
1090      aw_tweakenc.io.tweak_round_keys := key_extender.io.tweak_round_keys
1091
1092      // -------------------
1093      // AXI4 chanel W
1094      // -------------------
1095      wdata_q.io.enq :<>= pre_enc_w
1096      write_machine.io.in_w :<>= wdata_q.io.deq
1097      write_machine.io.in_kt :<>= aw_tweakenc.io.deq
1098      axi4w_kt_q.io.enq :<>= write_machine.io.out_w
1099      wdata_encpipe.io.in_w :<>= axi4w_kt_q.io.deq
1100      key_extender.io.enc_keyids := wdata_encpipe.io.enc_keyids
1101      wdata_encpipe.io.enc_round_keys := key_extender.io.enc_round_keys
1102
1103      // -------------------
1104      // AXI4 Write Arbiter
1105      // Unencrypt & Encrypt
1106      // -------------------
1107      write_arb.io.in0.aw :<>= unenc_aw
1108      write_arb.io.in0.aw.bits.addr := unenc_aw.bits.addr(PAddrBits-KeyIDBits-1, 0)
1109      write_arb.io.in0.w :<>= unenc_w
1110
1111      write_arb.io.in1.aw.valid := waddr_q.io.deq.valid && (waddr_q.io.deq.bits.len =/=0.U || write_machine.io.uncache_en)
1112      waddr_q.io.deq.ready := write_arb.io.in1.aw.ready && (waddr_q.io.deq.bits.len =/=0.U || write_machine.io.uncache_en)
1113      write_machine.io.uncache_commit := write_arb.io.in1.aw.fire
1114      write_arb.io.in1.aw.bits := waddr_q.io.deq.bits
1115      write_arb.io.in1.w :<>= wdata_encpipe.io.out_w
1116
1117      out.aw :<>= write_arb.io.out.aw
1118      out.w  :<>= write_arb.io.out.w
1119
1120      val ar_arb = Module(new IrrevocableArbiter(chiselTypeOf(out.ar.bits), 2))
1121      val ar_tweakenc = Module(new TweakEncrptyTable()(memencParams))
1122      val read_machine = Module(new AXI4ReadMachine()(memencParams))
1123      val axi4r_kt_q = Module(new Queue(new AXI4R_KT(false)(memencParams), entries = 2, flow = true))
1124      val pre_dec_rdata_route = Module(new RdataChanelRoute()(memencParams))
1125      val rdata_decpipe = Module(new RdataDecrptyPipe()(memencParams))
1126      val r_arb = Module(new IrrevocableArbiter(chiselTypeOf(out.r.bits), 2))
1127      val post_dec_rdata_route = Module(new RdataRoute()(memencParams))
1128
1129      // -------------------
1130      // AXI4 chanel AR
1131      // -------------------
1132      ar_arb.io.in(0) :<>= write_machine.io.out_ar
1133      // DecoupledIO connect IrrevocableIO
1134      ar_arb.io.in(1).valid := in.ar.valid
1135      ar_arb.io.in(1).bits := in.ar.bits
1136      in.ar.ready := ar_arb.io.in(1).ready
1137
1138      ar_arb.io.out.ready := out.ar.ready && ar_tweakenc.io.enq.ready
1139
1140      ar_tweakenc.io.enq.valid := ar_arb.io.out.valid && out.ar.ready
1141      ar_tweakenc.io.enq.bits.addr := ar_arb.io.out.bits.addr
1142      ar_tweakenc.io.enq.bits.len := ar_arb.io.out.bits.len
1143      ar_tweakenc.io.enq.bits.id := ar_arb.io.out.bits.id
1144      ar_tweakenc.io.tweak_round_keys := key_extender.io.tweak_round_keys
1145      ar_tweakenc.io.memenc_enable := memenc_enable
1146      key_extender.io.dec.keyid := ar_tweakenc.io.dec_keyid
1147      ar_tweakenc.io.dec_mode := key_extender.io.dec.mode
1148
1149      out.ar.valid := ar_arb.io.out.valid && ar_tweakenc.io.enq.ready
1150      out.ar.bits := ar_arb.io.out.bits
1151      out.ar.bits.addr := ar_arb.io.out.bits.addr(PAddrBits-KeyIDBits-1, 0)
1152
1153      // -------------------
1154      // AXI4 Rdata Route
1155      // Unencrypt & Encrypt
1156      // -------------------
1157      pre_dec_rdata_route.io.in_r :<>= out.r
1158      ar_tweakenc.io.dec_r.id := pre_dec_rdata_route.io.dec_rid
1159      pre_dec_rdata_route.io.dec_mode := ar_tweakenc.io.dec_r.mode
1160
1161      val undec_r = pre_dec_rdata_route.io.out_r0
1162      val pre_dec_r = pre_dec_rdata_route.io.out_r1
1163
1164      // -------------------
1165      // AXI4 chanel R
1166      // -------------------
1167      read_machine.io.in_r :<>= pre_dec_r
1168      ar_tweakenc.io.req :<>= read_machine.io.kt_req
1169      read_machine.io.in_kt :<>= ar_tweakenc.io.resp
1170      axi4r_kt_q.io.enq :<>= read_machine.io.out_r
1171      rdata_decpipe.io.in_r :<>= axi4r_kt_q.io.deq
1172      key_extender.io.dec_keyids := rdata_decpipe.io.dec_keyids
1173      rdata_decpipe.io.dec_round_keys := key_extender.io.dec_round_keys
1174
1175      // -------------------
1176      // AXI4 Rdata Arbiter
1177      // Unencrypt & Encrypt
1178      // -------------------
1179      r_arb.io.in(0) :<>= undec_r
1180      r_arb.io.in(1) :<>= rdata_decpipe.io.out_r
1181
1182      post_dec_rdata_route.io.in_r :<>= r_arb.io.out
1183      write_machine.io.in_r :<>= post_dec_rdata_route.io.out_r1
1184      in.r :<>= post_dec_rdata_route.io.out_r0
1185    }
1186  }
1187}
1188
1189object AXI4MemEncrypt
1190{
1191  def apply(address: AddressSet)(implicit p: Parameters): AXI4Node =
1192  {
1193    val axi4memenc = LazyModule(new AXI4MemEncrypt(address))
1194    axi4memenc.node
1195  }
1196}
1197