1""" 2/* Copyright (c) 2023 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30import os 31import argparse 32import sys 33 34import hashlib 35 36sys.path.append(os.path.join(os.path.dirname(__file__), '../weight-exchange')) 37 38import torch 39import wexchange.torch 40from wexchange.torch import dump_torch_weights 41from models import model_dict 42 43from utils.layers.limited_adaptive_comb1d import LimitedAdaptiveComb1d 44from utils.layers.limited_adaptive_conv1d import LimitedAdaptiveConv1d 45from utils.layers.td_shaper import TDShaper 46from utils.misc import remove_all_weight_norm 47from wexchange.torch import dump_torch_weights 48 49 50 51parser = argparse.ArgumentParser() 52 53parser.add_argument('checkpoint', type=str, help='LACE or NoLACE model checkpoint') 54parser.add_argument('output_dir', type=str, help='output folder') 55parser.add_argument('--quantize', action="store_true", help='quantization according to schedule') 56 57sparse_default=False 58schedules = { 59 'nolace': [ 60 ('pitch_embedding', dict()), 61 ('feature_net.conv1', dict()), 62 ('feature_net.conv2', dict(quantize=True, scale=None, sparse=sparse_default)), 63 ('feature_net.tconv', dict(quantize=True, scale=None, sparse=sparse_default)), 64 ('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=sparse_default, recurrent_sparse=sparse_default)), 65 ('cf1', dict(quantize=True, scale=None)), 66 ('cf2', dict(quantize=True, scale=None)), 67 ('af1', dict(quantize=True, scale=None)), 68 ('tdshape1', dict(quantize=True, scale=None)), 69 ('tdshape2', dict(quantize=True, scale=None)), 70 ('tdshape3', dict(quantize=True, scale=None)), 71 ('af2', dict(quantize=True, scale=None)), 72 ('af3', dict(quantize=True, scale=None)), 73 ('af4', dict(quantize=True, scale=None)), 74 ('post_cf1', dict(quantize=True, scale=None, sparse=sparse_default)), 75 ('post_cf2', dict(quantize=True, scale=None, sparse=sparse_default)), 76 ('post_af1', dict(quantize=True, scale=None, sparse=sparse_default)), 77 ('post_af2', dict(quantize=True, scale=None, sparse=sparse_default)), 78 ('post_af3', dict(quantize=True, scale=None, sparse=sparse_default)) 79 ], 80 'lace' : [ 81 ('pitch_embedding', dict()), 82 ('feature_net.conv1', dict()), 83 ('feature_net.conv2', dict(quantize=True, scale=None, sparse=sparse_default)), 84 ('feature_net.tconv', dict(quantize=True, scale=None, sparse=sparse_default)), 85 ('feature_net.gru', dict(quantize=True, scale=None, recurrent_scale=None, input_sparse=sparse_default, recurrent_sparse=sparse_default)), 86 ('cf1', dict(quantize=True, scale=None)), 87 ('cf2', dict(quantize=True, scale=None)), 88 ('af1', dict(quantize=True, scale=None)) 89 ] 90} 91 92 93# auxiliary functions 94def sha1(filename): 95 BUF_SIZE = 65536 96 sha1 = hashlib.sha1() 97 98 with open(filename, 'rb') as f: 99 while True: 100 data = f.read(BUF_SIZE) 101 if not data: 102 break 103 sha1.update(data) 104 105 return sha1.hexdigest() 106 107def osce_dump_generic(writer, name, module): 108 if isinstance(module, torch.nn.Linear) or isinstance(module, torch.nn.Conv1d) \ 109 or isinstance(module, torch.nn.ConvTranspose1d) or isinstance(module, torch.nn.Embedding) \ 110 or isinstance(module, LimitedAdaptiveConv1d) or isinstance(module, LimitedAdaptiveComb1d) \ 111 or isinstance(module, TDShaper) or isinstance(module, torch.nn.GRU): 112 dump_torch_weights(writer, module, name=name, verbose=True) 113 else: 114 for child_name, child in module.named_children(): 115 osce_dump_generic(writer, (name + "_" + child_name).replace("feature_net", "fnet"), child) 116 117 118def export_name(name): 119 name = name.replace('.', '_') 120 name = name.replace('feature_net', 'fnet') 121 return name 122 123def osce_scheduled_dump(writer, prefix, model, schedule): 124 if not prefix.endswith('_'): 125 prefix += '_' 126 127 for name, kwargs in schedule: 128 dump_torch_weights(writer, model.get_submodule(name), prefix + export_name(name), **kwargs, verbose=True) 129 130if __name__ == "__main__": 131 args = parser.parse_args() 132 133 checkpoint_path = args.checkpoint 134 outdir = args.output_dir 135 os.makedirs(outdir, exist_ok=True) 136 137 # dump message 138 message = f"Auto generated from checkpoint {os.path.basename(checkpoint_path)} (sha1: {sha1(checkpoint_path)})" 139 140 # create model and load weights 141 checkpoint = torch.load(checkpoint_path, map_location='cpu') 142 model = model_dict[checkpoint['setup']['model']['name']](*checkpoint['setup']['model']['args'], **checkpoint['setup']['model']['kwargs']) 143 model.load_state_dict(checkpoint['state_dict']) 144 remove_all_weight_norm(model, verbose=True) 145 146 # CWriter 147 model_name = checkpoint['setup']['model']['name'] 148 cwriter = wexchange.c_export.CWriter(os.path.join(outdir, model_name + "_data"), message=message, model_struct_name=model_name.upper() + 'Layers', add_typedef=True) 149 150 # Add custom includes and global parameters 151 cwriter.header.write(f''' 152#define {model_name.upper()}_PREEMPH {model.preemph}f 153#define {model_name.upper()}_FRAME_SIZE {model.FRAME_SIZE} 154#define {model_name.upper()}_OVERLAP_SIZE 40 155#define {model_name.upper()}_NUM_FEATURES {model.num_features} 156#define {model_name.upper()}_PITCH_MAX {model.pitch_max} 157#define {model_name.upper()}_PITCH_EMBEDDING_DIM {model.pitch_embedding_dim} 158#define {model_name.upper()}_NUMBITS_RANGE_LOW {model.numbits_range[0]} 159#define {model_name.upper()}_NUMBITS_RANGE_HIGH {model.numbits_range[1]} 160#define {model_name.upper()}_NUMBITS_EMBEDDING_DIM {model.numbits_embedding_dim} 161#define {model_name.upper()}_COND_DIM {model.cond_dim} 162#define {model_name.upper()}_HIDDEN_FEATURE_DIM {model.hidden_feature_dim} 163''') 164 165 for i, s in enumerate(model.numbits_embedding.scale_factors): 166 cwriter.header.write(f"#define {model_name.upper()}_NUMBITS_SCALE_{i} {float(s.detach().cpu())}f\n") 167 168 # dump layers 169 if model_name in schedules and args.quantize: 170 osce_scheduled_dump(cwriter, model_name, model, schedules[model_name]) 171 else: 172 osce_dump_generic(cwriter, model_name, model) 173 174 cwriter.close() 175