xref: /aosp_15_r20/external/libopus/dnn/torch/lpcnet/utils/sparsification/common.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import torch
31
32def sparsify_matrix(matrix : torch.tensor, density : float, block_size, keep_diagonal : bool=False, return_mask : bool=False):
33    """ sparsifies matrix with specified block size
34
35        Parameters:
36        -----------
37        matrix : torch.tensor
38            matrix to sparsify
39        density : int
40            target density
41        block_size : [int, int]
42            block size dimensions
43        keep_diagonal : bool
44            If true, the diagonal will be kept. This option requires block_size[0] == block_size[1] and defaults to False
45    """
46
47    m, n   = matrix.shape
48    m1, n1 = block_size
49
50    if m % m1 or n % n1:
51        raise ValueError(f"block size {(m1, n1)} does not divide matrix size {(m, n)}")
52
53    # extract diagonal if keep_diagonal = True
54    if keep_diagonal:
55        if m != n:
56            raise ValueError("Attempting to sparsify non-square matrix with keep_diagonal=True")
57
58        to_spare = torch.diag(torch.diag(matrix))
59        matrix   = matrix - to_spare
60    else:
61        to_spare = torch.zeros_like(matrix)
62
63    # calculate energy in sub-blocks
64    x = torch.reshape(matrix, (m // m1, m1, n // n1, n1))
65    x = x ** 2
66    block_energies = torch.sum(torch.sum(x, dim=3), dim=1)
67
68    number_of_blocks = (m * n) // (m1 * n1)
69    number_of_survivors = round(number_of_blocks * density)
70
71    # masking threshold
72    if number_of_survivors == 0:
73        threshold = 0
74    else:
75        threshold = torch.sort(torch.flatten(block_energies)).values[-number_of_survivors]
76
77    # create mask
78    mask = torch.ones_like(block_energies)
79    mask[block_energies < threshold] = 0
80    mask = torch.repeat_interleave(mask, m1, dim=0)
81    mask = torch.repeat_interleave(mask, n1, dim=1)
82
83    # perform masking
84    masked_matrix = mask * matrix + to_spare
85
86    if return_mask:
87        return masked_matrix, mask
88    else:
89        return masked_matrix
90
91def calculate_gru_flops_per_step(gru, sparsification_dict=dict(), drop_input=False):
92    input_size = gru.input_size
93    hidden_size = gru.hidden_size
94    flops = 0
95
96    input_density = (
97        sparsification_dict.get('W_ir', [1])[0]
98        + sparsification_dict.get('W_in', [1])[0]
99        + sparsification_dict.get('W_iz', [1])[0]
100    ) / 3
101
102    recurrent_density = (
103        sparsification_dict.get('W_hr', [1])[0]
104        + sparsification_dict.get('W_hn', [1])[0]
105        + sparsification_dict.get('W_hz', [1])[0]
106    ) / 3
107
108    # input matrix vector multiplications
109    if not drop_input:
110        flops += 2 * 3 * input_size * hidden_size * input_density
111
112    # recurrent matrix vector multiplications
113    flops += 2 * 3 * hidden_size * hidden_size * recurrent_density
114
115    # biases
116    flops += 6 * hidden_size
117
118    # activations estimated by 10 flops per activation
119    flops += 30 * hidden_size
120
121    return flops
122