{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"
"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SMFRWjaAf1uM"
},
"source": [
"# (Exercise) Autoencoders, Generative Adversarial Networks, and Diffusion Models"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jsooOEcW2RSr"
},
"source": [
"\n",
"\n",
"Théâtre D'opéra Spatial, 2022 artwork created by Jason M. Allen with Midjourney"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "EouTfDLJ1-ls"
},
"source": [
"**Autoencoders**, **GANs**, and **Diffusion Models** are all machine learning algorithms that can generate new data, often in an unsupervised manner. **Autoencoders** learn to compress and decompress data, capturing underlying patterns. **GAN**s (Generative adversarial networks) use two competing neural networks: a generator that creates new data and a discriminator that evaluates its authenticity. **Diffusion Models** gradually add noise to data and then learn to remove it, producing realistic samples. These models have applications in image generation, style transfer, and more."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "d3yiNZ023D2G"
},
"source": [
"We'll be implementing them on the [CIFAR-10](https://www.cs.toronto.edu/%7Ekriz/cifar.html) dataset to explore their capabilities in **capturing patterns, image generation and style transfer**. These models offer powerful techniques for learning latent representations, generating new data, and understanding complex patterns within data.\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_kYZ6r2aZ4Jr"
},
"source": [
"*Note* : CIFAR10 classes are: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck"
]
},
{
"cell_type": "markdown",
"source": [
"## Imports and Data Loading"
],
"metadata": {
"id": "dKdffE043l_w"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ghSLgH1LEkaa"
},
"outputs": [],
"source": [
"%%capture\n",
"!pip install tensorflow-gpu==2.8.0\n",
"\n",
"# You may need to restart the runtime to use the\n",
"# specific `tf` version installed for this notebook"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0bOF8HCGu5c0"
},
"outputs": [],
"source": [
"import sys\n",
"\n",
"import matplotlib as mpl\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import sklearn\n",
"import tensorflow as tf\n",
"from packaging import version\n",
"from sklearn.manifold import TSNE\n",
"\n",
"# make notebook reproducible\n",
"tf.random.set_seed(42)\n",
"\n",
"# make plot prettier\n",
"plt.rc('font', size=14)\n",
"plt.rc('axes', labelsize=14, titlesize=14)\n",
"plt.rc('legend', fontsize=14)\n",
"plt.rc('xtick', labelsize=10)\n",
"plt.rc('ytick', labelsize=10)\n",
"\n",
"# Check if right modules are installed\n",
"assert sys.version_info >= (3, 7)\n",
"assert version.parse(sklearn.__version__) >= version.parse(\"1.0.1\")\n",
"assert version.parse(tf.__version__) >= version.parse(\"2.8.0\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "TDpxOdlSfvWZ"
},
"outputs": [],
"source": [
"# for easy plotting later on\n",
"def plot_multiple_images(images, n_cols=None):\n",
" n_cols = n_cols or len(images)\n",
" n_rows = (len(images) - 1) // n_cols + 1\n",
" if images.shape[-1] == 1:\n",
" images = images.squeeze(axis=-1)\n",
" plt.figure(figsize=(n_cols, n_rows))\n",
" for index, image in enumerate(images):\n",
" plt.subplot(n_rows, n_cols, index + 1)\n",
" plt.imshow(image, cmap=\"binary\")\n",
" plt.axis(\"off\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-vcuOLBB14fF"
},
"source": [
"You may need to restart the runtime the first time you run it to use the specific `tf` version installed for this notebook"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xRFgMq0hfFHp"
},
"source": [
"### **Q1) Load the dataset, scale it, and split it into a training set, a validation set, and a test set**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RVAVmIznyLhY"
},
"outputs": [],
"source": [
"# Load the CIFAR-10 dataset\n",
"(X_train_full, y_train_full), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()\n",
"\n",
"# Normalize the pixel values to the range [0, 1]\n",
"# RGB values are between 0 and 256\n",
"X_train_full = X_train_full / __\n",
"X_test = X_test / __\n",
"\n",
"# Split the training set into a training set and a validation set\n",
"# Get 5000 images as the validation set\n",
"X_train, X_valid = X_train_full[___], X_train_full[___]\n",
"y_train, y_valid = y_train_full[___], y_train_full[___]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "X5n5ZHWv96lj"
},
"outputs": [],
"source": [
"### Get familiar with CIFAR-10 dataset\n",
"# Get the shape of the images\n",
"print(___.shape, ___.shape)\n",
"# Display a sample image\n",
"____.____"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "n2KGyEnvdVUi"
},
"source": [
"## 亖 Stacked Autoencoders"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KX5Z7cZDhmLP"
},
"source": [
"Autoencoders, like other neural networks, can employ multiple hidden layers, often referred to as **stacked autoencoders** or **deep autoencoders**. This layered architecture enables autoencoders to learn progressively more complex representations of the input data.\n",
"\n",
"Let's build and train a stacked Autoencoder with 3 hidden layers and 1 output layer (i.e., 2 stacked Autoencoders).\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jWvPsHy2jiqx"
},
"source": [
"Each epoch with this recommended parameters will take ~ 1 min 10 sec;\n"
]
},
{
"cell_type": "markdown",
"source": [
"### **Q2) Complete the stacked autoencoder architecture below**"
],
"metadata": {
"id": "M8M8ToMx4idP"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2Ei5DsAB8v9C"
},
"outputs": [],
"source": [
"# Define the stacked encoder architecture\n",
"# Recommended 512, 256 units for the Dense layers and ReLU activation\n",
"stacked_encoder = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(),\n",
" tf.keras.layers.Dense(___, activation=___),\n",
" tf.keras.layers.Dense(___, activation=___),\n",
"])\n",
"\n",
"# Define the stacked decoder architecture\n",
"# Recommended 512 and pixel count in one image as units for the Dense layers\n",
"stacked_decoder = tf.keras.Sequential([\n",
" tf.keras.layers.Dense(___, activation=\"relu\"),\n",
" tf.keras.layers.Dense(__ * __ * _),\n",
" tf.keras.layers.Reshape([32, 32, 3])\n",
"])\n",
"\n",
"# Combine encoder and decoder into the stacked autoencoder\n",
"stacked_ae = tf.keras.Sequential([___, ___])\n",
"\n",
"# Compile the stacked autoencoder\n",
"# Recommended loss is MSE and recommended optimizer is NAdam\n",
"stacked_ae.compile(loss=___, optimizer=___)\n",
"\n",
"# Train the stacked autoencoder\n",
"# We want the predictions to be as the input (X == y)\n",
"# Recommended epochs is 10\n",
"history = stacked_ae.fit(___, ___, epochs=__,\n",
" validation_data=(X_valid, X_valid))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "mPigM-nH6XiF"
},
"source": [
"This function processes a few validation images through the autoencoder and displays the original images and their reconstructions:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "u8LVmynxeOSo"
},
"outputs": [],
"source": [
"def plot_reconstructions(model, images=X_valid, n_images=5):\n",
" reconstructions = np.clip(model.predict(images[:n_images]), 0, 1)\n",
" fig = plt.figure(figsize=(n_images * 1.5, 3))\n",
" for image_index in range(n_images):\n",
" plt.subplot(2, n_images, 1 + image_index)\n",
" plt.imshow(images[image_index]) #, cmap=\"binary\")\n",
" plt.axis(\"off\")\n",
" plt.subplot(2, n_images, 1 + n_images + image_index)\n",
" plt.imshow(reconstructions[image_index]) #, cmap=\"binary\")\n",
" plt.axis(\"off\")\n",
"\n",
"plot_reconstructions(stacked_ae)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5UiJ7Akso4rH"
},
"source": [
"The reconstructions look fuzzy, but remember that the images were compressed down to just 256 numbers, instead of 3072."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Yf6VcPRd7M0z"
},
"source": [
"### **Q3) Visualize the CIFAR-10 dataset using [tsne](https://en.wikipedia.org/wiki/T-distributed_stochastic_neighbor_embedding)**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kJTyAhVweP85"
},
"outputs": [],
"source": [
"# Predict the validation set\n",
"X_valid_compressed = stacked_encoder.predict(___)\n",
"\n",
"# Apply t-SNE for dimensionality reduction\n",
"# You can initializes with PCA, and the learning_rate to auto\n",
"tsne = TSNE(init=___, learning_rate=___, random_state=42)\n",
"\n",
"# Transform the compressed validation set into 2D\n",
"X_valid_2D = tsne.fit_transform(___)\n",
"\n",
"# Plot the 2D data\n",
"plt.scatter(X_valid_2D[:, 0], X_valid_2D[:, 1], c=y_valid, s=10, cmap=\"tab10\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-kklgzHBYAeE"
},
"source": [
"Let's make this diagram prettier:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tc7zZNdqeQ-Q"
},
"outputs": [],
"source": [
"plt.figure(figsize=(10, 8))\n",
"cmap = plt.cm.tab10\n",
"Z = X_valid_2D\n",
"Z = (Z - Z.min()) / (Z.max() - Z.min()) # normalize to the 0-1 range\n",
"plt.scatter(Z[:, 0], Z[:, 1], c=y_valid, s=10, cmap=cmap)\n",
"image_positions = np.array([[1., 1.]])\n",
"for index, position in enumerate(Z):\n",
" dist = ((position - image_positions) ** 2).sum(axis=1)\n",
" if dist.min() > 0.02: # if far enough from other images\n",
" image_positions = np.r_[image_positions, [position]]\n",
" imagebox = mpl.offsetbox.AnnotationBbox(\n",
" mpl.offsetbox.OffsetImage(X_valid[index], cmap=\"binary\"),\n",
" position, bboxprops={\"edgecolor\": cmap(y_valid[index]), \"lw\": 2})\n",
" plt.gca().add_artist(imagebox)\n",
"\n",
"plt.axis(\"off\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5GLr_5VgYhQm"
},
"source": [
"## ⊩ Denoising Autoencoders"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PW0SabeVZVCC"
},
"source": [
"To make autoencoders learn better features, we can add noise to their inputs and train them to remove the noise and recover the original data. This is called **denoising autoencoding**."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "KIkMPVMXfmE4"
},
"source": [
"The implementation is straightforward: it's a standard stacked autoencoder with an additional Dropout layer applied to the encoder's inputs. You could also use a GaussianNoise layer instead.\n",
"\n",
"The noise can be pure Gaussian noise added to the inputs, or it can be randomly switched-off inputs, just like in dropout.\n",
"\n",
"*Note* : both Dropout and GaussianNoise layers are only active during training."
]
},
{
"cell_type": "markdown",
"source": [
"### **Q4) Complete the denoising autoencoder architecture below**"
],
"metadata": {
"id": "j0icEoGO7FwJ"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qXcft-kHyLhY"
},
"outputs": [],
"source": [
"# Define the denoising encoder\n",
"denoising_encoder = tf.keras.Sequential([\n",
" # GaussianNoise adds noise for robustness (0.1)\n",
" tf.keras.layers.GaussianNoise(___),\n",
" # Conv2D extracts features with 32 filters and (3x3) kernel, `same` padding, ReLU as activation\n",
" tf.keras.layers.Conv2D(___, ___, padding=___, activation=___),\n",
" tf.keras.layers.MaxPool2D(),\n",
" tf.keras.layers.Flatten(),\n",
" # Dense layer with 512 units for feature processing, ReLU as activation\n",
" tf.keras.layers.Dense(___, activation=___)\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "a2lC7XBqyLhY"
},
"outputs": [],
"source": [
"# Define the denoising decoder architecture\n",
"denoising_decoder = tf.keras.Sequential([\n",
" # Dense layer reshapes the compressed data back to match the decoder's input shape\n",
" tf.keras.layers.Dense(___ * ___ * ___, activation=___),\n",
" # Reshape changes the 1D vector back into 16x16x32 feature maps\n",
" tf.keras.layers.Reshape([___, ___, ___]),\n",
" # Conv2DTranspose performs upsampling (opposite of Conv2D) to restore the original image size\n",
" # Use 3 filters, 3 as kernel size, 2 as strides, `same` padding, `sigmoid` as activation\n",
" tf.keras.layers.Conv2DTranspose(filters=___, kernel_size=___, strides=___,\n",
" padding=___, activation=___)\n",
"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RgtxA9BjyLhY"
},
"outputs": [],
"source": [
"# Combine encoder and decoder into the denoising autoencoder\n",
"denoising_ae = tf.keras.Sequential([___, ___])\n",
"\n",
"# Compile the autoencoder\n",
"# Using binary crossentropy for the loss function and Nadam optimizer and MSE metric\n",
"# Metrics: Mean Squared Error to monitor reconstruction quality\n",
"denoising_ae.compile(loss=___, optimizer=___, metrics=[___])\n",
"\n",
"# Train the autoencoder\n",
"# Input and target are the same (denoising task), with 10 epochs and validation data\n",
"history = denoising_ae.fit(___, ___, epochs=___, validation_data=(___, ___))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yLKbOwmoktmh"
},
"source": [
"### **Q5) Try generating images from noisy inputs. What do you notice?**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dTlFyrmSyLhY"
},
"outputs": [],
"source": [
"# Number of images to process (e.g. 5)\n",
"n_images = __\n",
"\n",
"# Select a subset of test images\n",
"new_images = X_test[:___]\n",
"\n",
"# Add noise to these images and scale it by various factors (e.g., 0.1)\n",
"new_images_noisy = new_images + np.random.randn(___, 32, 32, 3) * ___\n",
"\n",
"# Predict denoised images using the autoencoder\n",
"new_images_denoised = denoising_ae.predict(___)\n",
"\n",
"# Plot the original, noisy and denoised images\n",
"plt.figure(figsize=(6, n_images * 2))\n",
"for index in range(n_images):\n",
" plt.subplot(n_images, 3, index * 3 + 1)\n",
" plt.imshow(new_images[index])\n",
" plt.axis('off')\n",
" if index == 0:\n",
" plt.title(\"Original\")\n",
" plt.subplot(n_images, 3, index * 3 + 2)\n",
" plt.imshow(new_images_noisy[index].clip(0., 1.))\n",
" plt.axis('off')\n",
" if index == 0:\n",
" plt.title(\"Noisy\")\n",
" plt.subplot(n_images, 3, index * 3 + 3)\n",
" plt.imshow(new_images_denoised[index])\n",
" plt.axis('off')\n",
" if index == 0:\n",
" plt.title(\"Denoised\")\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7gjQbjLKgpMm"
},
"source": [
"The images show examples of noisy images and the corresponding images reconstructed by the GaussianNoise-based denoising autoencoder.\n",
"\n",
"This demonstrates that denoising autoencoders can not only be used for data visualization or unsupervised pretraining but also for effectively removing noise from images."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ddB3aKyFyLhZ"
},
"source": [
"## ଽ Variational Autoencoders\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "e9rkRByUhtHN"
},
"source": [
"**Variational autoencoders (VAEs)** are different from other autoencoders because they use randomness to create their outputs. Instead of just producing a single code for an input, VAEs create a range of possible codes. This randomness helps them create new data that looks like the original data.\n",
"\n",
"Here's how it works:\n",
"\n",
"1. **Encoder:** The encoder takes an input and creates two things: a mean code and a standard deviation.\n",
"2. **Sampling:** A random code is chosen from a range based on the mean and standard deviation.\n",
"3. **Decoder:** The decoder uses this random code to create an output that looks similar to the original input."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "dUeGBTcs1bxl"
},
"outputs": [],
"source": [
"# Define a custom Keras layer for sampling from a normal distribution\n",
"class Sampling(tf.keras.layers.Layer):\n",
" def call(self, inputs):\n",
" mean, log_var = inputs\n",
" # Sample using reparameterization trick\n",
" return tf.random.normal(tf.shape(log_var)) * tf.exp(log_var / 2) + mean\n"
]
},
{
"cell_type": "markdown",
"source": [
"### **Q6) Complete the VAE architecture below**"
],
"metadata": {
"id": "L6ByhEI88BnE"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "nwKr2Yfy1xsS"
},
"outputs": [],
"source": [
"# Define the size of the latent space (e.g. 10)\n",
"codings_size = ___\n",
"\n",
"# Create input layer for 32x32x3 CIFAR images\n",
"inputs = tf.keras.layers.Input(shape=[__, __, _])\n",
"\n",
"# Flatten input, then pass through Dense layers with ReLU activation\n",
"Z = tf.keras.layers.Flatten()(inputs)\n",
"Z = tf.keras.layers.Dense(___, activation=\"relu\")(Z)\n",
"Z = tf.keras.layers.Dense(___, activation=\"relu\")(Z)\n",
"\n",
"# Compute mean and log variance for the latent space\n",
"codings_mean = tf.keras.layers.Dense(___)(Z) # μ\n",
"codings_log_var = tf.keras.layers.Dense(___)(Z) # γ\n",
"\n",
"# Sample from the latent space using the mean and log variance\n",
"codings = Sampling()([___, ___])\n",
"\n",
"# Define the encoder model\n",
"variational_encoder = tf.keras.Model(\n",
" inputs=[___],\n",
" # the outputs are: mean, log variance and codings\n",
" outputs=[___, ___, ___]\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kkcYhO-KXFhM"
},
"outputs": [],
"source": [
"# Create input layer for the latent space\n",
"decoder_inputs = tf.keras.layers.Input(shape=[___])\n",
"\n",
"# Pass through Dense layers to reconstruct the original image\n",
"# Recommended number of units are: 100, 150, and image size\n",
"x = tf.keras.layers.Dense(___, activation=\"relu\")(decoder_inputs)\n",
"x = tf.keras.layers.Dense(___, activation=\"relu\")(x)\n",
"x = tf.keras.layers.Dense(___)(x)\n",
"# Reshape as 32x32x3 CIFAR images\n",
"outputs = tf.keras.layers.Reshape([__, __, _])(x)\n",
"\n",
"# Define the decoder model\n",
"variational_decoder = tf.keras.Model(inputs=[___], outputs=[___])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Q-KBo_JOXsd-"
},
"outputs": [],
"source": [
"# Encode inputs to get latent space codings\n",
"dummy_var1, dummy_var2, codings = variational_encoder(___)\n",
"\n",
"# Decode codings to reconstruct the inputs\n",
"reconstructions = variational_decoder(___)\n",
"\n",
"# Define the variational autoencoder model with the reconstructions as output\n",
"variational_ae = tf.keras.Model(inputs=[___], outputs=[___])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "W13BRGkhFEOO"
},
"outputs": [],
"source": [
"latent_loss = -0.5 * tf.reduce_sum(\n",
" 1 + codings_log_var - tf.exp(codings_log_var) - tf.square(codings_mean),\n",
" axis=-1)\n",
"\n",
"variational_ae.add_loss(tf.reduce_mean(latent_loss) / 784.)"
]
},
{
"cell_type": "markdown",
"source": [
"### **Q7) Train the variational autoencoder to reconstruct the CIFAR images**"
],
"metadata": {
"id": "32kE-Aq18LjI"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "SYkax4YHYSpR"
},
"outputs": [],
"source": [
"# Compile the variational autoencoder\n",
"# Use Mean Squared Error for loss and Nadam optimizer for training\n",
"variational_ae.compile(loss=___, optimizer=___)\n",
"\n",
"# Train the variational autoencoder\n",
"# Fit the model using training data with e.g. 25 epochs and e.g. 128 batch size\n",
"history = variational_ae.fit(X_train, X_train, epochs=___, batch_size=___,\n",
" validation_data=(X_valid, X_valid))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PqoX073nYF5M"
},
"outputs": [],
"source": [
"def plot_reconstructions(model, images=X_valid, n_images=5):\n",
" reconstructions = np.clip(model.predict(images[:n_images]), 0, 1)\n",
" fig = plt.figure(figsize=(n_images * 1.5, 3))\n",
" for image_index in range(n_images):\n",
" plt.subplot(2, n_images, 1 + image_index)\n",
" plt.imshow(images[image_index]) #, cmap=\"binary\")\n",
" plt.axis(\"off\")\n",
" plt.subplot(2, n_images, 1 + n_images + image_index)\n",
" plt.imshow(reconstructions[image_index]) #, cmap=\"binary\")\n",
" plt.axis(\"off\")\n",
"\n",
"plot_reconstructions(variational_ae)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7fthkaZoltip"
},
"source": [
"## 🅶 GANs\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QJU8C1gbbwl0"
},
"source": [
"Generative Adversarial Networks (**GAN**s) represent one of the most fascinating concepts in computer science today. They involve training two models in tandem through an adversarial process. The generator, often called \"the artist,\" learns to produce images that appear realistic, while the discriminator, known as \"the art critic,\" learns to distinguish between genuine images and those created by the generator.\n",
"\n",
"During training, the generator gets better at making realistic images, while the discriminator gets better at spotting fakes. They reach a balance when the discriminator can't tell real images from fake ones anymore."
]
},
{
"cell_type": "markdown",
"source": [
"### **Q8) Complete the GAN architecture below**"
],
"metadata": {
"id": "NPCurK4s8cac"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "nVK3jdSQcUXa"
},
"outputs": [],
"source": [
"# Define the size of the latent space, e.g. 30\n",
"codings_size = ___\n",
"\n",
"# Build the generator model\n",
"Dense = tf.keras.layers.Dense\n",
"generator = tf.keras.Sequential([\n",
" Dense(___, activation=\"relu\", kernel_initializer=\"he_normal\"), # Expand to 300 units\n",
" Dense(___, activation=\"relu\", kernel_initializer=\"he_normal\"), # Expand to 450 units\n",
" Dense(___ * ___ * ___, activation=\"sigmoid\"), # Output layer to match 32x32x3 image\n",
" tf.keras.layers.Reshape([___, ___, ___]) # Reshape to 32x32x3 CIFAR image\n",
"])\n",
"\n",
"# Build the discriminator model\n",
"discriminator = tf.keras.Sequential([\n",
" tf.keras.layers.Flatten(), # Flatten the input image\n",
" Dense(___, activation=\"relu\", kernel_initializer=\"he_normal\"), # Hidden layer with 450 units\n",
" Dense(___, activation=\"relu\", kernel_initializer=\"he_normal\"), # Hidden layer with 300 units\n",
" Dense(1, activation=\"sigmoid\") # Output layer for binary classification\n",
"])\n",
"\n",
"# Combine generator and discriminator into a GAN\n",
"# The GAN model consists of the generator followed by the discriminator\n",
"gan = tf.keras.Sequential([generator, discriminator])\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "3fgErTNUcpnF"
},
"outputs": [],
"source": [
"# Compile the discriminator model\n",
"# Uses binary cross-entropy loss for binary classification and RMSprop optimizer\n",
"discriminator.compile(loss=___, optimizer=___)\n",
"\n",
"# Set discriminator to non-trainable when training the GAN to freeze its weights\n",
"discriminator.trainable = False\n",
"\n",
"# Compile the GAN model\n",
"# Uses binary cross-entropy loss and RMSprop optimizer for training the GAN\n",
"gan.compile(loss=___, optimizer=___)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "t8WDm4X1X8ce"
},
"outputs": [],
"source": [
"batch_size = 32\n",
"dataset = tf.data.Dataset.from_tensor_slices(X_train).shuffle(1000)\n",
"dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "CzmHiugEczJP"
},
"outputs": [],
"source": [
"# Define batch size for training\n",
"batch_size = ___\n",
"\n",
"# Create a TensorFlow dataset from the training data\n",
"# Shuffle the data with a buffer size of 1000 for randomness\n",
"dataset = tf.data.Dataset.from_tensor_slices(___).shuffle(___)\n",
"\n",
"# Batch the data into chunks of size batch_size\n",
"# Prefetch data to improve performance by overlapping data preprocessing and model training\n",
"dataset = dataset.batch(___, drop_remainder=True).prefetch(___)"
]
},
{
"cell_type": "markdown",
"source": [
"### **Q9) Train the GAN to generate new images**"
],
"metadata": {
"id": "PCiNMvwF8kPZ"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "HX9PG2NMX-_m"
},
"outputs": [],
"source": [
"# Helper function to train the GAN\n",
"def train_gan(gan, dataset, batch_size, codings_size, n_epochs):\n",
" generator, discriminator = gan.layers\n",
" for epoch in range(n_epochs):\n",
" print(f\"Epoch {epoch + 1}/{n_epochs}\")\n",
"\n",
" for X_batch in dataset:\n",
" # phase 1 - training the discriminator\n",
" noise = tf.random.normal(shape=[batch_size, codings_size])\n",
" generated_images = generator(noise)\n",
" X_batch = tf.cast(X_batch, tf.float32)\n",
" X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)\n",
" y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)\n",
" discriminator.train_on_batch(X_fake_and_real, y1)\n",
"\n",
" # phase 2 - training the generator\n",
" noise = tf.random.normal(shape=[batch_size, codings_size])\n",
" y2 = tf.constant([[1.]] * batch_size)\n",
" gan.train_on_batch(noise, y2)\n",
"\n",
" plot_multiple_images(generated_images.numpy(), 8)\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "TmhaIC6JdgYi"
},
"outputs": [],
"source": [
"# Train the GAN model, with e.g. 50 epochs\n",
"train_gan(gan, dataset, batch_size, codings_size, n_epochs=___)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "v-0hAJOSdqyS"
},
"outputs": [],
"source": [
"# Generate a batch of latent vectors\n",
"codings = tf.random.normal(shape=[batch_size, codings_size])\n",
"\n",
"# Generate images using the trained generator (using the `codings`)\n",
"generated_images = generator.predict(___)\n",
"\n",
"# Plot the generated images, e.g. 5\n",
"plot_multiple_images(generated_images, ___)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7Ceb7lmMoB-e"
},
"source": [
"## 灬🅶 Deep Convolutional GANs"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iTrNRBYde47v"
},
"source": [
"Deep GANs (Generative Adversarial Networks) are a type of GAN that use deep neural networks in both the generator and the discriminator. By leveraging deep architectures, these models can create more complex and realistic images or data."
]
},
{
"cell_type": "markdown",
"source": [
"### **Q10) Complete the deep convolutional GAN architecture below**"
],
"metadata": {
"id": "SGJc3U5Q9ATw"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "atLv0EN6ZO2u"
},
"outputs": [],
"source": [
"# Define the size of the latent space, e.g. 100\n",
"codings_size = ___\n",
"\n",
"# Build the generator model\n",
"# Generates images from the latent space vector\n",
"# First Dense layer expands to 8x8x128, then to be reshaped\n",
"generator = tf.keras.Sequential([\n",
" tf.keras.layers.Dense(___ * ___ * ___), # Expand to 8x8x128\n",
" tf.keras.layers.Reshape([___, ___, ___]), # Reshape to 8x8x128\n",
" tf.keras.layers.BatchNormalization(),\n",
" # Upsample to 64 channels, 5 as kernel size, 2 strides, `same` padding\n",
" tf.keras.layers.Conv2DTranspose(___, kernel_size=___, strides=___,\n",
" padding=___, activation=\"relu\"),\n",
" tf.keras.layers.BatchNormalization(),\n",
" # Output layer with 3 channels, 5 as kernel size, 2 strides, `same` padding\n",
" tf.keras.layers.Conv2DTranspose(___, kernel_size=___, strides=___,\n",
" padding=___, activation=\"tanh\"),\n",
"])\n",
"\n",
"# Build the discriminator model that classifies images as real or fake\n",
"discriminator = tf.keras.Sequential([\n",
" # Downsample to 64; 5 as kernel size, 2 strides, `same` padding\n",
" tf.keras.layers.Conv2D(___, kernel_size=___, strides=___, padding=___,\n",
" activation=tf.keras.layers.LeakyReLU(0.2)),\n",
" tf.keras.layers.Dropout(___), # e.g. 0.4\n",
" # Downsample to 128; 5 as kernel size, 2 strides, `same` padding\n",
" tf.keras.layers.Conv2D(___, kernel_size=___, strides=___, padding=___,\n",
" activation=tf.keras.layers.LeakyReLU(0.2)),\n",
" tf.keras.layers.Dropout(___), # e.g. 0.4\n",
" tf.keras.layers.Flatten(),\n",
" tf.keras.layers.Dense(1, activation=\"sigmoid\")\n",
"])\n",
"\n",
"# The GAN model consists of the generator followed by the discriminator\n",
"gan = tf.keras.Sequential([___, ___])"
]
},
{
"cell_type": "markdown",
"source": [
"### **Q11) Train this new model to generate images**\n",
"\n",
"Do you notice improvements?"
],
"metadata": {
"id": "Lv5xpHOL9gf-"
}
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "JH4hoBDqnkN-"
},
"outputs": [],
"source": [
"# Compile the discriminator with binary cross-entropy loss and RMSprop optimizer\n",
"discriminator.compile(loss=___, optimizer=___)\n",
"\n",
"# Freeze the discriminator during GAN training\n",
"discriminator.trainable = False\n",
"\n",
"# Compile the GAN with the same loss and optimizer\n",
"gan.compile(loss=___, optimizer=___)\n",
"\n",
"# Reshape to 32x32x3 and scale to match the generator's expected output range\n",
"X_train_dcgan = X_train.reshape(-1, ___, ___, ___) * 2. - 1."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "91yT8NWQnwpD"
},
"outputs": [],
"source": [
"# Set the batch size for training\n",
"batch_size = ___\n",
"\n",
"# Create a dataset from reshaped and rescaled training data\n",
"dataset = tf.data.Dataset.from_tensor_slices(___)\n",
"\n",
"# Shuffle the dataset with buffer size 1000, batch the data, and prefetch once at a time\n",
"dataset = dataset.shuffle(___).batch(___, drop_remainder=True).prefetch(___)\n",
"\n",
"# Train the GAN model with e.g. 50 epochs\n",
"train_gan(gan, dataset, batch_size, codings_size, n_epochs=___)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YkacDq_ZoQMq"
},
"outputs": [],
"source": [
"# Generate random noise for input to the generator\n",
"# `batch_size` is the number of samples, `codings_size` is the latent space size\n",
"noise = tf.random.normal(shape=[___, ___])\n",
"\n",
"# Generate images using the generator model\n",
"generated_images = generator.predict(noise)\n",
"\n",
"# Plot e.g. 5 generated images\n",
"plot_multiple_images(generated_images, ___)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "sIkCgj2g4k-z"
},
"source": [
"## య Diffusion models"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dyvKvczoil-7"
},
"source": [
"Starting with an image from the dataset, at each time step $t$, the diffusion process adds Gaussian noise with mean 0 and variance $\\beta_t$. The model is then trained to reverse that process. More specifically, given a noisy image produced by the forward process, and given the time $t$, the model is trained to predict the total noise that was added to the original image, scaled to variance 1.\n",
"\n",
"The [DDPM paper](https://arxiv.org/abs/2006.11239) increased $\\beta_t$ from $\\beta_1$ = 0.0001 to $\\beta_T = $0.02 ($T$ is the max step), but the [Improved DDPM paper](https://arxiv.org/pdf/2102.09672.pdf) suggested using the following $\\cos^2(\\ldots)$ schedule instead, which gradually decreases $\\bar{\\alpha_t} = \\prod_{i=0}^{t} \\alpha_i$ from 1 to 0, where $\\alpha_t = 1 - \\beta_t$:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yxPNTP3QpHAw"
},
"outputs": [],
"source": [
"def variance_schedule(T, s=0.008, max_beta=0.999):\n",
" t = np.arange(T + 1)\n",
" f = np.cos((t / T + s) / (1 + s) * np.pi / 2) ** 2\n",
" alpha = np.clip(f[1:] / f[:-1], 1 - max_beta, 1)\n",
" alpha = np.append(1, alpha).astype(np.float32) # add α₀ = 1\n",
" beta = 1 - alpha\n",
" alpha_cumprod = np.cumprod(alpha)\n",
" return alpha, alpha_cumprod, beta # αₜ , α̅ₜ , βₜ for t = 0 to T\n",
"\n",
"np.random.seed(42) # extra code – for reproducibility\n",
"T = 4000\n",
"alpha, alpha_cumprod, beta = variance_schedule(T)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "P4ijMDaVkyXi"
},
"source": [
"In the DDPM paper, the authors used $T = 1,000$, while in the Improved DDPM, they bumped this up to $T = 4,000$, so we use this value. The variable `alpha` is a vector containing $\\alpha_0, \\alpha_1, ..., \\alpha_T$. The variable `alpha_cumprod` is a vector containing $\\bar{\\alpha_0}, \\bar{\\alpha_1}, ..., \\bar{\\alpha_T}$."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9z6rfMntnSMi"
},
"source": [
"Let's plot `alpha_cumprod`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sBCgEsb8Smyt"
},
"outputs": [],
"source": [
"plt.figure(figsize=(6, 3))\n",
"plt.plot(beta, \"r--\", label=r\"$\\beta_t$\")\n",
"plt.plot(alpha_cumprod, \"b\", label=r\"$\\bar{\\alpha}_t$\")\n",
"plt.axis([0, T, 0, 1])\n",
"plt.grid(True)\n",
"plt.xlabel(r\"t\")\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZL7ZBkTKndy-"
},
"source": [
"The `prepare_batch()` function takes a batch of images and adds noise to each of them, using a different random time between 1 and $T$ for each image, and it returns a tuple containing the inputs and the targets:\n",
"\n",
"* The inputs are a `dict` containing the noisy images and the corresponding times. The function uses equation (4) from the DDPM paper to compute the noisy images in one shot, directly from the original images. It's a shortcut for the forward diffusion process.\n",
"* The target is the noise that was used to produce the noisy images."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sd0VYZCaSmyt"
},
"outputs": [],
"source": [
"def prepare_batch(X):\n",
" X = tf.cast(X[..., tf.newaxis], tf.float32) * 2 - 1 # scale from –1 to +1\n",
" X_shape = tf.shape(X)\n",
" t = tf.random.uniform([X_shape[0]], minval=1, maxval=T + 1, dtype=tf.int32)\n",
" alpha_cm = tf.gather(alpha_cumprod, t)\n",
" alpha_cm = tf.reshape(alpha_cm, [X_shape[0]] + [1] * (len(X_shape) - 1))\n",
" noise = tf.random.normal(X_shape)\n",
" return {\n",
" \"X_noisy\": alpha_cm ** 0.5 * X + (1 - alpha_cm) ** 0.5 * noise,\n",
" \"time\": t,\n",
" }, noise"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "tG-_liChSmyt"
},
"source": [
"### **Q12) Prepare one `tf.data.Dataset` for training, and one for validation.**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tQNNTi_dopTm"
},
"outputs": [],
"source": [
"def prepare_dataset(X, batch_size=32, shuffle=False):\n",
" # Create a dataset from input data\n",
" ds = tf.data.Dataset.from_tensor_slices(___)\n",
"\n",
" # Optionally shuffle the dataset\n",
" if shuffle:\n",
" ds = ds.shuffle(___) # e.g. 10000\n",
"\n",
" # Batch the data, apply necessary preparation and prefetch one at time\n",
" return ds.batch(batch_size).map(prepare_batch).prefetch(___)\n",
"\n",
"# Prepare the training and validation datasets\n",
"train_set = prepare_dataset(X_train, batch_size=___, shuffle=___)\n",
"valid_set = prepare_dataset(X_valid, batch_size=___)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "oAxL8EmF3Hcl"
},
"source": [
"As a quick sanity check, let's take a look at a few training samples, along with the corresponding noise to predict, and the original images (which we get by subtracting the appropriately scaled noise from the appropriately scaled noisy image):"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "qNAqiAtDqu_B"
},
"outputs": [],
"source": [
"def subtract_noise(X_noisy, time, noise):\n",
" X_shape = tf.shape(X_noisy)\n",
" alpha_cm = tf.gather(alpha_cumprod, time)\n",
" alpha_cm = tf.reshape(alpha_cm, [X_shape[0]] + [1] * (len(X_shape) - 1))\n",
" return (X_noisy - (1 - alpha_cm) ** 0.5 * noise) / alpha_cm ** 0.5\n",
"\n",
"X_dict, Y_noise = list(train_set.take(1))[0] # get the first batch\n",
"X_original = subtract_noise(X_dict[\"X_noisy\"], X_dict[\"time\"], Y_noise)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "y8mwYDDtz6Zm"
},
"outputs": [],
"source": [
"# Plot original images, noisy images and the noise to predict\n",
"print(\"Original images\")\n",
"plot_multiple_images(((X_original[:8].numpy()+1)*128).astype(np.uint8))\n",
"plt.show()\n",
"print(\"Time steps:\", X_dict[\"time\"].numpy()[:8])\n",
"print(\"Noisy images\")\n",
"plot_multiple_images(((X_dict[\"X_noisy\"][:8].numpy()+1)*128).astype(np.uint8))\n",
"plt.show()\n",
"print(\"Noise to predict\")\n",
"plot_multiple_images(((Y_noise[:8].numpy()+1)*128).astype(np.uint8))\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"source": [
"### **Q13) Complete the diffusion model architecture below**"
],
"metadata": {
"id": "ZtRETFah9_Pz"
}
},
{
"cell_type": "markdown",
"metadata": {
"id": "dcBpIkNxrPw9"
},
"source": [
"Now we're ready to build the diffusion model itself. It will need to process both images and times. We will encode the times using a sinusoidal encoding, as suggested in the DDPM paper, just like in the [Attention is all you need](https://arxiv.org/abs/1706.03762) paper. Given a vector of _m_ integers representing time indices (integers), the layer returns an _m_ × _d_ matrix, where _d_ is the chosen embedding size."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0HGLhS4oNkNC"
},
"outputs": [],
"source": [
"embed_size = 64\n",
"\n",
"class TimeEncoding(tf.keras.layers.Layer):\n",
" def __init__(self, T, embed_size, dtype=tf.float32, **kwargs):\n",
" super().__init__(dtype=dtype, **kwargs)\n",
" assert embed_size % 2 == 0, \"embed_size must be even\"\n",
" p, i = np.meshgrid(np.arange(T + 1), 2 * np.arange(embed_size // 2))\n",
" t_emb = np.empty((T + 1, embed_size))\n",
" t_emb[:, ::2] = np.sin(p / 10_000 ** (i / embed_size)).T\n",
" t_emb[:, 1::2] = np.cos(p / 10_000 ** (i / embed_size)).T\n",
" self.time_encodings = tf.constant(t_emb.astype(self.dtype))\n",
"\n",
" def call(self, inputs):\n",
" return tf.gather(self.time_encodings, inputs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "oT72nUUuqz7z"
},
"outputs": [],
"source": [
"# the size of the embedding e.g. 64\n",
"embed_size = ___\n",
"\n",
"# Custom layer to encode time steps with sinusoidal embeddings\n",
"class TimeEncoding(tf.keras.layers.Layer):\n",
" def __init__(self, T, embed_size, dtype=tf.float32, **kwargs):\n",
" # Initialize layer and ensure embed_size is even\n",
" super().__init__(dtype=dtype, **kwargs)\n",
" assert embed_size % 2 == 0, \"embed_size must be even\"\n",
"\n",
" # Create a meshgrid for time steps and embedding indices\n",
" p, i = np.meshgrid(np.arange(T + 1), 2 * np.arange(embed_size // 2))\n",
"\n",
" # Initialize the time embeddings matrix\n",
" t_emb = np.empty((T + 1, embed_size))\n",
"\n",
" # Fill even indices with sine values and odd indices with cosine values\n",
" t_emb[:, ::2] = np.sin(p / 10_000 ** (i / embed_size)).T\n",
" t_emb[:, 1::2] = np.cos(p / 10_000 ** (i / embed_size)).T\n",
"\n",
" # Convert the embeddings to TensorFlow constant\n",
" self.time_encodings = tf.constant(t_emb.astype(self.dtype))\n",
"\n",
" # Method to fetch time encodings for `input` time steps\n",
" def call(self, inputs):\n",
" return tf.gather(self.time_encodings, ___)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "18SYrUNysJ62"
},
"source": [
"Now let's build the model. In the Improved DDPM paper, they use a UNet model. We'll create a UNet-like model, that processes the image through `Conv2D` + `BatchNormalization` layers and skip connections, gradually downsampling the image (using `MaxPooling` layers with `strides=2`), then growing it back again (using `Upsampling2D` layers). Skip connections are also added across the downsampling part and the upsampling part. We also add the time encodings to the output of each block, after passing them through a `Dense` layer to resize them to the right dimension.\n",
"\n",
"* **Note**: an image's time encoding is added to every pixel in the image, along the last axis (channels). So the number of units in the `Conv2D` layer must correspond to the embedding size, and we must reshape the `time_enc` tensor to add the width and height dimensions.\n",
"* This UNet implementation was inspired by keras.io's [image segmentation example](https://keras.io/examples/vision/oxford_pets_image_segmentation/), as well as from the [official diffusion models implementation](https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/models/unet.py). Compared to the first implementation, I added a few things, especially time encodings and skip connections across down/up parts. Compared to the second implementation, I removed a few things, especially the attention layers. It seemed like overkill for Fashion MNIST, but feel free to add them."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "IbBIF8TvrF7C"
},
"outputs": [],
"source": [
"def build_diffusion_model():\n",
" # Define inputs: noisy image and time step\n",
" X_noisy = tf.keras.layers.Input(shape=[32, 32, 3], name=\"X_noisy\")\n",
" time_input = tf.keras.layers.Input(shape=[], dtype=tf.int32, name=\"time\")\n",
"\n",
" # Encode the time step using the custom TimeEncoding layer that takes `T` and `embed_size`\n",
" time_enc = TimeEncoding(__, ___)(time_input)\n",
"\n",
" dim = ___ # e.g 16\n",
" # Initial convolution with zero padding for the noisy input\n",
" Z = tf.keras.layers.ZeroPadding2D((3, 3))(X_noisy)\n",
" Z = tf.keras.layers.Conv2D(dim, 3)(Z)\n",
" Z = tf.keras.layers.BatchNormalization()(Z)\n",
" Z = tf.keras.layers.Activation(\"relu\")(Z)\n",
"\n",
" # Adapt the time encoding and add it to the image feature map\n",
" time = tf.keras.layers.Dense(dim)(time_enc)\n",
" Z = time[:, tf.newaxis, tf.newaxis, :] + Z # add time info to every pixel\n",
"\n",
" # Keep track of skip connections and initiate a residual connection\n",
" skip = Z\n",
" cross_skips = [] # for skip connections in UNet structure\n",
"\n",
" # Downsampling block\n",
" for dim in (32, 64, 128):\n",
" Z = tf.keras.layers.Activation(\"relu\")(Z)\n",
" Z = tf.keras.layers.SeparableConv2D(dim, 3, padding=\"same\")(Z)\n",
" Z = tf.keras.layers.BatchNormalization()(Z)\n",
"\n",
" Z = tf.keras.layers.Activation(\"relu\")(Z)\n",
" Z = tf.keras.layers.SeparableConv2D(dim, 3, padding=\"same\")(Z)\n",
" Z = tf.keras.layers.BatchNormalization()(Z)\n",
"\n",
" # Store intermediate output for skip connection\n",
" cross_skips.append(Z)\n",
"\n",
" # Downsample and add residual connection\n",
" Z = tf.keras.layers.MaxPooling2D(3, strides=2, padding=\"same\")(Z)\n",
" skip_link = tf.keras.layers.Conv2D(dim, 1, strides=2, padding=\"same\")(skip)\n",
" Z = tf.keras.layers.add([Z, skip_link])\n",
"\n",
" # Add time information to downsampled feature maps\n",
" time = tf.keras.layers.Dense(dim)(time_enc)\n",
" Z = time[:, tf.newaxis, tf.newaxis, :] + Z\n",
" skip = Z\n",
"\n",
" # Upsampling block\n",
" for dim in (64, 32, 16):\n",
" Z = tf.keras.layers.Activation(\"relu\")(Z)\n",
" Z = tf.keras.layers.Conv2DTranspose(dim, 3, padding=\"same\")(Z)\n",
" Z = tf.keras.layers.BatchNormalization()(Z)\n",
"\n",
" Z = tf.keras.layers.Activation(\"relu\")(Z)\n",
" Z = tf.keras.layers.Conv2DTranspose(dim, 3, padding=\"same\")(Z)\n",
" Z = tf.keras.layers.BatchNormalization()(Z)\n",
"\n",
" # Upsample and add residual connection\n",
" Z = tf.keras.layers.UpSampling2D(2)(Z)\n",
" skip_link = tf.keras.layers.UpSampling2D(2)(skip)\n",
" skip_link = tf.keras.layers.Conv2D(dim, 1, padding=\"same\")(skip_link)\n",
" Z = tf.keras.layers.add([Z, skip_link])\n",
"\n",
" # Add time encoding and merge with corresponding downsample skip connection\n",
" time = tf.keras.layers.Dense(dim)(time_enc)\n",
" Z = time[:, tf.newaxis, tf.newaxis, :] + Z\n",
" Z = tf.keras.layers.concatenate([Z, cross_skips.pop()], axis=-1)\n",
" skip = Z\n",
"\n",
" # Final convolution layer, output cropped to remove padding\n",
" outputs = tf.keras.layers.Conv2D(1, 3, padding=\"same\")(Z)[:, 2:-2, 2:-2]\n",
"\n",
" # Return the model with inputs and outputs\n",
" return tf.keras.Model(inputs=[X_noisy, time_input], outputs=[outputs])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "R60NMcUT0b1K"
},
"source": [
"Let's train the model!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sO7wA3JQtlOl"
},
"outputs": [],
"source": [
"# Build and compile the diffusion model\n",
"model = build_diffusion_model()\n",
"model.compile(loss=tf.keras.losses.Huber(), optimizer=\"nadam\")\n",
"\n",
"# Create a checkpoint callback to save the best model during training\n",
"checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(\"my_diffusion_model\",\n",
" save_best_only=True)\n",
"\n",
"# Train the model with the training and validation datasets with e.g. 100 epochs\n",
"history = model.fit(train_set, validation_data=valid_set, epochs=___,\n",
" callbacks=[checkpoint_cb])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HY4LX4TZSmyu"
},
"source": [
"Now that the model is trained, we can use it to generate new images. For this, we just generate Gaussian noise, and pretend this is the result of the diffusion process, and we're at time $T$. Then we use the model to predict the image at time $T - 1$, then we call it again to get $T - 2$, and so on, removing a bit of noise at each step. At the end, we get an image that looks like it's from the Fashion MNIST dataset. The equation for this reverse process is at the top of page 4 in the DDPM paper (step 4 in algorithm 2)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "vMXxz4qV8Luk"
},
"outputs": [],
"source": [
"def generate(model, batch_size=32):\n",
" X = tf.random.normal([batch_size, 28, 28, 1])\n",
" for t in range(T - 1, 0, -1):\n",
" print(f\"\\rt = {t}\", end=\" \") # extra code – show progress\n",
" noise = (tf.random.normal if t > 1 else tf.zeros)(tf.shape(X))\n",
" X_noise = model({\"X_noisy\": X, \"time\": tf.constant([t] * batch_size)})\n",
" X = (\n",
" 1 / alpha[t] ** 0.5\n",
" * (X - beta[t] / (1 - alpha_cumprod[t]) ** 0.5 * X_noise)\n",
" + (1 - alpha[t]) ** 0.5 * noise\n",
" )\n",
" return X"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Fntgnuvkt1Qq"
},
"outputs": [],
"source": [
"# Generate images\n",
"X_gen = generate(model)\n",
"\n",
"# Plot the generated images\n",
"plot_multiple_images(X_gen.numpy(), 5)\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UV7onCCASmyu"
},
"source": [
"Some of these images are really convincing! Compared to GANs, diffusion models tend to generate more diverse images, and they have surpassed GANs in image quality. Moreover, training is much more stable. However, generating images takes *much* longer."
]
}
],
"metadata": {
"colab": {
"gpuType": "T4",
"provenance": [],
"toc_visible": true,
"include_colab_link": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
},
"nav_menu": {
"height": "381px",
"width": "453px"
},
"toc": {
"navigate_menu": true,
"number_sections": true,
"sideBar": true,
"threshold": 6,
"toc_cell": false,
"toc_section_display": "block",
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 0
}