Yuxi Xie, Anirudh Goyal, Wenyue Zheng, Min-Yen Kan, T. Lillicrap, Kenji Kawaguchi, Michael Shieh
National University of Singapore,
Google DeepMind
arXiv.org
(2024)
ReasoningRL
📝 Paper Summary
LLM ReasoningReinforcement Learning from Human Feedback (RLHF)Preference Learning
Enhances LLM reasoning capabilities by using Monte Carlo Tree Search to iteratively generate granular step-level preference data, which is then used to update the policy via Direct Preference Optimization.
Core Problem
Standard preference learning relies on sparse instance-level supervision (only final answer correctness) and static offline datasets, failing to provide granular feedback for complex multi-step reasoning.
Why it matters:
Instance-level rewards lose important information about intermediate reasoning steps, limiting model improvement.
Static offline RLHF does not allow the model to adapt continuously or correct its own specific errors.
Existing MCTS methods often require training a separate, complex value function (critic), which is difficult to maintain.
Concrete Example:In a multi-step math problem, a model might make a small error in step 2 but get the final answer wrong. Standard methods penalize the whole chain. This approach uses MCTS to identify that step 2 specifically led to a low-value node compared to a better alternative, creating a specific training signal.
Uses Monte Carlo Tree Search (MCTS) as a dynamic data generator that explores reasoning paths and labels them with 'preference' based on Q-values (future expected rewards).
Extracts step-level preference pairs (a good step vs. a bad step at the same decision point) from the MCTS tree rather than just final-outcome pairs.
Updates the LLM policy using Direct Preference Optimization (DPO) on this self-generated data in an iterative loop, removing the need for a separate frozen reward model.
Architecture
Overview of the MCTS-Enhanced Iterative Preference Learning framework.
Evaluation Highlights
+5.9% accuracy improvement on GSM8K (81.8%) compared to the Mistral-7B SFT baseline.
+5.8% accuracy improvement on MATH (34.7%) compared to the Mistral-7B SFT baseline.
+15.8% accuracy improvement on ARC-C (76.4%) compared to the Mistral-7B SFT baseline.
Breakthrough Assessment
8/10
Strong combination of AlphaZero-style iteration with modern DPO, addressing the key bottleneck of sparse rewards in reasoning tasks. Significant empirical gains on hard benchmarks (MATH, GSM8K).
⚙️ Technical Details
Problem Definition
Setting: Iterative policy optimization for reasoning tasks.
Inputs: A prompt dataset D_P and an initial policy pi_theta(0).
Outputs: An improved policy pi_theta.
Pipeline Flow
Current Policy -> MCTS Sampling (Tree Construction)
MCTS Tree -> Step-level Preference Extraction
Preference Data -> DPO Update -> New Policy
System Modules
Selection (MCTS) (Data Generation)
Traverse the tree to find a leaf node to expand, balancing exploration and exploitation.
Model or implementation: PUCT Algorithm
Expansion (MCTS) (Data Generation)
Generate new reasoning steps from the selected leaf node.
Model or implementation: Current Policy pi_theta
Evaluation (MCTS) (Data Generation)
Assign a reward to the new state based on outcome correctness and self-evaluation.
Model or implementation: Current Policy (Self-Evaluation) + Environment (Outcome)
Backup (MCTS) (Data Generation)
Propagate rewards up the tree to update Q-values and visit counts.
Model or implementation: Mathematical Update
Preference Learner
Update the LLM policy using the collected preference pairs.
Model or implementation: DPO (Direct Preference Optimization)
Novel Architectural Elements
Integration of MCTS as a dynamic 'annotator' for DPO, creating an online iterative loop.
Derivation of step-level preference labels directly from MCTS Q-values (highest Q vs lowest Q children).
Modeling
Base Model: Mistral-7B
Training Method: Iterative DPO with MCTS-generated data
Objective Functions:
Purpose: Optimize policy to prefer higher-value steps.
Formally: DPO loss minimized over expectation of (x, y_w, y_l) drawn from MCTS data, utilizing adaptive label smoothing alpha based on visit counts.
Key Hyperparameters:
label_smoothing: Adaptive based on visit counts N(x, y)
tree_depth: T (average steps per sample)
search_breadth: Annealed from b1 to b2
Compute: Not reported in the paper
Comparison to Prior Work
vs. AlphaZero: Uses DPO instead of RL actor-critic; derives preferences from tree statistics directly.
vs. Standard RLHF: Iterative online data generation vs static offline data; step-level vs instance-level.
vs. Offline DPO: Generates its own preference data iteratively via MCTS instead of using a fixed dataset.
Limitations
Computational cost of MCTS during training (multiple rollouts per step).
Relies on the ability to define discrete 'steps' in reasoning chains.
Requires ground truth or a reliable self-evaluation signal to calculate rewards at leaf nodes.
Code is publicly available at https://github.com/YuxiXie/MCTS-DPO. The method uses Mistral-7B as the base. Step definitions for tasks are detailed in Appendices C and D (referenced in text).
📊 Experiments & Results
Evaluation Setup
Iterative training and evaluation on reasoning benchmarks.
Benchmarks:
GSM8K (Arithmetic Reasoning)
MATH (Advanced Mathematics)
ARC-C (Commonsense Reasoning)
SciQ (Science Question Answering)
Metrics:
Accuracy
Statistical methodology: Not explicitly reported in the paper
Key Results
Benchmark
Metric
Baseline
This Paper
Δ
GSM8K
Accuracy
75.9
81.8
+5.9
MATH
Accuracy
28.9
34.7
+5.8
ARC-C
Accuracy
60.6
76.4
+15.8
Main Takeaways
The proposed MCTS-enhanced iterative preference learning significantly outperforms Supervised Fine-Tuning (SFT) across arithmetic and commonsense reasoning tasks.
Gains are particularly large in commonsense reasoning (ARC-C/SciQ) compared to arithmetic tasks.
The method demonstrates that on-policy sampling (generating data with the current model) is crucial for self-improvement.
Using step-level signals derived from MCTS looks-ahead provides more granular and effective supervision than instance-level outcome labels.
📚 Prerequisite Knowledge
Prerequisites
Monte Carlo Tree Search (MCTS)
Reinforcement Learning from Human Feedback (RLHF)
Direct Preference Optimization (DPO)
Markov Decision Processes (MDP)
Key Terms
MCTS: Monte Carlo Tree Search—a heuristic search algorithm for decision processes that builds a search tree to find optimal moves by simulating future outcomes.
DPO: Direct Preference Optimization—a method to fine-tune language models to human preferences by directly optimizing the policy on preference pairs without training an explicit reward model.
PUCT: Predictor + Upper Confidence bounds applied to Trees—a selection strategy in MCTS that balances exploration (trying less visited nodes) and exploitation (visiting high-value nodes), scaled by a prior probability.
Q-value: The expected cumulative reward of taking a specific action from a specific state.
SFT: Supervised Fine-Tuning—The initial phase of training where the model learns to mimic a dataset of high-quality examples.
Step-level preference: Preference signals assigned to individual reasoning steps (intermediate tokens) rather than the entire generated response.
Self-evaluation: A process where the LLM itself estimates the correctness or quality of its generated partial or final outputs.