博客
关于我
生成模型应用——使用变分自编码器(VAE)控制人脸属性生成人脸图片
阅读量:790 次
发布时间:2019-03-25

本文共 9448 字,大约阅读时间需要 31 分钟。

使用变分自编码器(VAE)控制人脸属性生成人脸图片

基本概念

变分自编码器(VAE)是一种生成模型方法,能够根据输入数据生成高质量的图像。VAE通过最大化数据的对数似然来学习数据分布,在学习过程中生成潜在变量,并通过潜在变量生成新的数据。在人脸生成中,我们可以使用VAE以控制人脸的具体属性,比如表情、发型和肤色等。

数据集选择

选择合适的人脸数据集对于实验效果至关重要。在文中讨论的实验中,使用了以下两种数据集:

  • Celeb-A 数据集

    这是一个学术级数据集,包含了丰富的面部标注信息。虽然这类数据集不完全适合商业用途,但由于其丰富的数据特征和高质量的图像,可以满足我们对面部属性研究的需求。

  • DW-A 数据集

    这是一个免费、适合商业用途的数据集。该数据集提供了高分辨率的图像,适合图像生成实验。

  • 网络架构

    VAE的整体网络架构包含编码器(Encoder)和解码器(Decoder)。编码器将输入图像编码为潜在变量,解码器再将潜在变量反解码为图像。具体来说,编码器和解码器的结构如下:

    • 编码器(Encoder)

      • 卷积层(Conv2D):用于逐步减少图像的空间维度并提取特征。
      • 全连接层(Dense):将提取的特征映射为潜在变量。
    • 解码器(Decoder)

      • 卷积逆变换层(UpConvBlock):逐步增加空间维度并细化图像。
      • 分层卷积层(Conv2D):最终生成原始尺寸的图像。

    人脸重建效果

    VAE通过学习输入图像的特征生成新的图像。虽然生成的图像并不完美,但它们展现了VAE基本的生成能力。以下是一些重建效果的示例:

    重建图片

    从这些图像可以看出,VAE能够较好地重建女性人的面部特征。这是由于Celeb-A数据集中女性的比例较高。同时,背景颜色的编码和解码也能初步实现一定程度的图像细节恢复。

    生成新面孔

    为了生成新的面孔,我们从高斯分布中随机采样潜在变量,然后通过解码器生成图像。基本实现代码如下:

    z_samples = np.random.normal(loc=0., scale=1, size=(image_num, z_dim))images = vae.decoder(z_samples.astype(np.float32))

    生成的面孔可能会出现 eyed_Y criticized asToo monochrome 的问题。为了解决此类问题,我们可以采用采样技巧来生成更真实的图像。

    采样技巧

    随机采样可能会导致生成的图像多样性不足。我们可以通过以下方法优化生成效果:

  • 收集数据

    从训练集中获取图像输入VAE解码器,获取潜在变量的均值和方差。

  • 调整采样方式

    使用均值和方差生成标准正态分布的样本,然后添加均值,得到最终采样点。

  • 生成图片

    通过上述方法,可以显著改善图像生成的质量。

    控制人脸属性

    VAE的潜在空间即潜在变量的每个维度,代表了特定的语义信息。通过对潜在向量进行操作,我们可以实现面部属性的编辑与控制。

  • 潜在空间分析

    找到每个潜在变量的平均值和变化范围,帮助理解潜在变量的语义含义。

  • 属性向量提取

    从数据集中获取图像的面部属性注释,使用VAE解码器生成潜在向量,进而提取具有特定属性的潜在变量方向。

  • 属性向量应用

    将提取的属性向量加到潜在变量中,逐步生成具有目标属性的新图像。例如,可以实现以下操作:

  • new_z_samples = z_samples + smiling_magnitude * smiling_vector

    其中,smiling_magnitude 是标量缩放系数,smiling_vector 是属性向量。

    实际应用示例

    通过上述方法,我们可以生成带有特定属性的新图像。例如:

    修改多个属性

    完整代码

    以下是完整的VAE实现代码,供技术人员参考:

    # vae_faces.ipynbimport tensorflow as tffrom tensorflow_probability import distributions as tfdfrom tensorflow.keras import layers, Modelfrom tensorflow.keras.layers import Layer, Input, Conv2D, Dense, Flatten, Reshape, Lambda, Dropoutfrom tensorflow.keras.layers import Conv2DTranspose, MaxPooling2D, UpSampling2D, LeakyReLU, BatchNormalizationfrom tensorflow.keras.activations import relufrom tensorflow.keras.models import Sequential, load_modelfrom tensorflow.keras.callbacks import ModelCheckpoint, EarlyStoppingfrom tensorflow.keras.preprocessing.image import ImageDataGeneratorimport tensorflow_datasets as tfdsimport cv2import numpy as npimport matplotlib.pyplot as pltimport datetime, osimport warningswarnings.filterwarnings('ignore')print("Tensorflow", tf.__version__)strategy = tf.distribute.MirroredStrategy()num_devices = strategy.num_replicas_in_syncprint('Number of devides: {}'.format(num_devices))(ds_train, ds_test), ds_info = tfds.load(    'celeb_a',    split=['train', 'test'],    shuffle_files=True,    with_info=True)def preprocess(sample):    image = sample['image']    image = tf.image.resize(image, [112, 112])    image = tf.cast(image, tf.float32) / 255.    return image, imageds_train = ds_train.map(preprocess)ds_train = ds_train.shuffle(128)ds_train = ds_train.batch(batch_size, drop_remainder=True).prefetch(batch_size)ds_test = ds_test.map(preprocess).batch(batch_size, drop_remainder=True).prefetch(batch_size)train_num = ds_info.splits['train'].num_examplestest_num = ds_info.splits['test'].num_examplesclass GaussianSampling(Layer):    def call(self, inputs):        means, logvar = inputs        epsilon = tf.random.normal(shape=tf.shape(means), mean=0., stddev=1.)        samples = means + tf.exp(0.5 * logvar) * epsilon        return samplesclass DownConvBlock(Layer):    count = 0    def __init__(self, filters, kernel_size=(3,3), strides=1, padding='same'):        super(DownConvBlock, self).__init__(name=f"DownConvBlock_{DownConvBlock.count}")        DownConvBlock.count += 1        self.forward = Sequential([            Conv2D(filters, kernel_size, strides, padding),            BatchNormalization(),            LeakyReLU(0.2)        ])    def call(self, inputs):        return self.forward(inputs)class UpConvBlock(Layer):    count = 0    def __init__(self, filters, kernel_size=(3,3), padding='same'):        super(UpConvBlock, self).__init__(name=f"UpConvBlock_{UpConvBlock.count}")        UpConvBlock.count += 1        self.forward = Sequential([            Conv2D(filters, kernel_size, 1, padding),            LeakyReLU(0.2),            UpSampling2D((2,2))        ])    def call(self, inputs):        return self.forward(inputs)class Encoder(Layer):    def __init__(self, z_dim, name='encoder'):        super(Encoder, self).__init__(name=name)        self.features_extract = Sequential([            DownConvBlock(filters=32, kernel_size=(3,3), strides=2),            DownConvBlock(filters=32, kernel_size=(3,3), strides=2),            DownConvBlock(filters=64, kernel_size=(3,3), strides=2),            DownConvBlock(filters=64, kernel_size=(3,3), strides=2),            Flatten()        ])        self.dense_mean = Dense(z_dim, name='mean')        self.dense_logvar = Dense(z_dim, name='logvar')        self.sampler = GaussianSampling()    def call(self, inputs):        x = self.features_extract(inputs)        mean = self.dense_mean(x)        logvar = self.dense_logvar(x)        z = self.sampler([mean, logvar])        return z, mean, logvarclass Decoder(Layer):    def __init__(self, z_dim, name='decoder'):        super(Decoder, self).__init__(name=name)        self.forward = Sequential([            Dense(7*7*64, activation='relu'),            Reshape((7,7,64)),            UpConvBlock(filters=64, kernel_size=(3,3)),            UpConvBlock(filters=64, kernel_size=(3,3)),            UpConvBlock(filters=32, kernel_size=(3,3)),            UpConvBlock(filters=32, kernel_size=(3,3)),            Conv2D(filters=3, kernel_size=(3,3), strides=1, padding='same', activation='sigmoid')        ])    def call(self, inputs):        return self.forward(inputs)class VAE(Model):    def __init__(self, z_dim, name='VAE'):        super(VAE, self).__init__(name=name)        self.encoder = Encoder(z_dim)        self.decoder = Decoder(z_dim)        self.mean = None        self.logvar = None    def call(self, inputs):        z, self.mean, self.logvar = self.encoder(inputs)        out = self.decoder(z)        return outif num_devices > 1:    with strategy.scope():        vae = VAE(z_dim=200)else:    vae = VAE(z_dim=200)def vae_kl_loss(y_true, y_pred):    kl_loss = -0.5 * tf.reduce_mean(1 + vae.logvar - tf.square(vae.mean) - tf.exp(vae.logvar))    return kl_lossdef vae_rc_loss(y_true, y_pred):    rc_loss = tf.keras.losses.MSE(y_true, y_pred)    return rc_lossdef vae_loss(y_true, y_pred):    kl_loss = vae_kl_loss(y_true, y_pred)    rc_loss = vae_rc_loss(y_true, y_pred)    kl_weight_const = 0.01    return kl_weight_const * kl_loss + rc_lossmodel_path = "vae_faces_cele_a.h5"checkpoint = ModelCheckpoint(    model_path,    monitor='vae_rc_loss',    verbose=1,    save_best_only=True,    mode='auto',    save_weights_only=True)early = EarlyStopping(    monitor='vae_rc_loss',    mode='auto',    patience=3)callbacks_list = [checkpoint, early]initial_learning_rate = 1e-3steps_per_epoch = int(np.round(train_num / batch_size))lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(    initial_learning_rate,    decay_steps=steps_per_epoch,    decay_rate=0.96,    staircase=True)vae.compile(    loss=[vae_loss],    optimizer=tf.keras.optimizers.RMSprop(learning_rate=3e-3),    metrics=[vae_kl_loss, vae_rc_loss])history = vae.fit(ds_train, validation_data=ds_test, epochs=50, callbacks=callbacks_list)images, labels = next(iter(ds_train))vae.load_weights(model_path)outputs = vae.predict(images)# Displaygrid_col = 8grid_row = 2f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col*2, grid_row*2))i = 0for row in range(0, grid_row, 2):    for col in range(grid_col):        axarr[row, col].imshow(images[i])        axarr[row, col].axis('off')        axarr[row+1, col].imshow(outputs[i])        axarr[row+1, col].axis('off')        i += 1f.tight_layout(0.1, h_pad=0.2, w_pad=0.1)plt.show()avg_z_mean = []avg_z_std = []for i in range(steps_per_epoch):    images, labels = next(iter(ds_train))    z, z_mean, z_logvar = vae.encoder(images)    avg_z_mean.append(np.mean(z_mean, axis=0))    avg_z_std.append(np.mean(np.exp(0.5 * z_logvar), axis=0))avg_z_mean = np.mean(avg_z_mean, axis=0)avg_z_std = np.mean(avg_z_std, axis=0)plt.plot(avg_z_mean)plt.ylabel("Average z mean")plt.xlabel("z dimension")grid_col = 10grid_row = 10f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col, 1.5*grid_row))i = 0for row in range(grid_row):    for col in range(grid_col):        axarr[row, col].hist(z[:, i], bins=20)        i += 1z_dim = 200z_samples = np.random.normal(loc=0., scale=np.mean(avg_z_std), size=(25, z_dim))images = vae.decoder(z_samples.astype(np.float32))grid_col = 7grid_row = 2f, axarr = plt.subplots(grid_row, grid_col, figsize=(2*grid_col, 2*grid_row))i = 0for row in range(grid_row):    for col in range(grid_col):        axarr[row, col].imshow(images[i])        axarr[row, col].axis('off')        i += 1f.tight_layout(0.1, h_pad=0.2, w_pad=0.1)plt.show()# 采样技巧z_samples = np.random.normal(loc=0., scale=np.mean(avg_z_std), size=(1, 200))input_tensor = np.expand_dims(z_samples, 0).astype(np.float32) / 255.

    属性控制示例

    通过上述方法,我们可以对人脸属性进行深度控制。以下是一些具体的功能示例:

    explore_latent_variable(Male=(-5,5,0.1), Eyeglasses=(-5,5,0.1), Young=(-5,5,0.1), Smiling=(-5,5,0.1), Blond_Hair=(-5,5,0.1), Pale_Skin=(-5,5,0.1), Mustache=(-5,5,0.1))

    实用工具

    本文提供了一个Jupyter notebook,方便读者在本地环境中进行实践:

    https://github.com/yourGitHub/vae_faces/blob/master/vae_faces.ipynb

    总结

    通过本文的实践,我们展示了如何利用变分自编码器(VAE)进行人脸图片生成以及面部属性控制。从基础的图像生成到复杂的属性编辑,VAE提供了一种强大的工具实现图像生成任务。

    转载地址:http://qynuk.baihongyu.com/

    你可能感兴趣的文章
    MySQL与Oracle的数据迁移注意事项,另附转换工具链接
    查看>>
    mysql丢失更新问题
    查看>>
    MySQL两千万数据优化&迁移
    查看>>
    MySql中 delimiter 详解
    查看>>
    MYSQL中 find_in_set() 函数用法详解
    查看>>
    MySQL中auto_increment有什么作用?(IT枫斗者)
    查看>>
    MySQL中B+Tree索引原理
    查看>>
    mysql中cast() 和convert()的用法讲解
    查看>>
    mysql中datetime与timestamp类型有什么区别
    查看>>
    MySQL中DQL语言的执行顺序
    查看>>
    mysql中floor函数的作用是什么?
    查看>>
    MySQL中group by 与 order by 一起使用排序问题
    查看>>
    mysql中having的用法
    查看>>
    MySQL中interactive_timeout和wait_timeout的区别
    查看>>
    mysql中int、bigint、smallint 和 tinyint的区别、char和varchar的区别详细介绍
    查看>>
    mysql中json_extract的使用方法
    查看>>
    mysql中json_extract的使用方法
    查看>>
    mysql中kill掉所有锁表的进程
    查看>>
    mysql中like % %模糊查询
    查看>>
    MySql中mvcc学习记录
    查看>>