1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import json 8import os 9import sys 10from multiprocessing.connection import Client 11 12import numpy as np 13 14import torch 15from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype 16from executorch.examples.models.wav2letter import Wav2LetterModel 17from executorch.examples.qualcomm.utils import ( 18 build_executorch_binary, 19 make_output_dir, 20 parse_skip_delegation_node, 21 setup_common_args_and_variables, 22 SimpleADB, 23) 24 25 26class Conv2D(torch.nn.Module): 27 def __init__(self, stride, padding, weight, bias=None): 28 super().__init__() 29 use_bias = bias is not None 30 self.conv = torch.nn.Conv2d( 31 in_channels=weight.shape[1], 32 out_channels=weight.shape[0], 33 kernel_size=[weight.shape[2], 1], 34 stride=[*stride, 1], 35 padding=[*padding, 0], 36 bias=use_bias, 37 ) 38 self.conv.weight = torch.nn.Parameter(weight.unsqueeze(-1)) 39 if use_bias: 40 self.conv.bias = torch.nn.Parameter(bias) 41 42 def forward(self, x): 43 return self.conv(x) 44 45 46def get_dataset(data_size, artifact_dir): 47 from torch.utils.data import DataLoader 48 from torchaudio.datasets import LIBRISPEECH 49 50 def collate_fun(batch): 51 waves, labels = [], [] 52 53 for wave, _, text, *_ in batch: 54 waves.append(wave.squeeze(0)) 55 labels.append(text) 56 # need padding here for static ouput shape 57 waves = torch.nn.utils.rnn.pad_sequence(waves, batch_first=True) 58 return waves, labels 59 60 dataset = LIBRISPEECH(artifact_dir, url="test-clean", download=True) 61 data_loader = DataLoader( 62 dataset=dataset, 63 batch_size=data_size, 64 shuffle=True, 65 collate_fn=lambda x: collate_fun(x), 66 ) 67 # prepare input data 68 inputs, targets, input_list = [], [], "" 69 for wave, label in data_loader: 70 for index in range(data_size): 71 # reshape input tensor to NCHW 72 inputs.append((wave[index].reshape(1, 1, -1, 1),)) 73 targets.append(label[index]) 74 input_list += f"input_{index}_0.raw\n" 75 # here we only take first batch, i.e. 'data_size' tensors 76 break 77 78 return inputs, targets, input_list 79 80 81def eval_metric(pred, target_str): 82 from torchmetrics.text import CharErrorRate, WordErrorRate 83 84 def parse(ids): 85 vocab = " abcdefghijklmnopqrstuvwxyz'*" 86 return ["".join([vocab[c] for c in id]).replace("*", "").upper() for id in ids] 87 88 pred_str = parse( 89 [ 90 torch.unique_consecutive(pred[i, :, :].argmax(0)) 91 for i in range(pred.shape[0]) 92 ] 93 ) 94 wer, cer = WordErrorRate(), CharErrorRate() 95 return wer(pred_str, target_str), cer(pred_str, target_str) 96 97 98def main(args): 99 skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args) 100 101 # ensure the working directory exist 102 os.makedirs(args.artifact, exist_ok=True) 103 104 if not args.compile_only and args.device is None: 105 raise RuntimeError( 106 "device serial is required if not compile only. " 107 "Please specify a device serial by -s/--device argument." 108 ) 109 110 instance = Wav2LetterModel() 111 # target labels " abcdefghijklmnopqrstuvwxyz'*" 112 instance.vocab_size = 29 113 model = instance.get_eager_model().eval() 114 model.load_state_dict(torch.load(args.pretrained_weight, weights_only=True)) 115 116 # convert conv1d to conv2d in nn.Module level will only introduce 2 permute 117 # nodes around input & output, which is more quantization friendly. 118 for i in range(len(model.acoustic_model)): 119 for j in range(len(model.acoustic_model[i])): 120 module = model.acoustic_model[i][j] 121 if isinstance(module, torch.nn.Conv1d): 122 model.acoustic_model[i][j] = Conv2D( 123 stride=module.stride, 124 padding=module.padding, 125 weight=module.weight, 126 bias=module.bias, 127 ) 128 129 # retrieve dataset, will take some time to download 130 data_num = 100 131 inputs, targets, input_list = get_dataset( 132 data_size=data_num, artifact_dir=args.artifact 133 ) 134 pte_filename = "w2l_qnn" 135 build_executorch_binary( 136 model, 137 inputs[0], 138 args.model, 139 f"{args.artifact}/{pte_filename}", 140 inputs, 141 skip_node_id_set=skip_node_id_set, 142 skip_node_op_set=skip_node_op_set, 143 quant_dtype=QuantDtype.use_8a8w, 144 shared_buffer=args.shared_buffer, 145 ) 146 147 if args.compile_only: 148 sys.exit(0) 149 150 adb = SimpleADB( 151 qnn_sdk=os.getenv("QNN_SDK_ROOT"), 152 build_path=f"{args.build_folder}", 153 pte_path=f"{args.artifact}/{pte_filename}.pte", 154 workspace=f"/data/local/tmp/executorch/{pte_filename}", 155 device_id=args.device, 156 host_id=args.host, 157 soc_model=args.model, 158 shared_buffer=args.shared_buffer, 159 ) 160 adb.push(inputs=inputs, input_list=input_list) 161 adb.execute() 162 163 # collect output data 164 output_data_folder = f"{args.artifact}/outputs" 165 make_output_dir(output_data_folder) 166 adb.pull(output_path=args.artifact) 167 168 predictions = [] 169 for i in range(data_num): 170 predictions.append( 171 np.fromfile( 172 os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32 173 ) 174 ) 175 176 # evaluate metrics 177 wer, cer = 0, 0 178 for i, pred in enumerate(predictions): 179 pred = torch.from_numpy(pred).reshape(1, instance.vocab_size, -1) 180 wer_eval, cer_eval = eval_metric(pred, targets[i]) 181 wer += wer_eval 182 cer += cer_eval 183 184 if args.ip and args.port != -1: 185 with Client((args.ip, args.port)) as conn: 186 conn.send( 187 json.dumps({"wer": wer.item() / data_num, "cer": cer.item() / data_num}) 188 ) 189 else: 190 print(f"wer: {wer / data_num}\ncer: {cer / data_num}") 191 192 193if __name__ == "__main__": 194 parser = setup_common_args_and_variables() 195 196 parser.add_argument( 197 "-a", 198 "--artifact", 199 help="path for storing generated artifacts by this example. " 200 "Default ./wav2letter", 201 default="./wav2letter", 202 type=str, 203 ) 204 205 parser.add_argument( 206 "-p", 207 "--pretrained_weight", 208 help=( 209 "Location of pretrained weight, please download via " 210 "https://github.com/nipponjo/wav2letter-ctc-pytorch/tree/main?tab=readme-ov-file#wav2letter-ctc-pytorch" 211 " for torchaudio.models.Wav2Letter version" 212 ), 213 default=None, 214 type=str, 215 required=True, 216 ) 217 218 args = parser.parse_args() 219 try: 220 main(args) 221 except Exception as e: 222 if args.ip and args.port != -1: 223 with Client((args.ip, args.port)) as conn: 224 conn.send(json.dumps({"Error": str(e)})) 225 else: 226 raise Exception(e) 227