Towards Understanding Long Short Term Memory...

Post on 21-May-2020

5 views 0 download

transcript

Towards Understanding Long Short Term Memory Networks

HMI

1/28/2019

Jordan Rodu

Department of Statistics

University of Virginia

Towards Understanding Long Short Term Memory Networks

HMI

1/28/2019

Jordan Rodu

Department of Statistics

University of Virginia with Joao SedocUniversity of PennsylvaniaDepartment of Computer Science

Mapping LSTMs to state space models

• Goal is not to interpret results on specific data

• Rather, map LSTM onto reasonable models- understand the space of sequences captured by LSTMs

• Preliminary work, basic ideas

Hidden Markov Models

t t+1 t+2

Hidden Markov Models

t t+1 t+2

Hidden Markov Models

t t+1 t+2

T T

Hidden Markov Models

t t+1 t+2

T T

𝑝 𝑥𝑡+1 𝑥1:𝑡) = 𝑝 𝑥𝑡+1 𝑥𝑡)

Hidden Markov Models

t t+1 t+2

O O O

Hidden Markov Models

t t+1 t+2

O O O

𝑝 𝑦𝑡 𝑥𝑡)

Hidden Markov Models- flavors

• Output• Discrete• Continuous• Low dimensional• High dimensional

• States• Discrete• Continuous

• Low dimensional• High dimensional

• Time• Discrete• Continuous

Hidden Markov Models- flavors

• Output• Discrete• Continuous• Low dimensional• High dimensional

• States• Discrete• Continuous

• Low dimensional• High dimensional

• Time• Discrete• Continuous

Hidden Markov Models1⋮0

0⋮1

0⋮0

Hidden Markov Models

𝑏1⋮𝑏𝑘

Hidden Markov Models

𝑏1⋮𝑏𝑘

Hidden Markov Models

𝑏1⋮𝑏𝑘

𝑝𝑥1|𝑦⋮

𝑝𝑥𝑘|𝑦

Hidden Markov Models

𝑏1⋮𝑏𝑘

𝑝𝑥1|𝑦⋮

𝑝𝑥𝑘|𝑦

෨𝑏1⋮෨𝑏𝑘

Hidden Markov Model- belief states

Hidden Markov Model- belief states

Hidden Markov Model- belief states

A few related architectures

⋮ ⋮ ⋮

A few related architectures

⋮ ⋮ ⋮

A few related architectures

⋮ ⋮ ⋮

A few related architectures

⋮ ⋮ ⋮

A few related architectures

⋮ ⋮ ⋮

A few related architectures

Τ𝑡 Τ𝑡+1 Τ𝑡+2

A few related architectures

Τ𝑡 Τ𝑡+1 Τ𝑡+2

𝑂𝑡 𝑂𝑡+1 𝑂𝑡+2

LSTM

𝜎 𝜎 tanh 𝜎

x

x

+

x

tanh

𝑦𝑡

ℎ𝑡

LSTM

𝜎 𝜎 tanh 𝜎

x

x

+

x

tanh

𝑦𝑡

ℎ𝑡

LSTM

𝜎 𝜎 tanh 𝜎

x

x

+

x

tanh

𝑦𝑡

ℎ𝑡

𝜎(𝑊ℎℎ𝑡−1 +𝑊𝑦𝑦𝑡 + 𝑏)

LSTM

𝜎 𝜎 tanh 𝜎

x

x

+

x

tanh

𝑦𝑡

ℎ𝑡

𝜎(𝑊ℎℎ𝑡−1 +𝑊𝑦𝑦𝑡 + 𝑏)

Hidden Markov Models-reminder

𝑏1⋮𝑏𝑘

𝑝𝑥1|𝑦⋮

𝑝𝑥𝑘|𝑦

෨𝑏1⋮෨𝑏𝑘

LSTM

𝜎 𝜎 tanh 𝜎

x

x

+

x

tanh

𝑦𝑡

ℎ𝑡

𝜎(𝑊ℎℎ𝑡−1 +𝑊𝑦𝑦𝑡 + 𝑏)

LSTM

𝜎 𝜎 tanh 𝜎

x

x

+

x

tanh

𝑦𝑡

ℎ𝑡

𝜎(𝑊ℎℎ𝑡−1 +𝑊𝑦𝑦𝑡 + 𝑏)

prior (T)

LSTM

𝜎 𝜎 tanh 𝜎

x

x

+

x

tanh

𝑦𝑡

ℎ𝑡

𝜎(𝑊ℎℎ𝑡−1 +𝑊𝑦𝑦𝑡 + 𝑏)

posterior (T and O)

LSTM

𝜎 𝜎 tanh 𝜎

x

x

+

x

tanh

𝑦𝑡

ℎ𝑡

𝜎(𝑊ℎℎ𝑡−1 +𝑊𝑦𝑦𝑡 + 𝑏)

Hidden states, 2𝑘 states for k hidden nodes

⋮ ⋮ ⋮

Dependencies specified by weights from 𝑊ℎ

⋮ ⋮ ⋮

Incorporating memory cell

𝜎 𝜎 tanh 𝜎

x

x

+

x

tanh

𝑦𝑡

ℎ𝑡

Incorporating memory cell

𝜎 𝜎 tanh 𝜎

x

x

+

x

tanh

𝑦𝑡

ℎ𝑡𝐶𝑡 = 𝑓𝑡 ∗ 𝐶𝑡−1 + 𝑖𝑡 ∗ ෩𝐶𝑡

Incorporating memory cell

𝜎 𝜎 tanh 𝜎

x

x

+

x

tanh

𝑦𝑡

ℎ𝑡𝐶𝑡 = 𝑓𝑡 ∗ 𝐶𝑡−1 + 𝑖𝑡 ∗ ෩𝐶𝑡

𝜎(𝑊ℎℎ𝑡−1 +𝑊𝑦𝑦𝑡 + 𝑏)

⋮ ⋮ ⋮

⋮ ⋮ ⋮−1, 0, 1

⋮ ⋮ ⋮−1, 0, 1

State space size 3𝑘 with special symmetry thatmodulates excitation andinhibition

Partially Observable Markov Decision Process

t t+1 t+2T T

*relaxed visualization

O O O

T

𝑎𝑡 𝑎𝑡+1 𝑎𝑡+2

POMDP representation

𝜎 𝜎 tanh 𝜎

x

x

+

x

tanh

𝑦𝑡

ℎ𝑡

POMDP representation

𝜎 𝜎 tanh 𝜎

x

x

+

x

tanh

𝑦𝑡

ℎ𝑡

𝜎(𝑊ℎℎ𝑡−1 +𝑊𝑦𝑦𝑡 + 𝑏)

POMDP representation

𝜎 𝜎 tanh 𝜎

x

x

+

x

tanh

𝑦𝑡

ℎ𝑡

𝜎(𝑊ℎℎ𝑡−1 +𝑊𝑦𝑦𝑡 + 𝑏)

policy

POMDP representation

𝜎 𝜎 tanh 𝜎

x

x

+

x

tanh

𝑦𝑡

ℎ𝑡

𝜎(𝑊ℎℎ𝑡−1 +𝑊𝑦𝑦𝑡 + 𝑏)

policy

Recall that ourgoal here is notto learn a POMDPor to approximatea POMDP usingLSTMs, rather towrap LSTMs inmodels for which we have a morerobust understanding.

Thanks!