Diffusion Transformers, or DiT, swap the U‑Net backbone in diffusion models for a Vision Transformer. The idea is to retain everything that works in diffusion training, while allowing a standard Transformer to perform the denoising. The result is strong scaling behavior and state-of-the-art image quality on ImageNet at 256 and 512 resolution.
What changes compared to a U‑Net
DiT keeps the diffusion pipeline but moves it to the latent space of a pretrained autoencoder. Instead of predicting noise directly on pixels, the model predicts noise on a compact latent grid. The paper uses the Stable Diffusion VAE, where an input image at 256 by 256 becomes a latent of size 32 by 32 with 4 channels. For 512 by 512, the latent is 64 by 64 by 4. This 8x spatial downsample makes training much cheaper while preserving fidelity when decoding back to pixels.
The latent grid is turned into a sequence of non‑overlapping patches, just like in a Vision Transformer. With patch size p, the sequence length is T = (I/p)² where I is the latent spatial size. For 256 images, I = 32. So p = 2 gives 256 tokens, p = 4 gives 64 tokens, p = 8 gives 16 tokens. The model adds standard sine‑cosine positional embeddings to all tokens.
How conditioning is injected
Diffusion models must be conditioned on the timestep and often on a class label. The paper evaluates four ways to inject this information into a Transformer block:
- In‑context conditioning by concatenating two extra tokens that encode t and c.
- Cross‑attention that attends from image tokens to a small sequence built from t and c.
- Adaptive LayerNorm (adaLN) that uses an MLP of [t, c] to produce a per‑block scale and shift applied to normalized activations.
- adaLN‑Zero, which is adaLN plus zero‑initialized output gates, so each block starts as an identity map.
adaLN‑Zero performs best across training, and it adds almost no extra flops. Cross‑attention is slower, adding about fifteen percent flops, and does not match the quality of adaLN‑Zero in their tests. The authors recommend adaLN‑Zero and use it for all main results.
Putting the sequence back into an image
After the final Transformer block, DiT applies a layer norm and a linear “un‑patchify” projection. The head predicts two tensors over the original latent grid, the denoising target (noise), and a diagonal covariance. The model follows Nichol and Dhariwal’s training for the variance term while using the simple MSE loss for the noise term.
Training and evaluation choices that matter
The paper keeps the training recipe simple. Constant learning rate of 1e‑4, no weight decay, batch size 256, exponential moving average of weights with decay 0.9999, and only horizontal flips as augmentation. They reuse ADM’s diffusion hyperparameters, 1000 steps with a linear variance schedule from 1e‑4 to 2e‑2, and ADM’s timestep and label embeddings. For fair FID comparisons, they export samples and evaluate with ADM’s TensorFlow suite, using 250 DDPM sampling steps. Models were implemented in JAX and trained on TPU v3 pods.
Scaling behavior
This is the central finding. Increasing model flops gives better FID in a very steady way, whether you scale depth, width, or the number of tokens by shrinking patch size. At 400k steps, the correlation between Transformer Gflops and FID is about −0.93. Larger models are also more compute-efficient per unit of training compute. Simply taking a small model and using more sampling steps cannot close the gap with a larger model.
Headline results
On class‑conditional ImageNet 256, the DiT‑XL with patch size 2 (named DiT‑XL/2) trained for seven million steps reaches FID‑50k of 2.27 using classifier‑free guidance with scale 1.5. That beats the previous diffusion best of 3.60 from LDM under the same evaluation protocol. The model’s forward pass cost is about 118.6 Gflops at this setting.
On ImageNet 512, DiT‑XL/2 trained for three million steps reaches FID‑50k of 3.04 with the same guidance scale. It processes 1024 tokens at this resolution (patch size 2 on a 64 by 64 latent) and uses about 524.6 Gflops, which is still far below strong U‑Net baselines like ADM and ADM‑U in pixel space.
Model sizes and naming
DiT uses familiar ViT style scales, Small, Base, Large, and XLarge, paired with the patch size p. So DiT‑B/4 means Base width with patch size 4. A table in the paper lists layers, hidden size, heads, and flops measured at p = 2, the largest DiT‑XL/2 backs the best results reported above.
Why this matters
DiT shows that diffusion models do not need a U‑Net inductive bias to work well. A plain Transformer backbone inherits the scaling laws and unified design that already work in language and vision. It makes it easier to compare and transfer methods across domains, and it points to a future where text‑to‑image and class‑conditional systems share one simple backbone family.
Pseudocode
This captures the core data flow from the paper. The conditioning uses adaLN‑Zero. The model predicts noise and a diagonal variance over latent patches, then the diffusion objective follows Nichol and Dhariwal.
# z_vae: VAE latent of shape [I, I, C] from a frozen encoder (downsample 8x)
# t: diffusion timestep
# c: class label (or None for CFG during training)
# p: patch size, tokens T = (I / p)^2
def DIT_FORWARD(z_vae, t, c):
# 1) Patchify and embed
tokens = PATCHIFY(z_vae, patch=p) # [T, p*p*C]
tokens = LINEAR(tokens) # [T, d]
tokens = tokens + SINE_COS_POS_EMBED(T) # ViT positional encoding
# 2) Build conditioning embedding
e_t = TIMESTEP_EMBED(t) # same style as ADM
e_c = LABEL_EMBED(c) or NULL_EMBED() # CFG training drops c sometimes
e = MLP_COND(CONCAT(e_t, e_c)) # conditioning vector
# 3) N Transformer blocks with adaLN-Zero
x = tokens
for block in 1..N:
# Pre-norm with adaptive scale/shift
s1, b1, a1 = ADA_LN_ZERO_PARAMS_1(e) # scale, shift, output gate
h = LAYER_NORM(x)
h = s1 * h + b1
h = SELF_ATTENTION(h)
x = x + a1 * h # residual with zero-inited gate
s2, b2, a2 = ADA_LN_ZERO_PARAMS_2(e)
h = LAYER_NORM(x)
h = s2 * h + b2
h = MLP(h)
x = x + a2 * h
# 4) Decode to spatial tensors and predict noise and variance
y = LAYER_NORM(x)
patches = LINEAR(y) # [T, 2*C*p*p]
pred = UNPATCHIFY(patches, patch=p) # [I, I, 2*C]
eps_hat, logvar_hat = SPLIT(pred, channels=C) # noise and diagonal covariance
return eps_hat, logvar_hat
# Training step:
# Sample (x0, c) from ImageNet.
# Encode with frozen VAE: z0 = E(x0).
# Sample t and noise epsilon ~ N(0, I).
# Form z_t by the forward diffusion schedule (ADM's linear beta).
# With prob p_drop, replace c with null for classifier-free guidance.
# Run eps_hat, logvar_hat = DIT_FORWARD(z_t, t, c).
# Loss = MSE(eps_hat, epsilon) + KL_term(logvar_hat) # per Nichol & Dhariwal.
# Optimize with AdamW (lr = 1e-4), keep EMA of weights with decay 0.9999.
Key numbers to be noted from the paper
On ImageNet 256, DiT‑XL/2 reaches FID 2.27 with classifier‑free guidance at scale 1.5 after seven million steps, using about 118.6 Gflops. On ImageNet 512, DiT‑XL/2 achieves an FID of 3.04 with the same guidance scale after three million steps, at a patch size of 2, which corresponds to 1024 tokens and approximately 524.6 GFLOPs. Both results are better than prior diffusion and GAN baselines reported under the same protocol.
Paper and official code
Scalable Diffusion Models with Transformers (ICCV 2023 open‑access PDF) (CVF Open Access)
ArXiv entry (arXiv)
Official PyTorch implementation with pretrained weights and sampling scripts (GitHub)