1import torch.nn.functional as F 2from torch import Tensor 3 4from .module import Module 5 6 7__all__ = ["PairwiseDistance", "CosineSimilarity"] 8 9 10class PairwiseDistance(Module): 11 r""" 12 Computes the pairwise distance between input vectors, or between columns of input matrices. 13 14 Distances are computed using ``p``-norm, with constant ``eps`` added to avoid division by zero 15 if ``p`` is negative, i.e.: 16 17 .. math :: 18 \mathrm{dist}\left(x, y\right) = \left\Vert x-y + \epsilon e \right\Vert_p, 19 20 where :math:`e` is the vector of ones and the ``p``-norm is given by. 21 22 .. math :: 23 \Vert x \Vert _p = \left( \sum_{i=1}^n \vert x_i \vert ^ p \right) ^ {1/p}. 24 25 Args: 26 p (real, optional): the norm degree. Can be negative. Default: 2 27 eps (float, optional): Small value to avoid division by zero. 28 Default: 1e-6 29 keepdim (bool, optional): Determines whether or not to keep the vector dimension. 30 Default: False 31 Shape: 32 - Input1: :math:`(N, D)` or :math:`(D)` where `N = batch dimension` and `D = vector dimension` 33 - Input2: :math:`(N, D)` or :math:`(D)`, same shape as the Input1 34 - Output: :math:`(N)` or :math:`()` based on input dimension. 35 If :attr:`keepdim` is ``True``, then :math:`(N, 1)` or :math:`(1)` based on input dimension. 36 37 Examples:: 38 >>> pdist = nn.PairwiseDistance(p=2) 39 >>> input1 = torch.randn(100, 128) 40 >>> input2 = torch.randn(100, 128) 41 >>> output = pdist(input1, input2) 42 """ 43 44 __constants__ = ["norm", "eps", "keepdim"] 45 norm: float 46 eps: float 47 keepdim: bool 48 49 def __init__( 50 self, p: float = 2.0, eps: float = 1e-6, keepdim: bool = False 51 ) -> None: 52 super().__init__() 53 self.norm = p 54 self.eps = eps 55 self.keepdim = keepdim 56 57 def forward(self, x1: Tensor, x2: Tensor) -> Tensor: 58 return F.pairwise_distance(x1, x2, self.norm, self.eps, self.keepdim) 59 60 61class CosineSimilarity(Module): 62 r"""Returns cosine similarity between :math:`x_1` and :math:`x_2`, computed along `dim`. 63 64 .. math :: 65 \text{similarity} = \dfrac{x_1 \cdot x_2}{\max(\Vert x_1 \Vert _2 \cdot \Vert x_2 \Vert _2, \epsilon)}. 66 67 Args: 68 dim (int, optional): Dimension where cosine similarity is computed. Default: 1 69 eps (float, optional): Small value to avoid division by zero. 70 Default: 1e-8 71 Shape: 72 - Input1: :math:`(\ast_1, D, \ast_2)` where D is at position `dim` 73 - Input2: :math:`(\ast_1, D, \ast_2)`, same number of dimensions as x1, matching x1 size at dimension `dim`, 74 and broadcastable with x1 at other dimensions. 75 - Output: :math:`(\ast_1, \ast_2)` 76 Examples:: 77 >>> input1 = torch.randn(100, 128) 78 >>> input2 = torch.randn(100, 128) 79 >>> cos = nn.CosineSimilarity(dim=1, eps=1e-6) 80 >>> output = cos(input1, input2) 81 """ 82 83 __constants__ = ["dim", "eps"] 84 dim: int 85 eps: float 86 87 def __init__(self, dim: int = 1, eps: float = 1e-8) -> None: 88 super().__init__() 89 self.dim = dim 90 self.eps = eps 91 92 def forward(self, x1: Tensor, x2: Tensor) -> Tensor: 93 return F.cosine_similarity(x1, x2, self.dim, self.eps) 94