提出问题
在看GAN的实现代码的时候,发现了这么一个地方:
class GAN(): def __init__(self): self.img_rows = 28 self.img_cols = 28 self.channels = 1 self.img_shape = (self.img_rows, self.img_cols, self.channels) self.latent_dim = 100 optimizer = Adam(0.0002, 0.5) # Build and compile the discriminator self.discriminator = self.build_discriminator() self.discriminator.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['accuracy']) # Build the generator self.generator = self.build_generator() # The generator takes noise as input and generates imgs z = Input(shape=(self.latent_dim,)) img = self.generator(z) # For the combined model we will only train the generator self.discriminator.trainable = False # The discriminator takes generated images as input and determines validity validity = self.discriminator(img) # The combined model (stacked generator and discriminator) # Trains the generator to fool the discriminator self.combined = Model(z, validity) self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
缩小上面代码的范围,看这一行:
# For the combined model we will only train the generator self.discriminator.trainable = False
这里将判别器设置为不训练状态
那是不是意味着判别器就不能被训练了呢?
不!
我们先继续看后面的代码:
def build_generator(self): model = Sequential() model.add(Dense(256, input_dim=self.latent_dim)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(1024)) model.add(LeakyReLU(alpha=0.2)) model.add(BatchNormalization(momentum=0.8)) model.add(Dense(np.prod(self.img_shape), activation='tanh')) model.add(Reshape(self.img_shape)) model.summary() noise = Input(shape=(self.latent_dim,)) img = model(noise) return Model(noise, img) def build_discriminator(self): model = Sequential() model.add(Flatten(input_shape=self.img_shape)) model.add(Dense(512)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(256)) model.add(LeakyReLU(alpha=0.2)) model.add(Dense(1, activation='sigmoid')) model.summary() img = Input(shape=self.img_shape) validity = model(img) return Model(img, validity)
这是分别构建生成器和判别器网络的代码
好像看不出什么端倪
那继续看:
def train(self, epochs, batch_size=128, sample_interval=50): # Load the dataset (X_train, _), (_, _) = mnist.load_data() # Rescale -1 to 1 X_train = X_train / 127.5 - 1. X_train = np.expand_dims(X_train, axis=3) # Adversarial ground truths valid = np.ones((batch_size, 1)) fake = np.zeros((batch_size, 1)) for epoch in range(epochs): # --------------------- # Train Discriminator # --------------------- # Select a random batch of images idx = np.random.randint(0, X_train.shape[0], batch_size) imgs = X_train[idx] noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) # Generate a batch of new images gen_imgs = self.generator.predict(noise) # Train the discriminator d_loss_real = self.discriminator.train_on_batch(imgs, valid) d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # --------------------- # Train Generator # --------------------- noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) # Train the generator (to have the discriminator label samples as valid) g_loss = self.combined.train_on_batch(noise, valid) # Plot the progress print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss)) # If at save interval => save generated image samples if epoch % sample_interval == 0: self.sample_images(epoch)
在train
函数中,似乎发现了一些东西。来,缩小代码范围:
noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) # Generate a batch of new images gen_imgs = self.generator.predict(noise) # Train the discriminator d_loss_real = self.discriminator.train_on_batch(imgs, valid) d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # --------------------- # Train Generator # --------------------- noise = np.random.normal(0, 1, (batch_size, self.latent_dim)) # Train the generator (to have the discriminator label samples as valid) g_loss = self.combined.train_on_batch(noise, valid)
看!代码中出现了# Train the discriminator
的步骤!
不是说之前已经设置为判别器为不可训练状态了吗?
emm~
解决问题
经过一番搜索,终于找到了答案
By setting trainable=False after the discriminator has been compiled the discriminator is still trained during discriminator.train_on_batch but since it's set to non-trainable before the combined model is compiled it's not trained during combined.train_on_batch
当
discriminator
被compile
之后,即使设置了discriminator.trainable=False
,该discriminator
仍然可以通过train_on_batch
的方式被训练;但是如果
discriminator
在被compile
之前就把训练状态设置为False
,那么即使是使用discriminator.train_on_batch
的方式也不能训练该判别器。
When you call compile, it builds a trainable model, and uses the current trainable flags. A compiled model can then not have its trainable flags changed, so we are free to change them and compile another model with different flags. Then, for some reason, those two compiled models can still have the same weights (I guess?) even though tensorflow/keras still sees them as clearly separated things.
当
compile
一个模型时,该模型的训练状态就被固定成当前状态了(比如discriminator.trainable=True
);当尝试修改
compile
之后的discriminator
的训练状态时,实质上是对另外一个discriminator
的训练状态进行修改(也许接下来会compile
);由于框架自身的机制,上面
compile
之后的两个discriminator
拥有相同的网络权重(即使它们被看作是两个独立的模型)。
参考
https://github.com/eriklindernoren/Keras-GAN/issues/73
版权声明:本文为qq_38669138原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。