1# This is a copy of rnn_attention from MLPerf, with some common sizes hardcoded 2# for benchmarking and some control flow stripped out. 3# https://github.com/mlperf/training/blob/master/rnn_translator/pytorch/seq2seq/models/attention.py 4 5import torch 6 7from . import benchmark 8 9 10class BahdanauAttention(benchmark.Benchmark): 11 def __init__(self, mode, device, dtype, b, t_q, t_k, n): 12 super().__init__(mode, device, dtype) 13 self.b = b 14 self.t_q = t_q 15 self.t_k = t_k 16 self.n = n 17 self.att_query = self.rand( 18 [b, t_q, n], device=device, dtype=dtype, requires_grad=self.requires_grad 19 ) 20 self.att_keys = self.rand( 21 [b, t_k, n], device=device, dtype=dtype, requires_grad=self.requires_grad 22 ) 23 self.normalize_bias = self.rand( 24 [n], device=device, dtype=dtype, requires_grad=self.requires_grad 25 ) 26 self.linear_att = self.rand( 27 [n], device=device, dtype=dtype, requires_grad=self.requires_grad 28 ) 29 self.inputs = [ 30 self.att_query, 31 self.att_keys, 32 self.normalize_bias, 33 self.linear_att, 34 ] 35 36 def forward(self, att_query, att_keys, normalize_bias, linear_att): 37 """ 38 Calculate Bahdanau score 39 40 :param att_query: b x t_q x n 41 :param att_keys: b x t_k x n 42 43 return b x t_q x t_k scores 44 """ 45 46 b, t_k, n = att_keys.size() 47 t_q = att_query.size(1) 48 49 att_query = att_query.unsqueeze(2).expand(b, t_q, t_k, n) 50 att_keys = att_keys.unsqueeze(1).expand(b, t_q, t_k, n) 51 sum_qk = att_query + att_keys + normalize_bias 52 out = torch.tanh(sum_qk).matmul(linear_att) 53 return out 54 55 def reference(self): 56 return self.numpy(self.forward(*self.inputs)) 57 58 def config(self): 59 return [self.b, self.t_q, self.t_k, self.n] 60 61 @staticmethod 62 def module(): 63 return "attention" 64 65 def memory_workload(self): 66 def memsize(t): 67 return t.numel() * t.element_size() 68 69 input_size = ( 70 memsize(self.att_query) 71 + memsize(self.att_keys) 72 + memsize(self.normalize_bias) 73 + memsize(self.linear_att) 74 ) 75 output_size = 4 * torch.Size([self.b, self.t_q, self.t_k]).numel() 76 io_size = input_size + output_size 77 78 # If matmul is not fused, must write and then read `sum_qk`. 79 intermediate_size = ( 80 2 * 4 * torch.Size([self.b, self.t_q, self.t_k, self.n]).numel() 81 ) 82 return {"sol": io_size, "algorithmic": io_size + intermediate_size} 83 84 @staticmethod 85 def default_configs(): 86 mlperf_inference = [1280, 1, 66, 1024] 87 nvidia = [128, 10, 128, 1024] 88 return [mlperf_inference, nvidia] 89 90 91benchmark.register_benchmark_class(BahdanauAttention) 92