1、API
通过直接使用 tf.keras.Sequential() 函数可以轻松地构建网络,如:
mobile = tf.keras.applications.MobileNetV2(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
mobile.trainable = False
model = tf.keras.Sequential([
simplified_mobile,
tf.keras.layers.Dropout(0.5),
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(28, activation='softmax')
])
但是,通过 API 定义的方法并不容易自定义复杂的网络。
2、通过函数定义
mobile = tf.keras.applications.MobileNetV2(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
mobile.trainable = False
def MobileNetV2 (classes):
img_input = tf.keras.layers.Input(shape=(224, 224, 3))
x = mobile(img_input)
x = tf.keras.layers.Dropout(0.5)(x)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(classes, activation='softmax')(x)
model = tf.keras.Model(img_input, x)
return model
使用函数定义网络时要注意以下几点:
- 1、一个网络中往往包含了多个自定义网络函数,如卷积-批归一化-激活函数等,但在最后构建网络的函数的开头必须定义输入该网络的形状,而不是直接在函数名后面定义一个输入。当然,对前面的函数来说是可以直接多定义一个输入的;
- 2、在函数结尾必须有:model = tf.keras.Model(img_input, x)。
3、通过类定义
mobile = tf.keras.applications.MobileNetV2(include_top=False, weights='imagenet', input_shape=(224, 224, 3))
mobile.trainable = False
class MobileNetV2(tf.keras.Model):
def __init__(self, classes):
super().__init__()
self.mob = mobile
self.dropout = tf.keras.layers.Dropout(0.5)
self.gap = tf.keras.layers.GlobalAveragePooling2D()
self.dense = tf.keras.layers.Dense(classes, activation='softmax')
def call(self, inputs):
x = self.mob(inputs)
x = self.dropout(x)
x = self.gap(x)
x = self.dense(x)
return x
版权声明:本文为qq_36758914原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。