Jun-Hyun Bae*, Taewon Park*, Minho Lee
Kyungpook National University
* Equal Contribution
Abstract
Learning associative reasoning is necessary to implement human-level artificial intelligence even when a model faces unfamiliar associations of learned components. However, conventional memory augmented neural networks (MANNs) have shown degraded performance on systematically different data since they lack consideration of systematic generalization. In this work, we propose a novel architecture for MANNs which explicitly aims to learn recomposable representations with a modular structure of RNNs. Our method binds learned representations with a Tensor Product Representation (TPR) to manifest their associations and stores the associations into TPR-based external memory. In addition, to demonstrate the effectiveness of our approach, we introduce a new benchmark for evaluating systematic generalization performance on associative reasoning, which contains systematically different combinations of words between training and test data. From the experimental results, our method shows superior test accuracy on systematically different data compared to other models. Furthermore, we validate the models using TPR by analyzing whether the learned representations have symbolic properties.
Overview
๊ธฐ์กด MANN์ด ์ฒด๊ณ์ ์ผ๋ก ๋ค๋ฅธ ํ ์คํธ ๋ฐ์ดํฐ์์ ์คํจํ๋ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด, modular encoder์ TPR ๊ธฐ๋ฐ ์ธ๋ถ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๊ฒฐํฉํ ์๋ก์ด ์ํคํ ์ฒ๋ฅผ ์ ์ํ๋ค.
- Modular encoding โ Recurrent Independent Mechanisms(RIMs)๋ก ์ ๋ ฅ์ \(N\) ๊ฐ ๋ ๋ฆฝ ๋ชจ๋์ด ๊ฒฝ์์ ์ผ๋ก ์ธ์ฝ๋ฉํ์ฌ ์ฌ์กฐํฉ ๊ฐ๋ฅํ ํํ์ ํ์ตํ๋ค.
- TPR binding โ Tensor Product Representation์ผ๋ก role๊ณผ filler์ ์ฐ๊ด ๊ด๊ณ๋ฅผ ์ํ์ ์ผ๋ก ๋ฐ์ธ๋ฉํ๋ค: \(T = \sum_{k=1}^N \mathbf{r}_k \otimes \mathbf{f}_k\)
- Memory-based recall โ TPR ๊ธฐ๋ฐ ์ธ๋ถ ๋ฉ๋ชจ๋ฆฌ์ ์ฐ๊ด ๊ด๊ณ๋ฅผ ์ ์ฅํ๊ณ , ํ์ตํ์ง ์์ ์กฐํฉ์์๋ ์ฒด๊ณ์ ์ผ๋ก ์ถ๋ก ํ๋ค.
![]()
์ ์ฒด ์ํคํ ์ฒ. ๊ฐ ์์ $t$๋ง๋ค ์ ๋ ฅ์ modular encoder๋ฅผ ํต๊ณผํ์ฌ role $r_t$์ filler $f_t$๋ก ๋ถ๋ฆฌ๋๊ณ , TPR binding($\otimes$)์ผ๋ก ์ธ๋ถ ๋ฉ๋ชจ๋ฆฌ $\mathbf{M}_t$์ ๋์ ๋๋ค. ์ง์ ์์๋ ๋์ผํ encoder๊ฐ query role $q_r$์ ์์ฑํด ๋ฉ๋ชจ๋ฆฌ์์ ํด๋น filler๋ฅผ unbindํ๋ค.
Method
๊ธฐ์กด memory augmented neural network (MANN)๋ ํ์ต ๋ฐ์ดํฐ์ ์ฒด๊ณ์ ์ผ๋ก ๋ค๋ฅธ(systematically different) ํ ์คํธ ๋ฐ์ดํฐ์์ ์ฑ๋ฅ์ด ๊ธ๋ฝํ๋ค. ํต์ฌ ์์ธ์ encoder๊ฐ ํ์ต๋ ์กฐํฉ์ ๊ณผ์ ํฉํ์ฌ, ๊ฐ๋ณ ๊ตฌ์ฑ ์์๋ฅผ ์ฌ์กฐํฉ ๊ฐ๋ฅํ(recomposable) ํํ๋ก ํํํ์ง ๋ชปํ๊ธฐ ๋๋ฌธ์ด๋ค. ์ด๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด modular RNN encoder + TPR-based external memory๋ฅผ ๊ฒฐํฉํ๋ค.
ํต์ฌ ๊ตฌ์ฑ:
- Recurrent Independent Mechanisms (RIMs): \(N\) ๊ฐ์ RNN ๋ชจ๋์ด competitive learning์ผ๋ก ๊ฐ์ ๋ ๋ฆฝ์ ์ธ ์ธ์ฝ๋ฉ ๋ฉ์ปค๋์ฆ ํ์ต
- Tensor Product Representation (TPR): role๊ณผ filler์ tensor product๋ก ์ฐ๊ด ๊ด๊ณ๋ฅผ ์ํ์ ์ผ๋ก ๋ฐ์ธ๋ฉ โ \(T = \sum_{k=1}^N \mathbf{r}_k \otimes \mathbf{f}_k\)
- TPR-based External Memory: ๊ฐ ์๊ฐ ๋จ๊ณ์์ role/filler ํํ์ ์ถ์ถํ์ฌ ๋ฉ๋ชจ๋ฆฌ์ superpose. ์ฐ๊ธฐ ๊ท์น์ write strength \(\beta = \sigma(W_\beta h_t)\) ๋ฅผ ์ฌ์ฉํ delta-filler ํํ๋ก, \(\mathbf{M}_t = \mathbf{M}_{t-1} + \mathbf{r}_t \otimes (\beta \mathbf{f}_t - (1-\beta) \mathbf{f}_{t-1})\) ์ด๋ค.
- Systematic Associative Recall (SAR): ์ฒด๊ณ์ ์ผ๋ฐํ ํ๊ฐ๋ฅผ ์ํ ์ ๋ฒค์น๋งํฌ ์ ์
Results
Systematic Associative Recall (SAR) Task
SAR์ ๋ณธ ๋ ผ๋ฌธ์์ ์ ์ํ๋ ๋ฒค์น๋งํฌ๋ก, associative reasoning์์์ ์ฒด๊ณ์ ์ผ๋ฐํ๋ฅผ ์ธก์ ํ๊ธฐ ์ํด ์ค๊ณ๋์๋ค. ์ธ ๊ฐ์ง ๊ฐ์ฒด ์งํฉ(์ฌ๋ ์ด๋ฆ \(S_h\) , ๊ณผ์ผ ์ด๋ฆ \(S_f\) , ์ซ์ ์ด๋ฆ \(S_n\) )์ ์ฌ์ฉํ๋ฉฐ, ํ์ต ๋ฐ์ดํฐ์ ํ ์คํธ ๋ฐ์ดํฐ ๊ฐ์ ๊ฐ์ฒด ์กฐํฉ์ ์ฒด๊ณ์ ์ผ๋ก ๋ค๋ฅด๊ฒ ๊ตฌ์ฑํ๋ค.
๊ตฌ์ฒด์ ์ผ๋ก, \(S_h\) ์ ์ผ๋ถ(\(S_h^1\) )๋ ํ์ต ์ ์ซ์์๋ง ์ฐ๊ด๋๊ณ , ๋ค๋ฅธ ์ผ๋ถ(\(S_h^2\) )๋ ๊ณผ์ผ๊ณผ๋ง ์ฐ๊ด๋๋ค. **test (different)**์์๋ ์ด ๊ด๊ณ๊ฐ ์ญ์ ๋์ด, \(S_h^1\) ์ ๊ณผ์ผ๊ณผ, \(S_h^2\) ๋ ์ซ์์ ์ฐ๊ด๋๋ค. ๋์ด๋ ํ๋ผ๋ฏธํฐ \(p = |S_h^3| / |S_h|\) ๋ ๋ ์งํฉ ๋ชจ๋์ ์ฐ๊ด๋๋ ๊ฐ์ฒด์ ๋น์จ๋ก, ๊ฐ์ด ์์์๋ก ํ์ต/ํ ์คํธ ๊ฐ ์ฒด๊ณ์ ์ฐจ์ด๊ฐ ํฌ๋ค.
![]()
SAR ํ์คํฌ์์ DNC, FWM, ์ ์ ๋ฐฉ๋ฒ์ ํ์ต/ํ ์คํธ ์ ํ๋ ๋น๊ต.
DNC์ FWM์ test (same)์์๋ ๋์ ์ ํ๋๋ฅผ ๋ณด์ด์ง๋ง, test (different)์์๋ ํฐ ์ฑ๋ฅ ์ ํ๋ฅผ ๋ณด์ธ๋ค. ์ด๋ ํ์ต๋ ์กฐํฉ์ ๊ณผ์ ํฉํ์ฌ ์ฒด๊ณ์ ์ผ๋ฐํ์ ์คํจํ๋ ๊ฒ์ด๋ค. ์ ์ ๋ฐฉ๋ฒ์ \(p=0.3\) ๋ฐ \(p=0.5\) ์์ test (same)๊ณผ test (different) ๊ฐ์ ๊ฒฉ์ฐจ๋ฅผ ์ฑ๊ณต์ ์ผ๋ก ํด์ํ๋ฉฐ, ๊ฐ์ฅ ์ด๋ ค์ด \(p=0.1\) ์์๋ baseline ๋๋น ํ์ ํ ์์ ๊ฒฉ์ฐจ๋ฅผ ๋ณด์ธ๋ค. FWM์ด TPR ๊ธฐ๋ฐ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ์ฌ์ฉํจ์๋ ๋ถ๊ตฌํ๊ณ ์ฒด๊ณ์ ์ผ๋ฐํ์ ์คํจํ๋ค๋ ๊ฒ์, TPR ๋ฉ๋ชจ๋ฆฌ๋ง์ผ๋ก๋ ์ถฉ๋ถํ์ง ์์ผ๋ฉฐ encoder๊ฐ ์ฌ๋ฐ๋ฅธ symbolic representation์ ํ์ตํ๋ ๊ฒ์ด ํต์ฌ์์ ์์ฌํ๋ค.
Concatenated-bAbI (catbAbI)
SAR์ด ์ฒด๊ณ์ ์ผ๋ฐํ์ ์ด์ ์ ๋ง์ถ ๋ฐ๋ฉด, catbAbI๋ ์ผ๋ฐ์ ์ธ ์ฅ๊ธฐ ์ฐ๊ด ์ถ๋ก ์ฑ๋ฅ์ ํ๊ฐํ๋ค. ๋ฌดํ ๊ธธ์ด์ story sequence์์ ์ง์์๋ต์ ์ํํ๋ ํ์คํฌ์ด๋ค.
| Model | Test Accuracy |
|---|---|
| LSTM | 80.88% |
| Transformer-XL | 87.66% |
| Meta-learned Neural Memory | 88.97% |
| Fast Weight Memory (FWM) | 96.75% |
| FWM (our trial) | 94.94% |
| Ours | 96.63% |
๋์ผํ ์คํ ์ธํ (our trial)์์ ๋น๊ตํ๋ฉด, ์ ์ ๋ฐฉ๋ฒ(96.63%)์ด FWM(94.94%)๋ณด๋ค 1.7%p ๋๋ค. FWM์ ๊ณต์ ๊ฒฐ๊ณผ(96.75%)์๋ ๊ฑฐ์ ๋๋ฑํ ์์ค์ด๋ค. ์ด๋ modular encoder์ ๋์ ์ด ์ฒด๊ณ์ ์ผ๋ฐํ ๋ฅ๋ ฅ์ ์ถ๊ฐํ๋ฉด์๋, ์ผ๋ฐ์ ์ธ ์ฐ๊ด ์ถ๋ก ์ฑ๋ฅ์์ ๊ธฐ์กด ์ต๊ณ ์์ค์ ์ ์งํจ์ ๋ณด์ฌ์ค๋ค.
Symbolic Representation ๋ถ์
ํ์ต๋ ํํ์ด ์ฌ๋ฐ๋ฅธ symbolic property๋ฅผ ๊ฐ๋์ง ๋ ๊ฐ์ง ๋ถ์์ผ๋ก ๊ฒ์ฆํ๋ค.
Role-Unbinding Orthogonality: TPR์์ ์ฌ๋ฐ๋ฅธ unbinding์ ์ํด์๋ role ๋ฒกํฐ์ unbinding ๋ฒกํฐ๊ฐ orthogonalํด์ผ ํ๋ค. FWM์ role-unbinding ์ ์ฌ๋ ํ๋ ฌ์์ off-diagonal ๊ฐ์ญ์ด ๋ํ๋์ง๋ง, ์ ์ ๋ฐฉ๋ฒ์ ๊ฑฐ์ ์๋ฒฝํ orthogonality๋ฅผ ๋ณด์ธ๋ค. ์ด๋ modular encoder๊ฐ ๊ฐ ๊ฐ์ฒด์ ๋ํด ๋ถ๋ฆฌ ๊ฐ๋ฅํ(separable) symbolic representation์ ํ์ตํ์์ ์๋ฏธํ๋ค.
(a) FWM
(b) Ours
Role ๋ฒกํฐ์ unbinding ๋ฒกํฐ ๊ฐ ์ ์ฌ๋ ํ๋ ฌ. FWM์ orthogonalํ์ง ์์ง๋ง, ์ ์ ๋ฐฉ๋ฒ์ ๊ฑฐ์ ์๋ฒฝํ orthogonality๋ฅผ ๋ณด์ธ๋ค.
Filler Consistency: ๋์ผํ ๋์ ๊ฐ์ฒด์ ๋ํด, ์ด๋ค ์กฐํฉ์์ ์ง์ํ๋ ๋์ผํ read ๋ฒกํฐ๊ฐ ๋ฐํ๋์ด์ผ ์ฒด๊ณ์ ์ถ๋ก ์ด๋ผ ํ ์ ์๋ค. FWM์ ์กฐํฉ์ ๋ฐ๋ผ read ๋ฒกํฐ๊ฐ ๋ฌ๋ผ์ง์ง๋ง, ์ ์ ๋ฐฉ๋ฒ์ ์กฐํฉ์ ๊ด๊ณ์์ด ๊ฑฐ์ ๋์ผํ read ๋ฒกํฐ๋ฅผ ๋ฐํํ๋ค. ์ด๋ ๋ชจ๋ธ์ด ํน์ ์กฐํฉ์ ์๊ธฐํ๋ ๊ฒ์ด ์๋๋ผ, ๊ฐ๋ณ ๊ตฌ์ฑ ์์๋ฅผ ๋ ๋ฆฝ์ ์ผ๋ก ์ธ์ฝ๋ฉํ๊ณ ์ฌ์กฐํฉํ์ฌ ์ถ๋ก ํ๊ณ ์์์ ๋ณด์ฌ์ฃผ๋ ์ฆ๊ฑฐ์ด๋ค.
(a) FWM
(b) Ours
๋์ผํ fruit ๊ฐ์ฒด์ ๋ํ read ๋ฒกํฐ ๊ฐ ์ ์ฌ๋. ์ ์ ๋ฐฉ๋ฒ์ ์กฐํฉ์ ๊ด๊ณ์์ด ์ผ๊ด๋ ์ถ๋ ฅ์ ๋ณด์ธ๋ค.