Model overview

This directory contains implementation for Diffusion Transformer (DiT). Diffusion Transformer[1], as the name suggests, belongs to the class of diffusion models. However, the key difference is that it replaces the UNet architecture backbone typically used in previous diffusion models with a Transformer backbone and some modifications. This model beats the previous diffusion models in the FID-50K eval metric.

A DiT model consists of N layers of DiT blocks. We support the following variants of DiT Block. More details can be found in the Section 3.2 of the paper Diffusion Transformer and Step 4: Training the model on CS system or GPU using run.py

In addition, we also support patch sizes of 2 (default), 4 and 8. The Patchify block in Figure 1 takes noised latent tensor as input from dataloader and converts into patches of size patch_size. The lower the patch size, the larger the number of patches (i.e maximum sequence length (MSL)) and hence larger the number of FLOPS.

In order to change the patch size used, for example to 4 x 4, set model.patch_size: [4, 4] in yaml configs provided.

During training, an image from the dataset is taken and passed through a frozen VAE Encoder Variational Auto Encoder to convert the image into a lower dimensional latent. Then, random gaussian noise is added to the latent tensor (Algorithm 1 of Denoising Diffusion Implicit Models) and passed as input to the DiT .Since the VAE Encoder is frozen and not updated during the training process, we prefetch the latents for all the images in the dataset using the script create_imagenet_latents.py. This helps save computation and memory during the training process. Refer to Step 3.

Structure of the code

  • configs/: YAML configuration files.
  • modeling_dit.py: Defines the core model DiT.
  • model.py: The entry point to the model. Defines DiTModel.
  • utils.py: Miscellaneous scripts to parse the params dictionary from the YAML files.
  • data/vision/diffusion/: Folder containing Dataloader and preprocessing scripts.
  • samplers/: Folder containing samplers used in Diffusion models to sample images from checkpoints.
  • layers/vae/: Defines VAE(Variational Auto Encoder) model layers.
  • layers/*: Defines building block layers of DiT Model.
  • display_images.py: Utility script to display images in a folder in a grid format to look at all images at once.
  • pipeline.py: Defines a DiffusionPipeline object that takes in a random gaussian input and performs sampling.
  • sample_generator.py: Defines a Abstract Base Class SampleGenerator to define sample generators for diffusion models.
  • sample_generator_dit.py: Defines a DiTSampleGenerator that inherits from SampleGenerator class and is used to generate required number of samples from DiT Model using a given trained checkpoint.

Available Configurations

ConfigurationDescription
params_dit_large_patchsize_2x2.yamlDiT-L/2 model with ~458M parameters. Patch size 2×2, latent size 32×32×4.
params_dit_xlarge_patchsize_2x2.yamlDiT-XL/2 model with ~675M parameters. Patch size 2×2, latent size 32×32×4.
params_dit_2B_patchsize_2x2.yamlDiT-2B/2 model with ~2B parameters. Patch size 2×2, latent size 32×32×4.

Sequence of the steps to perform

The high-level steps for training a model are relatively simple, involving data-processing and then model training and evaluation

  • Step 1: ImageNet dataset download and preparation
  • Step 2: Checkpoint Conversion of Pre-trained VAE.
  • Step 3: Preprocessing and saving Latent tensors from images and VAE Encoder on GPU
  • Step 4: Training the model on CS system or GPU using run.py
  • Step 5: Generating 50K samples from trained checkpoint on GPUs
  • Step 6: Using OpenAI FID evaluation repository to compute FID score.

The steps are elaborated below:

Step 1: ImageNet dataset download and preparation

Inorder to download the ImageNet dataset, register on the ImageNet website[4]. The dataset can only be downloaded after the ImageNet website confirms the registration and sends a confirmation email. Please follow up with ImageNet support if a confirmation email is not received within a couple of days.

Download the tar files ILSVRC2012_img_train.tar, ILSVRC2012_img_val.tar, ILSVRC2012_devkit_t12.tar.gz for the ImageNet dataset.

Once we have all three tar files, we would need to extract and preprocess the archives into the appropriate directory structure as described below.

root_directory (imagenet1k_ilsvrc2012)
├── meta.bin
├── train/
│   ├── n01440764
│   │   ├── n01440764_10026.JPEG
│   │   ├── n01440764_10027.JPEG
│   │   ├── n01440764_10029.JPEG
│   │   ├── ...
│   ├── n01443537
│   │   ├── n01443537_10007.JPEG
│   │   ├── n01443537_10014.JPEG
│   │   ├── n01443537_10025.JPEG
│   │   ├── ...
│   ├── ...
│   └── ...
│   val/
│   ├── n01440764
│   │   ├── ILSVRC2012_val_00000946.JPEG
│   │   ├── ILSVRC2012_val_00001684.JPEG
│   │   └── ...
│   ├── n01443537
│   │   ├── ILSVRC2012_val_00001269.JPEG
│   │   ├── ILSVRC2012_val_00002327.JPEG
│   │   ├── ILSVRC2012_val_00003510.JPEG
│   │   └── ...
│   ├── ...
│   └── ...

Inorder to arrange the ImageNet dataset in the above format, Pytorch repository provides an easy to use script that can be found here: https://github.com/pytorch/examples/blob/main/imagenet/extract_ILSVRC.sh. Download this script and invoke it as follows to preprocess the ImageNet dataset.

source extract_ILSVRC.sh

We also need a meta.bin file. The simplest way is to create it is to initialize torchvision.datasets.ImageNet once.

import torchvision
root_dir = <path_to_root_dir_imagenet1k_ilsvrc2012_above>
torchvision.datasets.ImageNet(root=root_dir, split="train")
torchvision.datasets.ImageNet(root=root_dir, split="val)

Once the ImageNet dataset and folder are in the expected format, proceed to Step 2

Step 2: Checkpoint Conversion of Pre-trained VAE

The next step is to convert the pretrained checkpoint provided by StabilityAI and hosted on HuggingFace to CS namespace format. This can be done using the script vae_hf_cs.py. The script downloads the pretrained VAE checkpoint from StabilityAI in HuggingFace and converts to CS namespace based on model layers defined in dit/layers/vae. For this script, we only care about the params defined under model.vae_params and no changes are needed except for setting the model.vae_params.latent_size correctly.

$ python modelzoo/data_preparation/vision/dit/vae_hf_cs.py -h
usage: vae_hf_cs.py [-h] [--src_ckpt_path SRC_CKPT_PATH] [--dest_ckpt_path DEST_CKPT_PATH]
                    [--params_path PARAMS_PATH]

optional arguments:
  -h, --help            show this help message and exit
  --src_ckpt_path SRC_CKPT_PATH
                        Path to HF Pretrained VAE checkpoint .bin file. If not provided, file is automatically
                        downloaded from https://huggingface.co/stabilityai/sd-vae-ft-
                        mse/resolve/main/diffusion_pytorch_model.bin (default: None)
  --dest_ckpt_path DEST_CKPT_PATH
                        Path to converted modelzoo compatible checkpoint (default:
                        modelzoo/models/vision/dit/checkpoint_converter/mz_stabilityai-sd-vae-ft-mse_ckpt.bin)
  --params_path PARAMS_PATH
                        Path to VAE model params yaml

Command to run:

python modelzoo/data_preparation/vision/dit/vae_hf_cs.py --params_path=/path/to/dit_config.yaml

Step 3: Preprocessing and saving Latent tensors from images and VAE Encoder on GPU

For training the DiT model, we prefetch the latent tensor outputs from a pretrained VAE Encoder using the script create_imagenet_latents.py

$ python modelzoo/data_preparation/vision/dit/create_imagenet_latents.py -h
usage: create_imagenet_latents.py [-h] [--checkpoint_path CHECKPOINT_PATH] [--params_path PARAMS_PATH] [--horizontal_flip]
                                  --image_height IMAGE_HEIGHT --image_width IMAGE_WIDTH --src_dir SRC_DIR --dest_dir DEST_DIR [--resume]
                                  [--resume_ckpt RESUME_CKPT] [--log_steps LOG_STEPS]
                                  [--batch_size_per_gpu BATCH_SIZE_PER_GPU] [--num_workers NUM_WORKERS]
                                  [--dataset_split {train,val}]

optional arguments:
  -h, --help            show this help message and exit
  --checkpoint_path CHECKPOINT_PATH
                        Path to VAE model checkpoint (default: None)
  --params_path PARAMS_PATH
                        Path to VAE model params yaml (default:modelzoo/models/vision/dit/configs/params_dit_small_patchsize_2x2.yaml)
  --horizontal_flip     If passed, flip image horizonatally (default: False)
  --image_height IMAGE_HEIGHT
                        Height of the resized image
  --image_width IMAGE_WIDTH
                        Width of the resized image
  --src_dir SRC_DIR     source data location (default: None)
  --dest_dir DEST_DIR   Latent data location (default: None)
  --resume              If specified, resumes previous generation process.The dest_dir should point to previous generation
                        and have log_checkpoint saved. (default: False)
  --resume_ckpt RESUME_CKPT
                        log ckpt to resume data generation fromIf None, picks latest from log dir (default: None)
  --log_steps LOG_STEPS
                        Generation process ckpt and logging frequency (default: 1000)
  --batch_size_per_gpu BATCH_SIZE_PER_GPU
                        batch size of input to be passed to VAE model for encoding (default: 64)
  --num_workers NUM_WORKERS
                        Number of pytorch dataloader workers (default: 4)
  --dataset_split {train,val}
                        Number of pytorch dataloader workers (default: train)

Inorder to preprocess the ImageNet dataset and create latent tensors, using a GPU or multiple GPUs is required.

Sample command is as follows for a single node with 4 GPUs. In the following command, we are using an image of size 256 x 256 and saving the latents to a folder specified by --dest_dir. The command also specifies to log every 10 steps and use a batch size of 16 i.e 16 images are batched together and passed to the VAE Encoder on each GPU.

a. Create ImageNet Latent Tensors from VAE for train split of dataset
torchrun --nnodes 1 --nproc_per_node 4 modelzoo/data_preparation/vision/dit/create_imagenet_latents.py --image_height=256 --image_width=256 --src_dir=/path/to/imagenet1k_ilsvrc2012 --dest_dir=/path_to_dest_dir --log_steps=10 --dataset_split=train --batch_size_per_gpu=16 --checkpoint_path=/path/to/converted/vae_checkpoint/in_Step2
b. Create ImageNet Latent Tensors from VAE for val split of dataset
torchrun --nnodes 1 --nproc_per_node 4 modelzoo/data_preparation/vision/dit/create_imagenet_latents.py --image_height=256 --image_width=256 --src_dir=/path/to/imagenet1k_ilsvrc2012 --dest_dir=/path_to_dest_dir --log_steps=10 --dataset_split=val --batch_size_per_gpu=16 --checkpoint_path=/path/to/converted/vae_checkpoint/in_Step2

The output folder shown below for reference and will have the same format as shown in Step 1:

/path_to_dest_dir 
├── train/
│   ├── n01440764
│   │   ├── n01440764_10026.npz
│   │   ├── n01440764_10027.npz
│   │   ├── n01440764_10029.npz
│   │   ├── ...
│   ├── n01443537
│   │   ├── n01443537_10007.npz
│   │   ├── n01443537_10014.npz
│   │   ├── n01443537_10025.npz
│   │   ├── ...
│   ├── ...
│   └── ...
│   val/
│   ├── n01440764
│   │   ├── ILSVRC2012_val_00000946.npz
│   │   ├── ILSVRC2012_val_00001684.npz
│   │   └── ...
│   ├── n01443537
│   │   ├── ILSVRC2012_val_00001269.npz
│   │   ├── ILSVRC2012_val_00002327.npz
│   │   ├── ILSVRC2012_val_00003510.npz
│   │   └── ...
│   ├── ...
│   └── ...

DiT models use horizontal flip of images as augmentation. The script also supports saving latent tensors from horizontally flipped images by passing the flag --horizontal_flip

c. Create ImageNet Latent Tensors with horizontal flip from VAE for train split of dataset
torchrun --nnodes 1 --nproc_per_node 4 modelzoo/data_preparation/vision/dit/create_imagenet_latents.py --image_height=256 --image_width=256 --src_dir=/path/to/imagenet1k_ilsvrc2012 --dest_dir=/path_to_hflipped_dest_dir --log_steps=10 --dataset_split=train --batch_size_per_gpu=16 --checkpoint_path=/path/to/converted/vae_checkpoint/in_Step2 --horizontal_flip
d. Create ImageNet Latent Tensors with horizontal flip from VAE for val split of dataset
torchrun --nnodes 1 --nproc_per_node 4 modelzoo/data_preparation/vision/dit/create_imagenet_latents.py --image_height=256 --image_width=256 --src_dir=/path/to/imagenet1k_ilsvrc2012 --dest_dir=/path_to_hflipped_dest_dir --log_steps=10 --dataset_split=val --batch_size_per_gpu=16 --checkpoint_path=/path/to/converted/vae_checkpoint/in_Step2 --horizontal_flip

Step 4: Training the model on CS system or GPU using run.py

IMPORTANT: See the following notes before proceeding further.

Parameter settings in YAML config file: The config YAML files are located in the configs directory. Before starting a training run, make sure that in the YAML config file being used has the following set correctly:

  • The train_input.data_dir parameter points to the correct dataset
  • The train_input.image_size parameter corresponds to the image_size of the dataset.
  • The model.vae.latent_size parameter corresponds size of latent tensors.
    • Set to [32, 32] for image size of 256 x 256
    • Set to [64, 64] for image size of 512 x 512
    • In general, set to [floor(H / 8), floor(W / 8)] for an image size of H x W
  • The model.patch_size parameter to use different patch sizes

To use with image size 512 x 512, please make the following changes:

  1. train_input.image_size: [512, 512]
  2. model.vae.latent_size: [64, 64]
  3. train_input.transforms(if any): change size params under various transforms to [512, 512]

YAML config files: Details on the configs for this model can be found in Configs included for this model

In the following example run commands, we use /path/to/yaml, /path/to/model_dir, and train as placeholders for user supplied inputs.

  • /path/to/yaml is a path to the YAML config file with model parameters such one of the configurations described in Configs included for this model.
  • /path/to/model_dir is a path to the directory where we would like to store the logs and other artifacts of the run.
  • --mode specifies the desired mode to run the model in. Change to --mode eval to run in eval mode.

To compile/validate, run train and eval on Cerebras System

Please follow the instructions on our quickstart in the Developer Docs.

To run train and eval on GPU/CPU

If running on a cpu or gpu, activate the environment from Python GPU Environment setup, and simply run:

python run.py {CPU,GPU} --mode train --params /path/to/yaml --model_dir /path/to/model_dir

Step 5: Generating 50K samples from trained checkpoint on GPUs from FID score computation

Diffusion models report Fréchet inception distance (FID)[7] metric on 50K samples generated from the trained checkpoint. In order to generate samples, we use a DDPM Sampler [2] and without guidance (model.reverse_process.guidance_scale=1.0). Using a model.reverse_process.guidance_scale >1.0 enables classifier free guidance which trades off diversity for sample quality.

The sample generation settings can be found in model.reverse_params in config yaml. We support two samplers cuurently, the DDPM Sampler[2] and DDIM Sampler[3]. All arguments in the __init__ of the samplers can be set in the yaml config model.reverse_params.sampler section.

To generate samples from a trained DiT checkpoint, we use GPUs and sample_generator_dit.py. Sample command to run on a single node with 4 GPUs to generate 50000 samples using trained DiT-XL/2. Each GPU uses a batch size of 64 and generates 64 samples at once. --num_fid_samples controls the number of samples to generate. This script cares about the section model.reverse_params in config yaml. Make sure that the settings are appropriate.

torchrun --nnodes 1 --nproc_per_node 4 modelzoo/models/vision/dit/sample_generator_dit.py --model_ckpt_path /path/to/trained/dit_checkpoint --vae_ckpt_path /path/to/converted/vae_checkpoint/in_Step2 --params modelzoo/models/vision/dit/configs/params_dit_xlarge_patchsize_2x2.yaml --sample_dir=/path/to/store/samples_generated --num_fid_samples=50000 --batch_size 64

More information can be found by running:

python modelzoo/models/vision/dit/sample_generator_dit.py -h
usage: sample_generator_dit.py [-h] [--seed SEED] [--model_ckpt_path MODEL_CKPT_PATH] [--vae_ckpt_path VAE_CKPT_PATH]
                               --params PARAMS [--variant VARIANT] [--num_fid_samples NUM_FID_SAMPLES] --sample_dir
                               SAMPLE_DIR [--batch_size BATCH_SIZE] [--create_grid]

optional arguments:
  -h, --help            show this help message and exit
  --seed SEED
  --model_ckpt_path MODEL_CKPT_PATH
                        Optional path to a diffusion model checkpoint (default: None)
  --vae_ckpt_path VAE_CKPT_PATH
                        Optional VAE model checkpoint path (default: None)
  --params PARAMS       Path to params to initialize Diffusion model and VAE models (default: None)
  --variant VARIANT     Variant of Diffusion model (default: None)
  --num_fid_samples NUM_FID_SAMPLES
                        number of samples to generate (default: 50000)
  --sample_dir SAMPLE_DIR
                        Directory to store generated samples (default: None)
  --batch_size BATCH_SIZE
                        per-gpu batch size for forward pass (default: None)
  --create_grid         If passed, create a grid from images generated (default: False)

The script generates a .npz file that should be passed as input to FID score computation. Sample output looks as below:

2023-09-06 15:55:30,585 INFO[sample_generator.py:49] Saved .npz file to /path/to/store/samples_generated/sample.npz [shape=(`num_fid_samples`, train_input.image_shape[0], train_input.image_shape[1], train_input.image_channels
)].

To generate samples belonging to specific ImageNet label classes, please set model.reverse_params.pipeline.custom_labels to a list of integer ImageNet labels. This will generate samples belonging to only these classes. For ex: if model.reverse_params.pipeline.custom_labels: [207, 360], then we will only generate samples belonging to golden_retriever(label_id=207) and otter(label_id=360) classes respectively.

Step 6: Using OpenAI FID evaluation repository to compute FID score

Now that we have the 50K samples and .npz file generated from Step 5, we can compute FID score using ADM OpenAI script evaluator.py. In order to compute FID score,

a. Set up a conda environment to use OpenAI evaluation script

conda create --name tf python=3.8.16
conda activate tf
conda install -c conda-forge cudatoolkit=11.8.0
pip install nvidia-cudnn-cu11==8.6.0.163

CUDNN_PATH=$(dirname $(python -c "import nvidia.cudnn;print(nvidia.cudnn.__file__)"))
export LD_LIBRARY_PATH=$CONDA_PREFIX/lib/:$CUDNN_PATH/lib:$LD_LIBRARY_PATH
pip install --upgrade pip
pip install tensorflow==2.13.*
conda install scipy
conda install -c conda-forge tqdm
conda install -c anaconda requests
conda install -c anaconda chardet

b. Clone OpenAI guided-diffusion GitHub repository

git clone https://github.com/openai/guided-diffusion.git
cd guided-diffusion/evaluations

c. Download the npz files corresponding to reference batch of ImageNet

wget https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz -P /path/to/store/reference_batch

d. Make changes to evaluator.py

Make the following changes in evaluator.py. These are needed to account for numpy deprecations i.e replace instances of np.bool with bool

[evaluations](main)$ git diff
diff --git a/evaluations/evaluator.py b/evaluations/evaluator.py
index 9590855..6636d0b 100644
--- a/evaluations/evaluator.py
+++ b/evaluations/evaluator.py
@@ -340,8 +340,8 @@ class ManifoldEstimator:
                - precision: an np.ndarray of length K1
                - recall: an np.ndarray of length K2
        """
-        features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool)
-        features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool)
+        features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=bool)
+        features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=bool)
        for begin_1 in range(0, len(features_1), self.row_batch_size):
            end_1 = begin_1 + self.row_batch_size
            batch_1 = features_1[begin_1:end_1]

e. Launch FID eval script with the following command

conda activate tf
cd guided-diffusion/evaluations
python evaluator.py  /path/to/step(d)/downloaded/VIRTUAL_imagenet256_labeled.npz /path/to/generated/npz/from/step5 2>&1 | tee fid.log

The following changes can be made to use other settings of DiT model:

  • The model.vae.latent_size parameter corresponds size of latent tensors. This is the only param under model.vae_params that needs to be changed.
    • Set to [32, 32] for image size of 256 x 256
    • Set to [64, 64] for image size of 512 x 512
    • Set to [floor(H / 8), floor(W / 8)] for image size of H x W
  • The model.patch_size parameter to use different patch sizes

DataLoader Features Dictionary

DiffusionLatentImageNet1KProcessor outputs the following features dictionary with keys/values:

  • input: Noised latent tensor.
    • Shape: (batch_size, model.vae.latent_channels, model.vae.latent_size[0], model.vae.latent_size[1])
    • Type: torch.bfloat16
  • label: Scalar ImageNet labels.
    • Shape: (batch_size, )
    • Type: torch.int32
  • diffusion_noise: Gaussian noise that the model should predict. Also used in creating value of key noised_latent.
    • Shape: (batch_size, model.vae.latent_channels, model.vae.latent_size[0], model.vae.latent_size[1])
    • Type: torch.bfloat16
  • timestep: Timestep sampled from ~Uniform(0, train_input.num_diffusion_steps).
    • Shape: (batch_size, )
    • Type: torch.int32

Implementation notes

There are a couple modifications to the DiT model made in this implementation:

  1. We use ConvTranspose2D instead of Linear layer to un-patchify the outputs.
  2. While we support gelu with approximation tanh, we use gelu with no approximation for better performance.
  3. Inorder to use the exact model as StabilityAI pretrained VAE model, we don’t have to make any changes to the params under model.vae_params. The only modification we make in our implementation of VAE Model is that we use Attention Layer defined in modelzoo.
  4. We currently do not support Kullback-Leibler(KL) loss to optimize Σ, hence the output from DiT Model includes only the noise.
  5. We currently support AdaLN-Zero variant of DiT model. Support for In-Context and Cross-Attention variants are planned for future releases.

Citations

[1] Scalable Diffusion Models with Transformers

[2] Denoising Diffusion Probabilistic Models

[3] Denoising Diffusion Implicit Models

[4] ImageNet Large Scale Visual Recognition Challenge

[5] Diffusion Models Beat GANs on Image Synthesis

[6] Guided Diffusion GitHub

[7] GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash Equilibrium

[8] Auto-Encoding Variational Bayes