Introduction

As the name suggests, GAN holds an adversarial relationship between two networks: a generator and a discriminator. The generator is a generative model that learns the data distribution of a population. Note that it doesn’t learn directly, but via the prediction / classification / feedback of the discriminator. The discriminative model’s job is to tell whether a sample comes from the data distribution or generated by the generator. A real world scenario for this framework is one in which the criminal tries to forge fake money and the police would need to tell fake from real money. In game theoretical framework, this is a two person zero sum game whose solution would be derived through minimax strategy: one tries to maximize his utility and the other tries to min it. There is a global solution in which the forgery looks the same as the real images/money (the generator learns the distribution), and the discriminator cannot tell the difference any more (the estimation of a sample to be fake is one half - effectively a random choice).

For the technical part, we can use fully connected neural nets for both the player. Or we use convolutional network. In the follow up example, a deep convolutional network achieves the best result. The input noise variable comes from a uniform distribution. Over time, the parameters gradually learn how to convert random noise \(p_z(z)\) to forgery \(p_g(x)\) that fools the discriminator easily. \(G(\theta) = H_{\theta}\) is the mapping from random noise \(p(z)\) to the data space. We also define \(D_{\phi}\) to be the discriminator network that output a scalar. Since D is the (discriminative) neural net that classifies fake and real images, D(x) would be the probability that the image is real. Therefore D would be trained so that we maximize the ability to distinguish real from fake, which is D(x). G would be trained to minimize the log of being caught hence the term \(log(1-D(G(z)))\). In game theory, the optimal strategy is the minimax:

\[min_G max_D V(D,G) = E_x {[logD(x)]} + E_z {[log(1-D(G(z)))]}\]

The first term is the expected utility of the ability to tell the real image, we would max it. The second term is the expected utility to not be caught, we put a minus in front of the term D(G(z)) which is the probability of getting caught from forgery.

The loss function is implicit in the training, the algorithm learns its own loss function.

Algorithm

For each epoch:

  • in each step, we firstly train the discriminator:

    • sample minibatch of noise samples form noise prior \(p_g(z)\)

    • sample minibatch of example from data generating distribution p(x)

    • update the discriminator by ascending its stochastic gradient

\[\nabla_{\phi} \frac{1}{m} \sum_{i=1}^{m} {[log D(x) + log(1-D(G(z)))]}\]
  • after finishing training the discriminator, we train the generator:

    • sample minibatch of noise sample from \(p_g(z)\)

    • update the generator by descending its stochastic gradient:

\[\nabla_{\theta} \frac{1}{m} \sum_{i=1}^{m} log(1-D(G(z)))\]

When we fix the generator during the training of the discriminator, the discriminator will reach the perfect accuracy: distinguish the ratio of real images in the total of images. The minimizing of the the propensity of getting caught, unfortunately, combining with the feedback from the discriminator, lead to the point that faker tries to keep the production of the fake images to be equal to the amount of real images. This leads to the optimal strategy of the discriminator to not better than flipping a coin.

Training

Initially, we input the correctly labeled images both from the generator and the data distribution to the discriminator, so that the discriminitor learns to predict with backpropagation. Training the generator is a bit more complicated since we need the feedback from the discriminator. So, we freeze the discriminator network, and use it as an agent of feedback. For the part of the generator, we input a random vector (from the space of all possibile vector for images - called the laten space) for the generator to output a fake image. This fake image is put through the discriminitor for classification, but we set the label for the fake image to be 1 (real). This is to maximize the mistake of the discriminator. The predicted result from the discriminator is put back to the generator so that the generator learns. This is roughly supervised learning. Or if we see it from the search problem: Starting from a random point in the latent space, the generator conduct a search for the exquisite images with the guidance from the discriminator. The cross entropy cost function is used for the discriminiator’s \(\hat{y}\) and y true. In tensorflow, for the fully connected version of the GAN, we use ReLU activation, He normal initialization, binary cross entropy loss and RMSProp optimizer.

Fully Connected GAN


Dense = tf.keras.layers.Dense
generator = tf.keras.Sequential([
    Dense(100, activation="relu", kernel_initializer="he_normal"),
    Dense(150, activation="relu", kernel_initializer="he_normal"),
    Dense(28 * 28, activation="sigmoid"),
    tf.keras.layers.Reshape([28, 28])
])
discriminator = tf.keras.Sequential([
    tf.keras.layers.Flatten(),
    Dense(150, activation="relu", kernel_initializer="he_normal"),
    Dense(100, activation="relu", kernel_initializer="he_normal"),
    Dense(1, activation="sigmoid")
])
gan = tf.keras.Sequential([generator, discriminator])

discriminator.compile(loss="binary_crossentropy", optimizer="rmsprop")
discriminator.trainable = False
gan.compile(loss="binary_crossentropy", optimizer="rmsprop")

gan-mnist-fcn

Deep CNN GAN

generator = tf.keras.Sequential([
    tf.keras.layers.Dense(7 * 7 * 128),
    tf.keras.layers.Reshape([7, 7, 128]),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Conv2DTranspose(64, kernel_size=5, strides=2,
                                    padding="same", activation="relu"),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Conv2DTranspose(1, kernel_size=5, strides=2,
                                    padding="same", activation="tanh"),
])
discriminator = tf.keras.Sequential([
    tf.keras.layers.Conv2D(64, kernel_size=5, strides=2, padding="same",
                        activation=tf.keras.layers.LeakyReLU(0.2)),
    tf.keras.layers.Dropout(0.4),
    tf.keras.layers.Conv2D(128, kernel_size=5, strides=2, padding="same",
                        activation=tf.keras.layers.LeakyReLU(0.2)),
    tf.keras.layers.Dropout(0.4),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(1, activation="sigmoid")
])
gan = tf.keras.Sequential([generator, discriminator])

discriminator.compile(loss="binary_crossentropy", optimizer="rmsprop")
discriminator.trainable = False
gan.compile(loss="binary_crossentropy", optimizer="rmsprop")

gan-mnist-deep-cnn

Deeper CNN GAN

def build_discriminator(depth=64, p=0.4): # Define inputs
    image = Input((img_w,img_h,1))
    # Convolutional layers
    conv1 = Conv2D(depth*1, 5, strides=2,
                   padding='same', activation='relu')(image)
    conv1 = Dropout(p)(conv1)
    conv2 = Conv2D(depth*2, 5, strides=2,
               padding='same', activation='relu')(conv1)
    conv2 = Dropout(p)(conv2)
    conv3 = Conv2D(depth*4, 5, strides=2,
                   padding='same', activation='relu')(conv2)
    conv3 = Dropout(p)(conv3)
    conv4 = Conv2D(depth*8, 5, strides=1,
               padding='same', activation='relu')(conv3)
    conv4 = Flatten()(Dropout(p)(conv4))
    # Output layer
    
    prediction = Dense(1, activation='sigmoid')(conv4)
    # Model definition
    model = Model(inputs=image, outputs=prediction) 
    return model
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy',
                      optimizer=RMSprop(lr=0.0008, decay=6e-8,
                                        clipvalue=1.0),
                      metrics=['accuracy'])

def build_generator(latent_dim=z_dimensions,
                      depth=64, p=0.4):
    noise = Input((latent_dim,))
    # First dense layer
    dense1 = Dense(7*7*depth)(noise)
    dense1 = BatchNormalization(momentum=0.9)(dense1)
    dense1 = Activation(activation='relu')(dense1)
    dense1 = Reshape((7,7,depth))(dense1)
    dense1 = Dropout(p)(dense1)
    # De-Convolutional layers
    conv1 = UpSampling2D()(dense1)
    conv1 = Conv2DTranspose(int(depth/2),
                            kernel_size=5, padding='same', activation=None,)(conv1) 
    conv1 = BatchNormalization(momentum=0.9)(conv1)
    conv1 = Activation(activation='relu')(conv1)
    conv2 = UpSampling2D()(conv1)
    conv2 = Conv2DTranspose(int(depth/4),
                            kernel_size=5, padding='same',
                            activation=None,)(conv2) 
    conv2 = BatchNormalization(momentum=0.9)(conv2)
    conv2 = Activation(activation='relu')(conv2)
    conv3 = Conv2DTranspose(int(depth/8),
                            kernel_size=5, padding='same',
                            activation=None,)(conv2) 
    conv3 = BatchNormalization(momentum=0.9)(conv3)
    conv3 = Activation(activation='relu')(conv3)
    # Output layer
    image = Conv2D(1, kernel_size=5, padding='same',
                   activation='sigmoid')(conv3)
    # Model definition
    model = Model(inputs=noise, outputs=image) 
    return model

generator = build_generator()

img = generator(z) 
discriminator.trainable = False 
pred = discriminator(img) 
adversarial_model = Model(z, pred)

adversarial_model.compile(loss='binary_crossentropy',
                          optimizer=RMSprop(lr=0.0004, decay=3e-8,
                                            clipvalue=1.0),
                          metrics=['accuracy'])

gan-mnist-deeper-cnn

We can see that deeper network is more suitable for this task, since there are details in the hand written numbers, even though we have only 10 numbers.

CIFAR-10

CIFAR-10 is another popular image dataset for machine learning task, this time in RGB (color) and each has size of 32x32 pixels, making each image input to be a tensor of 32x32x3. We reuse the deeper CNN above since it yields the best result among the three.

gan-cifar-rgb