Training a Diffusion Model
text
Training a Diffusion Model
Now that we've got an intuitive understanding of latent diffusion models and their main components, let's dive more technically into what happens during the training process.
Iterative Steps
First, we need to understand that the work done by a diffusion model occurs over several small iterative steps, rather than one large step, like we may be used to seeing with other generative models.
During a single training epoch for a single training sample with a GAN, for example, a noise vector gets passed to a generator network that outputs an image of what we want to produce. Then we have to determine if the generator did a good job. To do that, we pass the image to the discriminator network, which determines if the image is real or fake (real meaning that it is from the training set, and fake meaning that it was generated by the generator).
This works, but GANs often run into problems like mode collapse where the generator collapses onto generating only a single image that it can continuously fool the discriminator with. Intuitively, it's a hard problem to go from a random noise vector to a great looking, quality resolution, realistic image in one step.
With diffusion models, we simplify the process into small, iterative steps so that the work that the generative model has to do at each step is much less. In other words, in a single training epoch, we break up the problem from one step into several small steps. We run the model iteratively over a a sequence of these small steps to complete the job of one training epoch.
Training Process
We'll now elaborate further on what occurs during these small steps to give a more thorough overview of the training process.
Noise Prediction
Just as we're used to with other model types, with diffusion models, we work with a training set of images. Since our focus is on latent diffusion models, we compress these images using a variational autoencoder (VAE). As mentioned in our introduction to latent diffusion models, these compressed images are called latents.
Using what's called a noise scheduler, we add various amounts of noise to these compressed training images. As we also discussed in our intuitive intro, we then pass the noisy images to the diffusion model, and we want it to predict the noise that is present in the images.
We can then take the predicted noise and subtract it from the noisy input sample, which will give us the hopefully accurate non-noisy training image. Having the network predict the noise, rather than the original image itself, turns out to be an simpler task.
During training, the network's job is to iteratively undo or remove this noise. We might think the best way to do this is to just train the network on the pairs of the original images and their corresponding final noisy image counterparts in order for the network to learn how to identify noise. In other words, we pass both images to the network and say, "here is the clear image, here is the noisy version of this image. Find out what needs to be done to the noisy image to arrive at the clear one."
We could take this approach, however, then we will have arrived at a similarly hard problem as with GANs where we're attempting to train in a single step by attempting to remove all the noise in this single step. Instead, with diffusion models, we make use of the incrementally noisy images for which the network will learn to incrementally undo over several small steps.
Noise Scheduler
The noise scheduler is a tool that determines how much noise is added to a training image according to some predefined schedule. For each training image, the noise, often referred to as \(\beta\) or \(\sigma\), is randomly sampled from a distribution and added to the image.
The noise is sampled based on the selection of a random number \(t\) ranging between \(0\) and some set max number \(T\). The higher the value of \(t\), the more noise. At \(t=0\), this represents no noise, and \(t=T\) represents the max amount of noise.
The graph of a noise schedule often depicts \(t\) decreasing from \(T\) to \(0\) because it usually represents the variance of the noise added at each step of the diffusion process, and this typically decreases over time.
Note that \(t\) is sometimes referred to as a timestep, as in the amount of noise added to an image at timestep \(t\). We won't refer to \(t\) as a timestep, as it will convolute the explanation. We'll just refer to \(t\) as a random number between \(0\) and \(T\) that determines the amount of noise added to a given image.
During training, we randomly select a training sample along with a random \(t\), and then the noise corresponding to \(t\) is added to the image. A different amount of noise will be added to each training image, and the amount of noise added to each image will vary from being just a slight amount to being so much that it only looks like random noise.
For example, in a batch of noisy training images passed to the network, the noise that has been added to image \(1\) might be from time \(t=10\), and the noise added to image \(2\) might be from time \(t=600\). For each of these images, we're asking the network to predict the noise present in the image.
By receiving images with varying amounts of noise, the network will learn how to denoise images incrementally over several small steps, rather than all at once.
Training Example in Steps
Suppose we pass a training image for which the noise scheduler has added an amount of noise corresponding to \(t=3\). We want the network to predict the total amount of noise in this image.
In a perfect world, with the network we've described so far, the end result would be the perfectly clear image at \(t=0\).
In reality, however, this is not the case. The predicted result of the image at \(t=0\) may be only a vague representation of the original training image, and so it is to be viewed only as a very first rough estimate of the original training image at time \(t=0\).
Next, we add most (but not all) of the predicted noise back to this estimated image at \(t=0\). Suppose we have a constant \(c\) for which we multiply the noise by to determine how much of the predicted noise is added to the estimated \(t=0\) image. For this example, let's suppose \(c=0.9\), meaning that \(90\%\) of the predicted noise will be added back to the \(t=0\) estimate.
Then we pass this now slightly less noisy image than the one we started with to the network and have it predict the noise. We again subtract the predicted noise from the new input image to get another, improved \(t=0\) estimate.
We, again, add back \(90\%\) of this new noise to the new \(t=0\) estimate. The result of this sum will be the slightly less noisy input to the network in the next step.
We repeat this process in a loop over a predefined number of training steps. Suppose for this example, we have \(100\) total steps. Each time we do an iteration, we're passing in a slightly less noisy image as input to the network and getting closer and closer estimates to the original image.
This process is how a typical step of training as we know it is broken up into multiple smaller steps for diffusion models.
After completing these steps for a batch of images, we use a gradient descent-based optimizer to calculate the gradients and update the weights in the network in order to improve the network's ability to predict noise in the next batch. This process continues over a defined number of epochs.
The overall objective of the model is the same as neural networks we've encountered in the past, which is of course to minimize the loss. The loss gives us a measure of how well the network is predicting the noise by comparing the output noise to the target noise in the original compressed noisy image.
Training Summary
To quickly summarize this process, during training, we choose a random training image and a random number \(t\) between 0 and a predefined max number \(T\). A noise scheduler determines the amount of noise that corresponds to \(t\) and adds this noise to the image. We then pass this noisy image along the random number \(t\) to the network.
The network predicts the noise present in the image. This predicted noise is subtracted from the noisy image to give an estimate of the original image at \(t=0\).
We then add back the majority of this predicted noise to the \(t=0\) estimate, and this slightly less noisy image is used as the input for the next step. This process is repeated for a set number of training steps, which results in clearer and clearer estimates of the original non-noisy image at each iteration.
This summarizes the training process. Now we can see how a diffusion model generates images by denoising noisy input images. We'll later expand more on how we can direct the types of images we want the model to generate by passing a text prompt.
quiz
resources
updates
Committed by on