xref: /aosp_15_r20/external/executorch/examples/qualcomm/scripts/wav2letter.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Qualcomm Innovation Center, Inc.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerimport json
8*523fa7a6SAndroid Build Coastguard Workerimport os
9*523fa7a6SAndroid Build Coastguard Workerimport sys
10*523fa7a6SAndroid Build Coastguard Workerfrom multiprocessing.connection import Client
11*523fa7a6SAndroid Build Coastguard Worker
12*523fa7a6SAndroid Build Coastguard Workerimport numpy as np
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard Workerimport torch
15*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.quantizer.quantizer import QuantDtype
16*523fa7a6SAndroid Build Coastguard Workerfrom executorch.examples.models.wav2letter import Wav2LetterModel
17*523fa7a6SAndroid Build Coastguard Workerfrom executorch.examples.qualcomm.utils import (
18*523fa7a6SAndroid Build Coastguard Worker    build_executorch_binary,
19*523fa7a6SAndroid Build Coastguard Worker    make_output_dir,
20*523fa7a6SAndroid Build Coastguard Worker    parse_skip_delegation_node,
21*523fa7a6SAndroid Build Coastguard Worker    setup_common_args_and_variables,
22*523fa7a6SAndroid Build Coastguard Worker    SimpleADB,
23*523fa7a6SAndroid Build Coastguard Worker)
24*523fa7a6SAndroid Build Coastguard Worker
25*523fa7a6SAndroid Build Coastguard Worker
26*523fa7a6SAndroid Build Coastguard Workerclass Conv2D(torch.nn.Module):
27*523fa7a6SAndroid Build Coastguard Worker    def __init__(self, stride, padding, weight, bias=None):
28*523fa7a6SAndroid Build Coastguard Worker        super().__init__()
29*523fa7a6SAndroid Build Coastguard Worker        use_bias = bias is not None
30*523fa7a6SAndroid Build Coastguard Worker        self.conv = torch.nn.Conv2d(
31*523fa7a6SAndroid Build Coastguard Worker            in_channels=weight.shape[1],
32*523fa7a6SAndroid Build Coastguard Worker            out_channels=weight.shape[0],
33*523fa7a6SAndroid Build Coastguard Worker            kernel_size=[weight.shape[2], 1],
34*523fa7a6SAndroid Build Coastguard Worker            stride=[*stride, 1],
35*523fa7a6SAndroid Build Coastguard Worker            padding=[*padding, 0],
36*523fa7a6SAndroid Build Coastguard Worker            bias=use_bias,
37*523fa7a6SAndroid Build Coastguard Worker        )
38*523fa7a6SAndroid Build Coastguard Worker        self.conv.weight = torch.nn.Parameter(weight.unsqueeze(-1))
39*523fa7a6SAndroid Build Coastguard Worker        if use_bias:
40*523fa7a6SAndroid Build Coastguard Worker            self.conv.bias = torch.nn.Parameter(bias)
41*523fa7a6SAndroid Build Coastguard Worker
42*523fa7a6SAndroid Build Coastguard Worker    def forward(self, x):
43*523fa7a6SAndroid Build Coastguard Worker        return self.conv(x)
44*523fa7a6SAndroid Build Coastguard Worker
45*523fa7a6SAndroid Build Coastguard Worker
46*523fa7a6SAndroid Build Coastguard Workerdef get_dataset(data_size, artifact_dir):
47*523fa7a6SAndroid Build Coastguard Worker    from torch.utils.data import DataLoader
48*523fa7a6SAndroid Build Coastguard Worker    from torchaudio.datasets import LIBRISPEECH
49*523fa7a6SAndroid Build Coastguard Worker
50*523fa7a6SAndroid Build Coastguard Worker    def collate_fun(batch):
51*523fa7a6SAndroid Build Coastguard Worker        waves, labels = [], []
52*523fa7a6SAndroid Build Coastguard Worker
53*523fa7a6SAndroid Build Coastguard Worker        for wave, _, text, *_ in batch:
54*523fa7a6SAndroid Build Coastguard Worker            waves.append(wave.squeeze(0))
55*523fa7a6SAndroid Build Coastguard Worker            labels.append(text)
56*523fa7a6SAndroid Build Coastguard Worker        # need padding here for static ouput shape
57*523fa7a6SAndroid Build Coastguard Worker        waves = torch.nn.utils.rnn.pad_sequence(waves, batch_first=True)
58*523fa7a6SAndroid Build Coastguard Worker        return waves, labels
59*523fa7a6SAndroid Build Coastguard Worker
60*523fa7a6SAndroid Build Coastguard Worker    dataset = LIBRISPEECH(artifact_dir, url="test-clean", download=True)
61*523fa7a6SAndroid Build Coastguard Worker    data_loader = DataLoader(
62*523fa7a6SAndroid Build Coastguard Worker        dataset=dataset,
63*523fa7a6SAndroid Build Coastguard Worker        batch_size=data_size,
64*523fa7a6SAndroid Build Coastguard Worker        shuffle=True,
65*523fa7a6SAndroid Build Coastguard Worker        collate_fn=lambda x: collate_fun(x),
66*523fa7a6SAndroid Build Coastguard Worker    )
67*523fa7a6SAndroid Build Coastguard Worker    # prepare input data
68*523fa7a6SAndroid Build Coastguard Worker    inputs, targets, input_list = [], [], ""
69*523fa7a6SAndroid Build Coastguard Worker    for wave, label in data_loader:
70*523fa7a6SAndroid Build Coastguard Worker        for index in range(data_size):
71*523fa7a6SAndroid Build Coastguard Worker            # reshape input tensor to NCHW
72*523fa7a6SAndroid Build Coastguard Worker            inputs.append((wave[index].reshape(1, 1, -1, 1),))
73*523fa7a6SAndroid Build Coastguard Worker            targets.append(label[index])
74*523fa7a6SAndroid Build Coastguard Worker            input_list += f"input_{index}_0.raw\n"
75*523fa7a6SAndroid Build Coastguard Worker        # here we only take first batch, i.e. 'data_size' tensors
76*523fa7a6SAndroid Build Coastguard Worker        break
77*523fa7a6SAndroid Build Coastguard Worker
78*523fa7a6SAndroid Build Coastguard Worker    return inputs, targets, input_list
79*523fa7a6SAndroid Build Coastguard Worker
80*523fa7a6SAndroid Build Coastguard Worker
81*523fa7a6SAndroid Build Coastguard Workerdef eval_metric(pred, target_str):
82*523fa7a6SAndroid Build Coastguard Worker    from torchmetrics.text import CharErrorRate, WordErrorRate
83*523fa7a6SAndroid Build Coastguard Worker
84*523fa7a6SAndroid Build Coastguard Worker    def parse(ids):
85*523fa7a6SAndroid Build Coastguard Worker        vocab = " abcdefghijklmnopqrstuvwxyz'*"
86*523fa7a6SAndroid Build Coastguard Worker        return ["".join([vocab[c] for c in id]).replace("*", "").upper() for id in ids]
87*523fa7a6SAndroid Build Coastguard Worker
88*523fa7a6SAndroid Build Coastguard Worker    pred_str = parse(
89*523fa7a6SAndroid Build Coastguard Worker        [
90*523fa7a6SAndroid Build Coastguard Worker            torch.unique_consecutive(pred[i, :, :].argmax(0))
91*523fa7a6SAndroid Build Coastguard Worker            for i in range(pred.shape[0])
92*523fa7a6SAndroid Build Coastguard Worker        ]
93*523fa7a6SAndroid Build Coastguard Worker    )
94*523fa7a6SAndroid Build Coastguard Worker    wer, cer = WordErrorRate(), CharErrorRate()
95*523fa7a6SAndroid Build Coastguard Worker    return wer(pred_str, target_str), cer(pred_str, target_str)
96*523fa7a6SAndroid Build Coastguard Worker
97*523fa7a6SAndroid Build Coastguard Worker
98*523fa7a6SAndroid Build Coastguard Workerdef main(args):
99*523fa7a6SAndroid Build Coastguard Worker    skip_node_id_set, skip_node_op_set = parse_skip_delegation_node(args)
100*523fa7a6SAndroid Build Coastguard Worker
101*523fa7a6SAndroid Build Coastguard Worker    # ensure the working directory exist
102*523fa7a6SAndroid Build Coastguard Worker    os.makedirs(args.artifact, exist_ok=True)
103*523fa7a6SAndroid Build Coastguard Worker
104*523fa7a6SAndroid Build Coastguard Worker    if not args.compile_only and args.device is None:
105*523fa7a6SAndroid Build Coastguard Worker        raise RuntimeError(
106*523fa7a6SAndroid Build Coastguard Worker            "device serial is required if not compile only. "
107*523fa7a6SAndroid Build Coastguard Worker            "Please specify a device serial by -s/--device argument."
108*523fa7a6SAndroid Build Coastguard Worker        )
109*523fa7a6SAndroid Build Coastguard Worker
110*523fa7a6SAndroid Build Coastguard Worker    instance = Wav2LetterModel()
111*523fa7a6SAndroid Build Coastguard Worker    # target labels " abcdefghijklmnopqrstuvwxyz'*"
112*523fa7a6SAndroid Build Coastguard Worker    instance.vocab_size = 29
113*523fa7a6SAndroid Build Coastguard Worker    model = instance.get_eager_model().eval()
114*523fa7a6SAndroid Build Coastguard Worker    model.load_state_dict(torch.load(args.pretrained_weight, weights_only=True))
115*523fa7a6SAndroid Build Coastguard Worker
116*523fa7a6SAndroid Build Coastguard Worker    # convert conv1d to conv2d in nn.Module level will only introduce 2 permute
117*523fa7a6SAndroid Build Coastguard Worker    # nodes around input & output, which is more quantization friendly.
118*523fa7a6SAndroid Build Coastguard Worker    for i in range(len(model.acoustic_model)):
119*523fa7a6SAndroid Build Coastguard Worker        for j in range(len(model.acoustic_model[i])):
120*523fa7a6SAndroid Build Coastguard Worker            module = model.acoustic_model[i][j]
121*523fa7a6SAndroid Build Coastguard Worker            if isinstance(module, torch.nn.Conv1d):
122*523fa7a6SAndroid Build Coastguard Worker                model.acoustic_model[i][j] = Conv2D(
123*523fa7a6SAndroid Build Coastguard Worker                    stride=module.stride,
124*523fa7a6SAndroid Build Coastguard Worker                    padding=module.padding,
125*523fa7a6SAndroid Build Coastguard Worker                    weight=module.weight,
126*523fa7a6SAndroid Build Coastguard Worker                    bias=module.bias,
127*523fa7a6SAndroid Build Coastguard Worker                )
128*523fa7a6SAndroid Build Coastguard Worker
129*523fa7a6SAndroid Build Coastguard Worker    # retrieve dataset, will take some time to download
130*523fa7a6SAndroid Build Coastguard Worker    data_num = 100
131*523fa7a6SAndroid Build Coastguard Worker    inputs, targets, input_list = get_dataset(
132*523fa7a6SAndroid Build Coastguard Worker        data_size=data_num, artifact_dir=args.artifact
133*523fa7a6SAndroid Build Coastguard Worker    )
134*523fa7a6SAndroid Build Coastguard Worker    pte_filename = "w2l_qnn"
135*523fa7a6SAndroid Build Coastguard Worker    build_executorch_binary(
136*523fa7a6SAndroid Build Coastguard Worker        model,
137*523fa7a6SAndroid Build Coastguard Worker        inputs[0],
138*523fa7a6SAndroid Build Coastguard Worker        args.model,
139*523fa7a6SAndroid Build Coastguard Worker        f"{args.artifact}/{pte_filename}",
140*523fa7a6SAndroid Build Coastguard Worker        inputs,
141*523fa7a6SAndroid Build Coastguard Worker        skip_node_id_set=skip_node_id_set,
142*523fa7a6SAndroid Build Coastguard Worker        skip_node_op_set=skip_node_op_set,
143*523fa7a6SAndroid Build Coastguard Worker        quant_dtype=QuantDtype.use_8a8w,
144*523fa7a6SAndroid Build Coastguard Worker        shared_buffer=args.shared_buffer,
145*523fa7a6SAndroid Build Coastguard Worker    )
146*523fa7a6SAndroid Build Coastguard Worker
147*523fa7a6SAndroid Build Coastguard Worker    if args.compile_only:
148*523fa7a6SAndroid Build Coastguard Worker        sys.exit(0)
149*523fa7a6SAndroid Build Coastguard Worker
150*523fa7a6SAndroid Build Coastguard Worker    adb = SimpleADB(
151*523fa7a6SAndroid Build Coastguard Worker        qnn_sdk=os.getenv("QNN_SDK_ROOT"),
152*523fa7a6SAndroid Build Coastguard Worker        build_path=f"{args.build_folder}",
153*523fa7a6SAndroid Build Coastguard Worker        pte_path=f"{args.artifact}/{pte_filename}.pte",
154*523fa7a6SAndroid Build Coastguard Worker        workspace=f"/data/local/tmp/executorch/{pte_filename}",
155*523fa7a6SAndroid Build Coastguard Worker        device_id=args.device,
156*523fa7a6SAndroid Build Coastguard Worker        host_id=args.host,
157*523fa7a6SAndroid Build Coastguard Worker        soc_model=args.model,
158*523fa7a6SAndroid Build Coastguard Worker        shared_buffer=args.shared_buffer,
159*523fa7a6SAndroid Build Coastguard Worker    )
160*523fa7a6SAndroid Build Coastguard Worker    adb.push(inputs=inputs, input_list=input_list)
161*523fa7a6SAndroid Build Coastguard Worker    adb.execute()
162*523fa7a6SAndroid Build Coastguard Worker
163*523fa7a6SAndroid Build Coastguard Worker    # collect output data
164*523fa7a6SAndroid Build Coastguard Worker    output_data_folder = f"{args.artifact}/outputs"
165*523fa7a6SAndroid Build Coastguard Worker    make_output_dir(output_data_folder)
166*523fa7a6SAndroid Build Coastguard Worker    adb.pull(output_path=args.artifact)
167*523fa7a6SAndroid Build Coastguard Worker
168*523fa7a6SAndroid Build Coastguard Worker    predictions = []
169*523fa7a6SAndroid Build Coastguard Worker    for i in range(data_num):
170*523fa7a6SAndroid Build Coastguard Worker        predictions.append(
171*523fa7a6SAndroid Build Coastguard Worker            np.fromfile(
172*523fa7a6SAndroid Build Coastguard Worker                os.path.join(output_data_folder, f"output_{i}_0.raw"), dtype=np.float32
173*523fa7a6SAndroid Build Coastguard Worker            )
174*523fa7a6SAndroid Build Coastguard Worker        )
175*523fa7a6SAndroid Build Coastguard Worker
176*523fa7a6SAndroid Build Coastguard Worker    # evaluate metrics
177*523fa7a6SAndroid Build Coastguard Worker    wer, cer = 0, 0
178*523fa7a6SAndroid Build Coastguard Worker    for i, pred in enumerate(predictions):
179*523fa7a6SAndroid Build Coastguard Worker        pred = torch.from_numpy(pred).reshape(1, instance.vocab_size, -1)
180*523fa7a6SAndroid Build Coastguard Worker        wer_eval, cer_eval = eval_metric(pred, targets[i])
181*523fa7a6SAndroid Build Coastguard Worker        wer += wer_eval
182*523fa7a6SAndroid Build Coastguard Worker        cer += cer_eval
183*523fa7a6SAndroid Build Coastguard Worker
184*523fa7a6SAndroid Build Coastguard Worker    if args.ip and args.port != -1:
185*523fa7a6SAndroid Build Coastguard Worker        with Client((args.ip, args.port)) as conn:
186*523fa7a6SAndroid Build Coastguard Worker            conn.send(
187*523fa7a6SAndroid Build Coastguard Worker                json.dumps({"wer": wer.item() / data_num, "cer": cer.item() / data_num})
188*523fa7a6SAndroid Build Coastguard Worker            )
189*523fa7a6SAndroid Build Coastguard Worker    else:
190*523fa7a6SAndroid Build Coastguard Worker        print(f"wer: {wer / data_num}\ncer: {cer / data_num}")
191*523fa7a6SAndroid Build Coastguard Worker
192*523fa7a6SAndroid Build Coastguard Worker
193*523fa7a6SAndroid Build Coastguard Workerif __name__ == "__main__":
194*523fa7a6SAndroid Build Coastguard Worker    parser = setup_common_args_and_variables()
195*523fa7a6SAndroid Build Coastguard Worker
196*523fa7a6SAndroid Build Coastguard Worker    parser.add_argument(
197*523fa7a6SAndroid Build Coastguard Worker        "-a",
198*523fa7a6SAndroid Build Coastguard Worker        "--artifact",
199*523fa7a6SAndroid Build Coastguard Worker        help="path for storing generated artifacts by this example. "
200*523fa7a6SAndroid Build Coastguard Worker        "Default ./wav2letter",
201*523fa7a6SAndroid Build Coastguard Worker        default="./wav2letter",
202*523fa7a6SAndroid Build Coastguard Worker        type=str,
203*523fa7a6SAndroid Build Coastguard Worker    )
204*523fa7a6SAndroid Build Coastguard Worker
205*523fa7a6SAndroid Build Coastguard Worker    parser.add_argument(
206*523fa7a6SAndroid Build Coastguard Worker        "-p",
207*523fa7a6SAndroid Build Coastguard Worker        "--pretrained_weight",
208*523fa7a6SAndroid Build Coastguard Worker        help=(
209*523fa7a6SAndroid Build Coastguard Worker            "Location of pretrained weight, please download via "
210*523fa7a6SAndroid Build Coastguard Worker            "https://github.com/nipponjo/wav2letter-ctc-pytorch/tree/main?tab=readme-ov-file#wav2letter-ctc-pytorch"
211*523fa7a6SAndroid Build Coastguard Worker            " for torchaudio.models.Wav2Letter version"
212*523fa7a6SAndroid Build Coastguard Worker        ),
213*523fa7a6SAndroid Build Coastguard Worker        default=None,
214*523fa7a6SAndroid Build Coastguard Worker        type=str,
215*523fa7a6SAndroid Build Coastguard Worker        required=True,
216*523fa7a6SAndroid Build Coastguard Worker    )
217*523fa7a6SAndroid Build Coastguard Worker
218*523fa7a6SAndroid Build Coastguard Worker    args = parser.parse_args()
219*523fa7a6SAndroid Build Coastguard Worker    try:
220*523fa7a6SAndroid Build Coastguard Worker        main(args)
221*523fa7a6SAndroid Build Coastguard Worker    except Exception as e:
222*523fa7a6SAndroid Build Coastguard Worker        if args.ip and args.port != -1:
223*523fa7a6SAndroid Build Coastguard Worker            with Client((args.ip, args.port)) as conn:
224*523fa7a6SAndroid Build Coastguard Worker                conn.send(json.dumps({"Error": str(e)}))
225*523fa7a6SAndroid Build Coastguard Worker        else:
226*523fa7a6SAndroid Build Coastguard Worker            raise Exception(e)
227