Tensorflow 2 实战(kears)- 生成式对抗网络 - GAN、WGAN-GP

Tensorflow 2 实战(kears)- 生成式对抗网络 - GAN、WGAN-GP

一、背景介绍

1.1、数据集简介

本次实战使用的Anime数据集为 “高清动漫人物头像” PNG图片。图片为64x64的彩色图片共21551张,所包含的图片示例如下:

在这里插入图片描述

1.2、模型简介

本次实战的 “ 生成对抗网络GAN/WGAN-GP ” 模型共包括两个部分,分别为:Generator生成器、Discriminator判别器。

  • Generator生成器(用于生成图像),其包含:一个全连接层、三个逆卷积层、两个BatchNormalization层,经过 tanh 输出生成图像。该部分的输入为"随机初始化的z",输出为生成的图像。
  • Discriminator判别器 (用于图像判别),其包含:三个卷积层、两个BatchNormalization层、一个全连接层。该部分的输入为"生成的图像"、“真实的图像”,输出为 logits。

模型通过 “ Generator生成器” 生成图像,并与 “真实图像” 一起输入 “Discriminator判别器” 进行判别。“Discriminator判别器” 通过 loss 以“真实图像为真,生成图像为假” 进行参数更新,而 “ Generator生成器” 通过 loss 以 “生成图像为真” 进行参数更新。

此次实战中GAN与WGAN-GP 的区别为 :WGAN-GP 改变了“Discriminator判别器” 的 loss 。WGAN-GP loss 如下:
在这里插入图片描述

本次实战GAN/WGAN-GP模型结构如下:

在这里插入图片描述

GAN 与 WGAN-GP 的区别与联系:
原始GAN存在的问题:如果 PG(生成样本的分布)和 Pd(真实样本的分布)两个概率分布没有任何重叠的话,用 “JS散度” 也就失去了意义。
WGAN-GP解决:WGAN-GP 用 “Wasserstei distance” 来衡量两个分布之间的距离,从而更好地将 PG (生成样本的分布)逼近 Pd(真实样本的分布)。
WGAN-GP 模型可以在原来 GAN 代码实现的基础上仅做少量修改。WGAN-GP 模型的判别器 D 添加了一个梯度惩罚项。
(参考 JS散度、Wasserstei distance

1.2.1 GAN原理介绍 (点击详见WGAN-GP原理介绍)

1.2.1.1 GAN能做什么?

生成对抗网络GAN,它的初衷就是生成 “不存在于真实世界(非真实发生)” 的数据,类似于使得 AI具有创造力或者想象力。(例如:AI画家、将模糊图变清晰“去雨,去雾,去抖动,去马赛克”、进行数据增强等)。

1.2.1.1 GAN的原理

GAN可以简单的看作是两个网络的博弈过程,它有两个网络:

  • Generator生成器(简称G): 负责凭空捏造数据出来
  • Discriminator判别器(简称D): 负责判断数据是不是真数据

GAN的整个过程都是无监督的,真图是没有标记过的,系统里的D并不知道来的图片是什么玩意儿,它只需要分辨真假;G也不知道自己生成的是什么玩意儿,反正就是学真图片的样子骗D。也就是让 “Generator生成器” 最大概率的生成真实图片, 也就是要找一个分布,让生成图片更接近于真实图片

二、“GAN/WGAN-GP”实战代码

2.1、GAN/WGAN-GP模型部分代码
2.2、 GAN-train 代码
2.3、 WGAN-GP-train 代码(与GAN-train区别:改变了“Discriminator判别器” 的 loss,此处只对改变的 loss 部分做详细注释)
2.4、工具代码 - dataset数据集加载

2.1. GAN/WGAN-GP模型部分代码

# -*-coding:utf-8-*-

import  tensorflow as tf
from    tensorflow import keras
from    tensorflow.keras import layers

#创建Generator网络(生成图像),包含:一个全连接层,三个卷积层,经过tanh得到输出
class Generator(keras.Model):

    def __init__(self):
        super(Generator, self).__init__()

        # z: [b, 100] => [b, 64, 64, 3]

        #升维,以便于后续生成图像
        self.fc = layers.Dense(3*3*512)

        #Conv2DTranspose(通道数,步长,卷积核大小,padding):逆卷积,用于扩大图像的尺寸【输出=(N−1)∗S−2P+F】
        # 输入:NxN
        # 卷积核大小,kernel_size:FxF
        # 步长strides:S
        # 边界扩充padding的值:P
        #   (3-1)*3 +3 =9
        self.conv1 = layers.Conv2DTranspose(256, 3, 3, 'valid')
        self.bn1 = layers.BatchNormalization()

        # (9-1) * 2 + 5 =21
        self.conv2 = layers.Conv2DTranspose(128, 5, 2, 'valid')
        self.bn2 = layers.BatchNormalization()

        # (23-1) * 3 + 4 =64
        self.conv3 = layers.Conv2DTranspose(3, 4, 3, 'valid')

    #前向传播(生成图像)
    def call(self, inputs, training=None):
        # 升维,并改变其形状以满足后续生成图片要求[b, 100] => [b, 3*3*512]
        x = self.fc(inputs)
        x = tf.reshape(x, [-1, 3, 3, 512])
        #由于relu对于x小于0时,会造成梯度弥散的现象,故此处使用leaky_relu
        x = tf.nn.leaky_relu(x)

        #z: [b, 3*3*512] => [b, 64, 64, 3]
        x = tf.nn.leaky_relu(self.bn1(self.conv1(x), training=training))
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
        x = self.conv3(x)
        #sigmoid函数在训练时容易不稳定,故此处使用tanh(D生成的时0~1,因为做分类时我们把0~255的值压缩到0~1,然后“(0~1)*2-1”把它降到-1~1的区间;网络接收图片的范围是-1~1,所以生成也为-1~1的区间;如果想人为查看这些图,需要把-1~1升到0~1再升到0~255才符合肉眼可见)
        x = tf.tanh(x)

        return x

#创建Discriminator网络(图像判别),包含:三个卷积层、一个全连接层
class Discriminator(keras.Model):

    def __init__(self):
        super(Discriminator, self).__init__()

        # [b, 64, 64, 3] => [b, 1, 1, 256]
        #Conv2D(通道数,步长S,卷积核大小F,padding),输出=向下取整((N-F)/S)+1)
        #(64-3)/5+1=13
        self.conv1 = layers.Conv2D(64, 5, 3, 'valid')
        #(13-3)/5+1=3
        self.conv2 = layers.Conv2D(128, 5, 3, 'valid')
        self.bn2 = layers.BatchNormalization()
        #(3-3)/5+1=1
        self.conv3 = layers.Conv2D(256, 5, 3, 'valid')
        self.bn3 = layers.BatchNormalization()

        # [b, h, w ,c] => [b, -1]
        #Flatten:用来将输入“压平”(即把多维的输入一维化,常用在从卷积层到全连接层的过渡)
        self.flatten = layers.Flatten()
        self.fc = layers.Dense(1)

    #前向传播(图像判别)
    def call(self, inputs, training=None):

        x = tf.nn.leaky_relu(self.conv1(inputs))
        x = tf.nn.leaky_relu(self.bn2(self.conv2(x), training=training))
        x = tf.nn.leaky_relu(self.bn3(self.conv3(x), training=training))

        # [b, h, w, c] => [b, -1]
        x = self.flatten(x)
        # [b, -1] => [b, 1]
        logits = self.fc(x)

        return logits

# #测试Discriminator、Generator
# def main():
#
#     d = Discriminator()
#     g = Generator()
#
#
#     x = tf.random.normal([2, 64, 64, 3])
#     z = tf.random.normal([2, 100])
#
#     prob = d(x)
#     print(prob)
#     x_hat = g(z)
#     print(x_hat.shape)

if __name__ == '__main__':
    pass
    # main()

2.2. GAN-train 代码

# -*-coding:utf-8-*-

import  os
#只显示 warning 和 Error
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import  numpy as np
import  tensorflow as tf
from    PIL import Image
import  glob
from    gan import Generator, Discriminator
from    dataset import make_anime_dataset


#将多个training图片拼成一张图片保存
def save_result(val_out, val_block_size, image_path, color_mode):
    def preprocess(img):
        img = ((img + 1.0) * 127.5).astype(np.uint8)
        return img
    preprocesed = preprocess(val_out)
    final_image = np.array([])
    single_row = np.array([])
    for b in range(val_out.shape[0]):
        # concat image into a row
        if single_row.size == 0:
            single_row = preprocesed[b, :, :, :]
        else:
            single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1)
        # concat image row to final_image
        if (b+1) % val_block_size == 0:
            if final_image.size == 0:
                final_image = single_row
            else:
                final_image = np.concatenate((final_image, single_row), axis=0)
            # reset single row
            single_row = np.array([])
    if final_image.shape[2] == 1:
        final_image = np.squeeze(final_image, axis=2)
    Image.fromarray(final_image,mode=color_mode).save(image_path)

#loss-label为真
def celoss_ones(logits):
    # [b, 1]
    # [b] = [1, 1, 1, 1,]
    loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
                                                   labels=tf.ones_like(logits))
    return tf.reduce_mean(loss)

#loss-label为假
def celoss_zeros(logits):
    # [b, 1]
    # [b] = [1, 1, 1, 1,]
    loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
                                                   labels=tf.zeros_like(logits))
    return tf.reduce_mean(loss)

#将真实图像判为真,将生成图像判为假
def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):
    #生成图像(将随机初始化的z输入generator)
    fake_image = generator(batch_z, is_training)
    #将生成图像输入discriminator
    d_fake_logits = discriminator(fake_image, is_training)
    # 将真实图像输入discriminator
    d_real_logits = discriminator(batch_x, is_training)

    #loss-真实图像,label为真
    d_loss_real = celoss_ones(d_real_logits)
    #loss-生成图像,label为假
    d_loss_fake = celoss_zeros(d_fake_logits)
    #loss-真是图像-真,生成图像-假
    loss = d_loss_fake + d_loss_real

    return loss

#将生成图像判为真
def g_loss_fn(generator, discriminator, batch_z, is_training):
    #将随机初始化的batch_z输入Generator
    fake_image = generator(batch_z, is_training)
    #将生成的图像输入Discriminator
    d_fake_logits = discriminator(fake_image, is_training)
    # loss-生成图像,label为真
    loss = celoss_ones(d_fake_logits)

    return loss

def main():
    #随机数种子
    tf.random.set_seed(22)
    np.random.seed(22)
    # startwith('2.') 这个函数用于判断tf.__version__的版本信息是否以'2.0'开头,返回True或者False
    assert tf.__version__.startswith('2.')


    # 超参数
    z_dim = 100
    epochs = 3000000
    batch_size = 256
    learning_rate = 0.002
    is_training = True

    #获取指定目录下的所有图片
    img_path = glob.glob(r'D:\PyCharmPro\shu_ju\anime-faces\*.png')
    #数据加载
    dataset, img_shape, _ = make_anime_dataset(img_path, batch_size)
    print(dataset, img_shape)
    sample = next(iter(dataset))
    print(sample.shape, tf.reduce_max(sample).numpy(),
          tf.reduce_min(sample).numpy())
    #repeat()如果括号内不填写数字,则可以无限制的重复
    dataset = dataset.repeat()
    db_iter = iter(dataset)

    #实例化Generator、Discriminator,以及自定义其网络的权重的维度
    generator = Generator()
    generator.build(input_shape = (None, z_dim))
    discriminator = Discriminator()
    discriminator.build(input_shape=(None, 64, 64, 3))

    #设置Generator的优化器
    g_optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
    # 设置optimizers的优化器
    d_optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)

    #epoch循环
    for epoch in range(epochs):
        #随机初始化z(-1~1)
        batch_z = tf.random.uniform([batch_size, z_dim], minval=-1., maxval=1.)
        batch_x = next(db_iter)

        # 反向传播,以将“生成图像判为假”、“真实图像判为真”更新Discriminator参数
        with tf.GradientTape() as tape:
            #计算Discriminator-loss
            d_loss = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)
        #反向传播
        grads = tape.gradient(d_loss, discriminator.trainable_variables)
        #使用优化器更新参数
        d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))

        # 反向传播,以将“生成图像判为真”更新Generator参数
        with tf.GradientTape() as tape:
            #计算Generator-loss
            g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
        #反向传播
        grads = tape.gradient(g_loss, generator.trainable_variables)
        #使用优化器更新参数
        g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))

        if epoch % 100 == 0:
            print(epoch, 'd-loss:',float(d_loss), 'g-loss:', float(g_loss))

            #查看生成图像效果(每100个epoch,sample一次图片)
            #随机初始化z
            z = tf.random.uniform([100, z_dim])
            #生成图像
            fake_image = generator(z, training=False)
            #保存图像
            img_path = os.path.join('images_gan', 'gan-%d.png'%epoch)
            save_result(fake_image.numpy(), 10, img_path, color_mode='P')

if __name__ == '__main__':
    main()

2.3. WGAN-GP-train 代码(与GAN区别:改变了“Discriminator判别器” 的 loss,此处只对改变的 loss 部分做详细注释)

import  os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import  numpy as np
import  tensorflow as tf
from    PIL import Image
import  glob
from    gan import Generator, Discriminator

from    dataset import make_anime_dataset


def save_result(val_out, val_block_size, image_path, color_mode):
    def preprocess(img):
        img = ((img + 1.0) * 127.5).astype(np.uint8)
        # img = img.astype(np.uint8)
        return img

    preprocesed = preprocess(val_out)
    final_image = np.array([])
    single_row = np.array([])
    for b in range(val_out.shape[0]):
        # concat image into a row
        if single_row.size == 0:
            single_row = preprocesed[b, :, :, :]
        else:
            single_row = np.concatenate((single_row, preprocesed[b, :, :, :]), axis=1)

        # concat image row to final_image
        if (b+1) % val_block_size == 0:
            if final_image.size == 0:
                final_image = single_row
            else:
                final_image = np.concatenate((final_image, single_row), axis=0)

            # reset single row
            single_row = np.array([])

    if final_image.shape[2] == 1:
        final_image = np.squeeze(final_image, axis=2)
    Image.fromarray(final_image).save(image_path)


def celoss_ones(logits):
    # [b, 1]
    # [b] = [1, 1, 1, 1,]
    loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
                                                   labels=tf.ones_like(logits))
    return tf.reduce_mean(loss)


def celoss_zeros(logits):
    # [b, 1]
    # [b] = [1, 1, 1, 1,]
    loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=logits,
                                                   labels=tf.zeros_like(logits))
    return tf.reduce_mean(loss)


def gradient_penalty(discriminator, batch_x, fake_image):

    batchsz = batch_x.shape[0]

    #从均匀分布中随机sample线性差值因子(整个image都用这一个t权值)
    # [b, h, w, c]
    t = tf.random.uniform([batchsz, 1, 1, 1])
    #broadcast_to函数将数组广播到新形状
    # [b, 1, 1, 1] => [b, h, w, c]
    t = tf.broadcast_to(t, batch_x.shape)

    #"真的图像batch_x" 与 “假的图像fake_image” 之间做线性差值(t是0~1之间的线性差值)
    interplate = t * batch_x + (1 - t) * fake_image

    #将“interplate”送入discriminator并求解梯度
    with tf.GradientTape() as tape:
        tape.watch([interplate])
        d_interplote_logits = discriminator(interplate)
    grads = tape.gradient(d_interplote_logits, interplate)

    #得到每一个样本的二范数
    # grads:[b, h, w, c] => [b, -1]
    grads = tf.reshape(grads, [grads.shape[0], -1])
    gp = tf.norm(grads, axis=1) #[b]
    gp = tf.reduce_mean( (gp-1)**2 )

    return gp



def d_loss_fn(generator, discriminator, batch_z, batch_x, is_training):
    # 1. treat real image as real
    # 2. treat generated image as fake
    fake_image = generator(batch_z, is_training)
    d_fake_logits = discriminator(fake_image, is_training)
    d_real_logits = discriminator(batch_x, is_training)

    d_loss_real = celoss_ones(d_real_logits)
    d_loss_fake = celoss_zeros(d_fake_logits)

    #计算梯度惩罚项
    gp = gradient_penalty(discriminator, batch_x, fake_image)

    # “1.” 为超参,可调节
    loss = d_loss_fake + d_loss_real + 1. * gp

    return loss, gp


def g_loss_fn(generator, discriminator, batch_z, is_training):

    fake_image = generator(batch_z, is_training)
    d_fake_logits = discriminator(fake_image, is_training)
    loss = celoss_ones(d_fake_logits)

    return loss

def main():

    tf.random.set_seed(22)
    np.random.seed(22)

    assert tf.__version__.startswith('2.')


    # hyper parameters
    z_dim = 100
    epochs = 3000000
    batch_size = 256
    learning_rate = 0.002
    is_training = True


    img_path = glob.glob(r'D:\PyCharmPro\shu_ju\anime-faces\*.png')

    dataset, img_shape, _ = make_anime_dataset(img_path, batch_size)
    print(dataset, img_shape)
    sample = next(iter(dataset))
    print(sample.shape, tf.reduce_max(sample).numpy(),
          tf.reduce_min(sample).numpy())
    dataset = dataset.repeat()
    db_iter = iter(dataset)


    generator = Generator()
    generator.build(input_shape = (None, z_dim))
    discriminator = Discriminator()
    discriminator.build(input_shape=(None, 64, 64, 3))

    g_optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)
    d_optimizer = tf.optimizers.Adam(learning_rate=learning_rate, beta_1=0.5)


    for epoch in range(epochs):

        batch_z = tf.random.uniform([batch_size, z_dim], minval=-1., maxval=1.)
        batch_x = next(db_iter)

        # train D
        with tf.GradientTape() as tape:
            d_loss, gp = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)
        grads = tape.gradient(d_loss, discriminator.trainable_variables)
        d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))


        with tf.GradientTape() as tape:
            g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
        grads = tape.gradient(g_loss, generator.trainable_variables)
        g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))

        if epoch % 100 == 0:
            print(epoch, 'd-loss:',float(d_loss), 'g-loss:', float(g_loss),
                  'gp:', float(gp))

            z = tf.random.uniform([100, z_dim])
            fake_image = generator(z, training=False)
            img_path = os.path.join('images', 'wgan-%d.png'%epoch)
            save_result(fake_image.numpy(), 10, img_path, color_mode='P')



if __name__ == '__main__':
    main()

2.4. 工具代码 - dataset数据集加载

import multiprocessing

import tensorflow as tf


def make_anime_dataset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, repeat=1):
    @tf.function
    def _map_fn(img):
        img = tf.image.resize(img, [resize, resize])
        img = tf.clip_by_value(img, 0, 255)
        img = img / 127.5 - 1
        return img

    dataset = disk_image_batch_dataset(img_paths,
                                          batch_size,
                                          drop_remainder=drop_remainder,
                                          map_fn=_map_fn,
                                          shuffle=shuffle,
                                          repeat=repeat)
    img_shape = (resize, resize, 3)
    len_dataset = len(img_paths) // batch_size

    return dataset, img_shape, len_dataset


def batch_dataset(dataset,
                  batch_size,
                  drop_remainder=True,
                  n_prefetch_batch=1,
                  filter_fn=None,
                  map_fn=None,
                  n_map_threads=None,
                  filter_after_map=False,
                  shuffle=True,
                  shuffle_buffer_size=None,
                  repeat=None):
    # set defaults
    if n_map_threads is None:
        n_map_threads = multiprocessing.cpu_count()
    if shuffle and shuffle_buffer_size is None:
        shuffle_buffer_size = max(batch_size * 128, 2048)  # set the minimum buffer size as 2048

    # [*] it is efficient to conduct `shuffle` before `map`/`filter` because `map`/`filter` is sometimes costly
    if shuffle:
        dataset = dataset.shuffle(shuffle_buffer_size)

    if not filter_after_map:
        if filter_fn:
            dataset = dataset.filter(filter_fn)

        if map_fn:
            dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)

    else:  # [*] this is slower
        if map_fn:
            dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)

        if filter_fn:
            dataset = dataset.filter(filter_fn)

    dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)

    dataset = dataset.repeat(repeat).prefetch(n_prefetch_batch)

    return dataset


def memory_data_batch_dataset(memory_data,
                              batch_size,
                              drop_remainder=True,
                              n_prefetch_batch=1,
                              filter_fn=None,
                              map_fn=None,
                              n_map_threads=None,
                              filter_after_map=False,
                              shuffle=True,
                              shuffle_buffer_size=None,
                              repeat=None):
    """Batch dataset of memory data.

    Parameters
    ----------
    memory_data : nested structure of tensors/ndarrays/lists

    """
    dataset = tf.data.Dataset.from_tensor_slices(memory_data)
    dataset = batch_dataset(dataset,
                            batch_size,
                            drop_remainder=drop_remainder,
                            n_prefetch_batch=n_prefetch_batch,
                            filter_fn=filter_fn,
                            map_fn=map_fn,
                            n_map_threads=n_map_threads,
                            filter_after_map=filter_after_map,
                            shuffle=shuffle,
                            shuffle_buffer_size=shuffle_buffer_size,
                            repeat=repeat)
    return dataset


def disk_image_batch_dataset(img_paths,
                             batch_size,
                             labels=None,
                             drop_remainder=True,
                             n_prefetch_batch=1,
                             filter_fn=None,
                             map_fn=None,
                             n_map_threads=None,
                             filter_after_map=False,
                             shuffle=True,
                             shuffle_buffer_size=None,
                             repeat=None):
    """Batch dataset of disk image for PNG and JPEG.

    Parameters
    ----------
        img_paths : 1d-tensor/ndarray/list of str
        labels : nested structure of tensors/ndarrays/lists

    """
    if labels is None:
        memory_data = img_paths
    else:
        memory_data = (img_paths, labels)

    def parse_fn(path, *label):
        img = tf.io.read_file(path)
        img = tf.image.decode_png(img, 3)  # fix channels to 3
        return (img,) + label

    if map_fn:  # fuse `map_fn` and `parse_fn`
        def map_fn_(*args):
            return map_fn(*parse_fn(*args))
    else:
        map_fn_ = parse_fn

    dataset = memory_data_batch_dataset(memory_data,
                                        batch_size,
                                        drop_remainder=drop_remainder,
                                        n_prefetch_batch=n_prefetch_batch,
                                        filter_fn=filter_fn,
                                        map_fn=map_fn_,
                                        n_map_threads=n_map_threads,
                                        filter_after_map=filter_after_map,
                                        shuffle=shuffle,
                                        shuffle_buffer_size=shuffle_buffer_size,
                                        repeat=repeat)

    return dataset


版权声明:本文为weixin_43560913原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。