xref: /aosp_15_r20/external/libopus/dnn/training_tf2/lossfuncs.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1"""
2Custom Loss functions and metrics for training/analysis
3"""
4
5from tf_funcs import *
6import tensorflow as tf
7
8# The following loss functions all expect the lpcnet model to output the lpc prediction
9
10# Computing the excitation by subtracting the lpc prediction from the target, followed by minimizing the cross entropy
11def res_from_sigloss():
12    def loss(y_true,y_pred):
13        p = y_pred[:,:,0:1]
14        model_out = y_pred[:,:,2:]
15        e_gt = tf_l2u(y_true - p)
16        e_gt = tf.round(e_gt)
17        e_gt = tf.cast(e_gt,'int32')
18        sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,model_out)
19        return sparse_cel
20    return loss
21
22# Interpolated and Compensated Loss (In case of end to end lpcnet)
23# Interpolates between adjacent embeddings based on the fractional value of the excitation computed (similar to the embedding interpolation)
24# Also adds a probability compensation (to account for matching cross entropy in the linear domain), weighted by gamma
25def interp_mulaw(gamma = 1):
26    def loss(y_true,y_pred):
27        y_true = tf.cast(y_true, 'float32')
28        p = y_pred[:,:,0:1]
29        real_p = y_pred[:,:,1:2]
30        model_out = y_pred[:,:,2:]
31        e_gt = tf_l2u(y_true - p)
32        exc_gt = tf_l2u(y_true - real_p)
33        prob_compensation = tf.squeeze((K.abs(e_gt - 128)/128.0)*K.log(256.0))
34        regularization = tf.squeeze((K.abs(exc_gt - 128)/128.0)*K.log(256.0))
35        alpha = e_gt - tf.math.floor(e_gt)
36        alpha = tf.tile(alpha,[1,1,256])
37        e_gt = tf.cast(e_gt,'int32')
38        e_gt = tf.clip_by_value(e_gt,0,254)
39        interp_probab = (1 - alpha)*model_out + alpha*tf.roll(model_out,shift = -1,axis = -1)
40        sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,interp_probab)
41        loss_mod = sparse_cel + prob_compensation + gamma*regularization
42        return loss_mod
43    return loss
44
45# Same as above, except a metric
46def metric_oginterploss(y_true,y_pred):
47    p = y_pred[:,:,0:1]
48    model_out = y_pred[:,:,2:]
49    e_gt = tf_l2u(y_true - p)
50    prob_compensation = tf.squeeze((K.abs(e_gt - 128)/128.0)*K.log(256.0))
51    alpha = e_gt - tf.math.floor(e_gt)
52    alpha = tf.tile(alpha,[1,1,256])
53    e_gt = tf.cast(e_gt,'int32')
54    e_gt = tf.clip_by_value(e_gt,0,254)
55    interp_probab = (1 - alpha)*model_out + alpha*tf.roll(model_out,shift = -1,axis = -1)
56    sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,interp_probab)
57    loss_mod = sparse_cel + prob_compensation
58    return loss_mod
59
60# Interpolated cross entropy loss metric
61def metric_icel(y_true, y_pred):
62    p = y_pred[:,:,0:1]
63    model_out = y_pred[:,:,2:]
64    e_gt = tf_l2u(y_true - p)
65    alpha = e_gt - tf.math.floor(e_gt)
66    alpha = tf.tile(alpha,[1,1,256])
67    e_gt = tf.cast(e_gt,'int32')
68    e_gt = tf.clip_by_value(e_gt,0,254) #Check direction
69    interp_probab = (1 - alpha)*model_out + alpha*tf.roll(model_out,shift = -1,axis = -1)
70    sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,interp_probab)
71    return sparse_cel
72
73# Non-interpolated (rounded) cross entropy loss metric
74def metric_cel(y_true, y_pred):
75    y_true = tf.cast(y_true, 'float32')
76    p = y_pred[:,:,0:1]
77    model_out = y_pred[:,:,2:]
78    e_gt = tf_l2u(y_true - p)
79    e_gt = tf.round(e_gt)
80    e_gt = tf.cast(e_gt,'int32')
81    e_gt = tf.clip_by_value(e_gt,0,255)
82    sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,model_out)
83    return sparse_cel
84
85# Variance metric of the output excitation
86def metric_exc_sd(y_true,y_pred):
87    p = y_pred[:,:,0:1]
88    e_gt = tf_l2u(y_true - p)
89    sd_egt = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)(e_gt,128)
90    return sd_egt
91
92def loss_matchlar():
93    def loss(y_true,y_pred):
94        model_rc = y_pred[:,:,:16]
95        #y_true = lpc2rc(y_true)
96        loss_lar_diff = K.log((1.01 + model_rc)/(1.01 - model_rc)) - K.log((1.01 + y_true)/(1.01 - y_true))
97        loss_lar_diff = tf.square(loss_lar_diff)
98        return tf.reduce_mean(loss_lar_diff, axis=-1)
99    return loss
100