ICONIP 2022
Jun-Hyun Bae*, Taewon Park*, Minho Lee
Kyungpook National University
* Equal Contribution
๐Ÿ“„ Paper

Abstract

Learning associative reasoning is necessary to implement human-level artificial intelligence even when a model faces unfamiliar associations of learned components. However, conventional memory augmented neural networks (MANNs) have shown degraded performance on systematically different data since they lack consideration of systematic generalization. In this work, we propose a novel architecture for MANNs which explicitly aims to learn recomposable representations with a modular structure of RNNs. Our method binds learned representations with a Tensor Product Representation (TPR) to manifest their associations and stores the associations into TPR-based external memory. In addition, to demonstrate the effectiveness of our approach, we introduce a new benchmark for evaluating systematic generalization performance on associative reasoning, which contains systematically different combinations of words between training and test data. From the experimental results, our method shows superior test accuracy on systematically different data compared to other models. Furthermore, we validate the models using TPR by analyzing whether the learned representations have symbolic properties.


Overview

๊ธฐ์กด MANN์ด ์ฒด๊ณ„์ ์œผ๋กœ ๋‹ค๋ฅธ ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ์—์„œ ์‹คํŒจํ•˜๋Š” ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด, modular encoder์™€ TPR ๊ธฐ๋ฐ˜ ์™ธ๋ถ€ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ๊ฒฐํ•ฉํ•œ ์ƒˆ๋กœ์šด ์•„ํ‚คํ…์ฒ˜๋ฅผ ์ œ์•ˆํ•œ๋‹ค.

  1. Modular encoding โ€” Recurrent Independent Mechanisms(RIMs)๋กœ ์ž…๋ ฅ์„ \(N\) ๊ฐœ ๋…๋ฆฝ ๋ชจ๋“ˆ์ด ๊ฒฝ์Ÿ์ ์œผ๋กœ ์ธ์ฝ”๋”ฉํ•˜์—ฌ ์žฌ์กฐํ•ฉ ๊ฐ€๋Šฅํ•œ ํ‘œํ˜„์„ ํ•™์Šตํ•œ๋‹ค.
  2. TPR binding โ€” Tensor Product Representation์œผ๋กœ role๊ณผ filler์˜ ์—ฐ๊ด€ ๊ด€๊ณ„๋ฅผ ์ˆ˜ํ•™์ ์œผ๋กœ ๋ฐ”์ธ๋”ฉํ•œ๋‹ค: \(T = \sum_{k=1}^N \mathbf{r}_k \otimes \mathbf{f}_k\)
  3. Memory-based recall โ€” TPR ๊ธฐ๋ฐ˜ ์™ธ๋ถ€ ๋ฉ”๋ชจ๋ฆฌ์— ์—ฐ๊ด€ ๊ด€๊ณ„๋ฅผ ์ €์žฅํ•˜๊ณ , ํ•™์Šตํ•˜์ง€ ์•Š์€ ์กฐํ•ฉ์—์„œ๋„ ์ฒด๊ณ„์ ์œผ๋กœ ์ถ”๋ก ํ•œ๋‹ค.

Modular TPR Architecture

์ „์ฒด ์•„ํ‚คํ…์ฒ˜. ๊ฐ ์‹œ์  $t$๋งˆ๋‹ค ์ž…๋ ฅ์€ modular encoder๋ฅผ ํ†ต๊ณผํ•˜์—ฌ role $r_t$์™€ filler $f_t$๋กœ ๋ถ„๋ฆฌ๋˜๊ณ , TPR binding($\otimes$)์œผ๋กœ ์™ธ๋ถ€ ๋ฉ”๋ชจ๋ฆฌ $\mathbf{M}_t$์— ๋ˆ„์ ๋œ๋‹ค. ์งˆ์˜ ์‹œ์—๋Š” ๋™์ผํ•œ encoder๊ฐ€ query role $q_r$์„ ์ƒ์„ฑํ•ด ๋ฉ”๋ชจ๋ฆฌ์—์„œ ํ•ด๋‹น filler๋ฅผ unbindํ•œ๋‹ค.


Method

๊ธฐ์กด memory augmented neural network (MANN)๋Š” ํ•™์Šต ๋ฐ์ดํ„ฐ์™€ ์ฒด๊ณ„์ ์œผ๋กœ ๋‹ค๋ฅธ(systematically different) ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ์—์„œ ์„ฑ๋Šฅ์ด ๊ธ‰๋ฝํ•œ๋‹ค. ํ•ต์‹ฌ ์›์ธ์€ encoder๊ฐ€ ํ•™์Šต๋œ ์กฐํ•ฉ์— ๊ณผ์ ํ•ฉํ•˜์—ฌ, ๊ฐœ๋ณ„ ๊ตฌ์„ฑ ์š”์†Œ๋ฅผ ์žฌ์กฐํ•ฉ ๊ฐ€๋Šฅํ•œ(recomposable) ํ˜•ํƒœ๋กœ ํ‘œํ˜„ํ•˜์ง€ ๋ชปํ•˜๊ธฐ ๋•Œ๋ฌธ์ด๋‹ค. ์ด๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด modular RNN encoder + TPR-based external memory๋ฅผ ๊ฒฐํ•ฉํ•œ๋‹ค.

ํ•ต์‹ฌ ๊ตฌ์„ฑ:

  • Recurrent Independent Mechanisms (RIMs): \(N\) ๊ฐœ์˜ RNN ๋ชจ๋“ˆ์ด competitive learning์œผ๋กœ ๊ฐ์ž ๋…๋ฆฝ์ ์ธ ์ธ์ฝ”๋”ฉ ๋ฉ”์ปค๋‹ˆ์ฆ˜ ํ•™์Šต
  • Tensor Product Representation (TPR): role๊ณผ filler์˜ tensor product๋กœ ์—ฐ๊ด€ ๊ด€๊ณ„๋ฅผ ์ˆ˜ํ•™์ ์œผ๋กœ ๋ฐ”์ธ๋”ฉ โ€” \(T = \sum_{k=1}^N \mathbf{r}_k \otimes \mathbf{f}_k\)
  • TPR-based External Memory: ๊ฐ ์‹œ๊ฐ„ ๋‹จ๊ณ„์—์„œ role/filler ํ‘œํ˜„์„ ์ถ”์ถœํ•˜์—ฌ ๋ฉ”๋ชจ๋ฆฌ์— superpose. ์“ฐ๊ธฐ ๊ทœ์น™์€ write strength \(\beta = \sigma(W_\beta h_t)\) ๋ฅผ ์‚ฌ์šฉํ•œ delta-filler ํ˜•ํƒœ๋กœ, \(\mathbf{M}_t = \mathbf{M}_{t-1} + \mathbf{r}_t \otimes (\beta \mathbf{f}_t - (1-\beta) \mathbf{f}_{t-1})\) ์ด๋‹ค.
  • Systematic Associative Recall (SAR): ์ฒด๊ณ„์  ์ผ๋ฐ˜ํ™” ํ‰๊ฐ€๋ฅผ ์œ„ํ•œ ์ƒˆ ๋ฒค์น˜๋งˆํฌ ์ œ์•ˆ

Results

Systematic Associative Recall (SAR) Task

SAR์€ ๋ณธ ๋…ผ๋ฌธ์—์„œ ์ œ์•ˆํ•˜๋Š” ๋ฒค์น˜๋งˆํฌ๋กœ, associative reasoning์—์„œ์˜ ์ฒด๊ณ„์  ์ผ๋ฐ˜ํ™”๋ฅผ ์ธก์ •ํ•˜๊ธฐ ์œ„ํ•ด ์„ค๊ณ„๋˜์—ˆ๋‹ค. ์„ธ ๊ฐ€์ง€ ๊ฐ์ฒด ์ง‘ํ•ฉ(์‚ฌ๋žŒ ์ด๋ฆ„ \(S_h\) , ๊ณผ์ผ ์ด๋ฆ„ \(S_f\) , ์ˆซ์ž ์ด๋ฆ„ \(S_n\) )์„ ์‚ฌ์šฉํ•˜๋ฉฐ, ํ•™์Šต ๋ฐ์ดํ„ฐ์™€ ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ ๊ฐ„์— ๊ฐ์ฒด ์กฐํ•ฉ์„ ์ฒด๊ณ„์ ์œผ๋กœ ๋‹ค๋ฅด๊ฒŒ ๊ตฌ์„ฑํ•œ๋‹ค.

๊ตฌ์ฒด์ ์œผ๋กœ, \(S_h\) ์˜ ์ผ๋ถ€(\(S_h^1\) )๋Š” ํ•™์Šต ์‹œ ์ˆซ์ž์™€๋งŒ ์—ฐ๊ด€๋˜๊ณ , ๋‹ค๋ฅธ ์ผ๋ถ€(\(S_h^2\) )๋Š” ๊ณผ์ผ๊ณผ๋งŒ ์—ฐ๊ด€๋œ๋‹ค. **test (different)**์—์„œ๋Š” ์ด ๊ด€๊ณ„๊ฐ€ ์—ญ์ „๋˜์–ด, \(S_h^1\) ์€ ๊ณผ์ผ๊ณผ, \(S_h^2\) ๋Š” ์ˆซ์ž์™€ ์—ฐ๊ด€๋œ๋‹ค. ๋‚œ์ด๋„ ํŒŒ๋ผ๋ฏธํ„ฐ \(p = |S_h^3| / |S_h|\) ๋Š” ๋‘ ์ง‘ํ•ฉ ๋ชจ๋‘์™€ ์—ฐ๊ด€๋˜๋Š” ๊ฐ์ฒด์˜ ๋น„์œจ๋กœ, ๊ฐ’์ด ์ž‘์„์ˆ˜๋ก ํ•™์Šต/ํ…Œ์ŠคํŠธ ๊ฐ„ ์ฒด๊ณ„์  ์ฐจ์ด๊ฐ€ ํฌ๋‹ค.

SAR Results

SAR ํƒœ์Šคํฌ์—์„œ DNC, FWM, ์ œ์•ˆ ๋ฐฉ๋ฒ•์˜ ํ•™์Šต/ํ…Œ์ŠคํŠธ ์ •ํ™•๋„ ๋น„๊ต.

DNC์™€ FWM์€ test (same)์—์„œ๋Š” ๋†’์€ ์ •ํ™•๋„๋ฅผ ๋ณด์ด์ง€๋งŒ, test (different)์—์„œ๋Š” ํฐ ์„ฑ๋Šฅ ์ €ํ•˜๋ฅผ ๋ณด์ธ๋‹ค. ์ด๋Š” ํ•™์Šต๋œ ์กฐํ•ฉ์— ๊ณผ์ ํ•ฉํ•˜์—ฌ ์ฒด๊ณ„์  ์ผ๋ฐ˜ํ™”์— ์‹คํŒจํ•˜๋Š” ๊ฒƒ์ด๋‹ค. ์ œ์•ˆ ๋ฐฉ๋ฒ•์€ \(p=0.3\) ๋ฐ \(p=0.5\) ์—์„œ test (same)๊ณผ test (different) ๊ฐ„์˜ ๊ฒฉ์ฐจ๋ฅผ ์„ฑ๊ณต์ ์œผ๋กœ ํ•ด์†Œํ•˜๋ฉฐ, ๊ฐ€์žฅ ์–ด๋ ค์šด \(p=0.1\) ์—์„œ๋„ baseline ๋Œ€๋น„ ํ˜„์ €ํžˆ ์ž‘์€ ๊ฒฉ์ฐจ๋ฅผ ๋ณด์ธ๋‹ค. FWM์ด TPR ๊ธฐ๋ฐ˜ ๋ฉ”๋ชจ๋ฆฌ๋ฅผ ์‚ฌ์šฉํ•จ์—๋„ ๋ถˆ๊ตฌํ•˜๊ณ  ์ฒด๊ณ„์  ์ผ๋ฐ˜ํ™”์— ์‹คํŒจํ•œ๋‹ค๋Š” ๊ฒƒ์€, TPR ๋ฉ”๋ชจ๋ฆฌ๋งŒ์œผ๋กœ๋Š” ์ถฉ๋ถ„ํ•˜์ง€ ์•Š์œผ๋ฉฐ encoder๊ฐ€ ์˜ฌ๋ฐ”๋ฅธ symbolic representation์„ ํ•™์Šตํ•˜๋Š” ๊ฒƒ์ด ํ•ต์‹ฌ์ž„์„ ์‹œ์‚ฌํ•œ๋‹ค.

Concatenated-bAbI (catbAbI)

SAR์ด ์ฒด๊ณ„์  ์ผ๋ฐ˜ํ™”์— ์ดˆ์ ์„ ๋งž์ถ˜ ๋ฐ˜๋ฉด, catbAbI๋Š” ์ผ๋ฐ˜์ ์ธ ์žฅ๊ธฐ ์—ฐ๊ด€ ์ถ”๋ก  ์„ฑ๋Šฅ์„ ํ‰๊ฐ€ํ•œ๋‹ค. ๋ฌดํ•œ ๊ธธ์ด์˜ story sequence์—์„œ ์งˆ์˜์‘๋‹ต์„ ์ˆ˜ํ–‰ํ•˜๋Š” ํƒœ์Šคํฌ์ด๋‹ค.

ModelTest Accuracy
LSTM80.88%
Transformer-XL87.66%
Meta-learned Neural Memory88.97%
Fast Weight Memory (FWM)96.75%
FWM (our trial)94.94%
Ours96.63%

๋™์ผํ•œ ์‹คํ—˜ ์„ธํŒ…(our trial)์—์„œ ๋น„๊ตํ•˜๋ฉด, ์ œ์•ˆ ๋ฐฉ๋ฒ•(96.63%)์ด FWM(94.94%)๋ณด๋‹ค 1.7%p ๋†’๋‹ค. FWM์˜ ๊ณต์‹ ๊ฒฐ๊ณผ(96.75%)์™€๋„ ๊ฑฐ์˜ ๋™๋“ฑํ•œ ์ˆ˜์ค€์ด๋‹ค. ์ด๋Š” modular encoder์˜ ๋„์ž…์ด ์ฒด๊ณ„์  ์ผ๋ฐ˜ํ™” ๋Šฅ๋ ฅ์„ ์ถ”๊ฐ€ํ•˜๋ฉด์„œ๋„, ์ผ๋ฐ˜์ ์ธ ์—ฐ๊ด€ ์ถ”๋ก  ์„ฑ๋Šฅ์—์„œ ๊ธฐ์กด ์ตœ๊ณ  ์ˆ˜์ค€์„ ์œ ์ง€ํ•จ์„ ๋ณด์—ฌ์ค€๋‹ค.

Symbolic Representation ๋ถ„์„

ํ•™์Šต๋œ ํ‘œํ˜„์ด ์˜ฌ๋ฐ”๋ฅธ symbolic property๋ฅผ ๊ฐ–๋Š”์ง€ ๋‘ ๊ฐ€์ง€ ๋ถ„์„์œผ๋กœ ๊ฒ€์ฆํ•œ๋‹ค.

Role-Unbinding Orthogonality: TPR์—์„œ ์˜ฌ๋ฐ”๋ฅธ unbinding์„ ์œ„ํ•ด์„œ๋Š” role ๋ฒกํ„ฐ์™€ unbinding ๋ฒกํ„ฐ๊ฐ€ orthogonalํ•ด์•ผ ํ•œ๋‹ค. FWM์€ role-unbinding ์œ ์‚ฌ๋„ ํ–‰๋ ฌ์—์„œ off-diagonal ๊ฐ„์„ญ์ด ๋‚˜ํƒ€๋‚˜์ง€๋งŒ, ์ œ์•ˆ ๋ฐฉ๋ฒ•์€ ๊ฑฐ์˜ ์™„๋ฒฝํ•œ orthogonality๋ฅผ ๋ณด์ธ๋‹ค. ์ด๋Š” modular encoder๊ฐ€ ๊ฐ ๊ฐ์ฒด์— ๋Œ€ํ•ด ๋ถ„๋ฆฌ ๊ฐ€๋Šฅํ•œ(separable) symbolic representation์„ ํ•™์Šตํ–ˆ์Œ์„ ์˜๋ฏธํ•œ๋‹ค.

FWM role-unbinding

(a) FWM

Ours role-unbinding

(b) Ours

Role ๋ฒกํ„ฐ์™€ unbinding ๋ฒกํ„ฐ ๊ฐ„ ์œ ์‚ฌ๋„ ํ–‰๋ ฌ. FWM์€ orthogonalํ•˜์ง€ ์•Š์ง€๋งŒ, ์ œ์•ˆ ๋ฐฉ๋ฒ•์€ ๊ฑฐ์˜ ์™„๋ฒฝํ•œ orthogonality๋ฅผ ๋ณด์ธ๋‹ค.

Filler Consistency: ๋™์ผํ•œ ๋Œ€์ƒ ๊ฐ์ฒด์— ๋Œ€ํ•ด, ์–ด๋–ค ์กฐํ•ฉ์—์„œ ์งˆ์˜ํ•˜๋“  ๋™์ผํ•œ read ๋ฒกํ„ฐ๊ฐ€ ๋ฐ˜ํ™˜๋˜์–ด์•ผ ์ฒด๊ณ„์  ์ถ”๋ก ์ด๋ผ ํ•  ์ˆ˜ ์žˆ๋‹ค. FWM์€ ์กฐํ•ฉ์— ๋”ฐ๋ผ read ๋ฒกํ„ฐ๊ฐ€ ๋‹ฌ๋ผ์ง€์ง€๋งŒ, ์ œ์•ˆ ๋ฐฉ๋ฒ•์€ ์กฐํ•ฉ์— ๊ด€๊ณ„์—†์ด ๊ฑฐ์˜ ๋™์ผํ•œ read ๋ฒกํ„ฐ๋ฅผ ๋ฐ˜ํ™˜ํ•œ๋‹ค. ์ด๋Š” ๋ชจ๋ธ์ด ํŠน์ • ์กฐํ•ฉ์„ ์•”๊ธฐํ•˜๋Š” ๊ฒƒ์ด ์•„๋‹ˆ๋ผ, ๊ฐœ๋ณ„ ๊ตฌ์„ฑ ์š”์†Œ๋ฅผ ๋…๋ฆฝ์ ์œผ๋กœ ์ธ์ฝ”๋”ฉํ•˜๊ณ  ์žฌ์กฐํ•ฉํ•˜์—ฌ ์ถ”๋ก ํ•˜๊ณ  ์žˆ์Œ์„ ๋ณด์—ฌ์ฃผ๋Š” ์ฆ๊ฑฐ์ด๋‹ค.

FWM read vectors

(a) FWM

Ours read vectors

(b) Ours

๋™์ผํ•œ fruit ๊ฐ์ฒด์— ๋Œ€ํ•œ read ๋ฒกํ„ฐ ๊ฐ„ ์œ ์‚ฌ๋„. ์ œ์•ˆ ๋ฐฉ๋ฒ•์€ ์กฐํ•ฉ์— ๊ด€๊ณ„์—†์ด ์ผ๊ด€๋œ ์ถœ๋ ฅ์„ ๋ณด์ธ๋‹ค.