Post on 29-Sep-2020
transcript
Two analyses of gradient-based optimization
for wide two-layer neural networks
Lenaıc Chizat*, joint work with Francis Bach+ and Edouard Oyallon§
Nov. 5th 2019 - Theory of Neural Networks Seminar - EPFL
∗CNRS and Universite Paris-Sud +INRIA and ENS Paris §Centrale Paris
Introduction
Setting
Supervised machine learning
• given input/output training data (x (1), y (1)), . . . , (x (n), y (n))
• build a function f such that f (x) ≈ y for unseen data (x , y)
Gradient-based learning paradigm
• choose a parametric class of functions f (w , ·) : x 7→ f (w , x)
• a convex loss ` to compare outputs: squared/logistic/hinge...
• starting from some w0, update parameters using gradients
Example: Stochastic Gradient Descent with step-sizes (η(k))k≥1
w (k) = w (k−1) − η(k)∇w [`(f (w (k−1), x (k)), y (k))]
[Refs]:
Robbins, Monroe (1951). A Stochastic Approximation Method.
LeCun, Bottou, Bengio, Haffner (1998). Gradient-Based Learning Applied to Document Recognition.
1/29
Models
Linear in the parameters
Linear regression, prior/random features, kernel methods:
f (w , x) = w · φ(x)
• convex optimization
Neural networks
Vanilla NN with activation σ & parameters (W1, b1), . . . , (WL, bL):
f (w , x) = W TL σ(W T
L−1σ(. . . σ(W T1 x + b1) . . . ) + bL−1) + bL
• interacting & interchangeable units/filters
• compositional
2/29
Models
Linear in the parameters
Linear regression, prior/random features, kernel methods:
f (w , x) = w · φ(x)
• convex optimization
Neural networks
Vanilla NN with activation σ & parameters (W1, b1), . . . , (WL, bL):
f (w , x) = W TL σ(W T
L−1σ(. . . σ(W T1 x + b1) . . . ) + bL−1) + bL
• interacting & interchangeable units/filters
• compositional
2/29
Wide two-layer neural networks
Two-layer neural networks
x [1]
x [2]
y
Hidden layerInput layer Output layer
• With activation σ, define φ(wi , x) = ciσ(ai · x + bi ) and
fm(w, x) =1
m
m∑i=1
φ(wi , x) with wi = (ai , bi , ci ) ∈ Rp
• Estimate the parameters w = (w1, . . . ,wm) by solving
minw
Fm(w) := R(fm(w, ·))︸ ︷︷ ︸Empirical/population risk
+ Gm(w)︸ ︷︷ ︸Regularization
4/29
Infinitely wide two-layer networks
• Parameterize the predictor with a probability µ ∈ P(Rp)
f (µ, x) =
∫Rp
φ(w , x)dµ(w)
• Estimate the measure µ by solving
minµ
F (µ) = R(f (µ, ·))︸ ︷︷ ︸Empirical/population risk
+ G (µ)︸ ︷︷ ︸Regularization
• lifted version of “convex” neural networks
[Refs]:
Bengio et al. (2006). Convex neural networks.
5/29
Adaptivity of neural networks
Goal: Estimate a 1-Lipschitz function y : Rd → R given n iid
samples from ρ ∈ P(Rd). Error bound on∫
(f (x)− y(x))2dρ(x) ?
• Ω(n−1/d) (curse of dimensionality)
Same question, if moreover y(x) = g(Ax) for some A ∈ Rs×d?
• O(n−1/d) for kernel methods (some lower bounds too)
• O(d1/2n−1/(s+3)) for 2-layer ReLU networks with weight decay
obtained with G (µ) =∫Vdµ with V (w) = ‖a‖2 + |b|2 + |c |2
no a priori bound on the number m of units required
connecting theory and practice:
Is it related to the predictor learnt by gradient descent?
[Refs]:
Barron (1993). Approximation and estimation bounds for artificial neural networks.
Bach. (2014). Breaking the curse of dimensionality with convex neural networks.
6/29
Adaptivity of neural networks
Goal: Estimate a 1-Lipschitz function y : Rd → R given n iid
samples from ρ ∈ P(Rd). Error bound on∫
(f (x)− y(x))2dρ(x) ?
• Ω(n−1/d) (curse of dimensionality)
Same question, if moreover y(x) = g(Ax) for some A ∈ Rs×d?
• O(n−1/d) for kernel methods (some lower bounds too)
• O(d1/2n−1/(s+3)) for 2-layer ReLU networks with weight decay
obtained with G (µ) =∫Vdµ with V (w) = ‖a‖2 + |b|2 + |c |2
no a priori bound on the number m of units required
connecting theory and practice:
Is it related to the predictor learnt by gradient descent?
[Refs]:
Barron (1993). Approximation and estimation bounds for artificial neural networks.
Bach. (2014). Breaking the curse of dimensionality with convex neural networks.
6/29
Mean-field dynamic and global
convergence
Continuous time dynamics
Gradient flow
Initialize w(0) = (w1(0), . . . ,wm(0)).
Small step-size limit of (stochastic) gradient descent:
w(t + η) = w(t)− η∇Fm(w(t)) ⇒η→0
d
dtw(t) = −m∇Fm(w(t))
Measure representation
Corresponding dynamics in the space of probabilities P(Rp):
µt,m =1
m
m∑i=1
δwi (t)
Technical note: in what follows P2(Rp) is the Wasserstein space
7/29
Many-particle / mean-field limit
Theorem
Assume that w1(0),w2(0), . . . are such that µ0,m → µ0 in P2(Rp)
and some regularity. Then µt,m → µt in P2(Rp), uniformly on
[0,T ], where µt is the unique Wasserstein gradient flow of F
starting from µ0.
Wasserstein gradient flows are characterized by
∂tµt = −div(−∇F ′µtµt)
where F ′µ ∈ C1(Rp) is the Frechet derivative of F at µ.
[Refs]:
Nitanda, Suzuki (2017). Stochastic particle gradient descent for infinite ensembles.
Mei, Montanari, Nguyen (2018). A Mean Field View of the Landscape of Two-Layers Neural Networks.
Rotskoff, Vanden-Eijndem (2018). Parameters as Interacting Particles [...].
Sirignano, Spiliopoulos (2018). Mean Field Analysis of Neural Networks.
Chizat, Bach (2018). On the Global Convergence of Gradient Descent for Over-parameterized Models [...]
8/29
Global convergence (Chizat & Bach 2018)
Theorem (2-homogeneous case)
Assume that φ is positively 2-homogeneous and some regularity. If
the support of µ0 covers all directions (e.g. Gaussian) and if
µt → µ∞ in P2(Rp), then µ∞ is a global minimizer of F .
Non-convex landscape : initialization matters
Corollary
Under the same assumptions, if at initialization µ0,m → µ0 then
limt→∞
limm→∞
F (µm,t) = limm→∞
limt→∞
F (µm,t) = inf F .
Generalization properties, if F is ...
• the regularized empirical risk: statistical adaptivity !
• the population risk: need convergence speed (?)
• the unregularized empirical risk: need implicit bias (?)
[Refs]:
Chizat, Bach (2018). On the Global Convergence of Gradient Descent for Over-parameterized Models [...].
9/29
Global convergence (Chizat & Bach 2018)
Theorem (2-homogeneous case)
Assume that φ is positively 2-homogeneous and some regularity. If
the support of µ0 covers all directions (e.g. Gaussian) and if
µt → µ∞ in P2(Rp), then µ∞ is a global minimizer of F .
Non-convex landscape : initialization matters
Corollary
Under the same assumptions, if at initialization µ0,m → µ0 then
limt→∞
limm→∞
F (µm,t) = limm→∞
limt→∞
F (µm,t) = inf F .
Generalization properties, if F is ...
• the regularized empirical risk: statistical adaptivity !
• the population risk: need convergence speed (?)
• the unregularized empirical risk: need implicit bias (?)
[Refs]:
Chizat, Bach (2018). On the Global Convergence of Gradient Descent for Over-parameterized Models [...]. 9/29
Numerical Illustrations
ReLU, d = 2, optimal predictor has 5 neurons (population risk)
2 1 0 1 2 3
2
1
0
1
2
particle gradient flowoptimal positionslimit measure
5 neurons
2 1 0 1 2 3
2
1
0
1
2
10 neurons
2 1 0 1 2 3
2
1
0
1
2
100 neurons
2 1 0 1 2 3
2
1
0
1
2
10/29
Performance
Population risk at convergence vs m
ReLU, d = 100, optimal predictor has 20 neurons
101 102
10 6
10 5
10 4
10 3
10 2
10 1
100
particle gradient flowconvex minimizationbelow optim. errorm0
11/29
Convex optimization on measures
Sparse deconvolution on T2 with Dirichlet kernel
(white) sources (red) particles.
Computational guaranties for the regularized case, but m
exponential in the dimension d
[Refs]:
Chizat (2019). Sparse Optimization on Measures with Over-parameterized Gradient Descent. 12/29
Lazy Training
Neural Tangent Kernel (Jacot et al. 2018)
Infinite width limit of standard neural networks
For infinitely wide fully connected neural networks of any depth
with “standard” initialization and no regularization: the gradient
flow implicitly performs kernel ridge(less) regression with the
neural tangent kernel
K (x , x ′) = limm→∞
〈∇w fm(w0, x),∇w fm(w0, x′)〉.
Reconciling the two views:
fm(w , x) =1√m
m∑i=1
φ(wi , x) vs. fm(w , x) =1
m
m∑i=1
φ(wi , x)
This behavior is not intrinsically due to over-parameterization
but to an exploding scale
[Refs]:
Jacot, Gabriel, Hongler (2018). Neural Tangent Kernel: Convergence and Generalization in Neural Networks.
20/29
Neural Tangent Kernel (Jacot et al. 2018)
Infinite width limit of standard neural networks
For infinitely wide fully connected neural networks of any depth
with “standard” initialization and no regularization: the gradient
flow implicitly performs kernel ridge(less) regression with the
neural tangent kernel
K (x , x ′) = limm→∞
〈∇w fm(w0, x),∇w fm(w0, x′)〉.
Reconciling the two views:
fm(w , x) =1√m
m∑i=1
φ(wi , x) vs. fm(w , x) =1
m
m∑i=1
φ(wi , x)
This behavior is not intrinsically due to over-parameterization
but to an exploding scale
[Refs]:
Jacot, Gabriel, Hongler (2018). Neural Tangent Kernel: Convergence and Generalization in Neural Networks.20/29
Linearized model and scale
• let h(w) = f (w , ·) be a differentiable model
• let h(w) = h(w0) + Dhw0(w − w0) be its linearization at w0
Compare 2 training trajectories starting from w0, with scale α > 0:
• wα(t) gradient flow of Fα(w) = R(αh(w))/α2
• wα(t) gradient flow of Fα(w) = R(αh(w))/α2
if h(w0) ≈ 0 and α large, then wα(t) ≈ wα(t)21/29
Lazy training theorems
Theorem (Non-asymptotic)
If h(w0) = 0, and R potentially non-convex, for any T > 0, it holds
limα→∞
supt∈[0,T ]
‖αh(wα(t))− αh(wα(t))‖ = 0
Theorem (Strongly convex)
If h(w0) = 0, and R strongly convex, it holds
limα→∞
supt≥0‖αh(wα(t))− αh(wα(t))‖ = 0
• instance of implicit bias: lazy because parameters hardly move
• may replace the model by its linearization
[Refs]:
Chizat, Oyallon, Bach (2018). On Lazy Training in Differentiable Programming.
22/29
When does lazy training occur (without α)?
Relative scale criterion
For R(y) = 12‖y − y?‖2, relative error at normalized time t is
err . t2κh(w0) where κh(w0) :=‖h(w0)− y?‖‖∇h(w0)‖
‖∇2h(w0)‖‖∇h(w0)‖
Examples (h(w) = f (w , ·)):
• Homogeneous models with f (w0, ·) = 0.
If for λ > 0, f (λw , x) = λLf (w , x), then κf (w0) 1/‖w0‖L
• Wide two-layer NNs with iid weights, EΦ(wi , ·) = 0.
If f (w , x) = α(m)∑m
i=1 Φ(wi , x), then κf (w0) (mα(m))−1
24/29
Numerical Illustrations
Training paths (ReLU, d = 2, optimal predictor has m = 3 neurons)
circle of radius 1gradient flow (+)gradient flow (-)
(a) Lazy (b) Not lazy
26/29
Performance
Teacher-student setting, generalization in 100-d vs init. scale τ :
10 2 10 1 100 101
0.0
0.5
1.0
1.5
2.0
2.5
3.0
3.5
4.0
Test
loss
end of trainingbest throughout training
Perf. of ConvNets (VGG11) for image classification (CIFAR10):
101 103 105 107
(scale of the model)
60
70
80
90
100
%
train accuracytest accuracystability of activations
Figure 3: VGG-11 on CIFAR10
• similar gaps observed for widened
ConvNets & ResNets
• CKNa and taylored NTKb perform
well on this (not so hard) task
aConvolutional Kernel NetworksbNeural Tangent Kernel
27/29
Conclusion
• Gradient descent on infinitely wide 2-layer networks converges
to global minimizers
• Generalization behavior depends on initialization, loss,
stopping time, signal scale, regularization...
• For the regularized empirical risk, it breaks the statistical
curse of dimensionality
• But not (yet) the computational curse of dimensionality
[Refs]:
- Chizat, Bach (2018). On the Global Convergence of Over-parameterized Models using Optimal Transport.
- Chizat, Oyallon, Bach (2019). On Lazy Training in Differentiable Programming.
- Chizat (2019). Sparse Optimization on Measures with Over-parameterized Gradient Descent.
28/29