共享变量 tf.variable_scope()

本文参考自书籍《深度学习之TensorFlow:入门、原理与进阶实战》

import tensorflow as tf

# 将图(一个计算任务)里面的变量清空
tf.reset_default_graph()
# 共享变量
with tf.variable_scope("test1"):
    var1 = tf.get_variable(name="firstvar", shape=[2], dtype=tf.float32)
    with tf.variable_scope("test2"):
        var2 = tf.get_variable(name="firstvar", shape=[2], dtype=tf.float32)

with tf.variable_scope("test1", reuse=True):
    var3 = tf.get_variable(name="firstvar", shape=[2], dtype=tf.float32)
    with tf.variable_scope("test2"):
        var4 = tf.get_variable(name="firstvar", shape=[2], dtype=tf.float32)

print(var1.name)
print(var2.name)
print(var3.name)
print(var4.name)

# 结果
'''
test1/firstvar:0
test1/test2/firstvar:0
test1/firstvar:0
test1/test2/firstvar:0
'''

var1和var3的输出名字是一样的,var2和var4的名字也是一样的,这表明var1和var3共用了一个变量,var2和var4共用了一个变量,这就实现了共享变量

在实际应用中,可以把var1和var2放在一个网络里去训练,

把var3和var4放在一个网络里去训练,而两个模型的训练结果都会作用于一个模型的学习参数上.


版权声明:本文为qq_27690765原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。