Lines Matching full:bmm
203 # transposed. If yes, transpose it back before feeding to torch.bmm
208 return torch.bmm(x, y.transpose(-1, -2))
210 return torch.bmm(x, y)
331 self.bmm = BatchMatrixMultiplication(transposed=False)
369 # bmm((B*H, L, D/H), (B*H, D/H, S)) -> (B*H, L, S).
370 # this is equiv. to `qk = torch.bmm(q, k.transpose(-1, -2))`
378 # bmm((B*H, L, S), (B*H, S, D/H)) -> (B*H, L, D/H).
379 # this is equiv. to `attention = torch.bmm(softmax_qk, v)`
380 attention = self.bmm(softmax_qk, v)
449 # bmm((B*H, L, D/H), (B*H, D/H, S)) -> (B*H, L, S).
450 qk = torch.bmm(q, k.transpose(1, 2))
458 # bmm((B*H, L, S), (B*H, S, D/H)) -> (B*H, L, D/H).
459 attention = torch.bmm(softmax_qk, v)