Y Pan, Y Yuan, Y Yin, J Shi, Z Xu, M Zhang, L Shang…
East China Normal University,
Alibaba Group
arXiv, 1/2024
(2024)
Pretraining
📝 Paper Summary
Efficient TrainingProgressive Training
Apollo accelerates deep model training by starting with a shallow network and progressively expanding it using a novel sampling strategy and interpolation method.
Core Problem
Training large Transformers from scratch is slow and resource-intensive, while existing progressive stacking methods often fail to achieve significant acceleration or suffer from instability.
Why it matters:
Training large models consumes massive computational resources and time, leading to high financial and environmental costs.
Relying on pretrained models limits applicability for new architectures where no prior weights exist.
Simple layer stacking (like StackBERT) is often unstable due to large gradients and semantic gaps between layers.
Concrete Example:When training a 12-layer Transformer, StackBERT might copy the 6th layer's weights to initialize the 7th-12th layers. This abrupt change causes instability because the lower layers haven't learned high-level semantics, leading to slow convergence.
Key Novelty
Apollo (progressive expansion via weight sharing and interpolation)
Low-Value-Prioritized Sampling (LVPS): Randomly selects different network depths during early training, prioritizing shallow depths to save compute while exposing weights to high-level functional requirements.
Weight Sharing: Uses shared weights across the sampled layers to learn features applicable to both low and high layers before expansion.
Layer Interpolation: Expands the model depth by interpolating weights (mathematical blending) rather than direct stacking, ensuring smoother transitions and better stability.
Architecture
Conceptual illustration of Apollo's training process involving LVPS and expansion.
Evaluation Highlights
Achieves state-of-the-art acceleration, outperforming StackBERT and even methods using pretrained models (like bert2BERT) in efficiency.
Reduces training FLOPs substantially by sampling lower depths more frequently during the early stages via LVPS.
Improves training stability compared to stacking methods, avoiding large gradient spikes during depth expansion.
Breakthrough Assessment
8/10
Offers a universal solution for efficient training from scratch that rivals pretrained initialization methods, addressing a critical bottleneck in training large custom architectures.
⚙️ Technical Details
Problem Definition
Setting: Training an L-layer Transformer model f^(L) from scratch by progressively expanding from a smaller size N^(s) to L over S stages.
Inputs: Training dataset (e.g., text corpus for BERT/GPT)
Outputs: Trained Transformer model with L layers
Pipeline Flow
Stage Initialization: Define S stages with increasing weight counts N^(s)
Training Step (LVPS): Sample a temporary depth L^(t) using LVPS distribution
Weight Sharing: Map the N^(s) available weights to the sampled L^(t) layers
Expansion (Interpolation): At stage transition, expand weights N^(s) to N^(s+1) using interpolation
Final Model: Reach target depth L
System Modules
LVPS Sampler (Training Strategy)
Determines the depth of the network for the current training step
Model or implementation: Probabilistic sampling function P_LVPS
Weight Manager (Training Strategy)
Maps physical weights to logical layers for forward/backward pass
Model or implementation: Mapping function g(l)
Interpolator
Expands model depth between stages
Model or implementation: Linear interpolation formula
Novel Architectural Elements
Dynamic depth sampling (LVPS) integrated into the training loop
Layer interpolation for depth expansion in Transformers (adapted from ResNet context)
Modeling
Base Model: BERT-Base / GPT-2
Training Method: Progressive training from scratch
Objective Functions:
Purpose: Standard language modeling loss (Masked LM for BERT, Causal LM for GPT).
Code is publicly available. Hyperparameters for BERT and GPT experiments are provided. The method relies on standard optimizers (AdamW) and architectures.
📊 Experiments & Results
Evaluation Setup
Pretraining BERT and GPT models from scratch on English Wikipedia and BookCorpus.
Benchmarks:
BERT Pretraining (Masked Language Modeling)
GPT Pretraining (Causal Language Modeling)
GLUE (Natural Language Understanding)
Metrics:
Validation Loss
Training FLOPs / Wall-clock time
GLUE Score (Average)
Statistical methodology: Not explicitly reported in the paper
Experiment Figures
Comparison of different sampling distributions for LVPS (controlled by parameter k).
Comparison between Stacking and Interpolation expansion methods.
Main Takeaways
Apollo consistently achieves lower validation loss than training from scratch and StackBERT for the same computational budget.
Apollo rivals or exceeds the efficiency of methods that start from pretrained models (like bert2BERT and LiGO).
Layer interpolation provides smoother convergence curves compared to stacking, indicating better training stability.
LVPS is the most efficient sampling strategy compared to Uniform, Edge, or Full sampling.
📚 Prerequisite Knowledge
Prerequisites
Transformer architecture (MHSA, FFN)
Progressive training / Stacking (StackBERT)
Function-preserving transformations (Net2Net)
Key Terms
LVPS: Low-Value-Prioritized Sampling—a strategy to sample layer depths during training where shallower depths are chosen with higher probability to save compute.
StackBERT: A baseline method that trains a shallow model and progressively stacks layers to initialize a deeper model.
MHSA: Multi-Head Self-Attention—a key component of Transformers that captures dependencies between different words in a sequence.
FFN: Feed-Forward Network—a neural network layer within the Transformer block that processes information after attention.
Layer Interpolation: A method to initialize new layers by mathematically blending existing weights (e.g., linear interpolation) rather than just copying them, used to smooth the expansion process.