Representation Learning DynamicsImplicit Bias in OptimizationGrokking and Shortcut Learning
Delayed generalization in neural networks is governed by a norm-hierarchy transition where weight decay drives a slow shift from high-norm shortcut solutions to low-norm structured features.
Core Problem
Neural networks often rely on spurious shortcuts for hundreds of epochs before discovering structured representations, but the mechanism governing when this transition occurs remains poorly understood.
Why it matters:
Current models exploit spurious correlations (e.g., background textures) long before learning causal features, reducing robustness
Grokking (sudden generalization long after overfitting) is observed but lacks a unified predictive theory linking it to standard learning dynamics
Simplicity bias explains why simple features are learned first, but not the timescale of the subsequent transition to structured features
Concrete Example:In CIFAR-10 with spurious borders, a model achieves high training accuracy by memorizing border colors (shortcut) and maintains this for many epochs before suddenly switching to classifying based on object shape (structured feature). Current theory cannot predict the timing of this switch.
Key Novelty
Norm-Hierarchy Transition (NHT) Framework
Conceptualizes learning as a competition between multiple interpolating solutions where weight decay exerts a directed pressure from high-norm (shortcut) to low-norm (structured) representations
Proposes that the delay time is logarithmically proportional to the ratio of the shortcut norm to the structured norm
Identifies 'Clean Norm Separation' as the critical condition determining whether transition timing is predictable
Architecture
Conceptual diagram of the optimization landscape showing Shortcut Manifold (high norm) and Structured Manifold (low norm).
Evaluation Highlights
Predicts transition delay with R^2 > 0.97 across modular arithmetic tasks, validating the logarithmic scaling law
Demonstrates 78% → 10% clean accuracy drop (reversion to shortcuts) as shortcut strength increases in CIFAR-10, matching theoretical predictions
Validates the framework across diverse architectures including ResNet18 with Batch Normalization, showing robustness beyond toy models
Breakthrough Assessment
9/10
Provides a unifying theoretical mechanism that connects grokking, shortcut learning, and emergent abilities. The tight theoretical bounds and multi-domain empirical validation are highly significant.
⚙️ Technical Details
Problem Definition
Setting: Regularized gradient-based training in the overparameterized regime with L2 regularization
Outputs: Learned parameter vector θ converging to structured manifold M_st
Pipeline Flow
Optimization on Shortcut Manifold (M_sc)
Norm-Driven Transition
Convergence to Structured Manifold (M_st)
System Modules
Optimizer
Updates parameters to minimize regularized loss
Model or implementation: SGD or AdamW
Modeling
Base Model: Varies (ResNet18, MLPs for modular arithmetic)
Training Method: Standard supervised learning with L2 regularization (weight decay)
Objective Functions:
Purpose: Minimize training error while penalizing complexity.
Formally: L(θ) = L_train(θ) + (λ/2)||θ||^2
Adaptation: Full training
Trainable Parameters: All weights
Training Data:
CIFAR-10 with spurious borders
Modular arithmetic datasets
CelebA
Waterbirds
Key Hyperparameters:
weight_decay_lambda: Varies (critical control parameter)
learning_rate_eta: Satisfies η ≤ λ/L
Compute: Not explicitly reported in the paper
Comparison to Prior Work
vs. Soudry et al.: NHT characterizes the *timescale* of the transition, not just the endpoint
vs. Shah et al.: NHT explains the *dynamic transition* from simple to structured, rather than just the initial preference
vs. Power et al.: NHT provides a unified *mechanism* (norm dynamics) explaining grokking as a specific instance of a broader phenomenon
Limitations
Assumes exact interpolation and stationarity, which are idealizations of real training
Predictive power fails when Clean Norm Separation condition is not met (e.g., Waterbirds)
Requires V_sc > V_st (shortcut norm > structured norm), which may not hold for all tasks
Reproducibility
Code availability is not provided. Detailed theoretical proofs are in appendices. Experimental setups (datasets like Waterbirds, CelebA, Modular Arithmetic) are standard.
📊 Experiments & Results
Evaluation Setup
Controlled experiments on datasets with known spurious/shortcut features to measure transition timing
Benchmarks:
Modular Arithmetic (Algorithmic reasoning)
CIFAR-10 (Spurious) (Image classification with added shortcuts) [New]
CelebA (Face attribute classification)
Waterbirds (Robustness to background correlation)
Metrics:
Clean Accuracy
Transition Delay (epochs/steps)
Parameter Norm
Statistical methodology: R-squared regression analysis for validating scaling laws
Key Results
Benchmark
Metric
Baseline
This Paper
Δ
Modular Arithmetic
R^2 (Transition Delay vs Theory)
0
0.97
+0.97
CIFAR-10 (Spurious)
Clean Accuracy
10.0
78.0
+68.0
CIFAR-10
Norm Ratio (V_sc / V_st)
1.0
37.0
+36.0
CelebA
Norm Ratio (V_sc / V_st)
1.0
3.0
+2.0
Main Takeaways
Transition delay is governed by the ratio of shortcut norm to structured norm and the effective regularization strength.
Three distinct regimes exist: Weak regularization (permanent shortcut), Intermediate (delayed transition/grokking), Strong (no learning).
ResNet18 with BatchNorm exhibits the same peak-then-decay norm dynamics as unnormalized models.
Waterbirds dataset demonstrates the framework's boundary: norm dynamics transfer, but representational transition fails due to lack of clean norm separation.
📚 Prerequisite Knowledge
Prerequisites
Gradient descent dynamics with weight decay
Concept of interpolation manifolds in overparameterized networks
Implicit bias of SGD
Key Terms
shortcut solution: A valid solution to the training objective that relies on spurious features (high norm)
structured solution: A valid solution that captures the true data-generating mechanism (low norm)
grokking: A phenomenon where generalization suddenly improves long after training accuracy has saturated
weight decay: L2 regularization term added to the loss function that penalizes large parameter weights
interpolation manifold: The set of all parameter configurations that achieve zero (or near-zero) training loss
clean norm separation: A condition where the norm of the shortcut solution is strictly distinguishable from the structured solution, allowing predictable transitions