1""" 2Custom Loss functions and metrics for training/analysis 3""" 4 5from tf_funcs import * 6import tensorflow as tf 7 8# The following loss functions all expect the lpcnet model to output the lpc prediction 9 10# Computing the excitation by subtracting the lpc prediction from the target, followed by minimizing the cross entropy 11def res_from_sigloss(): 12 def loss(y_true,y_pred): 13 p = y_pred[:,:,0:1] 14 model_out = y_pred[:,:,2:] 15 e_gt = tf_l2u(y_true - p) 16 e_gt = tf.round(e_gt) 17 e_gt = tf.cast(e_gt,'int32') 18 sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,model_out) 19 return sparse_cel 20 return loss 21 22# Interpolated and Compensated Loss (In case of end to end lpcnet) 23# Interpolates between adjacent embeddings based on the fractional value of the excitation computed (similar to the embedding interpolation) 24# Also adds a probability compensation (to account for matching cross entropy in the linear domain), weighted by gamma 25def interp_mulaw(gamma = 1): 26 def loss(y_true,y_pred): 27 y_true = tf.cast(y_true, 'float32') 28 p = y_pred[:,:,0:1] 29 real_p = y_pred[:,:,1:2] 30 model_out = y_pred[:,:,2:] 31 e_gt = tf_l2u(y_true - p) 32 exc_gt = tf_l2u(y_true - real_p) 33 prob_compensation = tf.squeeze((K.abs(e_gt - 128)/128.0)*K.log(256.0)) 34 regularization = tf.squeeze((K.abs(exc_gt - 128)/128.0)*K.log(256.0)) 35 alpha = e_gt - tf.math.floor(e_gt) 36 alpha = tf.tile(alpha,[1,1,256]) 37 e_gt = tf.cast(e_gt,'int32') 38 e_gt = tf.clip_by_value(e_gt,0,254) 39 interp_probab = (1 - alpha)*model_out + alpha*tf.roll(model_out,shift = -1,axis = -1) 40 sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,interp_probab) 41 loss_mod = sparse_cel + prob_compensation + gamma*regularization 42 return loss_mod 43 return loss 44 45# Same as above, except a metric 46def metric_oginterploss(y_true,y_pred): 47 p = y_pred[:,:,0:1] 48 model_out = y_pred[:,:,2:] 49 e_gt = tf_l2u(y_true - p) 50 prob_compensation = tf.squeeze((K.abs(e_gt - 128)/128.0)*K.log(256.0)) 51 alpha = e_gt - tf.math.floor(e_gt) 52 alpha = tf.tile(alpha,[1,1,256]) 53 e_gt = tf.cast(e_gt,'int32') 54 e_gt = tf.clip_by_value(e_gt,0,254) 55 interp_probab = (1 - alpha)*model_out + alpha*tf.roll(model_out,shift = -1,axis = -1) 56 sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,interp_probab) 57 loss_mod = sparse_cel + prob_compensation 58 return loss_mod 59 60# Interpolated cross entropy loss metric 61def metric_icel(y_true, y_pred): 62 p = y_pred[:,:,0:1] 63 model_out = y_pred[:,:,2:] 64 e_gt = tf_l2u(y_true - p) 65 alpha = e_gt - tf.math.floor(e_gt) 66 alpha = tf.tile(alpha,[1,1,256]) 67 e_gt = tf.cast(e_gt,'int32') 68 e_gt = tf.clip_by_value(e_gt,0,254) #Check direction 69 interp_probab = (1 - alpha)*model_out + alpha*tf.roll(model_out,shift = -1,axis = -1) 70 sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,interp_probab) 71 return sparse_cel 72 73# Non-interpolated (rounded) cross entropy loss metric 74def metric_cel(y_true, y_pred): 75 y_true = tf.cast(y_true, 'float32') 76 p = y_pred[:,:,0:1] 77 model_out = y_pred[:,:,2:] 78 e_gt = tf_l2u(y_true - p) 79 e_gt = tf.round(e_gt) 80 e_gt = tf.cast(e_gt,'int32') 81 e_gt = tf.clip_by_value(e_gt,0,255) 82 sparse_cel = tf.keras.losses.SparseCategoricalCrossentropy(reduction=tf.keras.losses.Reduction.NONE)(e_gt,model_out) 83 return sparse_cel 84 85# Variance metric of the output excitation 86def metric_exc_sd(y_true,y_pred): 87 p = y_pred[:,:,0:1] 88 e_gt = tf_l2u(y_true - p) 89 sd_egt = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.NONE)(e_gt,128) 90 return sd_egt 91 92def loss_matchlar(): 93 def loss(y_true,y_pred): 94 model_rc = y_pred[:,:,:16] 95 #y_true = lpc2rc(y_true) 96 loss_lar_diff = K.log((1.01 + model_rc)/(1.01 - model_rc)) - K.log((1.01 + y_true)/(1.01 - y_true)) 97 loss_lar_diff = tf.square(loss_lar_diff) 98 return tf.reduce_mean(loss_lar_diff, axis=-1) 99 return loss 100