Jun-Hyun Bae, Chanwoo Kim, Tae-Young Chang
Kyungpook National University
Abstract
Despite the effectiveness of deep neural networks trained with Empirical Risk Minimization (ERM) in medical imaging tasks, these models often exhibit performance degradation when faced with Out-of-Distribution (OoD) data, owing to potential biases in their predictive accuracy. Invariant Risk Minimization (IRM) seeks to rectify this issue by identifying invariant or causal correlations across various environments. However, its practical application does not consistently deliver the expected generalization performance in real-world scenarios. This paper addresses a potential limitation of the IRM framework, positing that the constraints enforced by IRM might not sufficiently guide the model in learning all causal features. In response, we propose a novel methodology leveraging modular neural networks within the IRM framework. Our approach aims to generate more diverse data representations, thereby enhancing the generalization performance of models trained with IRM. Experimental validation on three tasks โ two medical image classification tasks, namely, Camelyon17-wilds and CheXpert, and a synthetic task, Colored MNIST โ demonstrates significant improvements in generalization performance in both OoD settings and subpopulation shift cases.
Overview
IRM์ด ์ง๋ฐฐ์ ์ธ invariant feature๋ง ํ์ตํ๋ ํ๊ณ๋ฅผ modular neural network๋ก ๊ทน๋ณตํ์ฌ, ์๋ฃ ์์์์์ OoD ์ผ๋ฐํ๋ฅผ ๊ฐ์ ํ๋ค.
- Modular encoder โ ๋ฐ์ดํฐ ํํ ๋ชจ๋ธ์ \(N\) ๊ฐ ๋ชจ๋๋ก ๋ถํ ํ์ฌ ๊ฐ๊ฐ์ด ์๋ก ๋ค๋ฅธ invariant feature๋ฅผ ํ์ตํ๋๋ก ์ ๋ํ๋ค.
- Competitive selection โ Multi-head dot product attention์ผ๋ก ์ ๋ ฅ์ ๊ฐ์ฅ ๊ด๋ จ ์๋ \(k\) ๊ฐ ๋ชจ๋์ ์ ํํ๋ค.
- IRM optimization โ ์ ํ๋ ๋ชจ๋์ ๊ฐ์ค ํํ์ผ๋ก IRM ๋ชฉํ๋ฅผ ์ต์ ํํ์ฌ ๋ค์ํ invariant feature๋ฅผ ํ์ฉํ OoD ์ผ๋ฐํ๋ฅผ ๋ฌ์ฑํ๋ค.

์ ์ ๋ฐฉ๋ฒ์ ๊ตฌ์กฐ๋. Modular data representation์ IRM ํ๋ ์์ํฌ ๋ด์ ํตํฉํ๋ค.
Method
IRM์ invariant/causal ์๊ด๊ด๊ณ๋ฅผ ํ์ตํ๋ ค ํ์ง๋ง, ์ค์ ๋ก๋ ๊ฐ์ฅ ์ง๋ฐฐ์ ์ธ(dominant) invariant feature๋ง ์ธ์ฝ๋ฉํ๋ ํ๊ณ๊ฐ ์๋ค. ์ฐ๋ฆฌ๋ modular neural network๋ฅผ IRM ํ๋ ์์ํฌ์ ํตํฉํ์ฌ, ๊ฐ ๋ชจ๋์ด ์๋ก ๋ค๋ฅธ invariant feature๋ฅผ ํ์ตํ๋๋ก ์ ๋ํ๋ค.
- ๋ฐ์ดํฐ ํํ ๋ชจ๋ธ \(\Phi\) ๋ฅผ \(N\) ๊ฐ์ ๋ชจ๋ \(\{f_n\}_{n=1}^N\) ์ผ๋ก ๋ถํ
- Multi-head dot product attention์ผ๋ก ๋ชจ๋ ๊ฐ competitive learning ์ํ
- ์ ๋ ฅ ์์ฒด๊ฐ query, ๋ชจ๋ ์ถ๋ ฅ์ด key/value๋ก ์๋
- Top-\(k\) ๋ชจ๋์ ์ ํํ๋, ๋น์ ํ ๋ชจ๋๋ soft selection์ผ๋ก ์ ์ง (module collapse ๋ฐฉ์ง)

Camelyon17-wilds์ CheXpert ๋ฐ์ดํฐ์ ์ ํ๊ฒฝ๋ณ ์์ ์ด๋ฏธ์ง.
Results
Colored MNIST
| Algorithm | Val Accuracy (iid) | Test Accuracy (OoD) | # Params |
|---|---|---|---|
| ERM | 88.6% | 16.4% | 1,198,337 |
| IRM | 73.4% | 60.5% | 1,198,337 |
| Ours | 74.9% | 66.5% | 935,553 |
| Optimal | 75.0% | 75.0% | N/A |
Camelyon17-wilds (OoD Medical Imaging)
| Algorithm | Val Accuracy (iid) | Test Accuracy (OoD) | # Params |
|---|---|---|---|
| ERM | 91.9% | 73.3% | 42.8M |
| IRM | 94.1% | 72.9% | 42.8M |
| Ours (N=4, k=2) | 91.5% | 83.5% | 45.6M |
| Ours (N=2, k=1) | 90.4% | 74.5% | 22.8M |
CheXpert (Subpopulation Shift)
| Algorithm | Average Accuracy | Worst-case Accuracy |
|---|---|---|
| ERM | 86.9% | 50.2% |
| IRM | 89.8% | 34.4% |
| Ours (N=3, k=1) | 80.3% | 59.6% |
Camelyon17-wilds์์ ERM/IRM ๋๋น OoD ํ ์คํธ ์ ํ๋ 10% ํฅ์. CheXpert์์ worst-case ์ ํ๋ 9.4% ํฅ์.
BibTeX
@inproceedings{bae2024invariant,
author = {Bae, Jun-Hyun and Kim, Chanwoo and Chang, Tae-Young},
title = {Invariant Risk Minimization in Medical Imaging with Modular Data Representation},
booktitle = {International Conference on Electronics, Information, and Communication (ICEIC)},
year = {2024},
doi = {10.1109/ICEIC61013.2024.10457174}
}