ICEIC 2024
Jun-Hyun Bae, Chanwoo Kim, Tae-Young Chang
Kyungpook National University
๐Ÿ“„ Paper

Abstract

Despite the effectiveness of deep neural networks trained with Empirical Risk Minimization (ERM) in medical imaging tasks, these models often exhibit performance degradation when faced with Out-of-Distribution (OoD) data, owing to potential biases in their predictive accuracy. Invariant Risk Minimization (IRM) seeks to rectify this issue by identifying invariant or causal correlations across various environments. However, its practical application does not consistently deliver the expected generalization performance in real-world scenarios. This paper addresses a potential limitation of the IRM framework, positing that the constraints enforced by IRM might not sufficiently guide the model in learning all causal features. In response, we propose a novel methodology leveraging modular neural networks within the IRM framework. Our approach aims to generate more diverse data representations, thereby enhancing the generalization performance of models trained with IRM. Experimental validation on three tasks โ€” two medical image classification tasks, namely, Camelyon17-wilds and CheXpert, and a synthetic task, Colored MNIST โ€” demonstrates significant improvements in generalization performance in both OoD settings and subpopulation shift cases.


Overview

IRM์ด ์ง€๋ฐฐ์ ์ธ invariant feature๋งŒ ํ•™์Šตํ•˜๋Š” ํ•œ๊ณ„๋ฅผ modular neural network๋กœ ๊ทน๋ณตํ•˜์—ฌ, ์˜๋ฃŒ ์˜์ƒ์—์„œ์˜ OoD ์ผ๋ฐ˜ํ™”๋ฅผ ๊ฐœ์„ ํ•œ๋‹ค.

  1. Modular encoder โ€” ๋ฐ์ดํ„ฐ ํ‘œํ˜„ ๋ชจ๋ธ์„ \(N\) ๊ฐœ ๋ชจ๋“ˆ๋กœ ๋ถ„ํ• ํ•˜์—ฌ ๊ฐ๊ฐ์ด ์„œ๋กœ ๋‹ค๋ฅธ invariant feature๋ฅผ ํ•™์Šตํ•˜๋„๋ก ์œ ๋„ํ•œ๋‹ค.
  2. Competitive selection โ€” Multi-head dot product attention์œผ๋กœ ์ž…๋ ฅ์— ๊ฐ€์žฅ ๊ด€๋ จ ์žˆ๋Š” \(k\) ๊ฐœ ๋ชจ๋“ˆ์„ ์„ ํƒํ•œ๋‹ค.
  3. IRM optimization โ€” ์„ ํƒ๋œ ๋ชจ๋“ˆ์˜ ๊ฐ€์ค‘ ํ‘œํ˜„์œผ๋กœ IRM ๋ชฉํ‘œ๋ฅผ ์ตœ์ ํ™”ํ•˜์—ฌ ๋‹ค์–‘ํ•œ invariant feature๋ฅผ ํ™œ์šฉํ•œ OoD ์ผ๋ฐ˜ํ™”๋ฅผ ๋‹ฌ์„ฑํ•œ๋‹ค.

Modular IRM Framework

์ œ์•ˆ ๋ฐฉ๋ฒ•์˜ ๊ตฌ์กฐ๋„. Modular data representation์„ IRM ํ”„๋ ˆ์ž„์›Œํฌ ๋‚ด์— ํ†ตํ•ฉํ•œ๋‹ค.


Method

IRM์€ invariant/causal ์ƒ๊ด€๊ด€๊ณ„๋ฅผ ํ•™์Šตํ•˜๋ ค ํ•˜์ง€๋งŒ, ์‹ค์ œ๋กœ๋Š” ๊ฐ€์žฅ ์ง€๋ฐฐ์ ์ธ(dominant) invariant feature๋งŒ ์ธ์ฝ”๋”ฉํ•˜๋Š” ํ•œ๊ณ„๊ฐ€ ์žˆ๋‹ค. ์šฐ๋ฆฌ๋Š” modular neural network๋ฅผ IRM ํ”„๋ ˆ์ž„์›Œํฌ์— ํ†ตํ•ฉํ•˜์—ฌ, ๊ฐ ๋ชจ๋“ˆ์ด ์„œ๋กœ ๋‹ค๋ฅธ invariant feature๋ฅผ ํ•™์Šตํ•˜๋„๋ก ์œ ๋„ํ•œ๋‹ค.

  • ๋ฐ์ดํ„ฐ ํ‘œํ˜„ ๋ชจ๋ธ \(\Phi\) ๋ฅผ \(N\) ๊ฐœ์˜ ๋ชจ๋“ˆ \(\{f_n\}_{n=1}^N\) ์œผ๋กœ ๋ถ„ํ• 
  • Multi-head dot product attention์œผ๋กœ ๋ชจ๋“ˆ ๊ฐ„ competitive learning ์ˆ˜ํ–‰
  • ์ž…๋ ฅ ์ž์ฒด๊ฐ€ query, ๋ชจ๋“ˆ ์ถœ๋ ฅ์ด key/value๋กœ ์ž‘๋™
  • Top-\(k\) ๋ชจ๋“ˆ์„ ์„ ํƒํ•˜๋˜, ๋น„์„ ํƒ ๋ชจ๋“ˆ๋„ soft selection์œผ๋กœ ์œ ์ง€ (module collapse ๋ฐฉ์ง€)

Dataset Examples

Camelyon17-wilds์™€ CheXpert ๋ฐ์ดํ„ฐ์…‹์˜ ํ™˜๊ฒฝ๋ณ„ ์˜ˆ์‹œ ์ด๋ฏธ์ง€.


Results

Colored MNIST

AlgorithmVal Accuracy (iid)Test Accuracy (OoD)# Params
ERM88.6%16.4%1,198,337
IRM73.4%60.5%1,198,337
Ours74.9%66.5%935,553
Optimal75.0%75.0%N/A

Camelyon17-wilds (OoD Medical Imaging)

AlgorithmVal Accuracy (iid)Test Accuracy (OoD)# Params
ERM91.9%73.3%42.8M
IRM94.1%72.9%42.8M
Ours (N=4, k=2)91.5%83.5%45.6M
Ours (N=2, k=1)90.4%74.5%22.8M

CheXpert (Subpopulation Shift)

AlgorithmAverage AccuracyWorst-case Accuracy
ERM86.9%50.2%
IRM89.8%34.4%
Ours (N=3, k=1)80.3%59.6%

Camelyon17-wilds์—์„œ ERM/IRM ๋Œ€๋น„ OoD ํ…Œ์ŠคํŠธ ์ •ํ™•๋„ 10% ํ–ฅ์ƒ. CheXpert์—์„œ worst-case ์ •ํ™•๋„ 9.4% ํ–ฅ์ƒ.


BibTeX

@inproceedings{bae2024invariant,
  author    = {Bae, Jun-Hyun and Kim, Chanwoo and Chang, Tae-Young},
  title     = {Invariant Risk Minimization in Medical Imaging with Modular Data Representation},
  booktitle = {International Conference on Electronics, Information, and Communication (ICEIC)},
  year      = {2024},
  doi       = {10.1109/ICEIC61013.2024.10457174}
}