Fully Sharded Data Parallel (FSDP)
Introduction
Training large language models on a multi-node, multi-GPU system can be a challenging task. Thanks to Fully Sharded Data Parallel (FSDP), a novel approach designed to enhance training efficiency by optimizing memory usage and scalability. It achieves this by sharding, or partitioning, the model’s parameters across all available GPUs in a distributed computing environment. Each GPU holds only a fraction of the total model parameters, reducing the memory requirements significantly compared to traditional methods. In this blog post, I will explain the mechanism behind FSDP and the practical code implementation of the fine-tuning Llama-3 8B model using FSDP+4Bit+PEFT.
Comparison with Other Parallelism Strategies
- Data Parallel (DP): In DP, the entire model is replicated across each GPU, and each GPU processes different batches of data. While straightforward, DP’s efficiency drops as the model size increases to the point where the model itself might not fit into GPU memory.
- Model Parallel: Model parallelism involves splitting the model’s layers across multiple GPUs. This method can handle larger models but often suffers from significant communication overhead and underutilization of computational resources due to the sequential dependency of layers.
- Tensor Parallel: Similar to model parallel, tensor parallel splits the model’s layers at a finer granularity (tensor operations are split across GPUs). This approach is useful for large models but requires sophisticated splitting strategies to minimize communication overhead.
- Pipeline Parallel: Pipeline parallelism divides the model into several stages or segments, each processed on different GPUs. While it allows continuous model training and can utilize multiple GPUs efficiently, it introduces complexities in managing pipeline stalls and data batching.
FSDP differs from these techniques in its ability to reduce each GPU’s memory load by only assigning a shard of the entire model’s parameters and optimizer states, thus enabling the training of much larger models.
Key Terminologies in FSDP
- Sharding: Partitioning model parameters across multiple GPUs.
- NCCL (NVIDIA Collective Communications Library): A library that supports multi-GPU and multi-node collective communication operations that FSDP utilizes for efficiency.
NCCL Operations: All-Gather and Reduce-Scatter
-
All-Gather: An operation where every GPU collects parameters from all other GPUs to form a complete set of parameters temporarily, necessary for complete computations during the forward and backward passes.
-
Reduce-Scatter: After backpropagation, gradients must be returned to update the respective parameters. Reduce-scatter aggregates these gradients across GPUs and then scatters each segment of the gradients to the respective GPU that owns those model parameters.
Detailed Example with Workflow
Consider training a Transformer model using FSDP on n
GPUs. The model consists of multiple transformer layers, which are memory-intensive due to their parameters and the states stored during training.
Workflow:
-
Initialization: - The model is initialized, and its parameters are sharded across
n
GPUs. Each GPU receives a portion of the parameters, reducing the individual memory requirement. -
Forward Pass: - Each GPU performs an
all-gather
operation to collect the necessary parameters from all other GPUs to reconstruct the full model temporarily. - The forward pass computation is carried out on each GPU for its respective batch of data. - After computation, GPUs discard the gathered parameters, retaining only their shards. -
Backward Pass: - An
all-gather
operation is repeated to reassemble the complete model for gradient computation. - Gradients are computed on each GPU based on the loss function. - Areduce-scatter
operation follows, where computed gradients are divided and sent to the respective GPUs responsible for those specific model parameters. -
Parameter Update: - Each GPU updates its shard of the model parameters using the received gradients. - The model is ready for the next iteration or batch of data.
Efficiency of FSDP
- Reduced Memory Load: By sharding parameters, each GPU holds only a portion of the model’s parameters, significantly reducing the overall memory requirement per GPU.
- Enhanced Scalability: With reduced memory load per GPU, FSDP can scale to train larger models or use larger batches, which is often not feasible with other methods.
- Optimized Communication: Using NCCL for all-gather and reduce-scatter optimizes the necessary data exchange, making it faster and less bandwidth-intensive than other communication methods.
Implementation
There are many ways you can implement the FSDP in your code, you can use FSDP PyTorch API wrap your model.
fsdp_model = FullyShardedDataParallel(model, sharding_strategy=ShardingStrategy.FULL_SHARD, cpu_offload=CPUOffload(offload_params=True))
The two key configurations are
-
CPU Offload: FSDP allows for offloading both parameters and gradients to the CPU, which helps manage GPU memory more efficiently. This feature is beneficial for large models, where even sharded parameters can significantly consume GPU memory. Activating CPU offload is done by setting the cpu_offload parameter to CPUOffload(offload_params=True) during the initialization of FSDP.
-
Sharding Strategies: FSDP offers various sharding strategies,
FULL_SHARD
: Shards all parameters and gradients across GPUs, minimizing memory usage but requiring frequent re-gathering of parameters for computations.SHARD_GRAD_OP
: Keeps parameters unchanged during forward computations and shards them only during the backward pass. This reduces the need for re-gathering parameters, potentially enhancing performance but using more memory during forward operations.
Below the fine-tuning code for the Llama 3 8B model, I used FSDP for distributed training. I also implemented 4-bit quantization to optimize the model size and computational speed. I applied Parameter-Efficient Fine-Tuning (PEFT) with Low-Rank Adaptation (LoRA) for efficient fine-tuning. Modify the model name, dataset, and training arguments as needed to align with your specific use case and configurations.
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from datasets import load_dataset
from accelerate import FullyShardedDataParallelPlugin, Accelerator
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
# Initialize FSDP with configurations for state dictionary and optimizer state handling
fsdp_plugin = FullyShardedDataParallelPlugin(
state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False),
)
# Setup the Accelerator with the FSDP plugin
accelerator = Accelerator(fsdp_plugin=fsdp_plugin)
# Base model configuration for 4-bit quantization
base_model_id = "meta-llama/Meta-Llama-3-8B"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
# Load the pretrained model with quantization and device mapping
model = AutoModelForCausalLM.from_pretrained(base_model_id,
quantization_config=bnb_config,
device_map="auto",
cache_dir='')
# Initialize tokenizer with special tokens
tokenizer = AutoTokenizer.from_pretrained(
base_model_id,
add_eos_token=True,
add_bos_token=True,
)
# Load dataset
dataset_name = "hf_dataset"
dataset = load_dataset(dataset_name, split="train")
# Enable gradient checkpointing to save memory
model.gradient_checkpointing_enable()
# Prepare model for training with K-bit quantization techniques
model = prepare_model_for_kbit_training(model)
# Configuration for LoRA augmentation
config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"w1",
"w2",
"w3",
"lm_head",
],
bias="none",
lora_dropout=0.05,
task_type="CAUSAL_LM",
)
# Apply LoRA and other modifications for PEFT
model = get_peft_model(model, config)
# Prepare model for distributed training with Accelerator
model = accelerator.prepare_model(model)
# Enable model parallelism if multiple GPUs are available
if torch.cuda.device_count() > 1:
model.is_parallelizable = True
model.model_parallel = True
# Ensure padding token is set for the tokenizer
tokenizer.pad_token = tokenizer.eos_token
# Set up the Trainer with training arguments
trainer = transformers.Trainer(
model=model,
train_dataset=dataset,
args=transformers.TrainingArguments(
output_dir='./results',
warmup_steps=2,
per_device_train_batch_size=16,
gradient_checkpointing=True,
gradient_accumulation_steps=8,
num_train_epochs=3,
learning_rate=2.5e-5,
logging_steps=1,
fp16=True,
optim="paged_adamw_8bit",
logging_dir="./logs",
save_strategy="steps",
save_steps=100,
save_total_limit=2,
),
data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
# Disable caching to manage memory better during training
model.config.use_cache = False
# Start training
trainer.train()
To run the code using Hugging Face’s Accelerate library for distributed training, you first need to configure your environment. Run accelerate config
command to set up the number of GPUs, nodes, and other settings based on your system’s configuration. It’s important to do this on each node if you’re setting up a multi-node environment.
Once the configuration is complete, you can run your training script across the specified nodes and GPUs using accelerate launch run.py
. If you’re using SLURM, you can streamline the process by using a single command:
srun accelerate launch run.py
You can find more information on Accelerate library here
Reference
-
https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html
-
https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/