← Back to Paper List

Transformers Provably Learn Chain-of-Thought Reasoning with Length Generalization

Yu Huang, Zixin Wen, Aarti Singh, Yuejie Chi, Yuxin Chen
Department of Statistics and Data Science, Wharton School, University of Pennsylvania, Machine Learning Department, Carnegie Mellon University, Department of Statistics and Data Science, Yale University
arXiv.org (2025)
Reasoning RL Benchmark

📝 Paper Summary

Theoretical analysis of Transformers Chain-of-Thought (CoT) reasoning Length generalization
Theoretical analysis proving that gradient descent trains one-layer transformers to solve state-tracking tasks via Chain-of-Thought, with algebraic structure dictating whether length generalization happens automatically or requires recursive self-training.
Core Problem
It is unknown whether transformers trained via gradient descent can actually learn to solve inherently sequential reasoning problems (beyond simple TC0 tasks) and whether they can generalize to longer reasoning chains than seen during training.
Why it matters:
  • Current theoretical understanding is limited to expressiveness (what models *can* represent) or simple parallelizable tasks (TC0), leaving a gap in explaining how models *learn* sequential reasoning (NC1)
  • Length generalization is critical for LLMs to solve harder problems via longer CoT, but empirical results are mixed and mechanisms like 'context rot' are poorly understood
  • The distinction between problems that generalize automatically versus those needing specific curricula (like self-training) is not theoretically established
Concrete Example: Consider a 'symmetry' state-tracking task where multiple group elements map state A to state B (e.g., permutations). A model trained on short chains might learn to attend to 'distractor' clauses that happen to work for short lengths but fail for longer ones due to attention dilution. In contrast, 'cyclic' group actions have unique mappings, leading to robust attention concentration.
Key Novelty
Algebraic Structure Dictates Length Generalization
  • Proves that for 'simply transitive' group actions (e.g., modular addition), training on short chains automatically leads to strong attention concentration, enabling generalization to much longer sequences.
  • Shows that for 'symmetry' group actions (e.g., permutations), standard training fails to generalize due to attention distractors; however, a recursive self-training curriculum can bootstrap the model to solve maximal lengths.
  • Provides the first optimization guarantee that constant-depth transformers can learn NC1-complete problems (inherently serial tasks) via CoT, surpassing prior limits of TC0.
Evaluation Highlights
  • For simply transitive tasks (Cyclic C6), models trained on length L=10 achieve near 100% accuracy on lengths up to L=100.
  • For symmetry tasks (S5), models trained on length L=10 fail rapidly (accuracy drops to ~0) on lengths >20.
  • Recursive self-training on S5 enables the model to bridge this gap, extending solvable length from L=10 to L=160 with near-perfect accuracy.
Breakthrough Assessment
9/10
Significant theoretical advance: first optimization proof for learning NC1 tasks (beyond TC0) and a mechanistic explanation of length generalization linked to algebraic structure, validated by experiments.
×