Ziniu Li, Tian Xu, Yushun Zhang, Zhihang Lin, Yang Yu, Ruoyu Sun, Zhi-Quan Luo
Affiliations not listed in the provided text snippet
arXiv
(2023)
RLBenchmark
📝 Paper Summary
LLM AlignmentReinforcement Learning from Human Feedback (RLHF)
ReMax removes the resource-heavy value model from RLHF by leveraging the deterministic nature of language generation, achieving state-of-the-art performance with significantly lower memory usage than PPO.
Core Problem
The standard PPO algorithm used in RLHF is overly complex and computationally expensive for LLMs because it requires training a separate value model, doubling memory usage and complicating hyperparameter tuning.
Why it matters:
Training the value model consumes ~46% of GPU memory (for a 7B model), often causing Out-Of-Memory errors on standard hardware
PPO introduces sensitive hyperparameters (clipping, GAE coefficients) that are laborious to tune
The computational burden makes RLHF inaccessible for researchers with limited GPU resources
Concrete Example:When training a Llama-2-7B model on A800-80GB GPUs, standard PPO fails due to memory exhaustion unless slow optimizer offloading is used. In contrast, ReMax can train the same model natively on the GPU without offloading.
Key Novelty
ReMax (REINFORCE-based Maximization)
Identifies that RLHF has deterministic transitions and trajectory-level rewards, rendering the complex credit assignment of PPO's value model unnecessary
Replaces PPO's Actor-Critic architecture with a simpler REINFORCE-based gradient estimator
Uses a variance reduction technique to stabilize training without requiring a learned value network
Architecture
Conceptual illustration of RLHF properties: Fast Simulation, Deterministic Transitions, and Trajectory-level Rewards.
Evaluation Highlights
Saves ~46% GPU memory compared to PPO when training a 7B model
Achieves 94.78% win rate on AlpacaEval with Mistral-7B, setting a new SOTA for open-source 7B models
Increases training speed by ~1.6x compared to PPO
Breakthrough Assessment
8/10
Significantly simplifies the standard RLHF pipeline by removing the value model while matching or exceeding SOTA performance and drastically reducing compute requirements.
⚙️ Technical Details
Problem Definition
Setting: Reward maximization formulated as a Markov Decision Process (MDP) for sequence generation
Inputs: Prompt x sampled from distribution ρ
Outputs: Response sequence (a_1, ..., a_T) generated by policy π_θ
Pipeline Flow
Policy Model (LLM) generates response
Reward Model evaluates complete response
Gradient Estimator computes update (without Value Model)
System Modules
Policy Model
Generates text responses given a prompt
Model or implementation: Mistral-7B / Llama-2-7B
Reward Model
Assigns a scalar score to the completed response
Model or implementation: Initialized from LLM parameters
Gradient Estimator
Computes policy gradients using variance-reduced REINFORCE
Model or implementation: Algorithmic component (ReMax)
Novel Architectural Elements
Removal of the Value Model (Critic) entirely from the RLHF pipeline
Use of a modified REINFORCE estimator tailored for deterministic text generation environments
Modeling
Base Model: Mistral-7B / Llama-2-7B
Training Method: ReMax (modified REINFORCE)
Objective Functions:
Purpose: Maximize expected reward using unbiased gradient estimation.
Formally: Reward-weighted likelihood maximization with variance reduction baseline.
Key Hyperparameters:
eliminated_parameters: ['importance sampling clipping (epsilon)', 'GAE coefficient (lambda)', 'value model learning rate', 'off-policy training epochs']
Compute: Trainable on A800-80GB GPUs without offloading (unlike PPO). 1.6x faster throughput than PPO.
Comparison to Prior Work
vs. PPO: ReMax eliminates the Value Model, saving ~46% memory and removing 4+ hyperparameters
vs. DPO: ReMax outperforms DPO in win rates (implied by SOTA claim) while maintaining comparable computational efficiency
Limitations
Relies on the assumption that the environment (text generation) is deterministic
Detailed hyperparameter sensitivity analysis is not visible in the provided text snippet
Evaluation limited to 7B scale models in the provided summary
Code is publicly available at https://github.com/liziniu/ReMax. The paper claims implementation is simple (6 lines of main code). Hyperparameters like specific learning rates are not detailed in the provided text snippet.
📊 Experiments & Results
Evaluation Setup
RLHF fine-tuning of 7B models followed by automated evaluation
Benchmarks:
AlpacaEval (Instruction following / Win-rate evaluation)
MT-bench (Multi-turn conversation quality)
Metrics:
Win rate
MT-bench score
GPU memory usage
Training time
Key Results
Benchmark
Metric
Baseline
This Paper
Δ
AlpacaEval
Win Rate
50.00
94.78
+44.78
MT-bench
Score
Not reported in the paper
7.739
Not reported in the paper
Llama-2-7B Training
GPU Memory Savings
0
46
+46
Training Speed
Speedup Factor
1.0
1.6
+0.6
Experiment Figures
Wall-clock time comparison showing ReMax is faster than PPO.
Main Takeaways
ReMax matches or outperforms PPO without the complexity of a Value Model.
The approach is significantly more resource-efficient, enabling training on hardware where PPO would fail (OOM).
The theoretical properties of RLHF (deterministic, trajectory rewards) make the complex PPO machinery unnecessary.
📚 Prerequisite Knowledge
Prerequisites
Reinforcement Learning from Human Feedback (RLHF)
Proximal Policy Optimization (PPO)
REINFORCE algorithm
Markov Decision Processes (MDP)
Key Terms
RLHF: Reinforcement Learning from Human Feedback—a method to align LLMs with human preferences
PPO: Proximal Policy Optimization—a standard RL algorithm that uses a value model and clipped objective for stability
Value Model: A neural network that estimates the expected future reward from a given state, used in PPO to reduce variance
REINFORCE: A basic policy gradient algorithm that optimizes the policy directly using trajectory returns
SFT: Supervised Fine-Tuning—the initial phase of training on high-quality demonstrations
GAE: Generalized Advantage Estimation—a technique in PPO to balance bias and variance in advantage estimates
DPO: Direct Preference Optimization—an alternative method that optimizes preferences without explicit reward modeling
Trajectory-level reward: A reward signal given only at the completion of a full sequence, rather than at every step