博客
关于我
生成模型应用——使用变分自编码器(VAE)控制人脸属性生成人脸图片
阅读量:790 次
发布时间:2019-03-25

本文共 9448 字,大约阅读时间需要 31 分钟。

使用变分自编码器(VAE)控制人脸属性生成人脸图片

基本概念

变分自编码器(VAE)是一种生成模型方法,能够根据输入数据生成高质量的图像。VAE通过最大化数据的对数似然来学习数据分布,在学习过程中生成潜在变量,并通过潜在变量生成新的数据。在人脸生成中,我们可以使用VAE以控制人脸的具体属性,比如表情、发型和肤色等。

数据集选择

选择合适的人脸数据集对于实验效果至关重要。在文中讨论的实验中,使用了以下两种数据集:

  • Celeb-A 数据集

    这是一个学术级数据集,包含了丰富的面部标注信息。虽然这类数据集不完全适合商业用途,但由于其丰富的数据特征和高质量的图像,可以满足我们对面部属性研究的需求。

  • DW-A 数据集

    这是一个免费、适合商业用途的数据集。该数据集提供了高分辨率的图像,适合图像生成实验。

  • 网络架构

    VAE的整体网络架构包含编码器(Encoder)和解码器(Decoder)。编码器将输入图像编码为潜在变量,解码器再将潜在变量反解码为图像。具体来说,编码器和解码器的结构如下:

    • 编码器(Encoder)

      • 卷积层(Conv2D):用于逐步减少图像的空间维度并提取特征。
      • 全连接层(Dense):将提取的特征映射为潜在变量。
    • 解码器(Decoder)

      • 卷积逆变换层(UpConvBlock):逐步增加空间维度并细化图像。
      • 分层卷积层(Conv2D):最终生成原始尺寸的图像。

    人脸重建效果

    VAE通过学习输入图像的特征生成新的图像。虽然生成的图像并不完美,但它们展现了VAE基本的生成能力。以下是一些重建效果的示例:

    重建图片

    从这些图像可以看出,VAE能够较好地重建女性人的面部特征。这是由于Celeb-A数据集中女性的比例较高。同时,背景颜色的编码和解码也能初步实现一定程度的图像细节恢复。

    生成新面孔

    为了生成新的面孔,我们从高斯分布中随机采样潜在变量,然后通过解码器生成图像。基本实现代码如下:

    z_samples = np.random.normal(loc=0., scale=1, size=(image_num, z_dim))images = vae.decoder(z_samples.astype(np.float32))

    生成的面孔可能会出现 eyed_Y criticized asToo monochrome 的问题。为了解决此类问题,我们可以采用采样技巧来生成更真实的图像。

    采样技巧

    随机采样可能会导致生成的图像多样性不足。我们可以通过以下方法优化生成效果:

  • 收集数据

    从训练集中获取图像输入VAE解码器,获取潜在变量的均值和方差。

  • 调整采样方式

    使用均值和方差生成标准正态分布的样本,然后添加均值,得到最终采样点。

  • 生成图片

    通过上述方法,可以显著改善图像生成的质量。

    控制人脸属性

    VAE的潜在空间即潜在变量的每个维度,代表了特定的语义信息。通过对潜在向量进行操作,我们可以实现面部属性的编辑与控制。

  • 潜在空间分析

    找到每个潜在变量的平均值和变化范围,帮助理解潜在变量的语义含义。

  • 属性向量提取

    从数据集中获取图像的面部属性注释,使用VAE解码器生成潜在向量,进而提取具有特定属性的潜在变量方向。

  • 属性向量应用

    将提取的属性向量加到潜在变量中,逐步生成具有目标属性的新图像。例如,可以实现以下操作:

  • new_z_samples = z_samples + smiling_magnitude * smiling_vector

    其中,smiling_magnitude 是标量缩放系数,smiling_vector 是属性向量。

    实际应用示例

    通过上述方法,我们可以生成带有特定属性的新图像。例如:

    修改多个属性

    完整代码

    以下是完整的VAE实现代码,供技术人员参考:

    # vae_faces.ipynbimport tensorflow as tffrom tensorflow_probability import distributions as tfdfrom tensorflow.keras import layers, Modelfrom tensorflow.keras.layers import Layer, Input, Conv2D, Dense, Flatten, Reshape, Lambda, Dropoutfrom tensorflow.keras.layers import Conv2DTranspose, MaxPooling2D, UpSampling2D, LeakyReLU, BatchNormalizationfrom tensorflow.keras.activations import relufrom tensorflow.keras.models import Sequential, load_modelfrom tensorflow.keras.callbacks import ModelCheckpoint, EarlyStoppingfrom tensorflow.keras.preprocessing.image import ImageDataGeneratorimport tensorflow_datasets as tfdsimport cv2import numpy as npimport matplotlib.pyplot as pltimport datetime, osimport warningswarnings.filterwarnings('ignore')print("Tensorflow", tf.__version__)strategy = tf.distribute.MirroredStrategy()num_devices = strategy.num_replicas_in_syncprint('Number of devides: {}'.format(num_devices))(ds_train, ds_test), ds_info = tfds.load(    'celeb_a',    split=['train', 'test'],    shuffle_files=True,    with_info=True)def preprocess(sample):    image = sample['image']    image = tf.image.resize(image, [112, 112])    image = tf.cast(image, tf.float32) / 255.    return image, imageds_train = ds_train.map(preprocess)ds_train = ds_train.shuffle(128)ds_train = ds_train.batch(batch_size, drop_remainder=True).prefetch(batch_size)ds_test = ds_test.map(preprocess).batch(batch_size, drop_remainder=True).prefetch(batch_size)train_num = ds_info.splits['train'].num_examplestest_num = ds_info.splits['test'].num_examplesclass GaussianSampling(Layer):    def call(self, inputs):        means, logvar = inputs        epsilon = tf.random.normal(shape=tf.shape(means), mean=0., stddev=1.)        samples = means + tf.exp(0.5 * logvar) * epsilon        return samplesclass DownConvBlock(Layer):    count = 0    def __init__(self, filters, kernel_size=(3,3), strides=1, padding='same'):        super(DownConvBlock, self).__init__(name=f"DownConvBlock_{DownConvBlock.count}")        DownConvBlock.count += 1        self.forward = Sequential([            Conv2D(filters, kernel_size, strides, padding),            BatchNormalization(),            LeakyReLU(0.2)        ])    def call(self, inputs):        return self.forward(inputs)class UpConvBlock(Layer):    count = 0    def __init__(self, filters, kernel_size=(3,3), padding='same'):        super(UpConvBlock, self).__init__(name=f"UpConvBlock_{UpConvBlock.count}")        UpConvBlock.count += 1        self.forward = Sequential([            Conv2D(filters, kernel_size, 1, padding),            LeakyReLU(0.2),            UpSampling2D((2,2))        ])    def call(self, inputs):        return self.forward(inputs)class Encoder(Layer):    def __init__(self, z_dim, name='encoder'):        super(Encoder, self).__init__(name=name)        self.features_extract = Sequential([            DownConvBlock(filters=32, kernel_size=(3,3), strides=2),            DownConvBlock(filters=32, kernel_size=(3,3), strides=2),            DownConvBlock(filters=64, kernel_size=(3,3), strides=2),            DownConvBlock(filters=64, kernel_size=(3,3), strides=2),            Flatten()        ])        self.dense_mean = Dense(z_dim, name='mean')        self.dense_logvar = Dense(z_dim, name='logvar')        self.sampler = GaussianSampling()    def call(self, inputs):        x = self.features_extract(inputs)        mean = self.dense_mean(x)        logvar = self.dense_logvar(x)        z = self.sampler([mean, logvar])        return z, mean, logvarclass Decoder(Layer):    def __init__(self, z_dim, name='decoder'):        super(Decoder, self).__init__(name=name)        self.forward = Sequential([            Dense(7*7*64, activation='relu'),            Reshape((7,7,64)),            UpConvBlock(filters=64, kernel_size=(3,3)),            UpConvBlock(filters=64, kernel_size=(3,3)),            UpConvBlock(filters=32, kernel_size=(3,3)),            UpConvBlock(filters=32, kernel_size=(3,3)),            Conv2D(filters=3, kernel_size=(3,3), strides=1, padding='same', activation='sigmoid')        ])    def call(self, inputs):        return self.forward(inputs)class VAE(Model):    def __init__(self, z_dim, name='VAE'):        super(VAE, self).__init__(name=name)        self.encoder = Encoder(z_dim)        self.decoder = Decoder(z_dim)        self.mean = None        self.logvar = None    def call(self, inputs):        z, self.mean, self.logvar = self.encoder(inputs)        out = self.decoder(z)        return outif num_devices > 1:    with strategy.scope():        vae = VAE(z_dim=200)else:    vae = VAE(z_dim=200)def vae_kl_loss(y_true, y_pred):    kl_loss = -0.5 * tf.reduce_mean(1 + vae.logvar - tf.square(vae.mean) - tf.exp(vae.logvar))    return kl_lossdef vae_rc_loss(y_true, y_pred):    rc_loss = tf.keras.losses.MSE(y_true, y_pred)    return rc_lossdef vae_loss(y_true, y_pred):    kl_loss = vae_kl_loss(y_true, y_pred)    rc_loss = vae_rc_loss(y_true, y_pred)    kl_weight_const = 0.01    return kl_weight_const * kl_loss + rc_lossmodel_path = "vae_faces_cele_a.h5"checkpoint = ModelCheckpoint(    model_path,    monitor='vae_rc_loss',    verbose=1,    save_best_only=True,    mode='auto',    save_weights_only=True)early = EarlyStopping(    monitor='vae_rc_loss',    mode='auto',    patience=3)callbacks_list = [checkpoint, early]initial_learning_rate = 1e-3steps_per_epoch = int(np.round(train_num / batch_size))lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(    initial_learning_rate,    decay_steps=steps_per_epoch,    decay_rate=0.96,    staircase=True)vae.compile(    loss=[vae_loss],    optimizer=tf.keras.optimizers.RMSprop(learning_rate=3e-3),    metrics=[vae_kl_loss, vae_rc_loss])history = vae.fit(ds_train, validation_data=ds_test, epochs=50, callbacks=callbacks_list)images, labels = next(iter(ds_train))vae.load_weights(model_path)outputs = vae.predict(images)# Displaygrid_col = 8grid_row = 2f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col*2, grid_row*2))i = 0for row in range(0, grid_row, 2):    for col in range(grid_col):        axarr[row, col].imshow(images[i])        axarr[row, col].axis('off')        axarr[row+1, col].imshow(outputs[i])        axarr[row+1, col].axis('off')        i += 1f.tight_layout(0.1, h_pad=0.2, w_pad=0.1)plt.show()avg_z_mean = []avg_z_std = []for i in range(steps_per_epoch):    images, labels = next(iter(ds_train))    z, z_mean, z_logvar = vae.encoder(images)    avg_z_mean.append(np.mean(z_mean, axis=0))    avg_z_std.append(np.mean(np.exp(0.5 * z_logvar), axis=0))avg_z_mean = np.mean(avg_z_mean, axis=0)avg_z_std = np.mean(avg_z_std, axis=0)plt.plot(avg_z_mean)plt.ylabel("Average z mean")plt.xlabel("z dimension")grid_col = 10grid_row = 10f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col, 1.5*grid_row))i = 0for row in range(grid_row):    for col in range(grid_col):        axarr[row, col].hist(z[:, i], bins=20)        i += 1z_dim = 200z_samples = np.random.normal(loc=0., scale=np.mean(avg_z_std), size=(25, z_dim))images = vae.decoder(z_samples.astype(np.float32))grid_col = 7grid_row = 2f, axarr = plt.subplots(grid_row, grid_col, figsize=(2*grid_col, 2*grid_row))i = 0for row in range(grid_row):    for col in range(grid_col):        axarr[row, col].imshow(images[i])        axarr[row, col].axis('off')        i += 1f.tight_layout(0.1, h_pad=0.2, w_pad=0.1)plt.show()# 采样技巧z_samples = np.random.normal(loc=0., scale=np.mean(avg_z_std), size=(1, 200))input_tensor = np.expand_dims(z_samples, 0).astype(np.float32) / 255.

    属性控制示例

    通过上述方法,我们可以对人脸属性进行深度控制。以下是一些具体的功能示例:

    explore_latent_variable(Male=(-5,5,0.1), Eyeglasses=(-5,5,0.1), Young=(-5,5,0.1), Smiling=(-5,5,0.1), Blond_Hair=(-5,5,0.1), Pale_Skin=(-5,5,0.1), Mustache=(-5,5,0.1))

    实用工具

    本文提供了一个Jupyter notebook,方便读者在本地环境中进行实践:

    https://github.com/yourGitHub/vae_faces/blob/master/vae_faces.ipynb

    总结

    通过本文的实践,我们展示了如何利用变分自编码器(VAE)进行人脸图片生成以及面部属性控制。从基础的图像生成到复杂的属性编辑,VAE提供了一种强大的工具实现图像生成任务。

    转载地址:http://qynuk.baihongyu.com/

    你可能感兴趣的文章
    Mysql8 数据库安装及主从配置 | Spring Cloud 2
    查看>>
    mysql8 配置文件配置group 问题 sql语句group不能使用报错解决 mysql8.X版本的my.cnf配置文件 my.cnf文件 能够使用的my.cnf配置文件
    查看>>
    MySQL8.0.29启动报错Different lower_case_table_names settings for server (‘0‘) and data dictionary (‘1‘)
    查看>>
    MYSQL8.0以上忘记root密码
    查看>>
    Mysql8.0以上重置初始密码的方法
    查看>>
    mysql8.0新特性-自增变量的持久化
    查看>>
    Mysql8.0注意url变更写法
    查看>>
    Mysql8.0的特性
    查看>>
    MySQL8修改密码报错ERROR 1819 (HY000): Your password does not satisfy the current policy requirements
    查看>>
    MySQL8修改密码的方法
    查看>>
    Mysql8在Centos上安装后忘记root密码如何重新设置
    查看>>
    Mysql8在Windows上离线安装时忘记root密码
    查看>>
    MySQL8找不到my.ini配置文件以及报sql_mode=only_full_group_by解决方案
    查看>>
    mysql8的安装与卸载
    查看>>
    MySQL8,体验不一样的安装方式!
    查看>>
    MySQL: Host '127.0.0.1' is not allowed to connect to this MySQL server
    查看>>
    Mysql: 对换(替换)两条记录的同一个字段值
    查看>>
    mysql:Can‘t connect to local MySQL server through socket ‘/var/run/mysqld/mysqld.sock‘解决方法
    查看>>
    MYSQL:基础——3N范式的表结构设计
    查看>>
    MYSQL:基础——触发器
    查看>>