Diffusion From Scratch To Generate Cute Anime Faces!
Published:
The complete implementation is available on GitHub.
Introduction:
Diffusion models have revolutionized the field of AI art, enabling the generation of stunning visuals, including adorable anime faces. While prebuilt tools like Stable Diffusion make it easy to create such images, building a diffusion model from scratch offers invaluable benefits. It deepens your understanding of the underlying principles, gives you the flexibility to customize and innovate, and provides a hands-on learning experience that can be both rewarding and empowering. In this blog post, we’ll explore how to create a diffusion model step-by-step, equipping you with the knowledge to generate charming anime faces and unlock your creative potential.
Setup
Clone the Repository
git clone https://github.com/ramintoosi/diffusion-from-scratch
install requirements
pip install -r requirements.txt
train the model
python main.py train
inference
python main.py inference
Ok, let’s go through the details.
Diffusion Model
Diffusion models operate on a simple yet powerful intuition: they learn to generate data by reversing a gradual noise process. Imagine starting with a clear image and progressively adding random noise to it, step by step, until it becomes completely unrecognizable. Diffusion models essentially learn this “noising” process in reverse—they begin with random noise and iteratively refine it to recover the original, clear image. By training on vast datasets, they master this denoising process, enabling them to generate entirely new images from pure noise, guided by patterns and structures learned during training. This step-by-step refinement makes diffusion models particularly adept at producing high-quality and detailed outputs.
Model Architecture Implementation
Forward Pass
The DiffusionForwardProcess
` class implements the forward diffusion process, which systematically adds noise to an image over a series of time steps. This gradual “noising” process prepares the model for learning to reverse it during training.
class DiffusionForwardProcess:
"""
Implements the forward process of the diffusion model.
"""
def __init__(self,
num_time_steps: int =1000,
beta_start: float = 1e-4,
beta_end: float = 0.02
):
"""
Initializes the DiffusionForwardProcess with the given parameters.
:param num_time_steps: Number of time steps in the diffusion process.
:param beta_start: Starting value of beta.
:param beta_end: Ending value of beta.
"""
self.betas = torch.linspace(beta_start, beta_end, num_time_steps)
self.alphas = 1 - self.betas
self.alpha_bars = torch.cumprod(self.alphas, dim=0)
self.sqrt_alpha_bars = torch.sqrt(self.alpha_bars)
self.sqrt_one_minus_alpha_bars = torch.sqrt(1 - self.alpha_bars)
def add_noise(self, original: Tensor, noise: Tensor, t: Tensor) -> Tensor:
"""
Adds noise to the original image at the given time step t.
:param original: Input Image
:param noise: Random Noise Tensor sampled from Normal Dist
:param t: timestep
:return: Noisy image tensor
"""
sqrt_alpha_bar_t = self.sqrt_alpha_bars.to(original.device)[t]
sqrt_one_minus_alpha_bar_t = self.sqrt_one_minus_alpha_bars.to(original.device)[t]
# Broadcast to multiply with the original image.
sqrt_alpha_bar_t = sqrt_alpha_bar_t[:, None, None, None]
sqrt_one_minus_alpha_bar_t = sqrt_one_minus_alpha_bar_t[:, None, None, None]
# Return
return (sqrt_alpha_bar_t * original) \
+ \
(sqrt_one_minus_alpha_bar_t * noise)
Reverse Pass
The DiffusionReverseProcess
class implements the reverse denoising process, which is the core of the generative aspect of diffusion models. Starting with pure noise, this process iteratively predicts and reconstructs cleaner versions of the image until the original-like data emerges.
class DiffusionReverseProcess:
"""
Implements the reverse process of the diffusion model.
"""
def __init__(self,
num_time_steps: int = 1000,
beta_start: float = 1e-4,
beta_end: float = 0.02
):
"""
Initializes the DiffusionReverseProcess with the given parameters.
:param num_time_steps: Number of time steps in the diffusion process.
:param beta_start: Starting value of beta.
:param beta_end: Ending value of beta.
"""
# Precomputing beta, alpha, and alpha_bar for all t's.
self.b = torch.linspace(beta_start, beta_end, num_time_steps) # b -> beta
self.a = 1 - self.b # a -> alpha
self.a_bar = torch.cumprod(self.a, dim=0) # a_bar = alpha_bar
def sample_prev_timestep(self, xt: Tensor, noise_pred: Tensor, t) -> (Tensor, Tensor):
"""
Samples the previous timestep image given the current timestep image and noise prediction.
:param xt: Image tensor at timestep t of shape -> B x C x H x W
:param noise_pred: Noise tensor predicted by the model at timestep t of shape -> B x C x H x W
:param t: timestep
:return: predicted x_t-1 and x0
"""
# Original Image Prediction at timestep t
x0 = xt - (torch.sqrt(1 - self.a_bar.to(xt.device)[t]) * noise_pred)
x0 = x0 / torch.sqrt(self.a_bar.to(xt.device)[t])
x0 = torch.clamp(x0, -1., 1.)
# mean of x_(t-1)
mean = (xt - ((1 - self.a.to(xt.device)[t]) * noise_pred) / (torch.sqrt(1 - self.a_bar.to(xt.device)[t])))
mean = mean / (torch.sqrt(self.a.to(xt.device)[t]))
# only return mean
if t == 0:
return mean, x0
else:
variance = (1 - self.a_bar.to(xt.device)[t - 1]) / (1 - self.a_bar.to(xt.device)[t])
variance = variance * self.b.to(xt.device)[t]
sigma = variance ** 0.5
z = torch.randn(xt.shape).to(xt.device)
return mean + sigma * z, x0
Time Embedding
The get_time_embedding
function generates a time-step embedding, transforming scalar time-step values into high-dimensional vector representations. These embeddings are critical for encoding temporal information that guides the diffusion model in processing noise at specific timesteps.
def get_time_embedding(time_steps: torch.Tensor, t_emb_dim: int) -> torch.Tensor:
"""
Transform a scalar time-step into a vector representation of size t_emb_dim.
:param time_steps: 1D tensor of size -> (Batch, )
:param t_emb_dim: Embedding Dimension -> for ex: 128 (scalar value)
:return tensor of size -> (B, t_emb_dim)
"""
assert t_emb_dim % 2 == 0, "time embedding must be divisible by 2."
factor = 2 * torch.arange(start=0,
end=t_emb_dim // 2,
dtype=torch.float32,
device=time_steps.device
) / t_emb_dim
factor = 10000 ** factor
t_emb = time_steps[:, None] # B -> (B, 1)
t_emb = t_emb / factor # (B, 1) -> (B, t_emb_dim//2)
t_emb = torch.cat([torch.sin(t_emb), torch.cos(t_emb)], dim=1) # (B , t_emb_dim)
return t_emb
The TimeEmbedding
class is a simple neural network module that transforms a time-step embedding into the desired output dimension. This transformation is typically used to align the time embedding with the dimensions required for downstream operations in the diffusion model.
class TimeEmbedding(nn.Module):
"""
Maps the Time Embedding to the Required output Dimension.
"""
def __init__(self,
n_out: int, # Output Dimension
t_emb_dim: int = 128 # Time Embedding Dimension
):
super().__init__()
# Time Embedding Block
self.te_block = nn.Sequential(
nn.SiLU(),
nn.Linear(t_emb_dim, n_out)
)
def forward(self, x):
return self.te_block(x)
Conv Block
The NormActConv class is a module that sequentially applies Group Normalization, Activation, and Convolution operations. This modular design simplifies the construction of neural networks used in diffusion models, particularly for processing image data.
class NormActConv(nn.Module):
"""
Perform GroupNorm, Activation, and Convolution operations.
"""
def __init__(self,
in_channels: int,
out_channels: int,
num_groups: int = 8,
kernel_size: int = 3,
norm: bool = True,
act: bool = True
):
super().__init__()
# GroupNorm
self.g_norm = nn.GroupNorm(
num_groups,
in_channels
) if norm is True else nn.Identity()
# Activation
self.act = nn.SiLU() if act is True else nn.Identity()
# Convolution
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
padding=(kernel_size - 1) // 2
)
def forward(self, x):
x = self.g_norm(x)
x = self.act(x)
x = self.conv(x)
return x
Self Attention
The SelfAttentionBlock class is a neural network module that applies Group Normalization followed by Multi-Headed Self-Attention to capture long-range dependencies in the input data. This is particularly useful in processing spatial information in images for tasks like denoising and generation.
class SelfAttentionBlock(nn.Module):
"""
Perform GroupNorm and Multi-headed Self Attention operation.
"""
...
Downsample and Upsample Blocks
The Downsample
class performs downsampling operations on input tensors, reducing their spatial dimensions (height and width) by a factor of k. It provides two methods for downsampling—Convolution-based and Max-pooling-based, with the option to combine both for enhanced feature representation. Upsampling module follows a similar design pattern, but in the reverse direction, increasing the spatial dimensions of the input tensor.
class Downsample(nn.Module):
"""
Perform Down sampling by the factor of k across Height and Width.
"""
def __init__(self,
in_channels: int,
out_channels: int,
k: int = 2, # Downsampling factor
use_conv: bool = True, # If Downsampling using conv-block
use_mpool: bool = True # If Downsampling using max-pool
):
super(Downsample, self).__init__()
self.use_conv = use_conv
self.use_mpool = use_mpool
# Downsampling using Convolution
self.cv = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=1),
nn.Conv2d(
in_channels,
out_channels // 2 if use_mpool else out_channels,
kernel_size=4,
stride=k,
padding=1
)
) if use_conv else nn.Identity()
# Downsampling using Maxpool
self.mpool = nn.Sequential(
nn.MaxPool2d(k, k),
nn.Conv2d(
in_channels,
out_channels // 2 if use_conv else out_channels,
kernel_size=1,
stride=1,
padding=0
)
) if use_mpool else nn.Identity()
def forward(self, x):
if not self.use_conv:
return self.mpool(x)
if not self.use_mpool:
return self.cv(x)
return torch.cat([self.cv(x), self.mpool(x)], dim=1)
Unet Parts
The DownC
, MidC
, and UpC
modules are key components of a U-Net-like architecture tailored for image generation tasks, such as diffusion models. These modules work together to process and transform input data through hierarchical downsampling, mid-level refinement, and upsampling.
class DownC(nn.Module):
"""
Perform Down-convolution on the input using following approach.
1. Conv + TimeEmbedding
2. Conv
3. Skip-connection from input x.
4. Self-Attention
5. Skip-Connection from 3.
6. Downsampling
"""
class MidC(nn.Module):
"""
Refine the features obtained from the DownC block.
It refines the features using following operations:
1. Resnet Block with Time Embedding
2. A Series of Self-Attention + Resnet Block with Time-Embedding
"""
class UpC(nn.Module):
"""
Perform Up-convolution on the input using following approach.
1. Upsampling
2. Conv + TimeEmbedding
3. Conv
4. Skip-connection from 1.
5. Self-Attention
6. Skip-Connection from 3.
"""
1. DownC
: Down-Convolution Block
Responsible for extracting multi-scale features while reducing spatial resolution. It captures hierarchical features and incorporates temporal context via time embeddings.
Key Steps:
- Convolution + Time Embedding: Extracts spatial features while integrating temporal information (important for diffusion models).
- Additional Convolution: Deepens feature extraction.
- Skip-Connection: Creates a residual path from the original input, preserving critical low-level details.
- Self-Attention: Enables the model to capture long-range dependencies in spatial data.
- Skip-Connection: Adds the features from step 3 back into the flow, enhancing gradient flow and feature reuse.
- Downsampling: Reduces spatial resolution, allowing deeper layers to focus on abstract patterns.
2. MidC
: Middle Block
Serves as the bottleneck in the U-Net, refining features extracted by DownC
while integrating global context.
Key Steps:
- ResNet Block with Time Embedding: Refines features using residual connections while incorporating temporal embeddings for time-awareness.
- Self-Attention + ResNet Block: A sequence of self-attention layers and ResNet blocks further enriches features by combining spatial attention with robust feature refinement.
3. UpC
: Up-Convolution Block
Reconstructs the spatial dimensions by progressively upsampling while merging features from earlier layers via skip-connections.
Key Steps:
- Upsampling: Increases spatial resolution using learned or interpolation-based methods.
- Convolution + Time Embedding: Processes upscaled features while integrating temporal information.
- Additional Convolution: Refines upsampled features.
- Skip-Connection: Reintroduces earlier features, ensuring high-resolution details are preserved.
- Self-Attention: Captures global spatial dependencies, essential for generating coherent outputs.
- Skip-Connection: Combines processed features with those from step 3 for enhanced reconstruction.
Integration in the U-Net:
- Encoder (DownC): Captures features at multiple resolutions, gradually reducing spatial size while increasing feature richness.
- Bottleneck (MidC): Acts as a bridge, blending abstract, low-resolution features with global dependencies.
- Decoder (UpC): Reconstructs the image, integrating skip-connections from the encoder to ensure high-resolution details are preserved.
class Unet(nn.Module):
"""
U-net architecture
"""
def __init__(self,
im_channels: int = 1, # RGB
down_ch=None,
mid_ch=None,
up_ch=None,
down_sample=None,
t_emb_dim: int = 128,
num_downc_layers: int = 2,
num_midc_layers: int = 2,
num_upc_layers: int = 2
):
super(Unet, self).__init__()
if down_sample is None:
down_sample = [True, True, False]
if up_ch is None:
up_ch = [256, 128, 64, 16]
if mid_ch is None:
mid_ch = [256, 256, 128]
if down_ch is None:
down_ch = [32, 64, 128, 256]
self.im_channels = im_channels
self.down_ch = down_ch
self.mid_ch = mid_ch
self.up_ch = up_ch
self.t_emb_dim = t_emb_dim
self.down_sample = down_sample
self.num_downc_layers = num_downc_layers
self.num_midc_layers = num_midc_layers
self.num_upc_layers = num_upc_layers
self.up_sample = list(reversed(self.down_sample)) # [False, True, True]
# Initial Convolution
self.cv1 = nn.Conv2d(self.im_channels, self.down_ch[0], kernel_size=3, padding=1)
# Initial Time Embedding Projection
self.t_proj = nn.Sequential(
nn.Linear(self.t_emb_dim, self.t_emb_dim),
nn.SiLU(),
nn.Linear(self.t_emb_dim, self.t_emb_dim)
)
# DownC Blocks
self.downs = nn.ModuleList([
DownC(
self.down_ch[i],
self.down_ch[i + 1],
self.t_emb_dim,
self.num_downc_layers,
self.down_sample[i]
) for i in range(len(self.down_ch) - 1)
])
# MidC Block
self.mids = nn.ModuleList([
MidC(
self.mid_ch[i],
self.mid_ch[i + 1],
self.t_emb_dim,
self.num_midc_layers
) for i in range(len(self.mid_ch) - 1)
])
# UpC Block
self.ups = nn.ModuleList([
UpC(
self.up_ch[i],
self.up_ch[i + 1],
self.t_emb_dim,
self.num_upc_layers,
self.up_sample[i]
) for i in range(len(self.up_ch) - 1)
])
# Final Convolution
self.cv2 = nn.Sequential(
nn.GroupNorm(8, self.up_ch[-1]),
nn.Conv2d(self.up_ch[-1], self.im_channels, kernel_size=3, padding=1)
)
def forward(self, x, t):
out = self.cv1(x)
# Time Projection
t_emb = get_time_embedding(t, self.t_emb_dim)
t_emb = self.t_proj(t_emb)
# DownC outputs
down_outs = []
for down in self.downs:
down_outs.append(out)
out = down(out, t_emb)
# MidC outputs
for mid in self.mids:
out = mid(out, t_emb)
# UpC Blocks
for up in self.ups:
down_out = down_outs.pop()
out = up(out, down_out, t_emb)
# Final Conv
out = self.cv2(out)
return out
Data
Download the Anime Face Dataset and put it in ./data/anime
.
Training
The training function is a straight forward PyTorch training loop that iterates over the dataset and updates the model parameters using the Adam optimizer. The loss function is the Mean Squared Error (MSE) loss, which measures the difference between the predicted and target images.
Inference
Using the inference module, one can generate images using the trained model. The inference process involves sampling noise tensors and iteratively predicting the previous timestep image until the original-like image is reconstructed.
def generate(cfg: CONFIG) -> Tensor:
"""
Generate Image using trained model.
:param cfg: config
:return: image tensor
"""
# Device
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
# print(f'Device: {device}\n')
# Initialize Diffusion Reverse Process
drp = DiffusionReverseProcess()
# Set model to eval mode
model = torch.load(cfg.model_path).to(device)
model.eval()
# Generate Noise sample from N(0, 1)
xt = torch.randn(1, cfg.in_channels, cfg.img_size, cfg.img_size).to(device)
# Denoise step by step by going backward.
with torch.no_grad():
for t in reversed(range(cfg.num_timesteps)):
noise_pred = model(xt, torch.as_tensor(t).unsqueeze(0).to(device))
xt, x0 = drp.sample_prev_timestep(xt, noise_pred, torch.as_tensor(t).to(device))
# Convert the image to proper scale
xt = torch.clamp(xt, -1., 1.).detach().cpu()
xt = (xt + 1) / 2
return xt
Results
GitHub Repository
The complete implementation is available on GitHub.
Reference
[2] ChatGPT