← Back to Paper List

Making Large Language Models Better Reasoners with Alignment

Peiyi Wang, Lei Li, Liang Chen, Feifan Song, Binghuai Lin, Yunbo Cao, Tianyu Liu, Zhifang Sui
National Key Laboratory for Multimedia Information Processing, Peking University, Tencent Cloud AI, The University of Hong Kong
arXiv (2023)
Reasoning RL

📝 Paper Summary

Mathematical Reasoning Chain-of-Thought (CoT) Fine-tuning Alignment / Preference Optimization
Alignment Fine-Tuning (AFT) improves LLM reasoning by calibrating the model's scoring of generated Chain-of-Thought paths using a constraint loss that prevents the degradation of valid but non-optimal reasoning paths.
Core Problem
Vanilla Fine-Tuning (VFT) suffers from 'Assessment Misalignment,' where models assign higher probabilities (scores) to incorrect reasoning paths than to correct non-reference paths because VFT only optimizes the single reference solution.
Why it matters:
  • Standard fine-tuned models cannot accurately assess the quality of their own generated reasoning chains, limiting self-consistency and reranking capabilities.
  • Existing alignment methods like RRHF and PRO degrade reasoning performance because they aggressively down-weight valid but lower-ranked responses without constraints.
Concrete Example: In a math problem, a VFT model assigns lower perplexity (better score) to a Candidate Answer that gets the math wrong (e.g., 50 * 0.2 = 100) than to a correct alternative reasoning path that differs from the training reference.
Key Novelty
Alignment Fine-Tuning (AFT) with Constraint Alignment (CA) Loss
  • Refines a fine-tuned model by generating multiple Chain-of-Thought (CoT) samples and categorizing them as positive (correct answer) or negative.
  • optimizes the model to ensure positive CoTs score higher than negatives, but crucially applies a 'constraint' (via gradient detaching or a soft boundary) to prevent the model from crushing the scores of reasonable negative samples.
Evaluation Highlights
  • AFT outperforms Vanilla Fine-Tuning (VFT) by +2.57% accuracy on GSM8K using Llama2-7B.
  • In ranking scenarios, AFT achieves 26.08% accuracy on GSM8K-RANK (Llama-7B), whereas unconstrained alignment (RRHF) collapses performance to 7.51%.
  • Generalizes to out-of-domain tasks: AFT improves zero-shot MMLU performance by +1.73% over VFT (Llama-7B).
Breakthrough Assessment
8/10
Identifies a critical flaw in applying standard alignment methods (RLHF/DPO styles) to reasoning: they destroy model capabilities by penalizing 'good enough' reasoning too harshly. The proposed constraint solution is simple and effective.
×