arXiv 2021
Jun-Hyun Bae, Inchul Choi, Minho Lee
Kyungpook National University
📄 Paper
Translated by Claude Opus 4.7

Abstract

Invariant Risk Minimization (IRM) aims to find predictors that generalize across out-of-distribution (OOD) environments by learning invariant correlations rather than spurious ones. We propose meta-IRM, a meta-learning based approach that instantiates the IRM objective using MAML. The key idea is assigning different training environments to the inner and outer optimization loops — since spurious correlations vary across environments, this naturally drives the model toward invariant features. We empirically evaluate meta-IRM on Colored MNIST variants including multi-class, insufficient data, and multiple spurious correlation settings.


Overview

We directly solve the ideal IRM bi-level optimization using the MAML framework, assigning different training environments to inner and outer loops to learn invariant features.

  1. Environment-specific adaptation — Adapt model parameters to each training environment \(e_i\) in the inner loop.
  2. Cross-environment evaluation — Evaluate the adapted parameters \(\theta_i'\) on a different environment \(e_j\) to compute the meta-loss.
  3. Invariant convergence — Combining gradients from different environments cancels spurious correlations, converging toward invariant features.

Meta-IRM Diagram

Diagram of meta-IRM. Gradients from environment-adapted parameters are combined to converge toward the invariant feature direction ($\theta_{optimal}$).


Method

IRM aims to find an invariant predictor across multiple training environments. Meta-IRM approaches this using the MAML framework, with a key insight: assigning different training environments to the inner loop and outer loop. Since spurious correlations vary across environments, combining gradients from different environments drives convergence toward invariant features.

Meta-Learning Framework

The inner/outer structure of MAML maps onto the IRM objective. The inner loop adapts to one environment, while the outer loop (meta-loss) evaluates on a different environment.

Inner optimization: Adapt model parameters to each training environment \(e_i\) :

\[\theta_i' = \theta - \alpha \nabla_\theta R^{e_i}(f_\theta)\]

Meta-optimization: Update \(\theta\) using meta-losses computed from different environments \(e_j\) with adapted parameters \(\theta_i'\) :

\[\theta \leftarrow \theta - \beta \nabla_\theta \left\{ \sum_i \sum_j R^{e_j}(f_{\theta_i'}) + \lambda \sigma \right\}, \quad e_j \sim \mathcal{E}_{tr} \setminus e_i\]

Learning Process

Schematic of meta-IRM learning process. Model parameters $\theta$ are adapted to training environments in the inner optimization; each adapted parameter computes meta-loss from different environments.

Auxiliary Standard Deviation Loss

An auxiliary loss based on the standard deviation \(\sigma\) of meta-losses encourages uniform performance across all environments. This regularization improves training stability and convergence speed.


Results

Colored MNIST

Test accuracy (%) on Colored MNIST. $p_e=0.9$ is the OOD environment (reversed color-label correlation).

Algorithm\(p_e=0.1\)\(p_e=0.2\)\(p_e=0.9\) (OOD)
ERM88.6 ± 0.379.7 ± 0.616.4 ± 0.8
IRMv171.4 ± 0.970.8 ± 1.066.9 ± 2.5
MM-REx70.8 ± 1.570.4 ± 2.066.1 ± 4.9
V-REx71.5 ± 0.871.1 ± 0.968.6 ± 1.2
meta-IRM (Ours)70.9 ± 0.970.8 ± 1.070.4 ± 0.9
Random505050
ERM (grayscale)72.6 ± 0.372.7 ± 0.373.0 ± 0.5
Optimal757575

meta-IRM (70.4%) surpasses V-REx (68.6%), IRMv1 (66.9%), and MM-REx (66.1%), approaching the theoretical optimum (75%) most closely. Notably, the OoD standard deviation of 0.9 is substantially lower than other methods (IRMv1 2.5, MM-REx 4.9), demonstrating stable invariant feature learning. ERM (grayscale), trained on images without the spurious color feature, serves as a reference point (73.0%); meta-IRM’s proximity to it indicates that color spurious correlations are almost fully excluded.

Various Test Environments

Accuracy across test environments ($p_e$ from 0.1 to 0.9).

Multi-Class Problem

As the number of classes increases, the structure of spurious correlations becomes more complex, and meta-IRM’s advantage becomes more pronounced.

Multi-class Colored MNIST ($k$=5, 10) test accuracy (%).

Algorithm# ClassesTrainTest (OOD)
ERM595.2 ± 0.241.0 ± 0.6
IRMv1582.2 ± 0.462.0 ± 2.4
meta-IRM (Ours)576.4 ± 1.474.0 ± 3.6
Random52020
ERM (grayscale)573.2 ± 0.271.7 ± 0.4
ERM1092.6 ± 0.239.2 ± 0.9
IRMv11083.4 ± 0.558.6 ± 2.5
meta-IRM (Ours)1079.5 ± 0.673.4 ± 3.2
Random101010
ERM (grayscale)1073.2 ± 0.171.9 ± 0.5
Optimal7575

In the 10-class setting, meta-IRM (73.4%) surpasses IRMv1 (58.6%) by 14.8pp and exceeds the ERM (grayscale) reference (71.9%). This indicates that as the number of classes grows, IRMv1’s penalty term becomes insufficient, while meta-IRM’s cross-environment evaluation drives convergence toward invariant features regardless of problem complexity.

Insufficient Data

We evaluate how well IRM-family methods maintain performance when training data is halved (25,000 to 12,500 per environment).

Test accuracy (%) with reduced training data (12,500 per environment).

Algorithm\(p_e=0.1\)\(p_e=0.2\)\(p_e=0.9\) (OOD)
IRMv173.2 ± 1.671.7 ± 1.358.5 ± 2.5
V-REx75.0 ± 0.873.4 ± 0.661.3 ± 1.4
meta-IRM (Ours)70.6 ± 1.570.5 ± 1.668.3 ± 2.3
Optimal757575

When data is halved, IRMv1 drops from 66.9% to 58.5% (an 8.4pp decline), while meta-IRM drops only from 70.4% to 68.3% (2.1pp). This demonstrates that the meta-learning-based optimization efficiently extracts invariant features even under data-scarce conditions.

Data Decrease

Effect of decreasing training data (25,000 → 12,500 → 6,250).

Two Spurious Features

This setting introduces 2 spurious correlations (color + patch) with only 2 training environments. Under these conditions, the theoretical guarantees of IRMv1 do not hold (number of spurious features > number of environments).

Algorithm\(p_e=0.1\)\(p_e=0.2\)\(p_e=0.9\) (OOD)
IRMv193.5 ± 0.286.4 ± 0.313.4 ± 0.3
V-REx93.6 ± 0.486.3 ± 0.313.5 ± 0.3
meta-IRM (Ours)58.1 ± 3.156.8 ± 2.954.5 ± 4.0
Optimal757575

IRMv1 and V-REx drop to 13.4–13.5%, falling below random chance (50%). Both methods exploit both spurious features to achieve high in-distribution accuracy, but when both correlations reverse simultaneously in the OoD setting, performance collapses. meta-IRM achieves 54.5% — suboptimal but above random — indicating that invariant feature learning is at least partially maintained.

Two Spurious Features

Test accuracy with two spurious features (color + patch) across environments.

PunctuatedSST-2 (NLP)

We verify whether the principles from Colored MNIST transfer to NLP using PunctuatedSST-2, a sentiment analysis task where punctuation marks ("!" vs “.”) serve as spurious features.

PunctuatedSST-2 test accuracy (%). $\eta_e$ is the label noise rate.

Algorithm\(\eta_e\)Test (OOD)
ERM0.2530.7 ± 1.5
IRMv10.2562.0 ± 1.9
meta-IRM (Ours)0.2562.2 ± 1.8
ERM (grayscale)0.2562.3 ± 0.5
Optimal0.2575
ERM056.2 ± 2.9
IRMv1067.4 ± 1.4
meta-IRM (Ours)073.0 ± 0.7
ERM (grayscale)076.7 ± 2.7
Optimal0100

Without label noise (\(\eta_e=0\) ), meta-IRM (73.0%) surpasses IRMv1 (67.4%) by 5.6pp and approaches the ERM (grayscale) reference (76.7%), which trains on vanilla SST-2 without the punctuation spurious feature. With label noise (\(\eta_e=0.25\) ), meta-IRM (62.2%) is effectively tied with ERM (grayscale, 62.3%) and nearly matches IRMv1 — label noise caps the information content of the invariant feature itself, so the gap between methods vanishes and all already reach the ceiling attainable after removing the punctuation correlation.

Ablation Study

Variant\(p_e=0.1\)\(p_e=0.9\) (OOD)
w/o std. loss73.0 ± 0.864.2 ± 3.2
First-order approx.62.6 ± 5.559.1 ± 6.3
Same env. (inner = meta)89.3 ± 1.613.6 ± 6.6
meta-IRM (Full)70.9 ± 0.970.4 ± 0.9
Optimal7575

The same environment (inner = meta) variant collapses to 13.6%, exhibiting an ERM-level failure. This is the definitive result confirming that assigning different environments to the inner and outer loops is the core mechanism of meta-IRM. First-order approximation degrades both performance and stability due to loss of second-order gradient information (59.1% ± 6.3). Removing the standard deviation loss reduces OoD performance by 6.2pp and increases variance (3.2 vs 0.9).