TensorFlow:自定义类

在创建自定义网络层类时,需要继承自layers.Layer 基类;创建自定义的网络类,需要继承自 keras.Model 基类,这样产生的自定义类才能够方便的利用Layer/Model 基类提供的参数管理功能,同时也能够与其他的标准网络层类交互使用。

一、自定义网络层

对于自定义的网络层,需要实现初始化 __inti__方法和前向传播逻辑call方法
假设我们需要一个没有偏置的全连接层,即bias 为0,同时固定激活函数为ReLU 函数。

class MyDense(layers.Layer):
	# 自定义网络层
	def __init__(self, inp_dim, outp_dim):
	super(MyDense, self).__init__()
	# 创建权值张量并添加到类管理列表中,设置为需要优化
	self.kernel = self.add_variable('w', [inp_dim, outp_dim],trainable=True)
net = MyDense(4,3) #创建输入为4,输出为3节点的自定义层
print(net.variables,net.trainable_variables)

在这里插入图片描述
通过修改为self.kernel = self.add_variable('w', [inp_dim, outp_dim], trainable=False),我们可以设置张量不需要被优化,此时再来观测张量的管理状态:
在这里插入图片描述
看出此时张量并不会被 trainable_variables管理

完成自定义类的初始化工作后,可以设计自定义类的前项运算逻辑

def call(self, inputs, training=None):
	# 实现自定义类的前向计算逻辑
	# X@W
	out = inputs @ self.kernel
	# 执行激活函数运算
	out = tf.nn.relu(out)
	return out

如上所示,自定义类的前向运算逻辑需要实现在call(inputs, training)函数中,其中inputs 代表输入,由用户在调用时传入;training 参数用于指定模型的状态:training 为True 时执行训练模式,training 为False 时执行测试模式,默认参数为None,即测试模式。由于全连接层的训练模式和测试模式逻辑一致,此处不需要额外处理。对于部份测试模式和训练模式不一致的网络层,需要根据training 参数来设计需要执行的逻辑。

二、自定义网络

在完成了我们自定义的全连接层类之后,我们基于上述的“无偏置的全连接层”来实
现MNIST 手写数字图片模型的创建。
自定义的类可以和其他标准类一样,通过Sequential 容器方便地包裹成一个网络模
型:

network = keras.Sequential([MyDense(784, 256), # 使用自定义的层
    MyDense(256, 128),
    MyDense(128, 64),
    MyDense(64, 32),
    MyDense(32, 10)])
network.build(input_shape=(None, 28*28))
network.summary()

更普遍地,我们可以继承基类来实现任意逻辑的自定义网络类。下面我们来创建自定义网络类,首先创建并继承Model 基类,分布创建对应的网络层对象:

# 自定义网络类,继承自Model 基类
def __init__(self):
	super(MyModel, self).__init__()
	# 完成网络内需要的网络层的创建工作
	self.fc1 = MyDense(28*28, 256)
	self.fc2 = MyDense(256, 128)
	self.fc3 = MyDense(128, 64)
	self.fc4 = MyDense(64, 32)
	self.fc5 = MyDense(32, 10)
#然后实现自定义网络的前向运算逻辑:
def call(self, inputs, training=None):
	# 自定义前向运算逻辑
	x = self.fc1(inputs)
	x = self.fc2(x)
	x = self.fc3(x)
	x = self.fc4(x)
	x = self.fc5(x)
	return x

这个例子可以直接使用第一种方式通过Sequential 容器包裹。但是由于Sequential 在前向传播是依次调用每个网络层的前向传播函数,灵活性一般,而自定义网络的前向逻辑可以任意定制,两者各有优缺点


版权声明:本文为nanhuaibeian原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。