本文参考自书籍《深度学习之TensorFlow:入门、原理与进阶实战》
import tensorflow as tf
# 将图(一个计算任务)里面的变量清空
tf.reset_default_graph()
# 共享变量
with tf.variable_scope("test1"):
var1 = tf.get_variable(name="firstvar", shape=[2], dtype=tf.float32)
with tf.variable_scope("test2"):
var2 = tf.get_variable(name="firstvar", shape=[2], dtype=tf.float32)
with tf.variable_scope("test1", reuse=True):
var3 = tf.get_variable(name="firstvar", shape=[2], dtype=tf.float32)
with tf.variable_scope("test2"):
var4 = tf.get_variable(name="firstvar", shape=[2], dtype=tf.float32)
print(var1.name)
print(var2.name)
print(var3.name)
print(var4.name)
# 结果
'''
test1/firstvar:0
test1/test2/firstvar:0
test1/firstvar:0
test1/test2/firstvar:0
'''
var1和var3的输出名字是一样的,var2和var4的名字也是一样的,这表明var1和var3共用了一个变量,var2和var4共用了一个变量,这就实现了共享变量
在实际应用中,可以把var1和var2放在一个网络里去训练,
把var3和var4放在一个网络里去训练,而两个模型的训练结果都会作用于一个模型的学习参数上.
版权声明:本文为qq_27690765原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。