Back Propagation - Courses · 2015. 12. 11. · Back Propagation Machine Learning CSx824/ECEx242...

Post on 19-Jan-2021

5 views 0 download

transcript

Back Propagation

Machine Learning CSx824/ECEx242

Bert Huang Virginia Tech

Outline

• Logistic regression and perceptron as neural networks

• Likelihood gradient for 2-layered neural network

• General recipe for back propagation

Back Propagation• Back propagation:

• Compute hidden unit activations: forward propagation

• Compute gradient at output layer: error

• Propagate error back one layer at a time

• Chain rule via dynamic programming

Logistic Squashing Function�(x) =

1

1 + exp(�x)

Logistic Squashing Function�(x) =

1

1 + exp(�x)

d �(x)

d x

= �(x)(1� �(x))

Multi-Layered Perceptron

x1 x2 x3 x4 x5

h1 h2

yh = [h1, h2]>

h1 = �(w>11x) h2 = �(w>

12x)

p(y |x) = �(w>21h)

p(y |x) = �⇣w

>21

⇥�(w>

11x),�(w>12x)

⇤>⌘

Gradients

p(y |x) = �(w>21h)

p(y |x) = �⇣w

>21

⇥�(w>

11x),�(w>12x)

⇤>⌘

ll(W ) =

nX

i=1

log p(yi |xi )

rw21 ll =nX

i=1

1

p(yi |xi )⇥rw21p(yi |xi )

rw21 ll =nX

i=1

1

p(yi |xi )⇥rw21�(w

>21h)

rw21 ll =nX

i=1

(I (yi = 1)� �(w>21h))rw21w

>21h

rw21 ll =nX

i=1

(I (yi = 1)� �(w>21h))h

Gradients

p(y |x) = �(w>21h)

p(y |x) = �⇣w

>21

⇥�(w>

11x),�(w>12x)

⇤>⌘

ll(W ) =

nX

i=1

log p(yi |xi )

rw11 ll =nX

i=1

1

p(yi |xi )⇥rw11p(yi |xi )

rw11 ll =nX

i=1

(I (yi = 1)� �(w>21h))rw11w

>21h

rw11 ll =nX

i=1

(I (yi = 1)� �(w>21h))w

>21(rw11h)

rw11 ll =nX

i=1

(I (yi = 1)� �(w>21h))w21[1]rw11�(w

>11xi )

Gradients

p(y |x) = �(w>21h)

p(y |x) = �⇣w

>21

⇥�(w>

11x),�(w>12x)

⇤>⌘

ll(W ) =

nX

i=1

log p(yi |xi )

rw11 ll =nX

i=1

1

p(yi |xi )⇥rw11p(yi |xi )

rw11 ll =nX

i=1

(I (yi = 1)� �(w>21h))rw11w

>21h

rw11 ll =nX

i=1

(I (yi = 1)� �(w>21h))w

>21(rw11h)

rw11 ll =nX

i=1

(I (yi = 1)� �(w>21h))w21[1]rw11�(w

>11xi )

rw11 ll =nX

i=1

(I (yi = 1)� �(w>21h))w21[1]�(w

>11xi )(1� �(w>

11xi ))xi

ll(W ) =

nX

i=1

log p(yi |xi )=nX

i=1

log �⇣w

>21

⇥�(w>

11xi ),�(w>12xi )

⇤>⌘

w>21h h1 = �(w>

11x)log �(w>21h)

rw11 ll =nX

i=1

(I (yi = 1)� �(w>21h))w21[1]�(w

>11xi )(1� �(w>

11xi ))xi

w>21h h1 = �(w>

11x)log �(w>21h)

rw11 ll =nX

i=1

(I (yi = 1)� �(w>21h))w21[1]�(w

>11xi )(1� �(w>

11xi ))xi

x1 x2 x3 x4 x5

h1 h2

y

raw errorblame for error

gradient of blamed error

Matrix Formx

h1 = s(W1x)

Matrix Formx

h1 = s(W1x)w11

h1[1]

w12

w13 w14 w15

w11 w12 w13 w14 w15

W1 =

s(v) = [s(v1), s(v2), s(v3), ...]>

Matrix Formx

h1 = s(W1x)

w11 w12 w13 w14 w15

w21 w22 w23 w24 w25

w31 w32 w33 w34 w35

W1 =

s(v) = [s(v1), s(v2), s(v3), ...]>

# of output units

# of input units

Matrix Formx

h1 = s(W1x)

hm-1 = s(Wm-1 hm-2)

f(x, W) = s(Wm hm-1)

h2 = s(W2 h1)

J(W ) = `(f(x,W ))

Matrix Gradient Recipeh1 = s(W1x)

hm-1 = s(Wm-1 hm-2)

f(x, W) = s(Wm hm-1)

h2 = s(W2 h1)

J(W ) = `(f(x,W ))

�m = `

0(f(x,W ))rWmJ = �mh>

m�1

rWm�1J = �m�1h>m�2

rWiJ = �ih>i�1

rW1J = �1x>

�m�1 = (W>m�m)� s0(Wm�1hm�2)

�i = (W>i+1�i+1)� s0(Wihi�1)

Matrix Gradient Recipe

h1 = s(W1x)

f(x, W) = s(Wm hm-1)

hi = s(Wi hi-1)

J(W ) = `(f(x,W ))

�m = `

0(f(x,W )) rWiJ = �ih>i�1

rW1J = �1x>

Feed Forward Propagation Back Propagation

�i = (W>i+1�i+1)� s0(Wihi�1)

Challenges

• Local minima (non-convex)

• Overfitting

Remedies

• Regularization

• Parameter sharing: convolution

• Pre-training: initializing weights smartly

• Training data manipulation, e.g., dropout, noise, transformations

• Huge data sets