Muhammad Khalifa, Lajanugen Logeswaran, Moontae Lee, Honglak Lee, Lu Wang
University of Michigan,
LG AI Research,
University of Illinois at Chicago
arXiv
(2023)
ReasoningRLQA
📝 Paper Summary
Chain-of-Thought ReasoningControlled Decoding
GRACE improves multi-step reasoning by utilizing a contrastively trained discriminator to verify and select correct reasoning steps during decoding, training this discriminator without human annotations via a novel alignment algorithm.
Core Problem
Language models frequently assign high likelihood to incorrect reasoning steps, and standard decoding strategies (like greedy) or post-hoc filtering (like self-consistency) often fail to recover correct solutions once the generation goes off-track.
Why it matters:
Current oversampling methods (Self-Consistency, Verifiers) are inefficient as they generate full solutions before checking, wasting compute on doomed paths
Supervised fine-tuning on gold solutions can lead to overfitting, where valid alternative reasoning paths are penalized
Existing step-level reward models often require expensive, non-scalable human annotations
Concrete Example:When prompting LLaMA-13B on a GSM8K math problem with a correct prefix, the model assigns higher probability to incorrect next steps than to the correct one. Standard decoding picks the incorrect step, derailing the entire solution.
Key Novelty
Stepwise Discriminator-Guided Decoding
Intervenes *during* generation: instead of filtering complete answers, it scores candidate steps at each point using a discriminator and selects the best one to proceed
Self-supervised alignment: Creates training data for the discriminator by automatically aligning sampled incorrect solutions with correct references using the Needleman-Wunsch algorithm, avoiding the need for human step-level labels
Architecture
The complete GRACE pipeline: (Top) Discriminator Learning process involving Negative Sampling, Alignment, and Learning. (Bottom) Guided Decoding process.
Evaluation Highlights
Reduces solution error rate on GSM8K human evaluation by 44% (from 9.0% error rate with greedy decoding to 5.0% with GRACE)
Outperforms greedy decoding on GSM8K by 7.4% accuracy points using FLAN-T5-Large and 5.4% using LLaMA-7B
When combined with self-consistency, outperforms vanilla self-consistency by 15.7% points on MultiArith
Breakthrough Assessment
7/10
Strong methodological contribution in automated alignment for discriminator training, addressing the supervision bottleneck. Significant gains over standard baselines like Self-Consistency.
⚙️ Technical Details
Problem Definition
Setting: Multi-step reasoning where a question q is answered via a sequence of steps s_1...s_T
Inputs: Question q and correct solution prefix r (s_1...s_t-1)
Outputs: Next correct reasoning step s_t
Pipeline Flow
Generator (samples next-step candidates)
Discriminator (scores candidates)
Selector (picks best step)
External Calculator (optional, executes math)
System Modules
Generator
Generate J candidate next steps using nucleus sampling
Model or implementation: FLAN-T5-Large or LLaMA (7B/13B)
Discriminator
Predict correctness score of a candidate step given context
Model or implementation: FLAN-T5-Large encoder
Selector
Select the optimal next step based on combined LM likelihood and Discriminator score
Model or implementation: Algorithmic selection (Equation 6)
Novel Architectural Elements
Integration of a step-level discriminator D(q, r, s) directly into the decoding loop to modulate next-step probabilities
Automated alignment module (during training phase) that maps incorrect generated solutions to correct references to synthesize supervisory signals
Modeling
Base Model: FLAN-T5-Large (778M) and LLaMA (7B, 13B)
Training Method: Discriminator training via Contrastive Learning on synthesized data
Objective Functions:
Purpose: Distinguish correct from incorrect steps.
vs. ORM/PRM (OpenAI) [not cited in paper]: GRACE constructs step labels automatically via alignment, whereas Process Reward Models typically rely on human feedback
Limitations
Requires an external correct reference solution for the alignment phase during training (cannot learn from scratch without gold chains)
Inference cost is higher than greedy decoding due to sampling J candidates and running the discriminator at every step
Discriminator quality depends on the alignment algorithm's ability to correctly map semantic equivalence
Code is publicly available. Method relies on sampling ~100K solutions for training data, which requires significant inference compute. Discriminator uses a specific pretrained encoder (FLAN-T5-Large).
📊 Experiments & Results
Evaluation Setup
Multi-step reasoning tasks (Math and Symbolic) evaluated on final answer accuracy and reasoning chain correctness.
Benchmarks:
GSM8K (Math Word Problems)
MathQA-Gain (Math Word Problems)
SVAMP (Elementary Math)
MultiArith (Elementary Math)
Coin Flip (Symbolic Reasoning)
Tracking Shuffled Objects (Symbolic Reasoning)
Metrics:
Final Answer Accuracy
Solution Error Rate (Human Eval)
Statistical methodology: Not explicitly reported in the paper
Key Results
Benchmark
Metric
Baseline
This Paper
Δ
GSM8K
Solution Error Rate
9.0
5.0
-4.0
Experiment Figures
A motivating example using LLaMA-13B on GSM8K.
Visual representation of the Alignment cases: Missing Step, Extra Step, and Comparable Steps.
Main Takeaways
Discriminator guidance significantly improves accuracy over greedy decoding across both FLAN-T5 and LLaMA model families.
GRACE provides synergistic gains when combined with Self-Consistency (SC), outperforming vanilla SC by large margins (e.g., +15.7% on MultiArith).
The automated alignment method effectively creates training data for discriminators without human step-level annotation.
Human evaluation confirms that GRACE improves not just the final answer, but the correctness of the intermediate reasoning steps.
📚 Prerequisite Knowledge
Prerequisites
Chain-of-Thought (CoT) prompting
Language Model decoding strategies (Greedy, Nucleus Sampling)
Contrastive Learning
Key Terms
Chain-of-Thought: A prompting technique where the model generates intermediate reasoning steps before the final answer
Self-Consistency: A decoding strategy that samples multiple reasoning paths and selects the most consistent final answer via majority voting
Needleman-Wunsch: A dynamic programming algorithm originally for biological sequence alignment, used here to align reasoning steps based on semantic similarity
ROSCOE: A metric/embedding model for evaluating step-by-step reasoning, used here to compute similarity costs for alignment
Max-margin loss: A loss function that enforces a margin of separation between the scores of correct (positive) and incorrect (negative) examples
Nucleus sampling: A text generation method that samples from the smallest set of top tokens whose cumulative probability exceeds a threshold p