Jun-Hyun Bae, Minho Lee, Heechul Jung
Kyungpook National University
Abstract
Training deep neural networks with empirical risk minimization (ERM) often captures dataset biases, hindering generalization to new or unseen data. Previous solutions either require prior knowledge of biases or utilize training intentionally biased models as auxiliaries; however, they still suffer from multiple biases. To address this, we introduce Adaptive Bias Discovery (ABD), a novel learning framework designed to mitigate the impact of multiple unknown biases. ABD trains an auxiliary model to be adapted to biases based on the debiased parameters from the debiasing phase, allowing it to navigate through multiple biases. Then, samples are reweighted based on the discovered biases to update debiased parameters. Extensive evaluations of synthetic experiments and real-world datasets demonstrate that ABD consistently outperforms existing methods, particularly in real-world applications where multiple unknown biases are prevalent.
Overview
์ฌ์ ๋ฐ์ด์ด์ค ์ ๋ณด ์์ด ๋ฐ์ดํฐ์ ์กด์ฌํ๋ ์ฌ๋ฌ ๋ฐ์ด์ด์ค๋ฅผ ์์ฐจ์ ์ผ๋ก ๋ฐ๊ฒฌํ๊ณ ์ ๊ฑฐํ๋ ํ์ต ํ๋ ์์ํฌ๋ฅผ ์ ์ํ๋ค.
- Bias-adapted model โ Debiased ํ๋ผ๋ฏธํฐ \(\theta\) ์์ 1-step gradient descent๋ก ๋ฐ์ด์ด์ค์ ๋ฏผ๊ฐํ ๋ณด์กฐ ๋ชจ๋ธ \(f_\phi\) ๋ฅผ ์์ฑํ๋ค.
- Adaptive group formation โ \(f_\phi\) ์ ์์ธก์ผ๋ก ๋ฐ์ดํฐ๋ฅผ ๋ฐ์ด์ด์ค ์ ๋ ฌ ๊ทธ๋ฃน(\(G^\odot\) )๊ณผ ๋น์ ๋ ฌ ๊ทธ๋ฃน(\(G^\otimes\) )์ผ๋ก ๋ถํ ํ๋ค.
- Iterative debiasing โ Group DRO๋ก worst-case ๊ทธ๋ฃน ์์ค์ ์ต์ํํ๋ฉฐ, \(\theta\) ๊ฐ ํ ๋ฐ์ด์ด์ค์ ๊ฐ๊ฑดํด์ง๋ฉด \(\phi\) ๊ฐ ์์ฐ์ค๋ฝ๊ฒ ๋ค์ ๋ฐ์ด์ด์ค๋ฅผ ๋ฐ๊ฒฌํ๋ค.

ABD ํ๋ ์์ํฌ ๊ฐ์. ๋ ๊ฐ์ง ๋ฐ์ด์ด์ค(Bias1, Bias2)์ ๋ ํ์ต ์คํ ์ ์์๋ก ๋์ํ.
Method
ERM์ผ๋ก ํ์ต๋ ๋ชจ๋ธ์ ๋ฐ์ดํฐ์ ์กด์ฌํ๋ spurious correlation์ ์ฝ๊ฒ ํฌ์ฐฉํ์ฌ ์ผ๋ฐํ ์ฑ๋ฅ์ด ์ ํ๋๋ค. ๊ธฐ์กด ๋ฐฉ๋ฒ๋ค์ ๋ฐ์ด์ด์ค ์ ๋ณด๋ฅผ ์ฌ์ ์ ์๊ณ ์์ด์ผ ํ๊ฑฐ๋(Group DRO), ๋จ์ผ ๋ฐ์ด์ด์ค๋ง ์ฒ๋ฆฌํ ์ ์๋ค๋(PI, JTT) ํ๊ณ๊ฐ ์๋ค.
ABD๋ ๋ ๋จ๊ณ๋ก ๊ตฌ์ฑ๋๋ค. ๋จผ์ debiased ํ๋ผ๋ฏธํฐ \(\theta\) ์์ ํ ์คํ gradient descent๋ก bias-adapted ํ๋ผ๋ฏธํฐ \(\phi = \theta - \alpha \nabla_\theta \mathcal{L}(f_\theta)\) ๋ฅผ ์ป๋๋ค. ์ด \(f_\phi\) ๋ ๋ฐ์ดํฐ์ ํ๋ฉด์ ํจํด์ ๋ฏผ๊ฐํ๊ฒ ๋ฐ์ํ๋ฏ๋ก, ์์ธก ๊ฒฐ๊ณผ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๋ฐ์ดํฐ๋ฅผ ๋ฐ์ด์ด์ค ์ ๋ ฌ ๊ทธ๋ฃน(\(G^\odot\) )๊ณผ ๋น์ ๋ ฌ ๊ทธ๋ฃน(\(G^\otimes\) )์ผ๋ก ๋ถํ ํ๋ค. ์ดํ group DRO๋ฅผ ํตํด worst-case ๊ทธ๋ฃน์ ์์ค์ ์ต์ํํ๋๋ก \(\theta\) ๋ฅผ ์ ๋ฐ์ดํธํ๋ค.
ํต์ฌ์ \(\phi\) ๊ฐ ๋งค ์คํ ๋ง๋ค \(\theta\) ๋ก๋ถํฐ ์ฌ์์ฑ๋๋ค๋ ์ ์ด๋ค. \(\theta\) ๊ฐ ์ฒซ ๋ฒ์งธ ๋ฐ์ด์ด์ค์ ๋ํด ๊ฐ๊ฑดํด์ง๋ฉด, \(\phi\) ๊ฐ ๋ค์์ผ๋ก ๋๋๋ฌ์ง ๋ฐ์ด์ด์ค๋ฅผ ํฌ์ฐฉํ๊ฒ ๋๋ค. ์ด MAML ์ ์ฌ ๊ตฌ์กฐ ๋๋ถ์ ์ฌ์ ๋ฐ์ด์ด์ค ์ ๋ณด ์์ด๋ ์ฌ๋ฌ ๋ฐ์ด์ด์ค๋ฅผ ์์ฐจ์ ์ผ๋ก ๋ฐ๊ฒฌํ๊ณ ์ ๊ฑฐํ ์ ์๋ค.
์๋ GradCAM ์๊ฐํ๋ biased model \(f_\phi\) ์ attention์ด ํ์ต์ด ์งํ๋จ์ ๋ฐ๋ผ ๋ค๋ฅธ ์์ญ์ผ๋ก ์ด๋ํ๋ ๊ฒ์ ๋ณด์ฌ์ค๋ค. ์ด๋ ABD๊ฐ ํ์ต ๊ณผ์ ์์ ๋ค์ํ ๋ฐ์ด์ด์ค๋ฅผ ์ ์์ ์ผ๋ก ๋ฐ๊ฒฌํจ์ ์๊ฐ์ ์ผ๋ก ํ์ธํ ์ ์๋ ์ฆ๊ฑฐ์ด๋ค.

ERM ๋ชจ๋ธ๊ณผ ABD์ biased model $f_\phi$์ GradCAM ์๊ฐํ. ํ์ต ์คํ ์ด ์งํ๋๋ฉด์ $f_\phi$์ attention์ด ๋ค๋ฅธ ๋ฐ์ด์ด์ค ํน์ง์ผ๋ก ์ด๋ํ๋ค.
Results
Colored MNIST โ ๋ณต์ ๋ฐ์ด์ด์ค ์ฒ๋ฆฌ
๊ธฐ์กด ๋ฐฉ๋ฒ๊ณผ์ ์ฐจ์ด๋ฅผ ๊ฐ์ฅ ์ ๋ณด์ฌ์ฃผ๋ ์คํ์ด๋ค. ๋ฐ์ด์ด์ค๊ฐ Color ํ๋์ผ ๋์ Color + Patch ๋ ๋ค ์กด์ฌํ ๋๋ฅผ ๋น๊ตํ๋ค.
OoD test accuracy (%). ๋ฐ์ด์ด์ค๊ฐ ํ๋์ผ ๋์ ๋ ๋ค ์กด์ฌํ ๋์ ๋น๊ต.
| Algorithm | Color (OoD) | Color & Patch (OoD) |
|---|---|---|
| ERM | 16.4 | 14.0 |
| IRM | 66.9 | 13.4 |
| Group DRO | 13.6 | 14.1 |
| PI | 70.2 | 15.3 |
| ABD (Ours) | 70.7 | 62.3 |
| Optimal | 75.0 | 75.0 |
PI๋ Color ๋ฐ์ด์ด์ค๋ง ๋ฐ๊ฒฌํ๊ณ Patch๋ ํฌ์ฐฉํ์ง ๋ชปํ์ฌ, ๋ฐ์ด์ด์ค๊ฐ ๋ ๊ฐ๊ฐ ๋๋ฉด 15.3%๋ก ์ฌ์ค์ ์คํจํ๋ค (54.9%p ํ๋ฝ). IRM๊ณผ Group DRO ์ญ์ ๋ณต์ ๋ฐ์ด์ด์ค ํ๊ฒฝ์์ ERM ์์ค์ ๋จธ๋ฌธ๋ค. ABD๋ Color โ Patch ์์ผ๋ก ๋ฐ์ด์ด์ค๋ฅผ ์์ฐจ์ ์ผ๋ก ๋ฐ๊ฒฌํ์ฌ, ๋ณต์ ๋ฐ์ด์ด์ค ํ๊ฒฝ์์ 62.3%์ OoD ์ฑ๋ฅ์ ๋ฌ์ฑํ๋ ์ ์ผํ ๋ฐฉ๋ฒ์ด๋ค.


๊ทธ๋ฃน ๋ด Pearson ์๊ด๊ณ์ ๋น๊ต.
Real-World Tasks
๋ฐ์ด์ด์ค annotation์ด ์๋ ๋ฐ์ดํฐ์ โ CivilComments & MultiNLI
Worst-case test accuracy (%). CivilComments์ Group ์ด์ ๊ฐ ์๊ณ ๋ฆฌ์ฆ์ด ๊ทธ๋ฃนํ์ ์ฌ์ฉํ demographic ์ ๋ณด๋ฅผ, MultiNLI์ Group DRO*๋ prior bias ์ ๋ณด๋ก ์์์ ์ ์๋ ๊ทธ๋ฃน์ ์ฌ์ฉํ oracle ์ค์ ์ด๋ค.
| Algorithm | CivilComments | Group (CC) | MultiNLI |
|---|---|---|---|
| ERM | 56.0 | None | 61.8 |
| IRM | 66.3 | (label ร Black) | โ |
| Group DRO | 69.1 | (label) | 62.7 |
| Group DRO | 70.0 | (label ร Black) | โ |
| JTT | 69.3 | None | 63.2 |
| PI | 61.1 | None | 61.5 |
| ABD (Ours) | 71.1 | None | 67.1 |
| Group DRO* (oracle) | โ | โ | 67.5 |
๋ฐ์ด์ด์ค annotation์ด ์๋ ๋ฐ์ดํฐ์ โ Camelyon17 & FMoW (WILDS)
Camelyon17-wilds๋ OoD average accuracy, FMoW-wilds๋ worst-region accuracy (%).
| Algorithm | Camelyon17 | FMoW |
|---|---|---|
| ERM | 70.3 ยฑ 6.4 | 32.3 ยฑ 1.3 |
| IRM | 59.5 ยฑ 7.7 | 31.7 ยฑ 1.2 |
| Group DRO | 68.4 ยฑ 7.3 | 30.8 ยฑ 0.8 |
| CORAL | 59.5 ยฑ 7.7 | 32.8 ยฑ 0.7 |
| JTT | 63.8 ยฑ 1.4 | 33.4 ยฑ 0.9 |
| PI | 71.7 ยฑ 7.5 | 31.2 ยฑ 0.3 |
| CGD | 69.4 ยฑ 7.9 | 32.0 ยฑ 2.3 |
| LISA | 77.1 ยฑ 6.5 | 35.5 ยฑ 0.7 |
| ABD (Ours) | 81.1 ยฑ 4.8 | 34.1 ยฑ 2.5 |
CivilComments์์ ABD(71.1%)๋ ๋ฐ์ด์ด์ค annotation์ ์ฌ์ฉํ๋ Group DRO(label ร Black, 70.0%)๋ฅผ annotation ์์ด ๋ฅ๊ฐํ๋ค. MultiNLI์์๋ ABD(67.1%)๊ฐ ์์์ ๊ทธ๋ฃน์ ์ฌ์ฉํ๋ oracle Group DRO*(67.5%)์ 0.4%p ์ด๋ด๋ก ๊ทผ์ ํ๋ฉฐ, ์ฌ์ ๋ฐ์ด์ด์ค ์ ๋ณด ์์ด ์ด์ ํ์ ํ๋ ์ฑ๋ฅ์ ๋ฌ์ฑํ๋ค. Camelyon17์์๋ LISA ๋๋น +4.0%p ํฅ์(81.1% vs 77.1%)์ ๋ณด์ด๋ฉฐ, FMoW์์๋ LISA์ ๊ทผ์ ํ ๊ฒฝ์๋ ฅ ์๋ ์ฑ๋ฅ์ ์ ์งํ๋ค. IRM, CORAL, CGD ๋ฑ distribution alignment ๋๋ invariance ๊ธฐ๋ฐ ์ ๊ทผ์ WILDS ๋ฒค์น๋งํฌ์์ ERM ์์ค์ ๋จธ๋ฌผ๊ฑฐ๋ ์คํ๋ ค ํ๋ฝํ๋ค.

MultiNLI์์ ์ค๋ถ๋ฅ ๊ทธ๋ฃน $G^\otimes$์ ๋ฐ์ด์ด์ค ๊ตฌ์ฑ ๋ณํ. Negation ๋ฐ์ด์ด์ค ๋ฐ๊ฒฌ ํ Overlap ๋ฐ์ด์ด์ค๊ฐ ์ ์ฐจ ๋๋ฌ๋๋ค.
MetaShift โ Distributional Distance์ ๋ฐ๋ฅธ ์ฑ๋ฅ ๋ณํ
MetaShift์์๋ training๊ณผ test ๊ฐ distributional distance๋ฅผ ์กฐ์ ํ์ฌ ๊ฐ ๋ฐฉ๋ฒ์ ๊ฐ๊ฑด์ฑ์ ํ๊ฐํ๋ค. ABD์ ์ฐ์๋ distributional distance๊ฐ ์ปค์ง์๋ก ๋์ฑ ๋๋๋ฌ์ง๋ค.
MetaShift test accuracy (%). Distance๊ฐ ํด์๋ก ๋ถํฌ ์ฐจ์ด๊ฐ ํฌ๋ค.
| Distance | 0.44 | 0.71 | 1.12 | 1.43 |
|---|---|---|---|---|
| ERM | 80.1 | 68.4 | 52.1 | 33.2 |
| IRM | 79.5 | 67.4 | 51.8 | 32.0 |
| Group DRO | 77.0 | 68.9 | 51.9 | 34.2 |
| LISA | 81.3 | 69.7 | 54.2 | 37.5 |
| ABD (Ours) | 80.4 | 71.8 | 55.2 | 41.8 |
Distance๊ฐ ๊ฐ์ฅ ์์ 0.44์์๋ LISA(81.3%)๊ฐ ๊ทผ์ํ๊ฒ ์์์ง๋ง, distance๊ฐ ์ปค์ง์๋ก ABD์ ์ฐ์๊ฐ ๋๋ ทํด์ ธ 1.43์์๋ ABD(41.8%)๊ฐ LISA(37.5%)๋ฅผ 4.3%p ์ํํ๋ค. ์ด๋ ABD๊ฐ ๋ถํฌ ์ฐจ์ด๊ฐ ํด์๋ก ํจ๊ณผ์ ์ธ, ์ค์ OoD ํ๊ฒฝ์ ์ ํฉํ ๋ฐฉ๋ฒ์์ ๋ณด์ฌ์ค๋ค.

MetaShift ํ ์คํธ ๋ฐ์ดํฐ์ GradCAM ์๊ฐํ. ERM์ ๋ฐฐ๊ฒฝ์ ์์กดํ์ง๋ง, ABD๋ ๋์ ๊ฐ์ฒด์ ์ง์คํ๋ค.