Leo Feng, Frederick Tung, Hossein Hajimirsadeghi, Yoshua Bengio, Mohamed Osama Ahmed
arXiv
(2023)
MemoryBenchmark
📝 Paper Summary
Memory Efficient ModelsMeta-learning
Constant Memory Attentive Neural Processes achieve memory efficiency and fast updates by using a specialized attention block with fixed latent queries that supports exact, incremental cross-attention updates.
Core Problem
State-of-the-art Neural Processes leverage expensive attention mechanisms (like Transformers) that scale linearly or quadratically with context size, making them too memory-intensive for low-resource devices.
Why it matters:
High memory costs limit deployment on battery-powered edge devices (IoT, mobile robots) where energy efficiency is crucial.
Existing methods require re-processing the entire context dataset from scratch when new data arrives, which is computationally wasteful.
Current approaches struggle to scale to large context datasets due to $O(N^2)$ or $O(Nk)$ memory complexity.
Concrete Example:When a robot receives 10 new observational data points, a standard Transformer Neural Process must re-compute attention over its entire history of thousands of points (scaling quadratically). CMANP updates its representation using only the 10 new points in constant memory, without accessing the full history.
Key Novelty
Constant Memory Attention Block (CMAB) with Exact Updates
Uses a set of fixed, learnable latent vectors as queries in Cross Attention, decoupling the memory cost from the size of the input context dataset.
Reformulates Cross Attention as a rolling average operation (using log-sum-exp), enabling the model to update attention weights with new data without re-processing old data.
Ensures permutation invariance and stackability while maintaining constant memory complexity throughout the conditioning phase.
Architecture
The Constant Memory Attention Block (CMAB) architecture.
Breakthrough Assessment
8/10
The theoretical contribution of an exact, constant-memory update for Cross Attention is significant for efficient meta-learning, directly addressing the primary bottleneck of attention-based Neural Processes.
⚙️ Technical Details
Problem Definition
Setting: Meta-learning for predictive uncertainty estimation using Neural Processes.
Outputs: Predictive distribution $p(y_T|x_T, z_C)$.
Pipeline Flow
Input Splitter (batches context data)
CMAB Stack (iterative encoding of context)
Query Decoder (makes predictions using final latents)
System Modules
Input Batcher
Splits large context dataset into smaller batches to fit in constant memory
Model or implementation: Algorithmic splitting
Constant Memory Attention Block (CMAB)
Compresses context data into fixed-size latent representations
Model or implementation: Custom Attention Block
Query Decoder
Generates predictions for target points using the encoded latents
Model or implementation: Attention-based Decoder
Novel Architectural Elements
Constant Memory Attention Block (CMAB) containing fixed internal learned latents $L_B$ specifically designed to allow exact cross-attention updates.
Recursive conditioning pipeline where context is processed in batches using a rolling average update rule rather than all-at-once.
Modeling
Base Model: CMANPs (Constant Memory Attentive Neural Processes)
Training Method: Meta-learning via ELBO maximization
Objective Functions:
Purpose: Maximize the likelihood of target data given context data (stochastic NP variant).
Formally: Maximizing Evidence Lower Bound (ELBO).
Key Hyperparameters:
b_C: Batch size for context processing (constant determining memory usage)
k: Number of latent variables (scales with task difficulty)
L: Set of query tokens
Compute: Memory complexity is O(1) with respect to total context size |Dc| (specifically O(|L|) after batching).
Comparison to Prior Work
vs. TNPs: CMANP has O(1) memory vs. TNP's O((N+M)^2) memory complexity.
vs. LBANPs: CMANP enables efficient updates without re-computation, whereas LBANPs/Perceivers require re-computing attention when context changes.
Limitations
Numerical stability relies on the log-sum-exp trick; naive implementations of the rolling average update are unstable.
The constant memory benefit depends on the hyperparameter $b_C$; very small batches might increase time complexity despite saving memory.
The specific performance trade-off (accuracy vs. memory) compared to full-attention models is not visible in the provided text snippet.
Reproducibility
The paper provides detailed derivations for the efficient update rules in the Appendix (referenced). Code URL is not provided in the text.
📊 Experiments & Results
Evaluation Setup
Meta-learning for predictive uncertainty
Benchmarks:
Standard NP Benchmarks (Regression / Classification (implied by context))
Metrics:
Log-likelihood
Memory Usage
Statistical methodology: Not explicitly reported in the provided text
Main Takeaways
Theoretical Analysis: The paper proves that Cross Attention can be updated exactly in O(|D_U||L|) computation and O(|L|) memory, independent of the historical context size |D_C|.
Theoretical Analysis: CMANPs reduce the memory complexity of the conditioning phase from linear/quadratic (in prior SOTA) to constant O(1) with respect to the number of context points.
The proposed update rule allows for autoregressive extensions that maintain constant memory complexity, unlike prior Not-Diagonal extensions which required quadratic memory.
Neural Processes: A family of meta-learning models that combine neural networks with stochastic processes to estimate uncertainty by conditioning on context data.
CMAB: Constant Memory Attention Block—the proposed architectural unit that compresses context data into fixed latents using efficient cross-attention.
Cross Attention: An attention mechanism that weighs the relevance of a context set (keys/values) based on a query set.
ELBO: Evidence Lower Bound—a proxy objective function maximized during training to approximate the log-likelihood of the data.
Permutation Invariance: A property where the model's output remains the same regardless of the order in which context data points are processed.
Log-Sum-Exp Trick: A mathematical technique used to compute the logarithm of a sum of exponentials numerically stably, preventing overflow or underflow.