xref: /aosp_15_r20/external/executorch/examples/qualcomm/scripts/wav2letter.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import json
8import os
9import sys
10from multiprocessing.connection import Client
11
12import numpy as np
13
14import torch
15from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
16from executorch.examples.models.wav2letter import Wav2LetterModel
17from executorch.examples.qualcomm.utils import (
18    build_executorch_binary,
19    make_output_dir,
20    parse_skip_delegation_node,
21    setup_common_args_and_variables,
22    SimpleADB,
23)
24
25
26class Conv2D(torch.nn.Module):
27    def __init__(self, stride, padding, weight, bias=None):
28        super().__init__()
29        use_bias = bias is not None
30        self.conv = torch.nn.Conv2d(
31            in_channels=weight.shape[1],
32            out_channels=weight.shape[0],
33            kernel_size=[weight.shape[2], 1],
34            stride=[*stride, 1],
35            padding=[*padding, 0],
36            bias=use_bias,
37        )
38        self.conv.weight = torch.nn.Parameter(weight.unsqueeze(-1))
39        if use_bias:
40            self.conv.bias = torch.nn.Parameter(bias)
41
42    def forward(self, x):
43        return self.conv(x)
44
45
46def get_dataset(data_size, artifact_dir):
47    from torch.utils.data import DataLoader
48    from torchaudio.datasets import LIBRISPEECH
49
50    def collate_fun(batch):
51        waves, labels = [], []
52
53        for wave, _, text, *_ in batch:
54            waves.append(wave.squeeze(0))
55            labels.append(text)
56        # need padding here for static ouput shape
57        waves = torch.nn.utils.rnn.pad_sequence(waves, batch_first=True)
58        return waves, labels
59
60    dataset = LIBRISPEECH(artifact_dir, url="test-clean", download=True)
61    data_loader = DataLoader(
62        dataset=dataset,
63        batch_size=data_size,
64        shuffle=True,
65        collate_fn=lambda x: collate_fun(x),
66    )
67    # prepare input data
68    inputs, targets, input_list = [], [], ""
69    for wave, label in data_loader:
70        for index in range(data_size):
71            # reshape input tensor to NCHW
72            inputs.append((wave[index].reshape(1, 1, -1, 1),))
73            targets.append(label[index])
74            input_list += f"input_{index}_0.raw\n"
75        # here we only take first batch, i.e. 'data_size' tensors
76        break
77
78    return inputs, targets, input_list
79
80
81def eval_metric(pred, target_str):
82    from torchmetrics.text import CharErrorRate, WordErrorRate
83
84    def parse(ids):
85        vocab = " abcdefghijklmnopqrstuvwxyz'*"
86        return ["".join([vocab[c] for c in id]).replace("*", "").upper() for id in ids]
87
88    pred_str = parse(
89        [
90            torch.unique_consecutive(pred[i, :, :].argmax(0))
91            for i in range(pred.shape[0])
92        ]
93    )
94    wer, cer = WordErrorRate(), CharErrorRate()
95    return wer(pred_str, target_str), cer(pred_str, target_str)
96
97
98def main(args):
99    skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
100
101    # ensure the working directory exist
102    os.makedirs(args.artifact, exist_ok=True)
103
104    if not args.compile_only and args.device is None:
105        raise RuntimeError(
106            "device serial is required if not compile only. "
107            "Please specify a device serial by -s/--device argument."
108        )
109
110    instance = Wav2LetterModel()
111    # target labels " abcdefghijklmnopqrstuvwxyz'*"
112    instance.vocab_size = 29
113    model = instance.get_eager_model().eval()
114    model.load_state_dict(torch.load(args.pretrained_weight, weights_only=True))
115
116    # convert conv1d to conv2d in nn.Module level will only introduce 2 permute
117    # nodes around input & output, which is more quantization friendly.
118    for i in range(len(model.acoustic_model)):
119        for j in range(len(model.acoustic_model[i])):
120            module = model.acoustic_model[i][j]
121            if isinstance(module, torch.nn.Conv1d):
122                model.acoustic_model[i][j] = Conv2D(
123                    stride=module.stride,
124                    padding=module.padding,
125                    weight=module.weight,
126                    bias=module.bias,
127                )
128
129    # retrieve dataset, will take some time to download
130    data_num = 100
131    inputs, targets, input_list = get_dataset(
132        data_size=data_num, artifact_dir=args.artifact
133    )
134    pte_filename = "w2l_qnn"
135    build_executorch_binary(
136        model,
137        inputs[0],
138        args.model,
139        f"{args.artifact}/{pte_filename}",
140        inputs,
141        skip_node_id_set=skip_node_id_set,
142        skip_node_op_set=skip_node_op_set,
143        quant_dtype=QuantDtype.use_8a8w,
144        shared_buffer=args.shared_buffer,
145    )
146
147    if args.compile_only:
148        sys.exit(0)
149
150    adb = SimpleADB(
151        qnn_sdk=os.getenv("QNN_SDK_ROOT"),
152        build_path=f"{args.build_folder}",
153        pte_path=f"{args.artifact}/{pte_filename}.pte",
154        workspace=f"/data/local/tmp/executorch/{pte_filename}",
155        device_id=args.device,
156        host_id=args.host,
157        soc_model=args.model,
158        shared_buffer=args.shared_buffer,
159    )
160    adb.push(inputs=inputs, input_list=input_list)
161    adb.execute()
162
163    # collect output data
164    output_data_folder = f"{args.artifact}/outputs"
165    make_output_dir(output_data_folder)
166    adb.pull(output_path=args.artifact)
167
168    predictions = []
169    for i in range(data_num):
170        predictions.append(
171            np.fromfile(
172                os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
173            )
174        )
175
176    # evaluate metrics
177    wer, cer = 0, 0
178    for i, pred in enumerate(predictions):
179        pred = torch.from_numpy(pred).reshape(1, instance.vocab_size, -1)
180        wer_eval, cer_eval = eval_metric(pred, targets[i])
181        wer += wer_eval
182        cer += cer_eval
183
184    if args.ip and args.port != -1:
185        with Client((args.ip, args.port)) as conn:
186            conn.send(
187                json.dumps({"wer": wer.item() / data_num, "cer": cer.item() / data_num})
188            )
189    else:
190        print(f"wer: {wer / data_num}\ncer: {cer / data_num}")
191
192
193if __name__ == "__main__":
194    parser = setup_common_args_and_variables()
195
196    parser.add_argument(
197        "-a",
198        "--artifact",
199        help="path for storing generated artifacts by this example. "
200        "Default ./wav2letter",
201        default="./wav2letter",
202        type=str,
203    )
204
205    parser.add_argument(
206        "-p",
207        "--pretrained_weight",
208        help=(
209            "Location of pretrained weight, please download via "
210            "https://github.com/nipponjo/wav2letter-ctc-pytorch/tree/main?tab=readme-ov-file#wav2letter-ctc-pytorch"
211            " for torchaudio.models.Wav2Letter version"
212        ),
213        default=None,
214        type=str,
215        required=True,
216    )
217
218    args = parser.parse_args()
219    try:
220        main(args)
221    except Exception as e:
222        if args.ip and args.port != -1:
223            with Client((args.ip, args.port)) as conn:
224                conn.send(json.dumps({"Error": str(e)}))
225        else:
226            raise Exception(e)
227