Tensorflow 2 实战(kears)- 生成式对抗网络 - GAN、WGAN-GP
一、背景介绍
1.1、数据集简介
本次实战使用的Anime数据集为 “高清动漫人物头像” PNG图片。图片为64x64的彩色图片,共21551张,所包含的图片示例如下:
1.2、模型简介
本次实战的 “ 生成对抗网络GAN/WGAN-GP ” 模型共包括两个部分,分别为:Generator生成器、Discriminator判别器。
- Generator生成器(用于生成图像),其包含:一个全连接层、三个逆卷积层、两个BatchNormalization层,经过 tanh 输出生成图像。该部分的输入为"随机初始化的z",输出为生成的图像。
- Discriminator判别器 (用于图像判别),其包含:三个卷积层、两个BatchNormalization层、一个全连接层。该部分的输入为"生成的图像"、“真实的图像”,输出为 logits。
模型通过 “ Generator生成器” 生成图像,并与 “真实图像” 一起输入 “Discriminator判别器” 进行判别。“Discriminator判别器” 通过 loss 以“真实图像为真,生成图像为假” 进行参数更新,而 “ Generator生成器” 通过 loss 以 “生成图像为真” 进行参数更新。
此次实战中GAN与WGAN-GP 的区别为 :WGAN-GP 改变了“Discriminator判别器” 的 loss 。WGAN-GP loss 如下:
本次实战GAN/WGAN-GP模型结构如下:
GAN 与 WGAN-GP 的区别与联系:
原始GAN存在的问题:如果 PG(生成样本的分布)和 Pd(真实样本的分布)两个概率分布没有任何重叠的话,用 “JS散度” 也就失去了意义。
WGAN-GP解决:WGAN-GP 用 “Wasserstei distance” 来衡量两个分布之间的距离,从而更好地将 PG (生成样本的分布)逼近 Pd(真实样本的分布)。
WGAN-GP 模型可以在原来 GAN 代码实现的基础上仅做少量修改。WGAN-GP 模型的判别器 D 添加了一个梯度惩罚项。
(参考 JS散度、Wasserstei distance)
1.2.1 GAN原理介绍 (点击详见WGAN-GP原理介绍)
1.2.1.1 GAN能做什么?
生成对抗网络GAN,它的初衷就是生成 “不存在于真实世界(非真实发生)” 的数据,类似于使得 AI具有创造力或者想象力。(例如:AI画家、将模糊图变清晰“去雨,去雾,去抖动,去马赛克”、进行数据增强等)。
1.2.1.1 GAN的原理
GAN可以简单的看作是两个网络的博弈过程,它有两个网络:
- Generator生成器(简称G): 负责凭空捏造数据出来
- Discriminator判别器(简称D): 负责判断数据是不是真数据
GAN的整个过程都是无监督的,真图是没有标记过的,系统里的D并不知道来的图片是什么玩意儿,它只需要分辨真假;G也不知道自己生成的是什么玩意儿,反正就是学真图片的样子骗D。也就是让 “Generator生成器” 最大概率的生成真实图片, 也就是要找一个分布,让生成图片更接近于真实图片。
二、“GAN/WGAN-GP”实战代码
2.1、GAN/WGAN-GP模型部分代码
2.2、 GAN-train 代码
2.3、 WGAN-GP-train 代码(与GAN-train区别:改变了“Discriminator判别器” 的 loss,此处只对改变的 loss 部分做详细注释)
2.4、工具代码 - dataset数据集加载
2.1. GAN/WGAN-GP模型部分代码
# -*-coding:utf-8-*-
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
#创建Generator网络(生成图像),包含:一个全连接层,三个卷积层,经过tanh得到输出
class Generator(keras.Model):
def __init__(self):
super(Generator, self).__init__()
# z: [b, 100] => [b, 64, 64, 3]
#升维,以便于后续生成图像
self.fc = layers.Dense(3*3*512)
#Conv2DTranspose(通道数,步长,卷积核大小,padding):逆卷积,用于扩大图像的尺寸【输出=(N−1)∗S−2P+F】
# 输入:NxN
# 卷积核大小,kernel_size:FxF
# 步长strides:S
# 边界扩充padding的值:P
# (3-1)*3 +3 =9
self.conv1 = layers.Conv2DTranspose(256, 3, 3, 'valid')
self.bn1 = layers.BatchNormalization()
# (9-1) * 2 + 5 =21
self.conv2 = layers.Conv2DTranspose(128, 5, 2, 'valid')
self.bn2 = layers.BatchNormalization()
# (23-1) * 3 + 4 =64
self.conv3 = layers.Conv2DTranspose(3, 4, 3, 'valid')
#前向传播(生成图像)
def call(self, inputs, training=None):
# 升维,并改变其形状以满足后续生成图片要求[b, 100] => [b, 3*3*512]
x = self.fc(inputs)
x = tf.reshape(x, [-1, 3, 3, 512])
#由于relu对于x小于0时,会造成梯度弥散的现象,故此处使用leaky_relu
x = tf.nn.leaky_relu(x)
#z: [b, 3*3*512] => [b, 64, 64, 3]
x = tf.nn.leaky_relu(self.bn1(self.conv1(x), training=training))
x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
x = self.conv3(x)
#sigmoid函数在训练时容易不稳定,故此处使用tanh(D生成的时0~1,因为做分类时我们把0~255的值压缩到0~1,然后“(0~1)*2-1”把它降到-1~1的区间;网络接收图片的范围是-1~1,所以生成也为-1~1的区间;如果想人为查看这些图,需要把-1~1升到0~1再升到0~255才符合肉眼可见)
x = tf.tanh(x)
return x
#创建Discriminator网络(图像判别),包含:三个卷积层、一个全连接层
class Discriminator(keras.Model):
def __init__(self):
super(Discriminator, self).__init__()
# [b, 64, 64, 3] => [b, 1, 1, 256]
#Conv2D(通道数,步长S,卷积核大小F,padding),输出=向下取整((N-F)/S)+1)
#(64-3)/5+1=13
self.conv1 = layers.Conv2D(64, 5, 3, 'valid')
#(13-3)/5+1=3
self.conv2 = layers.Conv2D(128, 5, 3, 'valid')
self.bn2 = layers.BatchNormalization()
#(3-3)/5+1=1
self.conv3 = layers.Conv2D(256, 5, 3, 'valid')
self.bn3 = layers.BatchNormalization()
# [b, h, w ,c] => [b, -1]
#Flatten:用来将输入“压平”(即把多维的输入一维化,常用在从卷积层到全连接层的过渡)
self.flatten = layers.Flatten()
self.fc = layers.Dense(1)
#前向传播(图像判别)
def call(self, inputs, training=None):
x = tf.nn.leaky_relu(self.conv1(inputs))
x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))
# [b, h, w, c] => [b, -1]
x = self.flatten(x)
# [b, -1] => [b, 1]
logits = self.fc(x)
return logits
# #测试Discriminator、Generator
# def main():
#
# d = Discriminator()
# g = Generator()
#
#
# x = tf.random.normal([2, 64, 64, 3])
# z = tf.random.normal([2, 100])
#
# prob = d(x)
# print(prob)
# x_hat = g(z)
# print(x_hat.shape)
if __name__ == '__main__':
pass
# main()
2.2. GAN-train 代码
# -*-coding:utf-8-*-
import os
#只显示 warning 和 Error
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import numpy as np
import tensorflow as tf
from PIL import Image
import glob
from gan import Generator, Discriminator
from dataset import make_anime_dataset
#将多个training图片拼成一张图片保存
def save_result(val_out, val_block_size, image_path, color_mode):
def preprocess(img):
img = ((img + 1.0) * 127.5).astype(np.uint8)
return img
preprocesed = preprocess(val_out)
final_image = np.array([])
single_row = np.array([])
for b in range(val_out.shape[0]):
# concat image into a row
if single_row.size == 0:
single_row = preprocesed[b, :, :, :]
else:
single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1)
# concat image row to final_image
if (b+1) % val_block_size == 0:
if final_image.size == 0:
final_image = single_row
else:
final_image = np.concatenate((final_image, single_row), axis=0)
# reset single row
single_row = np.array([])
if final_image.shape[2] == 1:
final_image = np.squeeze(final_image, axis=2)
Image.fromarray(final_image,mode=color_mode).save(image_path)
#loss-label为真
def celoss_ones(logits):
# [b, 1]
# [b] = [1, 1, 1, 1,]
loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
labels=tf.ones_like(logits))
return tf.reduce_mean(loss)
#loss-label为假
def celoss_zeros(logits):
# [b, 1]
# [b] = [1, 1, 1, 1,]
loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
labels=tf.zeros_like(logits))
return tf.reduce_mean(loss)
#将真实图像判为真,将生成图像判为假
def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):
#生成图像(将随机初始化的z输入generator)
fake_image = generator(batch_z, is_training)
#将生成图像输入discriminator
d_fake_logits = discriminator(fake_image, is_training)
# 将真实图像输入discriminator
d_real_logits = discriminator(batch_x, is_training)
#loss-真实图像,label为真
d_loss_real = celoss_ones(d_real_logits)
#loss-生成图像,label为假
d_loss_fake = celoss_zeros(d_fake_logits)
#loss-真是图像-真,生成图像-假
loss = d_loss_fake + d_loss_real
return loss
#将生成图像判为真
def g_loss_fn(generator, discriminator, batch_z, is_training):
#将随机初始化的batch_z输入Generator
fake_image = generator(batch_z, is_training)
#将生成的图像输入Discriminator
d_fake_logits = discriminator(fake_image, is_training)
# loss-生成图像,label为真
loss = celoss_ones(d_fake_logits)
return loss
def main():
#随机数种子
tf.random.set_seed(22)
np.random.seed(22)
# startwith('2.') 这个函数用于判断tf.__version__的版本信息是否以'2.0'开头,返回True或者False
assert tf.__version__.startswith('2.')
# 超参数
z_dim = 100
epochs = 3000000
batch_size = 256
learning_rate = 0.002
is_training = True
#获取指定目录下的所有图片
img_path = glob.glob(r'D:\PyCharmPro\shu_ju\anime-faces\*.png')
#数据加载
dataset, img_shape, _ = make_anime_dataset(img_path, batch_size)
print(dataset, img_shape)
sample = next(iter(dataset))
print(sample.shape, tf.reduce_max(sample).numpy(),
tf.reduce_min(sample).numpy())
#repeat()如果括号内不填写数字,则可以无限制的重复
dataset = dataset.repeat()
db_iter = iter(dataset)
#实例化Generator、Discriminator,以及自定义其网络的权重的维度
generator = Generator()
generator.build(input_shape = (None, z_dim))
discriminator = Discriminator()
discriminator.build(input_shape=(None, 64, 64, 3))
#设置Generator的优化器
g_optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
# 设置optimizers的优化器
d_optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
#epoch循环
for epoch in range(epochs):
#随机初始化z(-1~1)
batch_z = tf.random.uniform([batch_size, z_dim], minval=-1., maxval=1.)
batch_x = next(db_iter)
# 反向传播,以将“生成图像判为假”、“真实图像判为真”更新Discriminator参数
with tf.GradientTape() as tape:
#计算Discriminator-loss
d_loss = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)
#反向传播
grads = tape.gradient(d_loss, discriminator.trainable_variables)
#使用优化器更新参数
d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))
# 反向传播,以将“生成图像判为真”更新Generator参数
with tf.GradientTape() as tape:
#计算Generator-loss
g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
#反向传播
grads = tape.gradient(g_loss, generator.trainable_variables)
#使用优化器更新参数
g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))
if epoch % 100 == 0:
print(epoch, 'd-loss:',float(d_loss), 'g-loss:', float(g_loss))
#查看生成图像效果(每100个epoch,sample一次图片)
#随机初始化z
z = tf.random.uniform([100, z_dim])
#生成图像
fake_image = generator(z, training=False)
#保存图像
img_path = os.path.join('images_gan', 'gan-%d.png'%epoch)
save_result(fake_image.numpy(), 10, img_path, color_mode='P')
if __name__ == '__main__':
main()
2.3. WGAN-GP-train 代码(与GAN区别:改变了“Discriminator判别器” 的 loss,此处只对改变的 loss 部分做详细注释)
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import numpy as np
import tensorflow as tf
from PIL import Image
import glob
from gan import Generator, Discriminator
from dataset import make_anime_dataset
def save_result(val_out, val_block_size, image_path, color_mode):
def preprocess(img):
img = ((img + 1.0) * 127.5).astype(np.uint8)
# img = img.astype(np.uint8)
return img
preprocesed = preprocess(val_out)
final_image = np.array([])
single_row = np.array([])
for b in range(val_out.shape[0]):
# concat image into a row
if single_row.size == 0:
single_row = preprocesed[b, :, :, :]
else:
single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1)
# concat image row to final_image
if (b+1) % val_block_size == 0:
if final_image.size == 0:
final_image = single_row
else:
final_image = np.concatenate((final_image, single_row), axis=0)
# reset single row
single_row = np.array([])
if final_image.shape[2] == 1:
final_image = np.squeeze(final_image, axis=2)
Image.fromarray(final_image).save(image_path)
def celoss_ones(logits):
# [b, 1]
# [b] = [1, 1, 1, 1,]
loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
labels=tf.ones_like(logits))
return tf.reduce_mean(loss)
def celoss_zeros(logits):
# [b, 1]
# [b] = [1, 1, 1, 1,]
loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
labels=tf.zeros_like(logits))
return tf.reduce_mean(loss)
def gradient_penalty(discriminator, batch_x, fake_image):
batchsz = batch_x.shape[0]
#从均匀分布中随机sample线性差值因子(整个image都用这一个t权值)
# [b, h, w, c]
t = tf.random.uniform([batchsz, 1, 1, 1])
#broadcast_to函数将数组广播到新形状
# [b, 1, 1, 1] => [b, h, w, c]
t = tf.broadcast_to(t, batch_x.shape)
#"真的图像batch_x" 与 “假的图像fake_image” 之间做线性差值(t是0~1之间的线性差值)
interplate = t * batch_x + (1 - t) * fake_image
#将“interplate”送入discriminator并求解梯度
with tf.GradientTape() as tape:
tape.watch([interplate])
d_interplote_logits = discriminator(interplate)
grads = tape.gradient(d_interplote_logits, interplate)
#得到每一个样本的二范数
# grads:[b, h, w, c] => [b, -1]
grads = tf.reshape(grads, [grads.shape[0], -1])
gp = tf.norm(grads, axis=1) #[b]
gp = tf.reduce_mean( (gp-1)**2 )
return gp
def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):
# 1. treat real image as real
# 2. treat generated image as fake
fake_image = generator(batch_z, is_training)
d_fake_logits = discriminator(fake_image, is_training)
d_real_logits = discriminator(batch_x, is_training)
d_loss_real = celoss_ones(d_real_logits)
d_loss_fake = celoss_zeros(d_fake_logits)
#计算梯度惩罚项
gp = gradient_penalty(discriminator, batch_x, fake_image)
# “1.” 为超参,可调节
loss = d_loss_fake + d_loss_real + 1. * gp
return loss, gp
def g_loss_fn(generator, discriminator, batch_z, is_training):
fake_image = generator(batch_z, is_training)
d_fake_logits = discriminator(fake_image, is_training)
loss = celoss_ones(d_fake_logits)
return loss
def main():
tf.random.set_seed(22)
np.random.seed(22)
assert tf.__version__.startswith('2.')
# hyper parameters
z_dim = 100
epochs = 3000000
batch_size = 256
learning_rate = 0.002
is_training = True
img_path = glob.glob(r'D:\PyCharmPro\shu_ju\anime-faces\*.png')
dataset, img_shape, _ = make_anime_dataset(img_path, batch_size)
print(dataset, img_shape)
sample = next(iter(dataset))
print(sample.shape, tf.reduce_max(sample).numpy(),
tf.reduce_min(sample).numpy())
dataset = dataset.repeat()
db_iter = iter(dataset)
generator = Generator()
generator.build(input_shape = (None, z_dim))
discriminator = Discriminator()
discriminator.build(input_shape=(None, 64, 64, 3))
g_optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
d_optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
for epoch in range(epochs):
batch_z = tf.random.uniform([batch_size, z_dim], minval=-1., maxval=1.)
batch_x = next(db_iter)
# train D
with tf.GradientTape() as tape:
d_loss, gp = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)
grads = tape.gradient(d_loss, discriminator.trainable_variables)
d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))
with tf.GradientTape() as tape:
g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
grads = tape.gradient(g_loss, generator.trainable_variables)
g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))
if epoch % 100 == 0:
print(epoch, 'd-loss:',float(d_loss), 'g-loss:', float(g_loss),
'gp:', float(gp))
z = tf.random.uniform([100, z_dim])
fake_image = generator(z, training=False)
img_path = os.path.join('images', 'wgan-%d.png'%epoch)
save_result(fake_image.numpy(), 10, img_path, color_mode='P')
if __name__ == '__main__':
main()
2.4. 工具代码 - dataset数据集加载
import multiprocessing
import tensorflow as tf
def make_anime_dataset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, repeat=1):
@tf.function
def _map_fn(img):
img = tf.image.resize(img, [resize, resize])
img = tf.clip_by_value(img, 0, 255)
img = img / 127.5 - 1
return img
dataset = disk_image_batch_dataset(img_paths,
batch_size,
drop_remainder=drop_remainder,
map_fn=_map_fn,
shuffle=shuffle,
repeat=repeat)
img_shape = (resize, resize, 3)
len_dataset = len(img_paths) // batch_size
return dataset, img_shape, len_dataset
def batch_dataset(dataset,
batch_size,
drop_remainder=True,
n_prefetch_batch=1,
filter_fn=None,
map_fn=None,
n_map_threads=None,
filter_after_map=False,
shuffle=True,
shuffle_buffer_size=None,
repeat=None):
# set defaults
if n_map_threads is None:
n_map_threads = multiprocessing.cpu_count()
if shuffle and shuffle_buffer_size is None:
shuffle_buffer_size = max(batch_size * 128, 2048) # set the minimum buffer size as 2048
# [*] it is efficient to conduct `shuffle` before `map`/`filter` because `map`/`filter` is sometimes costly
if shuffle:
dataset = dataset.shuffle(shuffle_buffer_size)
if not filter_after_map:
if filter_fn:
dataset = dataset.filter(filter_fn)
if map_fn:
dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)
else: # [*] this is slower
if map_fn:
dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)
if filter_fn:
dataset = dataset.filter(filter_fn)
dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
dataset = dataset.repeat(repeat).prefetch(n_prefetch_batch)
return dataset
def memory_data_batch_dataset(memory_data,
batch_size,
drop_remainder=True,
n_prefetch_batch=1,
filter_fn=None,
map_fn=None,
n_map_threads=None,
filter_after_map=False,
shuffle=True,
shuffle_buffer_size=None,
repeat=None):
"""Batch dataset of memory data.
Parameters
----------
memory_data : nested structure of tensors/ndarrays/lists
"""
dataset = tf.data.Dataset.from_tensor_slices(memory_data)
dataset = batch_dataset(dataset,
batch_size,
drop_remainder=drop_remainder,
n_prefetch_batch=n_prefetch_batch,
filter_fn=filter_fn,
map_fn=map_fn,
n_map_threads=n_map_threads,
filter_after_map=filter_after_map,
shuffle=shuffle,
shuffle_buffer_size=shuffle_buffer_size,
repeat=repeat)
return dataset
def disk_image_batch_dataset(img_paths,
batch_size,
labels=None,
drop_remainder=True,
n_prefetch_batch=1,
filter_fn=None,
map_fn=None,
n_map_threads=None,
filter_after_map=False,
shuffle=True,
shuffle_buffer_size=None,
repeat=None):
"""Batch dataset of disk image for PNG and JPEG.
Parameters
----------
img_paths : 1d-tensor/ndarray/list of str
labels : nested structure of tensors/ndarrays/lists
"""
if labels is None:
memory_data = img_paths
else:
memory_data = (img_paths, labels)
def parse_fn(path, *label):
img = tf.io.read_file(path)
img = tf.image.decode_png(img, 3) # fix channels to 3
return (img,) + label
if map_fn: # fuse `map_fn` and `parse_fn`
def map_fn_(*args):
return map_fn(*parse_fn(*args))
else:
map_fn_ = parse_fn
dataset = memory_data_batch_dataset(memory_data,
batch_size,
drop_remainder=drop_remainder,
n_prefetch_batch=n_prefetch_batch,
filter_fn=filter_fn,
map_fn=map_fn_,
n_map_threads=n_map_threads,
filter_after_map=filter_after_map,
shuffle=shuffle,
shuffle_buffer_size=shuffle_buffer_size,
repeat=repeat)
return dataset