← Back to Paper List

Implicit Bias and Fast Convergence Rates for Self-attention

Bhavya Vasudeva, Puneesh Deora, Christos Thrampoulidis
University of Southern California, University of British Columbia
Trans. Mach. Learn. Res. (2024)
Pretraining

📝 Paper Summary

Optimization Theory Implicit Bias Transformer Interpretability
The paper proves that training self-attention with normalized gradient descent globally converges to a max-margin solution that selects optimal tokens, providing the first finite-time convergence rates for this non-convex setting.
Core Problem
Understanding why and how fast gradient-based optimizers select specific solutions (implicit bias) in the non-convex landscape of self-attention remains theoretically unresolved.
Why it matters:
  • Prior theoretical results were limited to local convergence (dependent on specific initialization) and asymptotic analysis (infinite time), failing to explain practical training behaviors.
  • Transformers rely on adaptive optimizers (like Adam/Normalized GD), but existing theory largely focuses on standard Gradient Descent (GD) which behaves differently.
  • Connecting the success of attention mechanisms to rigorous optimization principles (like max-margin separation) is crucial for explaining Transformer generalization.
Concrete Example: In a sentiment classification task where only one token (e.g., 'terrible') determines the label, standard initialization might cause GD to get stuck or converge extremely slowly. This paper proves that adaptive methods (Normalized GD) will always find the attention weights that focus solely on 'terrible' (the max-margin solution) regardless of initialization, and quantifies the speed.
Key Novelty
Global Finite-Time Convergence for Self-Attention
  • Establishes that Normalized GD converges to the hard-margin SVM solution from *any* initialization (global), overcoming previous local limitations.
  • Derives explicit convergence rates (e.g., O(t^-1/2)) for the attention weights, showing they align with the direction separating 'optimal' tokens from others.
  • Proves that the attention map becomes sparse (focusing on one token) at an exponential rate.
Evaluation Highlights
  • Proves Normalized GD iterates converge to the max-margin solution at a rate of O(t^-1/2) for fixed decoders.
  • Demonstrates that softmax attention scores for optimal tokens converge to 1 (sparsification) at an exponential rate O(exp(-ηt)).
  • Shows that joint training of attention and decoder weights converges globally at a rate of O(1/log t), with loss converging at O(exp(-t^1/3)).
Breakthrough Assessment
8/10
Significant theoretical advance: moves self-attention implicit bias analysis from local/asymptotic (prior work) to global/finite-time, bridging the gap between theory and the adaptive optimizers used in practice.
×