1*a58d3d2aSXin Li""" module for handling extra model parameters for tf.keras models """ 2*a58d3d2aSXin Li 3*a58d3d2aSXin Liimport tensorflow as tf 4*a58d3d2aSXin Li 5*a58d3d2aSXin Li 6*a58d3d2aSXin Lidef set_parameter(model, parameter_name, parameter_value, dtype='float32'): 7*a58d3d2aSXin Li """ stores parameter_value as non-trainable weight with name parameter_name:0 """ 8*a58d3d2aSXin Li 9*a58d3d2aSXin Li weights = [weight for weight in model.weights if weight.name == (parameter_name + ":0")] 10*a58d3d2aSXin Li 11*a58d3d2aSXin Li if len(weights) == 0: 12*a58d3d2aSXin Li model.add_weight(parameter_name, trainable=False, initializer=tf.keras.initializers.Constant(parameter_value), dtype=dtype) 13*a58d3d2aSXin Li elif len(weights) == 1: 14*a58d3d2aSXin Li weights[0].assign(parameter_value) 15*a58d3d2aSXin Li else: 16*a58d3d2aSXin Li raise ValueError(f"more than one weight starting with {parameter_name}:0 in model") 17*a58d3d2aSXin Li 18*a58d3d2aSXin Li 19*a58d3d2aSXin Lidef get_parameter(model, parameter_name, default=None): 20*a58d3d2aSXin Li """ returns parameter value if parameter is present in model and otherwise default """ 21*a58d3d2aSXin Li 22*a58d3d2aSXin Li weights = [weight for weight in model.weights if weight.name == (parameter_name + ":0")] 23*a58d3d2aSXin Li 24*a58d3d2aSXin Li if len(weights) == 0: 25*a58d3d2aSXin Li return default 26*a58d3d2aSXin Li elif len(weights) > 1: 27*a58d3d2aSXin Li raise ValueError(f"more than one weight starting with {parameter_name}:0 in model") 28*a58d3d2aSXin Li else: 29*a58d3d2aSXin Li return weights[0].numpy().item() 30