tf.keras中关于model.trainable=False的设置(in GAN)

提出问题

在看GAN的实现代码的时候,发现了这么一个地方:

class GAN():
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100
​
        optimizer = Adam(0.0002, 0.5)
​
        # Build and compile the discriminator
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])
​
        # Build the generator
        self.generator = self.build_generator()
​
        # The generator takes noise as input and generates imgs
        z = Input(shape=(self.latent_dim,))
        img = self.generator(z)
​
        # For the combined model we will only train the generator
        self.discriminator.trainable = False
​
        # The discriminator takes generated images as input and determines validity
        validity = self.discriminator(img)
​
        # The combined model  (stacked generator and discriminator)
        # Trains the generator to fool the discriminator
        self.combined = Model(z, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)

缩小上面代码的范围,看这一行:

 # For the combined model we will only train the generator
        self.discriminator.trainable = False

这里将判别器设置为不训练状态

那是不是意味着判别器就不能被训练了呢?

不!

我们先继续看后面的代码:

 def build_generator(self):
​
        model = Sequential()
​
        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))
        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))
​
        model.summary()
​
        noise = Input(shape=(self.latent_dim,))
        img = model(noise)
​
        return Model(noise, img)
​
    def build_discriminator(self):
​
        model = Sequential()
​
        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(1, activation='sigmoid'))
        model.summary()
​
        img = Input(shape=self.img_shape)
        validity = model(img)
​
        return Model(img, validity)

这是分别构建生成器和判别器网络的代码

好像看不出什么端倪

那继续看:

    def train(self, epochs, batch_size=128, sample_interval=50):
​
        # Load the dataset
        (X_train, _), (_, _) = mnist.load_data()
​
        # Rescale -1 to 1
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)
​
        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))
​
        for epoch in range(epochs):
​
            # ---------------------
            #  Train Discriminator
            # ---------------------
​
            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]
​
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
​
            # Generate a batch of new images
            gen_imgs = self.generator.predict(noise)
​
            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
​
            # ---------------------
            #  Train Generator
            # ---------------------
​
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
​
            # Train the generator (to have the discriminator label samples as valid)
            g_loss = self.combined.train_on_batch(noise, valid)
​
            # Plot the progress
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
​
            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)

train函数中,似乎发现了一些东西。来,缩小代码范围:

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
​
            # Generate a batch of new images
            gen_imgs = self.generator.predict(noise)
​
            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
​
            # ---------------------
            #  Train Generator
            # ---------------------
​
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
​
            # Train the generator (to have the discriminator label samples as valid)
            g_loss = self.combined.train_on_batch(noise, valid)

看!代码中出现了# Train the discriminator的步骤!

不是说之前已经设置为判别器为不可训练状态了吗?

emm~

 

解决问题

经过一番搜索,终于找到了答案

By setting trainable=False after the discriminator has been compiled the discriminator is still trained during discriminator.train_on_batch but since it's set to non-trainable before the combined model is compiled it's not trained during combined.train_on_batch
  • discriminatorcompile之后,即使设置了discriminator.trainable=False,该discriminator仍然可以通过train_on_batch的方式被训练;

  • 但是如果discriminator在被compile之前就把训练状态设置为False,那么即使是使用discriminator.train_on_batch的方式也不能训练该判别器。

When you call compile, it builds a trainable model, and uses the current trainable flags. A compiled model can then not have its trainable flags changed, so we are free to change them and compile another model with different flags.
​
Then, for some reason, those two compiled models can still have the same weights (I guess?) even though tensorflow/keras still sees them as clearly separated things.
  • compile一个模型时,该模型的训练状态就被固定成当前状态了(比如discriminator.trainable=True);

  • 当尝试修改compile之后的discriminator的训练状态时,实质上是对另外一个discriminator的训练状态进行修改(也许接下来会compile);

  • 由于框架自身的机制,上面compile之后的两个discriminator拥有相同的网络权重(即使它们被看作是两个独立的模型)。

参考

https://github.com/eriklindernoren/Keras-GAN/issues/73

 


版权声明:本文为qq_38669138原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。