xref: /btstack/test/hfp/cvsd_plc_test.cpp (revision 46d6c6044a82d8cc98f35a2ca93e895adb50ca82)
1 #include <stdint.h>
2 #include <stdio.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <fcntl.h>
6 #include <unistd.h>
7 
8 #include "CppUTest/TestHarness.h"
9 #include "CppUTest/CommandLineTestRunner.h"
10 
11 #include "classic/btstack_cvsd_plc.h"
12 #include "wav_util.h"
13 
14 const  int     audio_samples_per_frame = 60;
15 static int16_t audio_frame_in[audio_samples_per_frame];
16 
17 // static int16_t test_data[][audio_samples_per_frame] = {
18 //     { 0x05, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff },
19 //     { 0xff, 0xff, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x05 },
20 //     { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05, 0x05 },
21 //     { 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 },
22 // };
23 
24 static btstack_cvsd_plc_state_t plc_state;
25 
26 // input signal: pre-computed sine wave, 160 Hz at 16000 kHz
27 static const int16_t sine_int16[] = {
28      0,    2057,    4107,    6140,    8149,   10126,   12062,   13952,   15786,   17557,
29  19260,   20886,   22431,   23886,   25247,   26509,   27666,   28714,   29648,   30466,
30  31163,   31738,   32187,   32509,   32702,   32767,   32702,   32509,   32187,   31738,
31  31163,   30466,   29648,   28714,   27666,   26509,   25247,   23886,   22431,   20886,
32  19260,   17557,   15786,   13952,   12062,   10126,    8149,    6140,    4107,    2057,
33      0,   -2057,   -4107,   -6140,   -8149,  -10126,  -12062,  -13952,  -15786,  -17557,
34 -19260,  -20886,  -22431,  -23886,  -25247,  -26509,  -27666,  -28714,  -29648,  -30466,
35 -31163,  -31738,  -32187,  -32509,  -32702,  -32767,  -32702,  -32509,  -32187,  -31738,
36 -31163,  -30466,  -29648,  -28714,  -27666,  -26509,  -25247,  -23886,  -22431,  -20886,
37 -19260,  -17557,  -15786,  -13952,  -12062,  -10126,   -8149,   -6140,   -4107,   -2057,
38 };
39 
40 static int count_equal_samples(int16_t * packet, uint16_t size){
41     int count = 0;
42     int temp_count = 1;
43     int i;
44     for (i = 0; i < size-1; i++){
45         if (packet[i] == packet[i+1]){
46             temp_count++;
47             continue;
48         }
49         if (count < temp_count){
50             count = temp_count;
51         }
52         temp_count = 1;
53     }
54     if (temp_count > count + 1){
55         count = temp_count;
56     }
57     return count;
58 }
59 
60 // @assumption frame len 24 samples
61 static int bad_frame(int16_t * frame, uint16_t size){
62     return count_equal_samples(frame, size) > audio_samples_per_frame - 4;
63 }
64 
65 static void btstack_cvsd_plc_mark_bad_frame(btstack_cvsd_plc_state_t * state, int16_t * in, uint16_t size, int16_t * out){
66     state->frame_count++;
67     if (bad_frame(in,size)){
68         memcpy(out, in, size * 2);
69         if (state->good_frames_nr > CVSD_LHIST/audio_samples_per_frame){
70             memset(out, 0x33, size * 2);
71             state->bad_frames_nr++;
72         }
73     } else {
74         memcpy(out, in, size);
75         state->good_frames_nr++;
76         if (state->good_frames_nr == 1){
77             printf("First good frame at index %d\n", state->frame_count-1);
78         }
79     }
80 }
81 
82 static int phase = 0;
83 static void create_sine_wave_int16_data(int num_samples, int16_t * data){
84     int i;
85     for (i=0; i < num_samples; i++){
86         data[i] = sine_int16[phase++] * 90/100;
87         phase++;
88         if (phase >= (sizeof(sine_int16) / sizeof(int16_t))){
89             phase = 0;
90         }
91     }
92 }
93 
94 // static int count_equal_bytes(int16_t * packet, uint16_t size){
95 //     int count = 0;
96 //     int temp_count = 1;
97 //     int i;
98 //     for (i = 0; i < size-1; i++){
99 //         if (packet[i] == packet[i+1]){
100 //             temp_count++;
101 //             continue;
102 //         }
103 //         if (count < temp_count){
104 //             count = temp_count;
105 //         }
106 //         temp_count = 1;
107 //     }
108 //     if (temp_count > count + 1){
109 //         count = temp_count;
110 //     }
111 //     return count;
112 // }
113 
114 static void create_sine_wav(const char * out_filename){
115     btstack_cvsd_plc_init(&plc_state);
116     wav_writer_open(out_filename, 1, 8000);
117 
118     int i;
119     for (i=0; i<2000; i++){
120         create_sine_wave_int16_data(audio_samples_per_frame, audio_frame_in);
121         wav_writer_write_int16(audio_samples_per_frame, audio_frame_in);
122     }
123     wav_writer_close();
124 }
125 
126 static int introduce_bad_frames_to_wav_file(const char * in_filename, const char * out_filename, int corruption_step){
127     btstack_cvsd_plc_init(&plc_state);
128     wav_writer_open(out_filename, 1, 8000);
129     wav_reader_open(in_filename);
130     int total_corruption_times = 0;
131     int fc = 0;
132     int start_corruption = 0;
133 
134     while (wav_reader_read_int16(audio_samples_per_frame, audio_frame_in) == 0){
135         if (corruption_step > 0 && fc >= corruption_step && fc%corruption_step == 0){
136             printf("corrupt fc %d, corruption_step %d\n", fc, corruption_step);
137             start_corruption = 1;
138         }
139         if (start_corruption > 0 && start_corruption < 4){
140             memset(audio_frame_in, 50,  audio_samples_per_frame * 2);
141             start_corruption++;
142         }
143         if (start_corruption > 4){
144             start_corruption = 0;
145             total_corruption_times++;
146             // printf("corupted 3 frames\n");
147         }
148         wav_writer_write_int16(audio_samples_per_frame, audio_frame_in);
149         fc++;
150     }
151     wav_reader_close();
152     wav_writer_close();
153     return total_corruption_times;
154 }
155 
156 static void process_wav_file_with_plc(const char * in_filename, const char * out_filename){
157     // printf("\nProcess %s -> %s\n", in_filename, out_filename);
158     btstack_cvsd_plc_init(&plc_state);
159     wav_writer_open(out_filename, 1, 8000);
160     wav_reader_open(in_filename);
161 
162     while (wav_reader_read_int16(audio_samples_per_frame, audio_frame_in) == 0){
163         int16_t audio_frame_out[audio_samples_per_frame];
164         btstack_cvsd_plc_process_data(&plc_state, false, audio_frame_in, audio_samples_per_frame, audio_frame_out);
165         wav_writer_write_int16(audio_samples_per_frame, audio_frame_out);
166     }
167     wav_reader_close();
168     wav_writer_close();
169     btstack_cvsd_dump_statistics(&plc_state);
170 }
171 
172 void mark_bad_frames_wav_file(const char * in_filename, const char * out_filename);
173 void mark_bad_frames_wav_file(const char * in_filename, const char * out_filename){
174     // printf("\nMark bad frame %s -> %s\n", in_filename, out_filename);
175     btstack_cvsd_plc_init(&plc_state);
176     CHECK_EQUAL(wav_writer_open(out_filename, 1, 8000), 0);
177     CHECK_EQUAL(wav_reader_open(in_filename), 0);
178 
179     while (wav_reader_read_int16(audio_samples_per_frame, audio_frame_in) == 0){
180         int16_t audio_frame_out[audio_samples_per_frame];
181         btstack_cvsd_plc_mark_bad_frame(&plc_state, audio_frame_in, audio_samples_per_frame, audio_frame_out);
182         wav_writer_write_int16(audio_samples_per_frame, audio_frame_out);
183     }
184     wav_reader_close();
185     wav_writer_close();
186     btstack_cvsd_dump_statistics(&plc_state);
187 }
188 
189 TEST_GROUP(CVSD_PLC){
190 
191 };
192 
193 static void fprintf_array_int16(FILE * oct_file, char * name, int data_len, int16_t * data){
194     fprintf(oct_file, "%s = [", name);
195     int i;
196     for (i = 0; i < data_len - 1; i++){
197         fprintf(oct_file, "%d, ", data[i]);
198     }
199     fprintf(oct_file, "%d", data[i]);
200     fprintf(oct_file, "%s", "];\n");
201 }
202 
203 static void fprintf_plot_history(FILE * oct_file, char * name, int data_len, int16_t * data){
204     fprintf_array_int16(oct_file, name, CVSD_LHIST, plc_state.hist);
205 
206     fprintf(oct_file, "y = [min(%s):1000:max(%s)];\n", name, name);
207     fprintf(oct_file, "x = zeros(1, size(y,2));\n");
208     fprintf(oct_file, "b = [0:500];\n");
209 
210     int pos = CVSD_FS;
211     fprintf(oct_file, "shift_x = x + %d;\n", pos);
212 
213     pos = CVSD_LHIST - 1;
214     fprintf(oct_file, "lhist_x = x + %d;\n", pos);
215     pos += CVSD_OLAL;
216     fprintf(oct_file, "lhist_olal1_x = x + %d;\n", pos);
217     pos += CVSD_FS - CVSD_OLAL;
218     fprintf(oct_file, "lhist_fs_x = x + %d;\n", pos);
219     pos += CVSD_OLAL;
220     fprintf(oct_file, "lhist_olal2_x = x + %d;\n", pos);
221     pos += CVSD_RT;
222     fprintf(oct_file, "lhist_rt_x = x + %d;\n", pos);
223 
224     fprintf(oct_file, "pattern_window_x = x + %d;\n", CVSD_LHIST - CVSD_M);
225 
226     fprintf(oct_file, "hold on;\n");
227     fprintf(oct_file, "plot(%s); \n", name);
228 
229     fprintf(oct_file, "plot(shift_x, y, 'k--'); \n");
230     fprintf(oct_file, "plot(lhist_x, y, 'k'); \n");
231     fprintf(oct_file, "plot(lhist_olal1_x, y, 'k'); \n");
232     fprintf(oct_file, "plot(lhist_fs_x, y, 'k'); \n");
233     fprintf(oct_file, "plot(lhist_olal2_x, y, 'k'); \n");
234     fprintf(oct_file, "plot(lhist_rt_x, y, 'k');\n");
235 
236     int x0 = plc_state.bestlag;
237     int x1 = plc_state.bestlag + CVSD_M - 1;
238     fprintf(oct_file, "plot(b(%d:%d), %s(%d:%d), 'rd'); \n", x0, x1, name, x0, x1);
239 
240     x0 = plc_state.bestlag + CVSD_M ;
241     x1 = plc_state.bestlag + CVSD_M + audio_samples_per_frame - 1;
242     fprintf(oct_file, "plot(b(%d:%d), %s(%d:%d), 'kd'); \n", x0, x1, name, x0, x1);
243 
244     x0 = CVSD_LHIST - CVSD_M;
245     x1 = CVSD_LHIST - 1;
246     fprintf(oct_file, "plot(b(%d:%d), %s(%d:%d), 'rd'); \n", x0, x1, name, x0, x1);
247     fprintf(oct_file, "plot(pattern_window_x, y, 'g'); \n");
248 }
249 
250 TEST(CVSD_PLC, CountEqBytes){
251     // init cvsd_fs in plc_state
252     float val, sf;
253     int i, x0, x1;
254 
255     char * name;
256     BTSTACK_CVSD_PLC_SAMPLE_FORMAT out[CVSD_FS];
257     BTSTACK_CVSD_PLC_SAMPLE_FORMAT hist[CVSD_LHIST+CVSD_FS+CVSD_RT+CVSD_OLAL];
258     FILE * oct_file = fopen("/Users/mringwal/octave/plc.m", "wb");
259     if (!oct_file) return;
260     fprintf(oct_file, "%s", "1;\n\n");
261 
262     int hist_len = sizeof(plc_state.hist)/2;
263     create_sine_wave_int16_data(CVSD_LHIST, hist);
264     memset(plc_state.hist, hist[CVSD_LHIST-1], sizeof(plc_state.hist));
265     memcpy(plc_state.hist, hist, CVSD_LHIST*2);
266 
267     // Perform pattern matching to find where to replicate
268     plc_state.bestlag = btstack_cvsd_plc_pattern_match(plc_state.hist);
269     name = (char *) "hist0";
270     fprintf_plot_history(oct_file, name, hist_len, plc_state.hist);
271 
272     plc_state.bestlag += CVSD_M;
273     sf = btstack_cvsd_plc_amplitude_match(&plc_state, audio_samples_per_frame, plc_state.hist, plc_state.bestlag);
274 
275     for (i=0;i<CVSD_OLAL;i++){
276         val = sf*plc_state.hist[plc_state.bestlag+i];
277         plc_state.hist[CVSD_LHIST+i] = btstack_cvsd_plc_crop_sample(val);
278     }
279     name = (char *) "olal1";
280     x0 = CVSD_LHIST;
281     x1 = x0 + CVSD_OLAL - 1;
282     fprintf_array_int16(oct_file, name, CVSD_OLAL, plc_state.hist+x0);
283     fprintf(oct_file, "plot(b(%d:%d), %s, 'b.'); \n", x0, x1, name);
284 
285     for (;i<CVSD_FS;i++){
286         val = sf*plc_state.hist[plc_state.bestlag+i];
287         plc_state.hist[CVSD_LHIST+i] = btstack_cvsd_plc_crop_sample(val);
288     }
289     name = (char *)"fs_minus_olal";
290     x0  = x1 + 1;
291     x1  = x0 + CVSD_FS - CVSD_OLAL - 1;
292     fprintf_array_int16(oct_file, name, CVSD_FS - CVSD_OLAL, plc_state.hist+x0);
293     fprintf(oct_file, "plot(b(%d:%d), %s, 'b.'); \n", x0, x1, name);
294 
295 
296     for (;i<CVSD_FS+CVSD_OLAL;i++){
297         float left  = sf*plc_state.hist[plc_state.bestlag+i];
298         float right = plc_state.hist[plc_state.bestlag+i];
299         val = left*btstack_cvsd_plc_rcos(i-CVSD_FS) + right*btstack_cvsd_plc_rcos(CVSD_OLAL-1-i+CVSD_FS);
300         plc_state.hist[CVSD_LHIST+i]  = btstack_cvsd_plc_crop_sample(val);
301     }
302     name = (char *)"olal2";
303     x0  = x1 + 1;
304     x1  = x0 + CVSD_OLAL - 1;
305     fprintf_array_int16(oct_file, name, CVSD_OLAL, plc_state.hist+x0);
306     fprintf(oct_file, "plot(b(%d:%d), %s, 'b.'); \n", x0, x1, name);
307 
308     for (;i<CVSD_FS+CVSD_RT+CVSD_OLAL;i++){
309         plc_state.hist[CVSD_LHIST+i] = plc_state.hist[plc_state.bestlag+i];
310     }
311     name = (char *)"rt";
312     x0  = x1 + 1;
313     x1  = x0 + CVSD_RT - 1;
314     fprintf_array_int16(oct_file, name, CVSD_RT, plc_state.hist+x0);
315     fprintf(oct_file, "plot(b(%d:%d), %s, 'b.'); \n", x0, x1, name);
316 
317     for (i=0;i<CVSD_FS;i++){
318         out[i] = plc_state.hist[CVSD_LHIST+i];
319     }
320     name = (char *)"out";
321     x0  = CVSD_LHIST;
322     x1  = x0 + CVSD_FS - 1;
323     fprintf_array_int16(oct_file, name, CVSD_FS, plc_state.hist+x0);
324     fprintf(oct_file, "plot(b(%d:%d), %s, 'cd'); \n", x0, x1, name);
325 
326     // shift the history buffer
327     for (i=0;i<CVSD_LHIST+CVSD_RT+CVSD_OLAL;i++){
328         plc_state.hist[i] = plc_state.hist[i+CVSD_FS];
329     }
330     fclose(oct_file);
331 }
332 
333 
334 // TEST(CVSD_PLC, CountEqBytes){
335 //     CHECK_EQUAL(23, count_equal_bytes(test_data[0],24));
336 //     CHECK_EQUAL(11, count_equal_bytes(test_data[1],24));
337 //     CHECK_EQUAL(12, count_equal_bytes(test_data[2],24));
338 //     CHECK_EQUAL(23, count_equal_bytes(test_data[3],24));
339 // }
340 
341 // TEST(CVSD_PLC, TestLiveWavFile){
342 //     int corruption_step = 10;
343 //     introduce_bad_frames_to_wav_file("data/sco_input-16bit.wav", "results/sco_input.wav", 0);
344 //     introduce_bad_frames_to_wav_file("data/sco_input-16bit.wav", "results/sco_input_with_bad_frames.wav", corruption_step);
345 
346 //     mark_bad_frames_wav_file("results/sco_input.wav", "results/sco_input_detected_frames.wav");
347 //     process_wav_file_with_plc("results/sco_input.wav", "results/sco_input_after_plc.wav");
348 //     process_wav_file_with_plc("results/sco_input_with_bad_frames.wav", "results/sco_input_with_bad_frames_after_plc.wav");
349 // }
350 
351 // TEST(CVSD_PLC, TestFanfareFile){
352 //     int corruption_step = 10;
353 //     introduce_bad_frames_to_wav_file("data/fanfare_mono.wav", "results/fanfare_mono.wav", 0);
354 //     introduce_bad_frames_to_wav_file("results/fanfare_mono.wav", "results/fanfare_mono_with_bad_frames.wav", corruption_step);
355 
356 //     mark_bad_frames_wav_file("results/fanfare_mono.wav", "results/fanfare_mono_detected_frames.wav");
357 //     process_wav_file_with_plc("results/fanfare_mono.wav", "results/fanfare_mono_after_plc.wav");
358 //     process_wav_file_with_plc("results/fanfare_mono_with_bad_frames.wav", "results/fanfare_mono_with_bad_frames_after_plc.wav");
359 // }
360 
361 TEST(CVSD_PLC, TestSineWave){
362     int corruption_step = 600;
363     create_sine_wav("results/sine_test.wav");
364     int total_corruption_times = introduce_bad_frames_to_wav_file("results/sine_test.wav", "results/sine_test_with_bad_frames.wav", corruption_step);
365     printf("corruptions %d\n", total_corruption_times);
366     process_wav_file_with_plc("results/sine_test.wav", "results/sine_test_after_plc.wav");
367     process_wav_file_with_plc("results/sine_test_with_bad_frames.wav", "results/sine_test_with_bad_frames_after_plc.wav");
368 }
369 
370 int main (int argc, const char * argv[]){
371     return CommandLineTestRunner::RunAllTests(argc, argv);
372 }
373