ICEIC 2024
Jun-Hyun Bae, Chanwoo Kim, Tae-Young Chang
Kyungpook National University
๐Ÿ“„ Paper
Translated by Claude Opus 4.6

Abstract

Despite the effectiveness of deep neural networks trained with Empirical Risk Minimization (ERM) in medical imaging tasks, these models often exhibit performance degradation when faced with Out-of-Distribution (OoD) data, owing to potential biases in their predictive accuracy. Invariant Risk Minimization (IRM) seeks to rectify this issue by identifying invariant or causal correlations across various environments. However, its practical application does not consistently deliver the expected generalization performance in real-world scenarios. This paper addresses a potential limitation of the IRM framework, positing that the constraints enforced by IRM might not sufficiently guide the model in learning all causal features. In response, we propose a novel methodology leveraging modular neural networks within the IRM framework. Our approach aims to generate more diverse data representations, thereby enhancing the generalization performance of models trained with IRM. Experimental validation on three tasks โ€” two medical image classification tasks, namely, Camelyon17-wilds and CheXpert, and a synthetic task, Colored MNIST โ€” demonstrates significant improvements in generalization performance in both OoD settings and subpopulation shift cases.


Overview

We overcome IRM’s limitation of learning only dominant invariant features by integrating modular neural networks, improving OoD generalization in medical imaging.

  1. Modular encoder โ€” Split the data representation model into \(N\) modules, each learning distinct invariant features.
  2. Competitive selection โ€” Select the \(k\) most relevant modules via multi-head dot product attention.
  3. IRM optimization โ€” Optimize the IRM objective with weighted representations from selected modules, achieving OoD generalization through diverse invariant features.

Modular IRM Framework

Overview of the proposed method: modular data representations integrated within the IRM framework.


Method

While IRM attempts to learn invariant/causal correlations, in practice it tends to encode only the most dominant invariant feature. We integrate a modular neural network into the IRM framework so that each module learns distinct invariant features.

Key components:

  • The data representation model \(\Phi\) is split into \(N\) modules \(\{f_n\}_{n=1}^N\)
  • Competitive learning via multi-head dot product attention across modules
  • Input itself serves as the query; module outputs serve as keys/values
  • Top-\(k\) modules are selected, with non-selected modules maintained via soft selection to prevent module collapse

Dataset Examples

Example images from the Camelyon17-wilds and CheXpert datasets across different environments.


Results

Colored MNIST

AlgorithmVal Accuracy (iid)Test Accuracy (OoD)# Params
ERM88.6%16.4%1,198,337
IRM73.4%60.5%1,198,337
Ours74.9%66.5%935,553
Optimal75.0%75.0%N/A

Camelyon17-wilds (OoD Medical Imaging)

AlgorithmVal Accuracy (iid)Test Accuracy (OoD)# Params
ERM91.9%73.3%42.8M
IRM94.1%72.9%42.8M
Ours (N=4, k=2)91.5%83.5%45.6M
Ours (N=2, k=1)90.4%74.5%22.8M

CheXpert (Subpopulation Shift)

AlgorithmAverage AccuracyWorst-case Accuracy
ERM86.9%50.2%
IRM89.8%34.4%
Ours (N=3, k=1)80.3%59.6%

On Camelyon17-wilds, our method improves OoD test accuracy by ~10% over ERM/IRM. On CheXpert, worst-case accuracy improves by 9.4%.


BibTeX

@inproceedings{bae2024invariant,
  author    = {Bae, Jun-Hyun and Kim, Chanwoo and Chang, Tae-Young},
  title     = {Invariant Risk Minimization in Medical Imaging with Modular Data Representation},
  booktitle = {International Conference on Electronics, Information, and Communication (ICEIC)},
  year      = {2024},
  doi       = {10.1109/ICEIC61013.2024.10457174}
}