让您的照片动起来first motion model(4)-对抗生成网络与模型训练

1、概述

本章将介绍模型剩余的部分与数据加载与训练

2、GeneratorFullModel完整的生成器

2.1 金字塔网络(ImagePyramide)

该网络用于获取不同缩放比的照片

class ImagePyramide(torch.nn.Module):
    """
    Create image pyramide for computing pyramide perceptual loss. See Sec 3.3
    """
    def __init__(self, scales, num_channels):
        super(ImagePyramide, self).__init__()
        downs = {}
        for scale in scales:
            downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
        self.downs = nn.ModuleDict(downs)

    def forward(self, x):
        out_dict = {}
        for scale, down_module in self.downs.items():
            out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
        return out_dict

测试代码

scales= [1, 0.5, 0.25, 0.125]
pyramide=ImagePyramide(scales,3)
pyramide_source=pyramide(source)

pyramide_source_list=[w2 for (w1,w2) in pyramide_source.items()]
figure,ax=plt.subplots(2,2,figsize=(8,4))
for i in range(2):
    for j in range(2):
        show_item=pyramide_source_list[(2*i)+j]
        ax[i,j].imshow(show_item[0].permute(1,2,0).data)

效果如下

2.2 Vgg19网络与感知损失(perceptual loss)

Vgg19是一个预训练好的网络,是风格转化中用到的一个经典网络,vgg不同卷积层的网络输出的多个特征映射。使用L1损失函数或平均绝对误差比较这些特征图。这些特征图包含图像的内容,但不包含外观。然后,感知损失计算出两个图像的内容有多相似。当然,我们希望生成的图像包含驱动图像的运动

下面代码主要实现一下功能

  1. 将输入进行按照指定的均值与方差进行normalize操作
  2. 取出vgg网络的第2,7,12,30层的特征输出并返回
class Vgg19(torch.nn.Module):
    """
    Vgg19 network for perceptual loss. See Sec 3.3.
    """
    def __init__(self, requires_grad=False):
        super(Vgg19, self).__init__()
        vgg_pretrained_features = models.vgg19(pretrained=True).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        for x in range(2):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(2, 7):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(7, 12):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(12, 21):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(21, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])

        self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
                                       requires_grad=False)
        self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
                                      requires_grad=False)

        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        #对输入进行归一化
        X = (X - self.mean) / self.std
        h_relu1 = self.slice1(X)
        h_relu2 = self.slice2(h_relu1)
        h_relu3 = self.slice3(h_relu2)
        h_relu4 = self.slice4(h_relu3)
        h_relu5 = self.slice5(h_relu4)
        out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
        return out

测试代码

vgg = Vgg19()
x_vgg=vgg(source)

感知损失的关键代码如下

检测配置文件中是否有感知损失的权重设定
if sum(self.loss_weights['perceptual']) != 0:
            value_total = 0
            #循环金字塔网络输出的各种大小的图片
            for scale in self.scales:
                #生成图片的特征图
                x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
                #真实图片的特征图
                y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])
                #根据权重进行加权
                for i, weight in enumerate(self.loss_weights['perceptual']):
                    value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
                    value_total += self.loss_weights['perceptual'][i] * value
                loss_values['perceptual'] = value_total

2.3 判别器(discriminator)

这里的判别器是一种不太规范的叫法,这里的判别器只是将图像与关键帧信息用来生成高斯的置信图,并加以返回

class DownBlock2d_disc(nn.Module):
    """
    Simple block for processing video (encoder).
    """

    def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
        super(DownBlock2d_disc, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)

        if sn:
            self.conv = nn.utils.spectral_norm(self.conv)

        if norm:
            self.norm = nn.InstanceNorm2d(out_features, affine=True)
        else:
            self.norm = None
        self.pool = pool

    def forward(self, x):
        out = x
        out = self.conv(out)
        if self.norm:
            out = self.norm(out)
        out = F.leaky_relu(out, 0.2)
        if self.pool:
            out = F.avg_pool2d(out, (2, 2))
        return out
class Discriminator(nn.Module):
    """
    Discriminator similar to Pix2Pix
    """

    def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
                 sn=False, use_kp=False, num_kp=10, kp_variance=0.01, **kwargs):
        super(Discriminator, self).__init__()

        down_blocks = []
        for i in range(num_blocks):
            down_blocks.append(
                DownBlock2d_disc(num_channels + num_kp * use_kp if i == 0 else min(max_features, block_expansion * (2 ** i)),
                            min(max_features, block_expansion * (2 ** (i + 1))),
                            norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))

        self.down_blocks = nn.ModuleList(down_blocks)
        self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
        if sn:
            self.conv = nn.utils.spectral_norm(self.conv)
        self.use_kp = use_kp
        self.kp_variance = kp_variance

    def forward(self, x, kp=None):
        feature_maps = []
        out = x
        if self.use_kp:
            heatmap = kp2gaussian(kp, x.shape[2:], self.kp_variance)
            out = torch.cat([out, heatmap], dim=1)

        for down_block in self.down_blocks:
            feature_maps.append(down_block(out))
            out = feature_maps[-1]
        prediction_map = self.conv(out)

        return feature_maps, prediction_map


class MultiScaleDiscriminator(nn.Module):
    """
    Multi-scale (scale) discriminator
    """

    def __init__(self, scales=(), **kwargs):
        super(MultiScaleDiscriminator, self).__init__()
        self.scales = scales
        discs = {}
        for scale in scales:
            discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
        self.discs = nn.ModuleDict(discs)

    def forward(self, x, kp=None):
        out_dict = {}
        for scale, disc in self.discs.items():
            scale = str(scale).replace('-', '.')
            key = 'prediction_' + scale
            feature_maps, prediction_map = disc(x[key], kp)
            out_dict['feature_maps_' + scale] = feature_maps
            out_dict['prediction_map_' + scale] = prediction_map
        return out_dict

注意 Discriminator 返回的第一个值是encode模块生成的特称图,第二个参数是针对每一个像素判断其是否为生成图像
同时这里的feature_maps会将downblock的每一个卷积层的输出的特征图都存储起来
MultiScaleDiscriminator是对Discriminator的又一次封装目的是对各个缩放之后的图片都进行处理
generator_loss的关键代码如下

# 检查配置中是否有相关权重的设置
if self.loss_weights['generator_gan'] != 0:
            # 获取生成图像的特征图和像素级判别结果
            discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
            # 获取真实图像的特征图和像素级判别结果
            discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))
            # 初始化loss值
            value_total = 0
            # 循环处理不同缩放图像
            for scale in self.disc_scales:
                key = 'prediction_map_%s' % scale
                # 优化意图是所有生成图像素判别结果为1
                value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
                #loss累加
                value_total += self.loss_weights['generator_gan'] * value
            loss_values['gen_gan'] = value_total
            # 检查配置文件中是否有特征图的匹配权重设置
            if sum(self.loss_weights['feature_matching']) != 0:
                value_total = 0
                for scale in self.disc_scales:
                    key = 'feature_maps_%s' % scale
                    for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])):
                        if self.loss_weights['feature_matching'][i] == 0:
                            continue
                        # 优化意图是生成图与真实图的特征一致
                        value = torch.abs(a - b).mean()
                        value_total += self.loss_weights['feature_matching'][i] * value
                    loss_values['feature_matching'] = value_total

2.4 Equivariance Loss等方差约束

  • T_{S \leftarrow D} 表示x经过一个变换得到Y

  • 给定一个已经变换\hat{T}_{X \leftarrow R}可以将任意图像转化到R

  • 则 \hat{T}_{D \leftarrow R}=T_{S \leftarrow D} \circ \hat{T}_{D \leftarrow R}(1)

    注意\hat{T}为已知 TPS变换

  • 对(1)式左侧进行泰勒的一阶展开式

    \hat{T}_{D \leftarrow R}(p)=\hat{T}_{D \leftarrow R}(p_k) + \frac{d}{dp}\hat{T}_{D \leftarrow R}|_{p=p_k} + o(2)

  • 对(1)式右侧进行泰勒的一阶展开式

    T_{S \leftarrow D} \circ \hat{T}_{D \leftarrow R}(p)=T_{S \leftarrow D} \circ \hat{T}_{D \leftarrow R}(p_k) + (\frac{d}{dp}T_{S \leftarrow D} \circ \hat{T}_{D \leftarrow R})|_{p=p_k}+o(3)

  • 对(3)式一阶项进行符合函数求导

    (\frac{d}{dp}T_{S \leftarrow D} \circ \hat{T}_{D \leftarrow R})|_{p=p_k}=(\frac{d}{dp}T_{S \leftarrow D}|_{p=\hat{T}_{D \leftarrow R}(p_k)})(\frac{d}{dp}\hat{T}_{D \leftarrow R}|_{p=p_k})+o

    将(4)式代入(3)式与(2)式比较需满足对应项相等

    \hat{T}_{D \leftarrow R}(p)=T_{S \leftarrow D} \circ \hat{T}_{D \leftarrow R}(p_k)(5)

    \frac{d}{dp}\hat{T}_{D \leftarrow R}|_{p=p_k}=(\frac{d}{dp}T_{S \leftarrow D}|_{p=\hat{T}_{D \leftarrow R}(p_k)})(\frac{d}{dp}\hat{T}_{D \leftarrow R}|_{p=p_k})(6)

tps变换类Transform
变换的核心组成部分是下面的函数

观察该函数的三维图像 1、该图像相当于在一个图片的上方有一个控制点,当控制点向上移动时带动周围的像素产生变形。变形的距离大小取决于点距离控制点的远近。下面图像中时以0点(不含0点)为控制点产生的一个图像,有多个控制点,每一个控制点都控制图像产生一定的像素变换,这种变换可以粗暴的理解为tps,当然在变换时需要乘以一个较小的系数来控制变形的幅度。

x=y=np.concatenate((np.arange(-10,0),np.arange(1,11)))
X,Y=np.meshgrid(x,y)
figure=plt.figure(figsize=(5,4))
ax3d=figure.add_subplot(projection='3d')
ax3d.plot_surface(X, Y, -(X**2+Y**2)*np.log((X**2+Y**2)), rstride=1, cstride=1, cmap=plt.get_cmap('rainbow'))
ax3d.set_xlabel('x')
ax3d.set_ylabel('y')
ax3d.set_zlabel('z')
plt.show()

图像如下

Transform 源码如下主要进行tps变换

class Transform:
    """
    Random tps transformation for equivariance constraints. See Sec 3.3
    """
    def __init__(self, bs, **kwargs): #bs是batch_size
        noise = torch.normal(mean=0, std=kwargs['sigma_affine'] * torch.ones([bs, 2, 3]))
        self.theta = noise + torch.eye(2, 3).view(1, 2, 3)
        self.bs = bs

        if ('sigma_tps' in kwargs) and ('points_tps' in kwargs):
            self.tps = True
            self.control_points = make_coordinate_grid((kwargs['points_tps'], kwargs['points_tps']), type=noise.type())
            self.control_points = self.control_points.unsqueeze(0)
            self.control_params = torch.normal(mean=0,
                                               std=kwargs['sigma_tps'] * torch.ones([bs, 1, kwargs['points_tps'] ** 2]))
        else:
            self.tps = False

    def transform_frame(self, frame):
        grid = make_coordinate_grid(frame.shape[2:], type=frame.type()).unsqueeze(0)
        grid = grid.view(1, frame.shape[2] * frame.shape[3], 2)
        grid = self.warp_coordinates(grid).view(self.bs, frame.shape[2], frame.shape[3], 2)
        return F.grid_sample(frame, grid, padding_mode="reflection")

    def warp_coordinates(self, coordinates):
        theta = self.theta.type(coordinates.type())
        theta = theta.unsqueeze(1)
        transformed = torch.matmul(theta[:, :, :, :2], coordinates.unsqueeze(-1)) + theta[:, :, :, 2:]
        transformed = transformed.squeeze(-1)

        if self.tps:
            control_points = self.control_points.type(coordinates.type())
            control_params = self.control_params.type(coordinates.type())
            distances = coordinates.view(coordinates.shape[0], -1, 1, 2) - control_points.view(1, 1, -1, 2)
            distances = torch.abs(distances).sum(-1)

            result = distances ** 2
            result = result * torch.log(distances + 1e-6)
            result = result * control_params
            result = result.sum(dim=2).view(self.bs, coordinates.shape[1], 1)
            transformed = transformed + result

        return transformed

    def jacobian(self, coordinates):
        new_coordinates = self.warp_coordinates(coordinates)
        grad_x = grad(new_coordinates[..., 0].sum(), coordinates, create_graph=True)
        grad_y = grad(new_coordinates[..., 1].sum(), coordinates, create_graph=True)
        jacobian = torch.cat([grad_x[0].unsqueeze(-2), grad_y[0].unsqueeze(-2)], dim=-2)
        return jacobian


def detach_kp(kp):
    return {key: value.detach() for key, value in kp.items()}

transform = Transform(1, **config['train_params']['transform_params'])

针对(5)式的loss约束可以采用如下思路,真实的目标?D图像映射到R空间然后使用keypoin_detector预测10个关键点再将10个关键的坐标从R映射到真实的目标图像?D

  1. 首先观察通过tps直接变换原图的情况
  2. 注意这里的transform的初始化实在forward内部初始化的,因为一旦transform初始化其变形状态也就确定了,为了保持变换的多样性,再forward中进行初始化
transformed_frame = transform.transform_frame(drive)
figure,ax=plt.subplots(1,2,figsize=(6,3))
ax[0].imshow(drive[0].permute(1,2,0).data)
ax[1].imshow(transformed_frame[0].permute(1,2,0).data)

计算关键点损失关键代码如下:

# 初始化transform
transform = Transform(x['driving'].shape[0], **self.train_params['transform_params'])
# 对驱动图像进行tps变换
transformed_frame = transform.transform_frame(x['driving'])
# 抽取变换之后的关键帧
transformed_kp = self.kp_extractor(transformed_frame)

generated['transformed_frame'] = transformed_frame
generated['transformed_kp'] = transformed_kp

## Value loss part
if self.loss_weights['equivariance_value'] != 0:
    # 优化意图变形之后的图像抽取关键点与先抽取关键点再进行变形结果应当保持一致
    value = torch.abs(kp_driving['value'] - transform.warp_coordinates(transformed_kp['value'])).mean()
    loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value

计算雅各比的相关约束

\frac{d}{dp}\hat{T}_{D \leftarrow R}|_{p=p_k}=(\frac{d}{dp}T_{S \leftarrow D}|_{p=\hat{T}_{D \leftarrow R}(p_k)})(\frac{d}{dp}\hat{T}_{D \leftarrow R}|_{p=p_k}) (6)

 针对(6)的雅各比loss可以通过tensorflow自动微分功能求得,为方便求解(6)式可以转化为下面公式

1=(\frac{d}{dp}\hat{T}_{D \leftarrow R}|_{p=p_k})^{-1}(\frac{d}{dp}T_{S \leftarrow D}|_{p=\hat{T}_{D \leftarrow R}(p_k)})(\frac{d}{dp}\hat{T}_{D \leftarrow R}|_{p=p_k})(7)

关键代码如下

# 公式7右侧第二项与第三项相乘
jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp['value']),transformed_kp['jacobian'])
# 公式7右侧第一项
normed_driving = torch.inverse(kp_driving['jacobian'])
normed_transformed = jacobian_transformed
# 公式7右侧全部
value = torch.matmul(normed_driving, normed_transformed)
# 公式7左侧
eye = torch.eye(2).view(1, 1, 2, 2).type(value.type())
# 优化目的公式成立
value = torch.abs(eye - value).mean()
# 加权
loss_values['equivariance_jacobian'] = self.loss_weights['equivariance_jacobian'] * value

2.5 完整的生成器代码 (FullGenerator)

class GeneratorFullModel(torch.nn.Module):
    """
    Merge all generator related updates into single model for better multi-gpu usage
    """

    def __init__(self, kp_extractor, generator, discriminator, train_params):
        super(GeneratorFullModel, self).__init__()
        self.kp_extractor = kp_extractor
        self.generator = generator
        self.discriminator = discriminator
        self.train_params = train_params
        self.scales = train_params['scales']
        self.disc_scales = self.discriminator.scales
        self.pyramid = ImagePyramide(self.scales, generator.num_channels)
        if torch.cuda.is_available():
            self.pyramid = self.pyramid.cuda()

        self.loss_weights = train_params['loss_weights']

        if sum(self.loss_weights['perceptual']) != 0:
            self.vgg = Vgg19()
            if torch.cuda.is_available():
                self.vgg = self.vgg.cuda()

    def forward(self, x):
        kp_source = self.kp_extractor(x['source'])
        kp_driving = self.kp_extractor(x['driving'])

        generated = self.generator(x['source'], kp_source=kp_source, kp_driving=kp_driving)
        generated.update({'kp_source': kp_source, 'kp_driving': kp_driving})

        loss_values = {}

        pyramide_real = self.pyramid(x['driving'])
        pyramide_generated = self.pyramid(generated['prediction'])

        if sum(self.loss_weights['perceptual']) != 0:
            value_total = 0
            for scale in self.scales:
                x_vgg = self.vgg(pyramide_generated['prediction_' + str(scale)])
                y_vgg = self.vgg(pyramide_real['prediction_' + str(scale)])

                for i, weight in enumerate(self.loss_weights['perceptual']):
                    value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
                    value_total += self.loss_weights['perceptual'][i] * value
                loss_values['perceptual'] = value_total

        if self.loss_weights['generator_gan'] != 0:
            discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
            discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))
            value_total = 0
            for scale in self.disc_scales:
                key = 'prediction_map_%s' % scale
                value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
                value_total += self.loss_weights['generator_gan'] * value
            loss_values['gen_gan'] = value_total

            if sum(self.loss_weights['feature_matching']) != 0:
                value_total = 0
                for scale in self.disc_scales:
                    key = 'feature_maps_%s' % scale
                    for i, (a, b) in enumerate(zip(discriminator_maps_real[key], discriminator_maps_generated[key])):
                        if self.loss_weights['feature_matching'][i] == 0:
                            continue
                        value = torch.abs(a - b).mean()
                        value_total += self.loss_weights['feature_matching'][i] * value
                    loss_values['feature_matching'] = value_total

        if (self.loss_weights['equivariance_value'] + self.loss_weights['equivariance_jacobian']) != 0:
            transform = Transform(x['driving'].shape[0], **self.train_params['transform_params'])
            transformed_frame = transform.transform_frame(x['driving'])
            transformed_kp = self.kp_extractor(transformed_frame)

            generated['transformed_frame'] = transformed_frame
            generated['transformed_kp'] = transformed_kp

            ## Value loss part
            if self.loss_weights['equivariance_value'] != 0:
                value = torch.abs(kp_driving['value'] - transform.warp_coordinates(transformed_kp['value'])).mean()
                loss_values['equivariance_value'] = self.loss_weights['equivariance_value'] * value

            ## jacobian loss part
            if self.loss_weights['equivariance_jacobian'] != 0:
                jacobian_transformed = torch.matmul(transform.jacobian(transformed_kp['value']),
                                                    transformed_kp['jacobian'])

                normed_driving = torch.inverse(kp_driving['jacobian'])
                normed_transformed = jacobian_transformed
                value = torch.matmul(normed_driving, normed_transformed)

                eye = torch.eye(2).view(1, 1, 2, 2).type(value.type())

                value = torch.abs(eye - value).mean()
                loss_values['equivariance_jacobian'] = self.loss_weights['equivariance_jacobian'] * value

        return loss_values, generated

3 完整判别器代码(Full Discriminator)

class DiscriminatorFullModel(torch.nn.Module):
    """
    Merge all discriminator related updates into single model for better multi-gpu usage
    """

    def __init__(self, kp_extractor, generator, discriminator, train_params):
        super(DiscriminatorFullModel, self).__init__()
        self.kp_extractor = kp_extractor
        self.generator = generator
        self.discriminator = discriminator
        self.train_params = train_params
        self.scales = self.discriminator.scales
        self.pyramid = ImagePyramide(self.scales, generator.num_channels)
        if torch.cuda.is_available():
            self.pyramid = self.pyramid.cuda()

        self.loss_weights = train_params['loss_weights']

    def forward(self, x, generated):
        pyramide_real = self.pyramid(x['driving'])
        pyramide_generated = self.pyramid(generated['prediction'].detach())

        kp_driving = generated['kp_driving']
        discriminator_maps_generated = self.discriminator(pyramide_generated, kp=detach_kp(kp_driving))
        discriminator_maps_real = self.discriminator(pyramide_real, kp=detach_kp(kp_driving))

        loss_values = {}
        value_total = 0
        for scale in self.scales:
            key = 'prediction_map_%s' % scale
            value = (1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2
            value_total += self.loss_weights['discriminator_gan'] * value.mean()
        loss_values['disc_gan'] = value_total

        return loss_values

4 总结

至此模型解读完毕,至于生成数据集和训练过程本文没有涉及,请大家参照官方源码

 

 

 

 


版权声明:本文为amao1998原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明。