Vanilla GAN
# Import all the necessary libraries
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
from tqdm import notebook
# Define our simple vanilla generator
class Generator(nn.Module):
"""
Architecture
------------
Latent Input: latent_shape
Flattened
Linear MLP(256, 512, 1024, prod(img_shape))
Leaky Relu activation after every layer except last. (Important!)
Tanh activation after last layer to normalize
"""
def __init__(self, latent_shape, img_shape):
super(Generator, self).__init__()
self.img_shape = img_shape
self.mlp = nn.Sequential(
nn.Flatten(),
nn.Linear(np.prod(latent_shape), 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, np.prod(img_shape)),
nn.Tanh()
)
def forward(self, x):
batch_size = x.shape[0]
# reshape into a image
return self.mlp(x).reshape(batch_size, 1, *self.img_shape)
# Define our simple vanilla discriminator
class Discriminator(nn.Module):
"""
Architecture
------------
Input Image: img_shape
Flattened
Linear MLP(1024, 512, 256, 1)
Leaky Relu activation after every layer except last.
Sigmoid activation after last layer to normalize in range 0 to 1
"""
def __init__(self, img_shape):
super(Discriminator, self).__init__()
self.mlp = nn.Sequential(
nn.Flatten(),
nn.Linear(np.prod(img_shape), 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.mlp(x)
# load our data
latent_shape = (28, 28)
img_shape = (28, 28)
batch_size = 64
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(root="./data", train = True, download=True, transform=transform)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # for gpu usage if possible
generator = Generator(latent_shape, img_shape)
discriminator = Discriminator(img_shape)
gen_optim = torch.optim.Adam(generator.parameters(), lr=2e-4)
disc_optim = torch.optim.Adam(discriminator.parameters(), lr=2e-4)
# .to(device) moves the networks / models to that device, which is either CPU or the GPU depending on what was detected
# if moved to GPU, then the networks can make use of the GPU for computations which is much faster!
generator = generator.to(device)
discriminator = discriminator.to(device)
def train(generator, discriminator, generator_optim: torch.optim, discriminator_optim: torch.optim, epochs=100):
adversarial_loss = torch.nn.BCELoss()
for epoch in range(1, epochs+1):
print("Epoch {}".format(epoch))
avg_g_loss = 0
avg_d_loss = 0
# notebook.tqdm is a nice way of displaying progress on a jupyter or colab notebook while we loop over the data in train_dataloader
pbar = notebook.tqdm(train_dataloader, total=len(train_dataloader))
i = 0
for data in pbar:
i += 1
real_images = data[0].to(device)
### Train Generator ###
# .zero_grad() is important in PyTorch. Don't forget it. If you do, the optimizer won't work.
generator_optim.zero_grad()
latent_input = torch.randn((len(real_images), 1, *latent_shape)).to(device)
fake_images = generator(latent_input)
fake_res = discriminator(fake_images)
# we penalize the generator for being unable to make the discrminator predict 1s for generated fake images
generator_loss = adversarial_loss(fake_res, torch.ones_like(fake_res))
# .backward() computes gradients for the loss function with respect to anything that is not detached
generator_loss.backward()
# .step() uses a optimizer to apply the gradients to the model parameters, updating the model to reduce the loss
generator_optim.step()
### Train Discriminator ###
discriminator_optim.zero_grad()
real_res = discriminator(real_images)
# .detach() removes fake_images variable from gradient computation, meaning our
# generator is not going to be updated when we use the optimizer
fake_res = discriminator(fake_images.detach())
# we penalize the discriminator for not predicting 1s for real images
discriminator_real_loss = adversarial_loss(real_res, torch.ones_like(real_res))
# we penalize the discriminator for not predicting 0s for generated, fake images
discriminator_fake_loss = adversarial_loss(fake_res, torch.zeros_like(real_res))
discriminator_loss = (discriminator_real_loss + discriminator_fake_loss) / 2
discriminator_loss.backward()
discriminator_optim.step()
avg_g_loss += generator_loss.item()
avg_d_loss += discriminator_loss.item()
pbar.set_postfix({"G_loss": generator_loss.item(), "D_loss": discriminator_loss.item()})
print("Avg G_loss {} - Avg D_loss {}".format(avg_g_loss / i, avg_d_loss / i))
# train our generator and discriminator
# Note: don't always expect loss to go down simultaneously for both models. They are competing against each other! So sometimes one model
# may perform better than the other
train(generator=generator, discriminator=discriminator, generator_optim=gen_optim, discriminator_optim=disc_optim)
# test it out!
latent_input = torch.randn((batch_size, 1, *latent_shape))
test = generator(latent_input.to(device))
plt.imshow(test[0].reshape(28, 28).cpu().detach().numpy())