Mô hình WGAN
Trong các bài trước chúng ta đã tìm hiểu cách xây dựng mô hình GAN truyền thống. Ở đó chúng ta sử dụng binary cross entropy (BCE) làm cost function. Tuy nhiên việc sử dụng BCE có hai nhược điểm:
- Xảy ra mode collapse (“sụp đổ mô hình”). Hiểu đơn giản thì giả sử ban đầu chúng ta có 10 classes các chữ số viết tay từ 0 đến 9, tuy nhiên sau khi training GAN khi sinh dữ liệu chúng ta thường chỉ nhận được samples từ một class nào đó.
- Vanishing gradient - dẫn đến việc học rất chậm, điều này do discriminator có thể quá xuất sắc và nó không đưa ra được feedback tốt cho generator cách cải thiện như nào.
Để cải thiện mô hình GAN đòi hỏi sử dụng cost function mới (có liên quan đến độ đo sự khác biệt giữa hai phân phối real samples và generated samples). Martin Arjovsky và cộng sự đã đưa ra mô hình Wasserstein GAN - WGAN, đưa vào cost function - Wasserstein.
Trong mô hình GAN truyền thống, discriminator có nhiệm vụ dự đoán xác suất ảnh tạo ra là thật hay giả. Trong khi đó đối với WGAN vai trò của discriminator được thay thế vởi critic. Critic bây giờ sẽ chấm điểm cho ảnh dựa trên độ thật giả (giá trị không còn bị giới hạn trong khoảng 0, 1).
Triển khai Wasserstein GAN
Dưới đây là thuật toán cho Wasserstein GAN được các tác giả đưa ra.
Sự khác biệt trong triển khai WGAN khác với GAN truyền thống như sau:
- Output layer của critic sử dụng linear activation function (không sử dụng sigmoid function)
model.add(Dense(1))
- Sử dụng
label = -1
cho real images,label = 1
cho fake images (GAN truyền thông sử dụng 1, 0 cho real và fake images). Đây chỉ là một cách giúp triển khai WGAN, đôi khi không cần gán nhãn vẫn có thể triển khai được, do trong discriminator có phân biệt rõ ảnh thật giả.
# tạo class labels -1 cho real images
y = - np.ones((n_samples, 1))
# tạo class labels 1 cho fake images
y = np.ones((n_samples, 1))
- Dùng Wasserstein loss để train critic và generator (không dùng BCE)
Chú ý: $\mathbb{E}$ thể hiện lấy kì vọng, khi triển khai đối với các examples thì chúng ta lấy trung bình. $\mathbb{E}(c(x))$ chính là điểm số critic đánh giá ảnh thật.
Critic sẽ đi maximize $\mathbb{E}(c(x)) - \mathbb{E}(c(g(z))$ do nó muốn phân bố của ảnh thật và ảnh giả khác nhau càng nhiều càng tốt. Generator lại muốn minimize biểu thức đó do muốn tạo ảnh giả gần với thật nhất. Tuy nhiên thông thường chúng ta hay đi xây dựng bài toán dựa trên gradient descent, do đó chúng ta sẽ đảo dấu của biểu thức (1) và đi tìm min cho nó $- \mathbb{E}(c(x)) + \mathbb{E}(c(g(z))$. Để đơn giản sẽ đi gán nhãn cho ảnh thật có label = -1
và ảnh giả có label = 1
. Lúc này ta đi tìm min cho:
ở đây, $y(x)$ là label của example $x$. Xin nhấn mạnh lại việc gán nhãn như này chỉ là một cách triển khai WGAN, chúng ta hoàn toàn có thể làm cách khác.
import tensorflow as tf
# wasserstein loss
def wasserstein_loss(y_true, y_pred):
return tf.reduce_mean(y_true * y_pred)
Khi compile model chúng ta sẽ sử dụng tên loss function vừa xây dựng:
# compile the model
model.compile(loss=wasserstein_loss, ...)
- Giới hạn weights của critic trong khoảng cho phép sau mỗi lần cập nhật bằng weight clipping. Đây là một cách thỏa mãn điều kiện L-1 continuous.
Chúng ta có thể thực hiện weight clipping bằng Keras constraint. Để thực hiện điều này sẽ đi tạo class mới kế thừa từ Constranit class và định nghĩa method __call__()
thực hiện việc giới hạn các giá trị và get_config()
để trả về các cấu hình.
# clip model weights
class ClipConstraint(Constraint):
# truyền vào giá trị cận trên
def __init__(self, clip_value):
self.clip_value = clip_value
# clip model weights to hypercube
def __call__(self, weights):
return tf.clip_by_value(weights, -self.clip_value, self.clip_value)
# get the config
def get_config(self):
return {'clip_value': self.clip_value}
Để sử dụng constraint chúng ta đi khởi tạo object và truyền nó cho kernel_constraint trong layer.
...
# define the constraint
const = ClipConstraint(0.01)
...
# use the constraint in a layer
model.add(Conv2D(..., kernel_constraint=const))
- Update critic nhiều lần hơn generator trong mỗi iteration, thường lấy 5 (GAN truyền thống hay để 1)
...
# main gan training loop
for i in range(n_steps):
# update the critic
for _ in range(n_critic):
# get randomly selected 'real' samples
X_real, y_real = generate_real_samples(dataset, half_batch)
# update critic model weights
c_loss1 = c_model.train_on_batch(X_real, y_real)
# generate 'fake' examples
X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
# update critic model weights
c_loss2 = c_model.train_on_batch(X_fake, y_fake)
# update generator
# prepare points in latent space as input for the generator
X_gan = generate_latent_points(latent_dim, n_batch)
# create inverted labels for the fake samples
y_gan = np.ones((n_batch, 1))
# update the generator via the critic's error
g_loss = gan_model.train_on_batch(X_gan, y_gan)
- Sử dụng RMSProp với learning rate nhỏ (ví dụ 0.00005) và không có momentum (truyền thống hay dùng SGD với momentum)
opt = RMSprop(lr=0.00005)
Train Wasserstein GAN model
Trong phần này sẽ xây dựng WGAN để tạo ra một chữ số viết tay từ bộ dữ liệu MNIST. Sau đó có thể thử cho toàn bộ các chữ số của MNIST.
Critic model
Xây dựng critic model có sử dụng:
- LeakyReLU activation function với alpha = 0.2
- Batch Normalization
- Conv layer với
stride=2
để giảm kích thước thay cho Pooling layer.
# định nghĩa critic model
def define_critic(in_shape=(28,28,1)):
# khởi tạo weights
init = RandomNormal(stddev=0.02)
# weight constraint - weight clipping
const = ClipConstraint(0.01)
# định nghĩa model
model = Sequential()
# giảm kích thước xuống 14x14
model.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const, input_shape=in_shape))
model.add(BatchNormalization())
model.add(LeakyReLU(alpha=0.2))
# giảm kích thước xuống 7x7
model.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const))
model.add(BatchNormalization())
model.add(LeakyReLU(alpha=0.2))
# linear activation
model.add(Flatten())
model.add(Dense(1))
# compile model
opt = RMSprop(lr=0.00005)
model.compile(loss=wasserstein_loss, optimizer=opt)
return model
Generator model
Khi xây dựng generator model chúng ta không thực hiện compile vì việc train generator sẽ thông qua critic.
# định nghĩa generator model
def define_generator(latent_dim):
# khởi tạo weights
init = RandomNormal(stddev=0.02)
# define model
model = Sequential()
# foundation for 7x7 image
n_nodes = 128 * 7 * 7
model.add(Dense(n_nodes, kernel_initializer=init, input_dim=latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(Reshape((7, 7, 128)))
# tăng kích thước lên 14x14
model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
model.add(BatchNormalization())
model.add(LeakyReLU(alpha=0.2))
# tăng kích thước lên 28x28
model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
model.add(BatchNormalization())
model.add(LeakyReLU(alpha=0.2))
# output 28x28x1, dùng tanh chuyển về (-1, 1)
model.add(Conv2D(1, (7,7), activation='tanh', padding='same', kernel_initializer=init))
return model
GAN model
# define the combined generator and critic model, for updating the generator
def define_gan(generator, critic):
# make weights in the critic not trainable
for layer in critic.layers:
if not isinstance(layer, BatchNormalization):
layer.trainable = False
# connect them
model = Sequential()
# add generator
model.add(generator)
# add the critic
model.add(critic)
# compile model
opt = RMSprop(lr=0.00005)
model.compile(loss=wasserstein_loss, optimizer=opt)
return model
Đây là ảnh được tạo ra sau khi training với 10 epochs.
Source code mọi người có thể tham khảo tại đây.
Tài liệu tham khảo
- https://machinelearningmastery.com/how-to-code-a-wasserstein-generative-adversarial-network-wgan-from-scratch/
- https://www.coursera.org/specializations/generative-adversarial-networks-gans
- https://arxiv.org/abs/1701.07875
- https://github.com/TanyaChutani/WGAN-TF2.x