xref: /aosp_15_r20/external/tensorflow/tensorflow/python/keras/utils/kernelized_utils.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility methods related to kernelized layers."""
16
17from tensorflow.python.ops import array_ops
18from tensorflow.python.ops import math_ops
19
20
21def _to_matrix(u):
22  """If input tensor is a vector (i.e., has rank 1), converts it to matrix."""
23  u_rank = len(u.shape)
24  if u_rank not in [1, 2]:
25    raise ValueError('The input tensor should have rank 1 or 2. Given rank: {}'
26                     .format(u_rank))
27  if u_rank == 1:
28    return array_ops.expand_dims(u, 0)
29  return u
30
31
32def _align_matrices(x, y):
33  """Aligns x and y tensors to allow computations over pairs of their rows."""
34  x_matrix = _to_matrix(x)
35  y_matrix = _to_matrix(y)
36  x_shape = x_matrix.shape
37  y_shape = y_matrix.shape
38  if y_shape[1] != x_shape[1]:  # dimensions do not match.
39    raise ValueError(
40        'The outermost dimensions of the input tensors should match. Given: {} '
41        'vs {}.'.format(y_shape[1], x_shape[1]))
42
43  x_tile = array_ops.tile(
44      array_ops.expand_dims(x_matrix, 1), [1, y_shape[0], 1])
45  y_tile = array_ops.tile(
46      array_ops.expand_dims(y_matrix, 0), [x_shape[0], 1, 1])
47  return x_tile, y_tile
48
49
50def inner_product(u, v):
51  u = _to_matrix(u)
52  v = _to_matrix(v)
53  return math_ops.matmul(u, v, transpose_b=True)
54
55
56def exact_gaussian_kernel(x, y, stddev):
57  r"""Computes exact Gaussian kernel value(s) for tensors x and y and stddev.
58
59  The Gaussian kernel for vectors u, v is defined as follows:
60       K(u, v) = exp(-||u-v||^2 / (2* stddev^2))
61  where the norm is the l2-norm. x, y can be either vectors or matrices. If they
62  are vectors, they must have the same dimension. If they are matrices, they
63  must have the same number of columns. In the latter case, the method returns
64  (as a matrix) K(u, v) values for all pairs (u, v) where u is a row from x and
65  v is a row from y.
66
67  Args:
68    x: a tensor of rank 1 or 2. It's shape should be either [dim] or [m, dim].
69    y: a tensor of rank 1 or 2. It's shape should be either [dim] or [n, dim].
70    stddev: The width of the Gaussian kernel.
71
72  Returns:
73    A single value (scalar) with shape (1, 1) (if x, y are vectors) or a matrix
74      of shape (m, n) with entries K(u, v) (where K is the Gaussian kernel) for
75      all (u,v) pairs where u, v are rows from x and y respectively.
76
77  Raises:
78    ValueError: if the shapes of x, y are not compatible.
79  """
80  x_aligned, y_aligned = _align_matrices(x, y)
81  diff_squared_l2_norm = math_ops.reduce_sum(
82      math_ops.squared_difference(x_aligned, y_aligned), 2)
83  return math_ops.exp(-diff_squared_l2_norm / (2 * stddev * stddev))
84
85
86def exact_laplacian_kernel(x, y, stddev):
87  r"""Computes exact Laplacian kernel value(s) for tensors x and y using stddev.
88
89  The Laplacian kernel for vectors u, v is defined as follows:
90       K(u, v) = exp(-||u-v|| / stddev)
91  where the norm is the l1-norm. x, y can be either vectors or matrices. If they
92  are vectors, they must have the same dimension. If they are matrices, they
93  must have the same number of columns. In the latter case, the method returns
94  (as a matrix) K(u, v) values for all pairs (u, v) where u is a row from x and
95  v is a row from y.
96
97  Args:
98    x: a tensor of rank 1 or 2. It's shape should be either [dim] or [m, dim].
99    y: a tensor of rank 1 or 2. It's shape should be either [dim] or [n, dim].
100    stddev: The width of the Gaussian kernel.
101
102  Returns:
103    A single value (scalar) with shape (1, 1)  if x, y are vectors or a matrix
104    of shape (m, n) with entries K(u, v) (where K is the Laplacian kernel) for
105    all (u,v) pairs where u, v are rows from x and y respectively.
106
107  Raises:
108    ValueError: if the shapes of x, y are not compatible.
109  """
110  x_aligned, y_aligned = _align_matrices(x, y)
111  diff_l1_norm = math_ops.reduce_sum(
112      math_ops.abs(math_ops.subtract(x_aligned, y_aligned)), 2)
113  return math_ops.exp(-diff_l1_norm / stddev)
114