Conditional GAN

# Import all the necessary libraries
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision
import matplotlib.pyplot as plt
from tqdm import notebook
class Generator(nn.Module):
def __init__(self, latent_shape, img_shape):
super(Generator, self).__init__()
self.img_shape = img_shape
self.flatten = nn.Flatten()
self.mlp = nn.Sequential(
nn.Linear(np.prod(latent_shape) + 10, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, np.prod(img_shape)),
nn.Tanh()
)
def forward(self, x, label):
batch_size = x.shape[0]
# generator now uses the latent input noise x and a one hot encoded label for conditioning to generate a fake digit
x = self.flatten(x)
x = torch.cat([x, label], dim=1)
# reshape into a image
return self.mlp(x).reshape(batch_size, 1, *self.img_shape)
class Discriminator(nn.Module):
def __init__(self, img_shape):
super(Discriminator, self).__init__()

self.mlp = nn.Sequential(
nn.Flatten(),
nn.Linear(np.prod(img_shape), 1024),
nn.LeakyReLU(0.2),
nn.Linear(1024, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.mlp(x)
# load our data
latent_shape = (28, 28)
img_shape = (28, 28)
batch_size = 64

transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5), std=(0.5))])
train_dataset = torchvision.datasets.MNIST(root="./data", train = True, download=True, transform=transform)der(train_dataset, batch_size=batch_size, shuffle=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # for gpu usage if possible

generator = Generator(latent_shape, img_shape)
discriminator = Discriminator(img_shape)

gen_optim = torch.optim.Adam(generator.parameters(), lr=2e-4)
disc_optim = torch.optim.Adam(discriminator.parameters(), lr=2e-4)

# use gpu if possible
generator = generator.to(device)
discriminator = discriminator.to(device)
def train(generator, discriminator, generator_optim: torch.optim, discriminator_optim: torch.optim, epochs=100):
adversarial_loss = torch.nn.BCELoss()

for epoch in range(1, epochs+1):
print("Epoch {}".format(epoch))
avg_g_loss = 0
avg_d_loss = 0
pbar = notebook.tqdm(train_dataloader, total=len(train_dataloader))
i = 0
for data in pbar:
i += 1
real_images = data[0].to(device)
labels = data[1].to(device)

one_hot_labels = torch.zeros((len(labels), 10)).to(device)
for j in range(len(labels)):
one_hot_labels[j][labels[j]] = 1

### Train Generator ###

generator_optim.zero_grad()

latent_input = torch.randn((len(real_images), 1, *latent_shape)).to(device)

fake_images = generator(latent_input, one_hot_labels)

fake_res = discriminator(fake_images, one_hot_labels)

generator_loss = adversarial_loss(fake_res, torch.ones_like(fake_res))
generator_loss.backward()
generator_optim.step()

### Train Discriminator ###
discriminator_optim.zero_grad()

real_res = discriminator(real_images)

fake_res = discriminator(fake_images.detach())

discriminator_real_loss = adversarial_loss(real_res, torch.ones_like(real_res))
discriminator_fake_loss = adversarial_loss(fake_res, torch.zeros_like(fake_res))
discriminator_loss = (discriminator_real_loss + discriminator_fake_loss) / 2
discriminator_loss.backward()
discriminator_optim.step()


avg_g_loss += generator_loss.item()
avg_d_loss += discriminator_loss.item()
pbar.set_postfix({"G_loss": generator_loss.item(), "D_loss": discriminator_loss.item()})
print("Avg G_loss {} - Avg D_loss {}".format(avg_g_loss / i, avg_d_loss / i))
# train our generator and discriminator
# Note: don't always expect loss to go down simultaneously for both models. They are competing against each other! So sometimes one model
# may perform better than the other
train(generator=generator, discriminator=discriminator, generator_optim=gen_optim, discriminator_optim=disc_optim)
# test it out!
latent_input = torch.randn((batch_size, 1, *latent_shape))

# generate one hot encoded labels
labels = torch.zeros((batch_size))
one_hot_labels = torch.zeros((batch_size, 10))
one_hot_labels[torch.arange(batch_size), labels] = 1

test = generator(latent_input.to(device), one_hot_labels)
k = 0
plt.title("Generating a fake {} digit".format(one_hot_labels[k]))
plt.imshow(test[k].reshape(28, 28).cpu().detach().numpy())