xref: /aosp_15_r20/external/libopus/dnn/torch/osce/export_model_weights.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2/* Copyright (c) 2023 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import os
31import argparse
32import sys
33
34import hashlib
35
36sys.path.append(os.path.join(os.path.dirname(__file__), '../weight-exchange'))
37
38import torch
39import wexchange.torch
40from wexchange.torch import dump_torch_weights
41from models import model_dict
42
43from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d
44from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d
45from utils.layers.td_shaper import TDShaper
46from utils.misc import remove_all_weight_norm
47from wexchange.torch import dump_torch_weights
48
49
50
51parser = argparse.ArgumentParser()
52
53parser.add_argument('checkpoint', type=str, help='LACE or NoLACE model checkpoint')
54parser.add_argument('output_dir', type=str, help='output folder')
55parser.add_argument('--quantize', action="store_true", help='quantization according to schedule')
56
57sparse_default=False
58schedules = {
59    'nolace': [
60        ('pitch_embedding', dict()),
61        ('feature_net.conv1', dict()),
62        ('feature_net.conv2', dict(quantize=True, scale=None, sparse=sparse_default)),
63        ('feature_net.tconv', dict(quantize=True, scale=None, sparse=sparse_default)),
64        ('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=sparse_default, recurrent_sparse=sparse_default)),
65        ('cf1', dict(quantize=True, scale=None)),
66        ('cf2', dict(quantize=True, scale=None)),
67        ('af1', dict(quantize=True, scale=None)),
68        ('tdshape1', dict(quantize=True, scale=None)),
69        ('tdshape2', dict(quantize=True, scale=None)),
70        ('tdshape3', dict(quantize=True, scale=None)),
71        ('af2', dict(quantize=True, scale=None)),
72        ('af3', dict(quantize=True, scale=None)),
73        ('af4', dict(quantize=True, scale=None)),
74        ('post_cf1', dict(quantize=True, scale=None, sparse=sparse_default)),
75        ('post_cf2', dict(quantize=True, scale=None, sparse=sparse_default)),
76        ('post_af1', dict(quantize=True, scale=None, sparse=sparse_default)),
77        ('post_af2', dict(quantize=True, scale=None, sparse=sparse_default)),
78        ('post_af3', dict(quantize=True, scale=None, sparse=sparse_default))
79    ],
80    'lace' : [
81        ('pitch_embedding', dict()),
82        ('feature_net.conv1', dict()),
83        ('feature_net.conv2', dict(quantize=True, scale=None, sparse=sparse_default)),
84        ('feature_net.tconv', dict(quantize=True, scale=None, sparse=sparse_default)),
85        ('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=sparse_default, recurrent_sparse=sparse_default)),
86        ('cf1', dict(quantize=True, scale=None)),
87        ('cf2', dict(quantize=True, scale=None)),
88        ('af1', dict(quantize=True, scale=None))
89    ]
90}
91
92
93# auxiliary functions
94def sha1(filename):
95    BUF_SIZE = 65536
96    sha1 = hashlib.sha1()
97
98    with open(filename, 'rb') as f:
99        while True:
100            data = f.read(BUF_SIZE)
101            if not data:
102                break
103            sha1.update(data)
104
105    return sha1.hexdigest()
106
107def osce_dump_generic(writer, name, module):
108    if isinstance(module, torch.nn.Linear) or isinstance(module, torch.nn.Conv1d) \
109            or isinstance(module, torch.nn.ConvTranspose1d) or isinstance(module, torch.nn.Embedding) \
110                or isinstance(module, LimitedAdaptiveConv1d) or isinstance(module, LimitedAdaptiveComb1d) \
111                    or isinstance(module, TDShaper) or isinstance(module, torch.nn.GRU):
112                        dump_torch_weights(writer, module, name=name, verbose=True)
113    else:
114        for child_name, child in module.named_children():
115            osce_dump_generic(writer, (name + "_" + child_name).replace("feature_net", "fnet"), child)
116
117
118def export_name(name):
119    name = name.replace('.', '_')
120    name = name.replace('feature_net', 'fnet')
121    return name
122
123def osce_scheduled_dump(writer, prefix, model, schedule):
124    if not prefix.endswith('_'):
125        prefix += '_'
126
127    for name, kwargs in schedule:
128        dump_torch_weights(writer, model.get_submodule(name), prefix + export_name(name), **kwargs, verbose=True)
129
130if __name__ == "__main__":
131    args = parser.parse_args()
132
133    checkpoint_path = args.checkpoint
134    outdir = args.output_dir
135    os.makedirs(outdir, exist_ok=True)
136
137    # dump message
138    message = f"Auto generated from checkpoint {os.path.basename(checkpoint_path)} (sha1: {sha1(checkpoint_path)})"
139
140    # create model and load weights
141    checkpoint = torch.load(checkpoint_path, map_location='cpu')
142    model = model_dict[checkpoint['setup']['model']['name']](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs'])
143    model.load_state_dict(checkpoint['state_dict'])
144    remove_all_weight_norm(model, verbose=True)
145
146    # CWriter
147    model_name = checkpoint['setup']['model']['name']
148    cwriter = wexchange.c_export.CWriter(os.path.join(outdir, model_name + "_data"), message=message, model_struct_name=model_name.upper() + 'Layers', add_typedef=True)
149
150    # Add custom includes and global parameters
151    cwriter.header.write(f'''
152#define {model_name.upper()}_PREEMPH {model.preemph}f
153#define {model_name.upper()}_FRAME_SIZE {model.FRAME_SIZE}
154#define {model_name.upper()}_OVERLAP_SIZE 40
155#define {model_name.upper()}_NUM_FEATURES {model.num_features}
156#define {model_name.upper()}_PITCH_MAX {model.pitch_max}
157#define {model_name.upper()}_PITCH_EMBEDDING_DIM {model.pitch_embedding_dim}
158#define {model_name.upper()}_NUMBITS_RANGE_LOW {model.numbits_range[0]}
159#define {model_name.upper()}_NUMBITS_RANGE_HIGH {model.numbits_range[1]}
160#define {model_name.upper()}_NUMBITS_EMBEDDING_DIM {model.numbits_embedding_dim}
161#define {model_name.upper()}_COND_DIM {model.cond_dim}
162#define {model_name.upper()}_HIDDEN_FEATURE_DIM {model.hidden_feature_dim}
163''')
164
165    for i, s in enumerate(model.numbits_embedding.scale_factors):
166        cwriter.header.write(f"#define {model_name.upper()}_NUMBITS_SCALE_{i} {float(s.detach().cpu())}f\n")
167
168    # dump layers
169    if model_name in schedules and args.quantize:
170        osce_scheduled_dump(cwriter, model_name, model, schedules[model_name])
171    else:
172        osce_dump_generic(cwriter, model_name, model)
173
174    cwriter.close()
175