Yu Wang, Dmitry Krotov, Yuanzhe Hu, Yifan Gao, Wangchunshu Zhou, Julian McAuley, Dan Gutfreund, Rogério Feris, Zexue He
International Conference on Machine Learning
(2025)
MemoryPretrainingQA
📝 Paper Summary
Latent-Space MemoryLong-Context LLMs
M+ extends MemoryLLM by offloading older latent memory states to a CPU-based long-term storage and using a co-trained retriever to fetch relevant information during generation, enabling context retention beyond 160k tokens.
Core Problem
Existing latent-space memory models like MemoryLLM compress context into a fixed-size GPU memory pool, causing information loss after roughly 20k tokens as older states are discarded.
Why it matters:
Retaining information from the distant past is critical for long-book understanding and extended conversations
Current approaches either drop information (sliding windows) or use separate, high-latency retrievers for every query head (like SnapKV)
Scaling context windows usually incurs prohibitive GPU memory costs
Concrete Example:MemoryLLM effectively handles sequences up to 16k tokens, but when processing a 160k token sequence, it fails to recall knowledge injected at the beginning because the fixed-size memory pool (1B parameters) forces the eviction of early information.
Key Novelty
Scalable Latent-Space Long-Term Memory with Co-Trained Retrieval
Instead of deleting tokens evicted from the GPU short-term memory, M+ moves them to a CPU-based long-term memory (LTM), preserving them indefinitely.
Integrates a lightweight retriever trained jointly with the LLM to fetch relevant latent states from the CPU LTM back to the GPU during generation.
Uses separate LoRA adapters for the memory 'update' (writing/compressing) and 'generate' (reading/loading) phases to optimize each distinct task.
Architecture
The Update and Generate processes of M+. It illustrates how tokens dropped from the GPU Short-Term Memory (STM) during the Update phase are moved to the CPU Long-Term Memory (LTM). It also shows the Generate phase where relevant tokens are retrieved from LTM and concatenated with STM.
Evaluation Highlights
Extends effective knowledge retention from <20k tokens (MemoryLLM) to >160k tokens.
Maintains similar GPU memory overhead to MemoryLLM by storing long-term history on CPU.
Retrieves memory once per layer for all heads, improving efficiency compared to per-head retrieval methods like H2O or SnapKV.
Breakthrough Assessment
8/10
Significantly extends the utility of latent-space memory models by solving the 'forgetting' problem via CPU offloading, bridging the gap between fixed-context models and RAG.
⚙️ Technical Details
Problem Definition
Setting: Causal language modeling with augmented latent memory
Inputs: Input text sequence x split into chunks
Outputs: Next token probabilities utilizing both immediate context and retrieved latent history
Pipeline Flow
Input Encoder (Llama-3) -> STM Update (GPU) -> Eviction to LTM (CPU) -> Retrieval (CPU to GPU) -> Generation (GPU)
System Modules
Base LLM
Core language modeling and hidden state generation
Model or implementation: Llama-3.1-8B
Memory Updater
Compresses new chunks into STM and evicts oldest K tokens to LTM
Model or implementation: LoRA Adapter (Update-specific)
Latent Retriever
Selects relevant tokens from LTM based on current query hidden states
Model or implementation: 2-layer MLP Projectors (Query and Key)
Generator
Generates text using combined STM and retrieved LTM
Model or implementation: LoRA Adapter (Generate-specific)
Novel Architectural Elements
Dual LoRA weights: distinct adapters for 'Update' (write) and 'Generate' (read) phases
Hybrid GPU/CPU memory hierarchy for latent states (STM on GPU, LTM on CPU)
Layer-wise latent retriever shared across all query heads (unlike per-head retrieval in SnapKV)
Code is publicly available at https://github.com/wangyu-ustc/MemoryLLM. Model built on Llama-3.1-8B. Detailed training stages and data mix ratios provided.
📊 Experiments & Results
Evaluation Setup
Long-context evaluation involving knowledge retention and understanding tasks.
Benchmarks:
Knowledge Retention (Recall of injected information at varying distances) [New]
Long-book understanding (Comprehension of lengthy narratives)
Metrics:
Effective Context Length
Knowledge Retention Rate
Statistical methodology: Not explicitly reported in the paper
Key Results
Benchmark
Metric
Baseline
This Paper
Δ
Knowledge Retention
Retained Token Limit
20000
160000
140000
Main Takeaways
M+ successfully decouples memory capacity from GPU memory limits by leveraging CPU storage for long-term history.
The co-trained retriever effectively identifies relevant latent states without the high computational cost of per-head retrieval mechanisms.
Using separate LoRA weights for memory updates and memory generation facilitates better learning of the distinct 'compression' and 'reading' tasks.
📚 Prerequisite Knowledge
Prerequisites
Transformer Architecture (Decoder-only)
Latent Space Representation
Key-Value (KV) Caching
Low-Rank Adaptation (LoRA)
Key Terms
Latent-Space Memory: Storing information as high-dimensional vectors (hidden states) rather than raw text or discrete tokens.
Short-Term Memory (STM): A fixed-size memory pool kept on the GPU (denoted as theta) containing the most recent context states.
Long-Term Memory (LTM): A scalable memory pool kept on the CPU (denoted as Theta) containing older states evicted from STM.
Co-trained Retriever: A retrieval module trained simultaneously with the language model to select relevant hidden states based on the current context.
LoRA: Low-Rank Adaptation—a parameter-efficient fine-tuning technique that freezes pre-trained weights and injects trainable rank decomposition matrices.
SlimPajama: A large-scale, deduplicated dataset used for training LLMs, specifically used here for long-context curriculum training.
Contrastive Loss: A training objective that pulls positive pairs (relevant memory) closer and pushes negative pairs (irrelevant memory) apart in the embedding space.