← Back to Paper List

Transformers Learn to Implement Multi-step Gradient Descent with Chain of Thought

Jianhao Huang, Zixuan Wang, Jason D. Lee
Shanghai Jiaotong University, Princeton University
International Conference on Learning Representations (2025)
Reasoning Pretraining

📝 Paper Summary

In-Context Learning (ICL) Training Dynamics Linear Regression
Transformers trained with Chain of Thought on linear regression tasks learn to implement multi-step gradient descent, significantly outperforming standard transformers which are theoretically limited to a single step.
Core Problem
Standard one-layer linear transformers can only implement a single step of gradient descent during in-context learning, which fails to recover the ground-truth weight vector when the number of examples is limited (n ≈ d).
Why it matters:
  • Standard In-Context Learning (ICL) without CoT hits an approximation floor, unable to solve tasks requiring iterative refinement
  • While CoT improves expressivity in theory, the actual training dynamics (how models learn these iterative algorithms via gradient descent) were previously unknown
  • Understanding this mechanism bridges the gap between transformer architecture and iterative optimization algorithms
Concrete Example: In a linear regression task where the dimension d=10 and examples n=20, a standard transformer outputs a weight estimate with high error (Scaling as d^2/n). A CoT-prompted transformer generates intermediate weight updates, iteratively reducing the error to near zero.
Key Novelty
Learnable Separation via In-Context Weight Prediction
  • Formalizes 'in-context weight prediction' where the model must output the regression weight vector w* rather than just a label y
  • Proves that training on CoT data (intermediate gradient steps) allows a one-layer transformer to learn multi-step Gradient Descent autoregressively
  • Demonstrates a theoretical separation: Non-CoT models are stuck at 1-step GD, while CoT models converge to near-exact recovery via Gradient Flow
Evaluation Highlights
  • Theoretical Lower Bound: Proves standard one-layer transformers cannot achieve error better than Θ(d^2/n) on the weight prediction task
  • Theoretical Upper Bound: Proves CoT transformers achieve error O(1/poly(d)) with Θ(log d) intermediate steps
  • Empirical Validation: Trained models recover exact sparse weight structures corresponding to gradient descent operations (Verified via heatmaps)
Breakthrough Assessment
7/10
Provides rigorous theoretical grounding for CoT's benefits in simple models, proving a learnable separation. However, the setting (linear regression, one-layer linear attention) is very simplified compared to LLMs.
×