current position:Home>Simply use pytorch to build Gan model

Simply use pytorch to build Gan model

2021-08-31 16:05:00 Heart of machine

author |Ta-Ying Cheng, A doctoral student at Oxford University ,Medium Technology Blogger , Many articles have been published by the official publications of the platform Towards Data Science Included

translate | Song Xian

In the past, it was generally believed that generating images was an impossible task , Because according to the traditional idea of machine learning , We have no real value at all (ground truth) It can be used to check whether the generated image is qualified .

2014 year ,Goodfellow Others proposed to generate Against the network (Generative Adversarial Network, GAN), It allows us to rely entirely on machine learning to generate extremely realistic images .GAN The whole AI industry was shocked by the birth of , Great changes have taken place in the field of computer vision and image generation .

This article will take you to understand GAN How it works , And how to adopt PyTorch Simple to fit GAN.

GAN Principle

In the traditional way , The prediction results of the model can be directly compared with the existing true values . However , It's hard to define and measure what counts as “ Correct ” Generate the image .

Goodfellow Others proposed an interesting solution : We can train a classification tool first , To automatically distinguish between generated images and real images . thus , We can use this classification tool to train a generative network , Until it can output completely false images , Even the classification tool can't judge the true and false .  chart  1. GAN The operation process of .  Source Author . In this way , We have GAN: That's one generator (generator) And a Judging device (discriminator). The generator is responsible for generating images based on a given data set , The discriminator is responsible for distinguishing whether the image is true or false .GAN The operation process is shown in the figure above .

Loss function

stay GAN In the operation process of , We can find an obvious contradiction : It is difficult to optimize the generator and discriminator at the same time . As you can imagine , The two models have completely opposite goals : The generator wants to fake the real thing as much as possible , The discriminator must see through the image generated by the generator .

To illustrate this point , We set up D(x) Is the output of the discriminator , namely x Is the probability of a real image , And set up G(z) For the output of the generator . The discriminator is similar to a binary classifier , So the goal is to maximize the result of the function : Please add a picture description This function is essentially a nonnegative binary cross entropy loss function . On the other hand , The goal of the generator is to minimize the probability of the discriminator making correct judgment , Therefore, its goal is to minimize the result of the above function .

therefore , The final loss function will be a minimax game between two classifiers , Shown by the following :  Please add a picture description In theory , The final result of the game will be that the probability of success of the discriminator converges to 0.5. But in practice , Minimax games usually lead to network non convergence , Therefore, it is very important to carefully adjust the parameters of model training .

In the training GAN when , In particular, we should pay attention to super parameters such as learning rate , The learning rate is relatively small, which can make GAN In the case of more input noise, it can also have a more unified output .

Computing environment


This article will guide you through PyTorch Build the whole program ( Include torchvision). meanwhile , We will use Matplotlib To make the GAN Visualization of the generated results . The following code can import all the above Libraries :

Import necessary libraries to create a generative adversarial network
The code is mainly developed using the PyTorch library
import time
import torch
import torch.nn as nn
import torch.optim as optim
from import DataLoader
from torchvision import datasets
from torchvision.transforms import transforms
from model import discriminator, generator
import numpy as np
import matplotlib.pyplot as plt

Data sets

Data sets are important for training GAN It's very important , Especially considering that we are GAN Unstructured data is usually processed in ( It's usually pictures 、 Video etc. ), Any one class Can have data distribution . This data distribution is just GAN The basis for generating output .

In order to better demonstrate GAN Setup process , This article will take you to use the simplest MNIST Data sets , It contains 6 Ten thousand pictures of handwritten Arabic numerals .

image MNIST Such high-quality unstructured data sets can be used in Titanium Of Open dataset Found on website . in fact , Titanium Open Datasets The platform covers many high-quality public data sets , It can also achieve Data set hosting and one-stop search , This is right AI For developers , It is a very practical community platform .  Please add a picture description

Hardware requirements

Generally speaking , Although it can be used CPU To train neural networks , But the best choice is GPU, Because it can greatly improve the training speed . We can use the following code to test whether our machine can use GPU To train :

Determine if any GPUs are available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


Network structure

Because numbers are very simple information , We can build the discriminator and generator into a full connection layer (fully connected layers).

We can use the following code in PyTorch Build discriminator and generator : 

Network Architectures
The following are the discriminator and generator architectures

class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 1)
        self.activation = nn.LeakyReLU(0.1)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.activation(self.fc1(x))
        x = self.fc2(x)
        return nn.Sigmoid()(x)

class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.fc1 = nn.Linear(128, 1024)
        self.fc2 = nn.Linear(1024, 2048)
        self.fc3 = nn.Linear(2048, 784)
        self.activation = nn.ReLU()

def forward(self, x):
    x = self.activation(self.fc1(x))
    x = self.activation(self.fc2(x))
    x = self.fc3(x)
    x = x.view(-1, 1, 28, 28)
    return nn.Tanh()(x)


In the training GAN When , We need to optimize the discriminator , While improving the generator , Therefore, we need to optimize two conflicting loss functions at the same time in each iteration .

For generators , We will input some random noise , Let the generator output the image according to the slight change of noise :

Network training procedure
Every step both the loss for disciminator and generator is updated
Discriminator aims to classify reals and fakes
Generator aims to generate images as realistic as possible
for epoch in range(epochs):
    for idx, (imgs, _) in enumerate(train_loader):
        idx += 1

        # Training the discriminator
        # Real inputs are actual images of the MNIST dataset
        # Fake inputs are from the generator
        # Real inputs should be classified as 1 and fake as 0
        real_inputs =
        real_outputs = D(real_inputs)
        real_label = torch.ones(real_inputs.shape[0], 1).to(device)

        noise = (torch.rand(real_inputs.shape[0], 128) - 0.5) / 0.5
        noise =
        fake_inputs = G(noise)
        fake_outputs = D(fake_inputs)
        fake_label = torch.zeros(fake_inputs.shape[0], 1).to(device)

        outputs =, fake_outputs), 0)
        targets =, fake_label), 0)

        D_loss = loss(outputs, targets)

        # Training the generator
        # For generator, goal is to make the discriminator believe everything is 1
        noise = (torch.rand(real_inputs.shape[0], 128)-0.5)/0.5
        noise =

        fake_inputs = G(noise)
        fake_outputs = D(fake_inputs)
        fake_targets = torch.ones([fake_inputs.shape[0], 1]).to(device)
        G_loss = loss(fake_outputs, fake_targets)

        if idx % 100 == 0 or idx == len(train_loader):
            print('Epoch {} Iteration {}: discriminator_loss {:.3f} generator_loss {:.3f}'.format(epoch, idx, D_loss.item(), G_loss.item()))

    if (epoch+1) % 10 == 0:, 'Generator_epoch_{}.pth'.format(epoch))
        print('Model saved.')


after 100 After a training period , We can visualize the data set , Directly see the numbers generated by the model from random noise :  Please add a picture description We can see , The generated results are very similar to the real data . Considering that we only built a very simple model here , The actual application effect will have great room for improvement .

It's not just something to learn

GAN It is different from the ideas put forward by machine vision experts in the past , And the use of GAN The specific scenario application makes many people admire the infinite potential of deep network . Let's take a look at the two most famous GAN Extended application .


Zhu Junyan et al 2017 Published in CycleGAN Be able to remove a picture from... Without matching pictures X The domain is converted directly to Y Domain , Like turning a horse into a zebra 、 Turn hot summer into winter 、 Turn Monet's paintings into Van Gogh's paintings and so on . These seemingly fantastic transformations CycleGAN Can easily do , And the results are very accurate .  Please add a picture description


NVIDIA passed GAN So that people can outline their ideas with just a few strokes , You can get a very realistic picture of the real scene . Although the computational cost of this application is extremely high , however GauGAN With its transformation ability, it has explored unprecedented research and application fields .

 Please add a picture description


I believe I can see here , You already know GAN How it works , And be able to build a simple GAN 了 .

copyright notice
author[Heart of machine],Please bring the original link to reprint, thank you.

Random recommended