ACCV 2024
Jun-Hyun Bae, Minho Lee, Heechul Jung
Kyungpook National University
๐Ÿ“„ Paper

Abstract

Training deep neural networks with empirical risk minimization (ERM) often captures dataset biases, hindering generalization to new or unseen data. Previous solutions either require prior knowledge of biases or utilize training intentionally biased models as auxiliaries; however, they still suffer from multiple biases. To address this, we introduce Adaptive Bias Discovery (ABD), a novel learning framework designed to mitigate the impact of multiple unknown biases. ABD trains an auxiliary model to be adapted to biases based on the debiased parameters from the debiasing phase, allowing it to navigate through multiple biases. Then, samples are reweighted based on the discovered biases to update debiased parameters. Extensive evaluations of synthetic experiments and real-world datasets demonstrate that ABD consistently outperforms existing methods, particularly in real-world applications where multiple unknown biases are prevalent.


Overview

์‚ฌ์ „ ๋ฐ”์ด์–ด์Šค ์ •๋ณด ์—†์ด ๋ฐ์ดํ„ฐ์— ์กด์žฌํ•˜๋Š” ์—ฌ๋Ÿฌ ๋ฐ”์ด์–ด์Šค๋ฅผ ์ˆœ์ฐจ์ ์œผ๋กœ ๋ฐœ๊ฒฌํ•˜๊ณ  ์ œ๊ฑฐํ•˜๋Š” ํ•™์Šต ํ”„๋ ˆ์ž„์›Œํฌ๋ฅผ ์ œ์•ˆํ•œ๋‹ค.

  1. Bias-adapted model โ€” Debiased ํŒŒ๋ผ๋ฏธํ„ฐ \(\theta\) ์—์„œ 1-step gradient descent๋กœ ๋ฐ”์ด์–ด์Šค์— ๋ฏผ๊ฐํ•œ ๋ณด์กฐ ๋ชจ๋ธ \(f_\phi\) ๋ฅผ ์ƒ์„ฑํ•œ๋‹ค.
  2. Adaptive group formation โ€” \(f_\phi\) ์˜ ์˜ˆ์ธก์œผ๋กœ ๋ฐ์ดํ„ฐ๋ฅผ ๋ฐ”์ด์–ด์Šค ์ •๋ ฌ ๊ทธ๋ฃน(\(G^\odot\) )๊ณผ ๋น„์ •๋ ฌ ๊ทธ๋ฃน(\(G^\otimes\) )์œผ๋กœ ๋ถ„ํ• ํ•œ๋‹ค.
  3. Iterative debiasing โ€” Group DRO๋กœ worst-case ๊ทธ๋ฃน ์†์‹ค์„ ์ตœ์†Œํ™”ํ•˜๋ฉฐ, \(\theta\) ๊ฐ€ ํ•œ ๋ฐ”์ด์–ด์Šค์— ๊ฐ•๊ฑดํ•ด์ง€๋ฉด \(\phi\) ๊ฐ€ ์ž์—ฐ์Šค๋Ÿฝ๊ฒŒ ๋‹ค์Œ ๋ฐ”์ด์–ด์Šค๋ฅผ ๋ฐœ๊ฒฌํ•œ๋‹ค.

ABD Framework

ABD ํ”„๋ ˆ์ž„์›Œํฌ ๊ฐœ์š”. ๋‘ ๊ฐ€์ง€ ๋ฐ”์ด์–ด์Šค(Bias1, Bias2)์™€ ๋‘ ํ•™์Šต ์Šคํ…์„ ์˜ˆ์‹œ๋กœ ๋„์‹ํ™”.


Method

ERM์œผ๋กœ ํ•™์Šต๋œ ๋ชจ๋ธ์€ ๋ฐ์ดํ„ฐ์— ์กด์žฌํ•˜๋Š” spurious correlation์„ ์‰ฝ๊ฒŒ ํฌ์ฐฉํ•˜์—ฌ ์ผ๋ฐ˜ํ™” ์„ฑ๋Šฅ์ด ์ €ํ•˜๋œ๋‹ค. ๊ธฐ์กด ๋ฐฉ๋ฒ•๋“ค์€ ๋ฐ”์ด์–ด์Šค ์ •๋ณด๋ฅผ ์‚ฌ์ „์— ์•Œ๊ณ  ์žˆ์–ด์•ผ ํ•˜๊ฑฐ๋‚˜(Group DRO), ๋‹จ์ผ ๋ฐ”์ด์–ด์Šค๋งŒ ์ฒ˜๋ฆฌํ•  ์ˆ˜ ์žˆ๋‹ค๋Š”(PI, JTT) ํ•œ๊ณ„๊ฐ€ ์žˆ๋‹ค.

ABD๋Š” ๋‘ ๋‹จ๊ณ„๋กœ ๊ตฌ์„ฑ๋œ๋‹ค. ๋จผ์ € debiased ํŒŒ๋ผ๋ฏธํ„ฐ \(\theta\) ์—์„œ ํ•œ ์Šคํ… gradient descent๋กœ bias-adapted ํŒŒ๋ผ๋ฏธํ„ฐ \(\phi = \theta - \alpha \nabla_\theta \mathcal{L}(f_\theta)\) ๋ฅผ ์–ป๋Š”๋‹ค. ์ด \(f_\phi\) ๋Š” ๋ฐ์ดํ„ฐ์˜ ํ‘œ๋ฉด์  ํŒจํ„ด์— ๋ฏผ๊ฐํ•˜๊ฒŒ ๋ฐ˜์‘ํ•˜๋ฏ€๋กœ, ์˜ˆ์ธก ๊ฒฐ๊ณผ๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ๋ฐ์ดํ„ฐ๋ฅผ ๋ฐ”์ด์–ด์Šค ์ •๋ ฌ ๊ทธ๋ฃน(\(G^\odot\) )๊ณผ ๋น„์ •๋ ฌ ๊ทธ๋ฃน(\(G^\otimes\) )์œผ๋กœ ๋ถ„ํ• ํ•œ๋‹ค. ์ดํ›„ group DRO๋ฅผ ํ†ตํ•ด worst-case ๊ทธ๋ฃน์˜ ์†์‹ค์„ ์ตœ์†Œํ™”ํ•˜๋„๋ก \(\theta\) ๋ฅผ ์—…๋ฐ์ดํŠธํ•œ๋‹ค.

ํ•ต์‹ฌ์€ \(\phi\) ๊ฐ€ ๋งค ์Šคํ…๋งˆ๋‹ค \(\theta\) ๋กœ๋ถ€ํ„ฐ ์žฌ์ƒ์„ฑ๋œ๋‹ค๋Š” ์ ์ด๋‹ค. \(\theta\) ๊ฐ€ ์ฒซ ๋ฒˆ์งธ ๋ฐ”์ด์–ด์Šค์— ๋Œ€ํ•ด ๊ฐ•๊ฑดํ•ด์ง€๋ฉด, \(\phi\) ๊ฐ€ ๋‹ค์Œ์œผ๋กœ ๋‘๋“œ๋Ÿฌ์ง„ ๋ฐ”์ด์–ด์Šค๋ฅผ ํฌ์ฐฉํ•˜๊ฒŒ ๋œ๋‹ค. ์ด MAML ์œ ์‚ฌ ๊ตฌ์กฐ ๋•๋ถ„์— ์‚ฌ์ „ ๋ฐ”์ด์–ด์Šค ์ •๋ณด ์—†์ด๋„ ์—ฌ๋Ÿฌ ๋ฐ”์ด์–ด์Šค๋ฅผ ์ˆœ์ฐจ์ ์œผ๋กœ ๋ฐœ๊ฒฌํ•˜๊ณ  ์ œ๊ฑฐํ•  ์ˆ˜ ์žˆ๋‹ค.

์•„๋ž˜ GradCAM ์‹œ๊ฐํ™”๋Š” biased model \(f_\phi\) ์˜ attention์ด ํ•™์Šต์ด ์ง„ํ–‰๋จ์— ๋”ฐ๋ผ ๋‹ค๋ฅธ ์˜์—ญ์œผ๋กœ ์ด๋™ํ•˜๋Š” ๊ฒƒ์„ ๋ณด์—ฌ์ค€๋‹ค. ์ด๋Š” ABD๊ฐ€ ํ•™์Šต ๊ณผ์ •์—์„œ ๋‹ค์–‘ํ•œ ๋ฐ”์ด์–ด์Šค๋ฅผ ์ ์‘์ ์œผ๋กœ ๋ฐœ๊ฒฌํ•จ์„ ์‹œ๊ฐ์ ์œผ๋กœ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋Š” ์ฆ๊ฑฐ์ด๋‹ค.

Biased Model Evolution

ERM ๋ชจ๋ธ๊ณผ ABD์˜ biased model $f_\phi$์˜ GradCAM ์‹œ๊ฐํ™”. ํ•™์Šต ์Šคํ…์ด ์ง„ํ–‰๋˜๋ฉด์„œ $f_\phi$์˜ attention์ด ๋‹ค๋ฅธ ๋ฐ”์ด์–ด์Šค ํŠน์ง•์œผ๋กœ ์ด๋™ํ•œ๋‹ค.


Results

Colored MNIST โ€” ๋ณต์ˆ˜ ๋ฐ”์ด์–ด์Šค ์ฒ˜๋ฆฌ

๊ธฐ์กด ๋ฐฉ๋ฒ•๊ณผ์˜ ์ฐจ์ด๋ฅผ ๊ฐ€์žฅ ์ž˜ ๋ณด์—ฌ์ฃผ๋Š” ์‹คํ—˜์ด๋‹ค. ๋ฐ”์ด์–ด์Šค๊ฐ€ Color ํ•˜๋‚˜์ผ ๋•Œ์™€ Color + Patch ๋‘˜ ๋‹ค ์กด์žฌํ•  ๋•Œ๋ฅผ ๋น„๊ตํ•œ๋‹ค.

OoD test accuracy (%). ๋ฐ”์ด์–ด์Šค๊ฐ€ ํ•˜๋‚˜์ผ ๋•Œ์™€ ๋‘˜ ๋‹ค ์กด์žฌํ•  ๋•Œ์˜ ๋น„๊ต.

AlgorithmColor (OoD)Color & Patch (OoD)
ERM16.414.0
IRM66.913.4
Group DRO13.614.1
PI70.215.3
ABD (Ours)70.762.3
Optimal75.075.0

PI๋Š” Color ๋ฐ”์ด์–ด์Šค๋งŒ ๋ฐœ๊ฒฌํ•˜๊ณ  Patch๋Š” ํฌ์ฐฉํ•˜์ง€ ๋ชปํ•˜์—ฌ, ๋ฐ”์ด์–ด์Šค๊ฐ€ ๋‘ ๊ฐœ๊ฐ€ ๋˜๋ฉด 15.3%๋กœ ์‚ฌ์‹ค์ƒ ์‹คํŒจํ•œ๋‹ค (54.9%p ํ•˜๋ฝ). IRM๊ณผ Group DRO ์—ญ์‹œ ๋ณต์ˆ˜ ๋ฐ”์ด์–ด์Šค ํ™˜๊ฒฝ์—์„œ ERM ์ˆ˜์ค€์— ๋จธ๋ฌธ๋‹ค. ABD๋Š” Color โ†’ Patch ์ˆœ์œผ๋กœ ๋ฐ”์ด์–ด์Šค๋ฅผ ์ˆœ์ฐจ์ ์œผ๋กœ ๋ฐœ๊ฒฌํ•˜์—ฌ, ๋ณต์ˆ˜ ๋ฐ”์ด์–ด์Šค ํ™˜๊ฒฝ์—์„œ 62.3%์˜ OoD ์„ฑ๋Šฅ์„ ๋‹ฌ์„ฑํ•˜๋Š” ์œ ์ผํ•œ ๋ฐฉ๋ฒ•์ด๋‹ค.

PI Baseline
(a) PI โ€” Color๋งŒ ๋ฐœ๊ฒฌ, Patch ํฌ์ฐฉ ์‹คํŒจ.
Bias Discovery - ABD
(b) ABD โ€” ํ•™์Šต์ด ์ง„ํ–‰๋˜๋ฉด์„œ Color โ†’ Patch ์ˆœ์œผ๋กœ ๋ฐœ๊ฒฌ.

๊ทธ๋ฃน ๋‚ด Pearson ์ƒ๊ด€๊ณ„์ˆ˜ ๋น„๊ต.

Real-World Tasks

๋ฐ”์ด์–ด์Šค annotation์ด ์žˆ๋Š” ๋ฐ์ดํ„ฐ์…‹ โ€” CivilComments & MultiNLI

Worst-case test accuracy (%). CivilComments์˜ Group ์—ด์€ ๊ฐ ์•Œ๊ณ ๋ฆฌ์ฆ˜์ด ๊ทธ๋ฃนํ™”์— ์‚ฌ์šฉํ•œ demographic ์ •๋ณด๋ฅผ, MultiNLI์˜ Group DRO*๋Š” prior bias ์ •๋ณด๋กœ ์ˆ˜์ž‘์—… ์ •์˜๋œ ๊ทธ๋ฃน์„ ์‚ฌ์šฉํ•œ oracle ์„ค์ •์ด๋‹ค.

AlgorithmCivilCommentsGroup (CC)MultiNLI
ERM56.0None61.8
IRM66.3(label ร— Black)โ€”
Group DRO69.1(label)62.7
Group DRO70.0(label ร— Black)โ€”
JTT69.3None63.2
PI61.1None61.5
ABD (Ours)71.1None67.1
Group DRO* (oracle)โ€”โ€”67.5

๋ฐ”์ด์–ด์Šค annotation์ด ์—†๋Š” ๋ฐ์ดํ„ฐ์…‹ โ€” Camelyon17 & FMoW (WILDS)

Camelyon17-wilds๋Š” OoD average accuracy, FMoW-wilds๋Š” worst-region accuracy (%).

AlgorithmCamelyon17FMoW
ERM70.3 ยฑ 6.432.3 ยฑ 1.3
IRM59.5 ยฑ 7.731.7 ยฑ 1.2
Group DRO68.4 ยฑ 7.330.8 ยฑ 0.8
CORAL59.5 ยฑ 7.732.8 ยฑ 0.7
JTT63.8 ยฑ 1.433.4 ยฑ 0.9
PI71.7 ยฑ 7.531.2 ยฑ 0.3
CGD69.4 ยฑ 7.932.0 ยฑ 2.3
LISA77.1 ยฑ 6.535.5 ยฑ 0.7
ABD (Ours)81.1 ยฑ 4.834.1 ยฑ 2.5

CivilComments์—์„œ ABD(71.1%)๋Š” ๋ฐ”์ด์–ด์Šค annotation์„ ์‚ฌ์šฉํ•˜๋Š” Group DRO(label ร— Black, 70.0%)๋ฅผ annotation ์—†์ด ๋Šฅ๊ฐ€ํ•œ๋‹ค. MultiNLI์—์„œ๋Š” ABD(67.1%)๊ฐ€ ์ˆ˜์ž‘์—… ๊ทธ๋ฃน์„ ์‚ฌ์šฉํ•˜๋Š” oracle Group DRO*(67.5%)์— 0.4%p ์ด๋‚ด๋กœ ๊ทผ์ ‘ํ•˜๋ฉฐ, ์‚ฌ์ „ ๋ฐ”์ด์–ด์Šค ์ •๋ณด ์—†์ด ์ด์— ํ•„์ ํ•˜๋Š” ์„ฑ๋Šฅ์„ ๋‹ฌ์„ฑํ•œ๋‹ค. Camelyon17์—์„œ๋Š” LISA ๋Œ€๋น„ +4.0%p ํ–ฅ์ƒ(81.1% vs 77.1%)์„ ๋ณด์ด๋ฉฐ, FMoW์—์„œ๋„ LISA์— ๊ทผ์ ‘ํ•œ ๊ฒฝ์Ÿ๋ ฅ ์žˆ๋Š” ์„ฑ๋Šฅ์„ ์œ ์ง€ํ•œ๋‹ค. IRM, CORAL, CGD ๋“ฑ distribution alignment ๋˜๋Š” invariance ๊ธฐ๋ฐ˜ ์ ‘๊ทผ์€ WILDS ๋ฒค์น˜๋งˆํฌ์—์„œ ERM ์ˆ˜์ค€์— ๋จธ๋ฌผ๊ฑฐ๋‚˜ ์˜คํžˆ๋ ค ํ•˜๋ฝํ•œ๋‹ค.

MultiNLI Analysis

MultiNLI์—์„œ ์˜ค๋ถ„๋ฅ˜ ๊ทธ๋ฃน $G^\otimes$์˜ ๋ฐ”์ด์–ด์Šค ๊ตฌ์„ฑ ๋ณ€ํ™”. Negation ๋ฐ”์ด์–ด์Šค ๋ฐœ๊ฒฌ ํ›„ Overlap ๋ฐ”์ด์–ด์Šค๊ฐ€ ์ ์ฐจ ๋“œ๋Ÿฌ๋‚œ๋‹ค.

MetaShift โ€” Distributional Distance์— ๋”ฐ๋ฅธ ์„ฑ๋Šฅ ๋ณ€ํ™”

MetaShift์—์„œ๋Š” training๊ณผ test ๊ฐ„ distributional distance๋ฅผ ์กฐ์ ˆํ•˜์—ฌ ๊ฐ ๋ฐฉ๋ฒ•์˜ ๊ฐ•๊ฑด์„ฑ์„ ํ‰๊ฐ€ํ•œ๋‹ค. ABD์˜ ์šฐ์œ„๋Š” distributional distance๊ฐ€ ์ปค์งˆ์ˆ˜๋ก ๋”์šฑ ๋‘๋“œ๋Ÿฌ์ง„๋‹ค.

MetaShift test accuracy (%). Distance๊ฐ€ ํด์ˆ˜๋ก ๋ถ„ํฌ ์ฐจ์ด๊ฐ€ ํฌ๋‹ค.

Distance0.440.711.121.43
ERM80.168.452.133.2
IRM79.567.451.832.0
Group DRO77.068.951.934.2
LISA81.369.754.237.5
ABD (Ours)80.471.855.241.8

Distance๊ฐ€ ๊ฐ€์žฅ ์ž‘์€ 0.44์—์„œ๋Š” LISA(81.3%)๊ฐ€ ๊ทผ์†Œํ•˜๊ฒŒ ์•ž์„œ์ง€๋งŒ, distance๊ฐ€ ์ปค์งˆ์ˆ˜๋ก ABD์˜ ์šฐ์œ„๊ฐ€ ๋šœ๋ ทํ•ด์ ธ 1.43์—์„œ๋Š” ABD(41.8%)๊ฐ€ LISA(37.5%)๋ฅผ 4.3%p ์ƒํšŒํ•œ๋‹ค. ์ด๋Š” ABD๊ฐ€ ๋ถ„ํฌ ์ฐจ์ด๊ฐ€ ํด์ˆ˜๋ก ํšจ๊ณผ์ ์ธ, ์‹ค์ œ OoD ํ™˜๊ฒฝ์— ์ ํ•ฉํ•œ ๋ฐฉ๋ฒ•์ž„์„ ๋ณด์—ฌ์ค€๋‹ค.

GradCAM

MetaShift ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ์˜ GradCAM ์‹œ๊ฐํ™”. ERM์€ ๋ฐฐ๊ฒฝ์— ์˜์กดํ•˜์ง€๋งŒ, ABD๋Š” ๋Œ€์ƒ ๊ฐ์ฒด์— ์ง‘์ค‘ํ•œ๋‹ค.