xref: /aosp_15_r20/external/libopus/dnn/training_tf2/parameters.py (revision a58d3d2adb790c104798cd88c8a3aff4fa8b82cc)
1*a58d3d2aSXin Li""" module for handling extra model parameters for tf.keras models """
2*a58d3d2aSXin Li
3*a58d3d2aSXin Liimport tensorflow as tf
4*a58d3d2aSXin Li
5*a58d3d2aSXin Li
6*a58d3d2aSXin Lidef set_parameter(model, parameter_name, parameter_value, dtype='float32'):
7*a58d3d2aSXin Li    """ stores parameter_value as non-trainable weight with name parameter_name:0 """
8*a58d3d2aSXin Li
9*a58d3d2aSXin Li    weights = [weight for weight in model.weights if weight.name == (parameter_name + ":0")]
10*a58d3d2aSXin Li
11*a58d3d2aSXin Li    if len(weights) == 0:
12*a58d3d2aSXin Li        model.add_weight(parameter_name, trainable=False, initializer=tf.keras.initializers.Constant(parameter_value), dtype=dtype)
13*a58d3d2aSXin Li    elif len(weights) == 1:
14*a58d3d2aSXin Li        weights[0].assign(parameter_value)
15*a58d3d2aSXin Li    else:
16*a58d3d2aSXin Li        raise ValueError(f"more than one weight starting with {parameter_name}:0 in model")
17*a58d3d2aSXin Li
18*a58d3d2aSXin Li
19*a58d3d2aSXin Lidef get_parameter(model, parameter_name, default=None):
20*a58d3d2aSXin Li    """ returns parameter value if parameter is present in model and otherwise default """
21*a58d3d2aSXin Li
22*a58d3d2aSXin Li    weights = [weight for weight in model.weights if weight.name == (parameter_name + ":0")]
23*a58d3d2aSXin Li
24*a58d3d2aSXin Li    if len(weights) == 0:
25*a58d3d2aSXin Li        return default
26*a58d3d2aSXin Li    elif len(weights) > 1:
27*a58d3d2aSXin Li        raise ValueError(f"more than one weight starting with {parameter_name}:0 in model")
28*a58d3d2aSXin Li    else:
29*a58d3d2aSXin Li        return weights[0].numpy().item()
30