参数的提取,把参数存入文本
1. 参数打印输出。
在实际训练过程中,我们可以将参数的训练通过文本进行记录,或者打印出来进行查看。
其中model.trainable_variables可以返回模型中的参数。
我们可以使用printf进行打印,但是直接使用打印可能为出现很多数据无法显示,我们可以先设置
np.set_printoptions(threshold=10) # 其中threshold表示输出的阈值,超出阈值的参数会用省略号表示,
# 当阈值设置为np.inf时,表示数据全部输出
2. 参数写入文本
还可以使用代码将参数写入文本中进行查看,操作如下:
file = open('./weights.txt', 'w')
# 将训练参数存入文本,其中包括参数名称,大小,数据
for v in model.trainable_variables:
file.write(str(v.name) + '\n')
file.write(str(v.shape) + '\n')
file.write(str(v.numpy()) + '\n')
file.close()
3. 示例代码
import tensorflow as tf
import os
import numpy as np
np.set_printoptions(threshold=np.inf)
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=['sparse_categorical_accuracy'])
checkpoint_save_path = "./checkpoint/mnist.ckpt"
# 判断是否拥有模型
if os.path.exists(checkpoint_save_path + '.index'):
print('-------------load the model--- --------------')
# 加载模型数据
model.load_weights(checkpoint_save_path)
# 保存模型数据
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
save_weights_only=True,
save_best_only=True)
history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1,
callbacks=[cp_callback])
model.summary()
# 设置打印输出格式
# 参数可以设置输出模式,超过阈值threshold的参数用省略号显示,使用inf时表示全部显示
# 其中model.trainable_variables可以返回模型中的参数
np.set_printoptions(threshold=10)
print(model.trainable_variables)
file = open('./weights.txt', 'w')
# 将训练参数存入文本,其中包括参数名称,大小,数据
for v in model.trainable_variables:
file.write(str(v.name) + '\n')
file.write(str(v.shape) + '\n')
file.write(str(v.numpy()) + '\n')
file.close()
4. 运行结果
运行结果显示我们的参数,其中有四个array,分别是第一、二层神经元的w和b。
# 784 = 28 * 28 第一层神经元中一个w的参数个数和输入参数个数是匹配的,拥有128个神经元,也就代表拥有128个b
[<tf.Variable 'sequential/dense/kernel:0' shape=(784, 128) dtype=float32, numpy=
array([[-0.04371355, -0.03728049, 0.00738895, ..., -0.03773092,
......
<tf.Variable 'sequential/dense/bias:0' shape=(128,) dtype=float32, numpy=
array([-0.16389479, 0.1481038 , -0.00113325, ..., 0.10987901,
0.00905252, -0.19073108], dtype=float32)>,
# 第二层神经元输入为128,拥有10个神经元。
<tf.Variable 'sequential/dense_1/kernel:0' shape=(128, 10) dtype=float32, numpy=
array([[ 0.02201208, 0.12389698, -0.16859886, ..., 0.19811475,
......
<tf.Variable 'sequential/dense_1/bias:0' shape=(10,) dtype=float32, numpy=
array([-0.14484479, -0.24807121, 0.11401606, -0.20439751, 0.12878774,
-0.09878621, -0.10151227, -0.23220563, 0.5990565 , -0.03787111],
dtype=float32)>]
然后我们可以在工程文件夹下看见我们建立的weight.txt文件,其中参数和我们打印训练的参数一致。
版权声明:本文为weixin_43115631原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。