+ All Categories
Home > Documents > The Deep Ritz method: A deep learning-based numerical ... · 2. A numerical quadrature rule for the...

The Deep Ritz method: A deep learning-based numerical ... · 2. A numerical quadrature rule for the...

Date post: 15-Nov-2019
Category:
Upload: others
View: 39 times
Download: 0 times
Share this document with a friend
14
The Deep Ritz method: A deep learning-based numerical algorithm for solving variational problems Weinan E 1 and Bing Yu 2 1 The Beijing Institute of Big Data Research, Department of Mathematics and PACM, Princeton University, School of Mathematical Sciences and BICMR, Peking University 2 School of Mathematical Sciences, Peking University October 3, 2017 Abstract We propose a deep learning based method, the Deep Ritz Method, for numeri- cally solving variational problems, particularly the ones that arise from partial differential equations. The Deep Ritz method is naturally nonlinear, naturally adaptive and has the potential to work in rather high dimensions. The framework is quite simple and fits well with the stochastic gradient descent method used in deep learning. We illustrate the method on several problems including some eigenvalue problems. Keywords Deep Ritz Method · Variational problems · PDE · Eigenvalue problems Mathematical Subject Classification 35Q68 1 Introduction Deep learning has had great success in computer vision and other artificial intelligence tasks [1]. Underlying this success is a new way to approximate functions, from an additive construction commonly used in approximation theory to a compositional construction used in deep neural networks. The compositional construction seems to be particularly powerful in high dimensions. This suggests that deep neural network based models can be of use in other contexts that involve constructing functions. This includes solving partial differential equations, molecular modeling, model reduction, etc. These aspects have been explored recently in [2, 3, 4, 5, 6, 7]. 1 arXiv:1710.00211v1 [cs.LG] 30 Sep 2017
Transcript
Page 1: The Deep Ritz method: A deep learning-based numerical ... · 2. A numerical quadrature rule for the functional. 3. An algorithm for solving the nal optimization problem. 2.1 Building

The Deep Ritz method: A deep learning-basednumerical algorithm for solving variational problems

Weinan E1 and Bing Yu2

1The Beijing Institute of Big Data Research, Department of Mathematicsand PACM, Princeton University, School of Mathematical Sciences and

BICMR, Peking University2School of Mathematical Sciences, Peking University

October 3, 2017

Abstract We propose a deep learning based method, the Deep Ritz Method, for numeri-cally solving variational problems, particularly the ones that arise from partial differentialequations. The Deep Ritz method is naturally nonlinear, naturally adaptive and has thepotential to work in rather high dimensions. The framework is quite simple and fits wellwith the stochastic gradient descent method used in deep learning. We illustrate themethod on several problems including some eigenvalue problems.

Keywords Deep Ritz Method · Variational problems · PDE · Eigenvalue problems

Mathematical Subject Classification 35Q68

1 Introduction

Deep learning has had great success in computer vision and other artificial intelligencetasks [1]. Underlying this success is a new way to approximate functions, from an additiveconstruction commonly used in approximation theory to a compositional construction usedin deep neural networks. The compositional construction seems to be particularly powerfulin high dimensions. This suggests that deep neural network based models can be of use inother contexts that involve constructing functions. This includes solving partial differentialequations, molecular modeling, model reduction, etc. These aspects have been exploredrecently in [2, 3, 4, 5, 6, 7].

1

arX

iv:1

710.

0021

1v1

[cs

.LG

] 3

0 Se

p 20

17

Page 2: The Deep Ritz method: A deep learning-based numerical ... · 2. A numerical quadrature rule for the functional. 3. An algorithm for solving the nal optimization problem. 2.1 Building

In this paper, we continue this line of work and propose a new algorithm for solvingvariational problems. We call this new algorithm the Deep Ritz method since it is based onusing the neural network representation of functions in the context of the Ritz method. TheDeep Ritz method has a number of interesting and promising features, which we explorelater in the paper.

2 The Deep Ritz Method

An explicit example of the kind of variational problems we are interested in is [8]

minu∈H

I(u) (1)

where

I(u) =

∫Ω

(1

2|∇u(x)|2 − f(x)u(x)

)dx (2)

and H is the set of admissible functions (also called trial function, here represented byu), f is a given function, representing external forcing to the system under consideration.Problems of this type are fairly common in physical sciences. The Deep Ritz method isbased on the following set of ideas:

1. Deep neural network based approximation of the trial function.

2. A numerical quadrature rule for the functional.

3. An algorithm for solving the final optimization problem.

2.1 Building trial functions

The basic component of the Deep Ritz method is a nonlinear transformation x →zθ(x) ∈ Rm defined by a deep neural network. Here θ denotes the parameters, typically theweights in the neural network, that help to define this transformation. In the architecturethat we use, each layer of the network is constructed by stacking several blocks, each blockconsists of two linear transformations, two activation functions and a residual connection,both the input s and the output t of the block are vectors in Rm. The i-th block can beexpressed as:

t = fi(s) = φ(Wi,2 · φ(Wi,1s+ bi,1) + bi,2) + s (3)

where Wi,1,Wi,2 ∈ Rm×m, bi,1, bi,2 ∈ Rm are parameters associated with the block. φ is the(scalar) activation function [1].

Our experience has suggested that the smoothness of the activation function φ playsa key role in the accuracy of the algorithm. To balance simplicity and accuracy, we havedecided to use

φ(x) = maxx3, 0 (4)

2

Page 3: The Deep Ritz method: A deep learning-based numerical ... · 2. A numerical quadrature rule for the functional. 3. An algorithm for solving the nal optimization problem. 2.1 Building

The last term in (3), the residual connection, makes the network much easier to trainsince it helps to avoid the vanishing gradient problem [9]. The structure of the two blocks,including two residual connections, is shown in Figure 1.

input x

FC layer (size m)+ activation residual

connection

output u

FC layer (size m)+ activation

FC layer (size m)+ activation

FC layer (size m)+ activation

residualconnection

FC layer (size 1)

Figure 1: The figure shows a network with two blocks and an output linear layer. Eachblock consists of two fully-connected layers and a skip connection.

The full n-layer network can now be expressed as:

zθ(x) = fn ... f1(x) (5)

θ denotes the set of all the parameters in the whole network. Note the input x for the firstblock is in Rd, not Rm. To handle this discrepancy we can either pad x by a zero vectorwhen d < m, or apply a linear transformation on x when d > m. Having zθ, we obtain uby

u(x; θ) = a · zθ(x) + b (6)

Here in the left-hand side and in what follows, we will use θ to denote the full parameterset θ, a, b. Substituting this into the form of I, we obtain a function of θ, which weshould minimize.

For the functional that occurs in (2), denote:

g(x; θ) =1

2|∇xu(x; θ)|2 − f(x)u(x; θ) (7)

3

Page 4: The Deep Ritz method: A deep learning-based numerical ... · 2. A numerical quadrature rule for the functional. 3. An algorithm for solving the nal optimization problem. 2.1 Building

then we are left with the optimization problem:

minθL(θ), L(θ) =

∫Ω

g(x; θ)dx (8)

2.2 The stochastic gradient descent algorithm and the quadra-ture rule

To finish describing the algorithm, we need to furnish the remaining two components:the optimization algorithm and the discretization of the integral in I in (2) or L in (8).The latter is necessary since computing the integral in I (or L) explicitly for functions ofthe form (6) is quite an impossible task.

In machine learning, the optimization problem that one encounters often takes theform:

minx∈Rd

L(θ) :=1

N

N∑i=1

Li(θ), (9)

where each term at the right-hand side corresponds to one data point. n, the numberof data points, is typically very large. For this problem, the algorithm of choice is thestochastic gradient descent (SGD) method, which can be described as follows:

θk+1 = θk − η∇fγk(θk). (10)

Here γk are i.i.d random variables uniformly distributed over 1, 2, · · · , n. This is thestochastic version of the gradient descent algorithm (GD). The key idea is that insteadof computing the sum when evaluating the gradient of L, we simply randomly chooseone term in the sum. Compared with GD, SGD requires only one function evaluation of nfunction evaluations at each iteration. In practice, instead of picking one term, one choosesa ”mini-batch” of terms at each step.

At a first sight, our problem seems different from the ones that occur in machine learningsince there are no data involved. The connection becomes clear once we view the integralin I as a continuous sum, each point in Ω then becomes a data point. Therefore, at eachstep of the SGD iteration, one chooses a mini-batch of points to discretize the integral.These points are chosen randomly and the same quadrature weight is used at every point.

Note that if we use standard quadrature rules to discretize the integral, then we arebound to choose a fixed set of nodes. In this case, we run into the risk where the integrandis minimized on these fixed nodes but the functional itself is far from being minimized. Itis nice that SGD fits naturally with the needed numerical integration in this context.

In summary, the SGD in this context is given by:

θk+1 = θk − η∇θ1

N

N∑j=1

g(xj,k; θk) (11)

4

Page 5: The Deep Ritz method: A deep learning-based numerical ... · 2. A numerical quadrature rule for the functional. 3. An algorithm for solving the nal optimization problem. 2.1 Building

where for each k, xj,k is a set of points in Ω that are randomly sampled with uniformdistribution. To accelerate the training of the neural network, we use the Adam optimizerversion of the SGD [10].

3 Numerical Results

3.1 The Poisson equation in two dimension

Consider the Poisson equation:

−∆u(x) = 1, x ∈ Ω

u(x) = 0, x ∈ ∂Ω(12)

where Ω = (−1, 1)×(−1, 1)\[0, 1)×0. The solution to this problem suffers from the well-known ”corner singularity” caused by the nature of the domain [11]. A simple asymptotic

analysis shows that at the origin, the solution behaves as u(x) = u(r, θ) ∼ r12 sin θ

2[11].

Models of this type have been extensively used to help developing and testing adaptivefinite element methods.

The network we used to solve this problem is a stack of four blocks (eight fully-connectedlayers) and an output layer with m = 10. There are a total of 811 parameters in the model.As far as we can tell, this network structure is not special in any way. It is simply the onethat we used.

The boundary condition causes some problems. Here for simplicity, we use a penaltymethod and consider the modified functional

I(u) =

∫Ω

(1

2|∇xu(x)|2 − f(x)u(x)

)dx+ β

∫∂Ω

u(x)2ds (13)

We choose β = 500. The results from the Deep Ritz method is shown in see Figure 2(a). Forcomparison, we also plot the result of the finite difference method with ∆x1 = ∆x2 = 0.1(1, 681 degrees of freedom), see Figure 2(b).

To analyze the error more quantitatively, we consider the following problem

∆u(x) = 0, x ∈ Ω

u(x) = u(r, θ) = r12 sin

θ

2, x ∈ ∂Ω

(14)

where Ω = (−1, 1) × (−1, 1)\[0, 1) × 0. This problem has an explicit solution u∗(x) =

r12 sin θ

2in polar coordinates. The error e = max |u∗(x)− uh(x)|, where u∗ and uh are the

exact and approximate solutions respectively, is shown in Table 1 for both the Deep Ritzmethod and the finite difference method (on uniform grids). We can see that with fewer

5

Page 6: The Deep Ritz method: A deep learning-based numerical ... · 2. A numerical quadrature rule for the functional. 3. An algorithm for solving the nal optimization problem. 2.1 Building

-1.0 -0.5 0.0 0.5 1.0

1.0

0.5

0.0

-0.5

-1.0 0.00

0.02

0.04

0.06

0.08

0.10

0.12

0.14

0.16

(a) Solution of Deep Ritz method, 811 parameters

-1.0 -0.5 0.0 0.5 1.0

1.0

0.5

0.0

-0.5

-1.0 0.00

0.02

0.04

0.06

0.08

0.10

0.12

0.14

0.16

(b) Solution of finite difference method, 1, 681 pa-rameters

Figure 2: Solutions computed by two different methods.

Table 1: Error of Deep Ritz method (DRM) and finite difference method (FDM)

Method Blocks Num Parameters relative L2 errorDRM 3 591 0.0079

4 811 0.00725 1031 0.006476 1251 0.0057

FDM 625 0.01252401 0.0063

parameters, the Deep Ritz method gives more accurate solution than the finite differencemethod.

Being a naturally nonlinear variational method, the Deep Ritz method is also naturallyadaptive. We believe that this contributes to the better accuracy of the Deep Ritz method.

3.2 Poisson equation in high dimension

Experiences in computer vision and other artificial intelligence tasks suggest that deeplearning-based methods are particularly powerful in high dimensions. This has been con-firmed by the results of the Deep BSDE method [3]. In this subsection, we investigate theperformance of the Deep Ritz method in relatively high dimension.

6

Page 7: The Deep Ritz method: A deep learning-based numerical ... · 2. A numerical quadrature rule for the functional. 3. An algorithm for solving the nal optimization problem. 2.1 Building

Consider (d = 10)

−∆u = 0, x ∈ (0, 1)10

u(x) =5∑

k=1

x2k−1x2k, x ∈ ∂(0, 1)10 .(15)

The solution of this problem is simply u(x) =∑5

k=1 x2k−1x2k, and we will use the exactsolution to compute the error of our model later.

For the network structure, we stack six fully-connected layers with three skip connec-tions and a final linear layer, and there are a total of 671 parameters. For numericalintegration, at each step of the SGD iteration, we sample 1,000 points in Ω and 100 pointsat each hyperplane that composes ∂Ω. We set β = 103. After 50,000 iterations, the relativeL2 error was reduced to about 0.4%. The training process is shown in Figure 3(a).

0 10000 20000 30000 40000 50000

−5

−4

−3

−2

−1

0

(a) ln e and ln lossboundary, d=10

0 2500 5000 7500 10000 12500 15000 17500 20000

−3.5

−3.0

−2.5

−2.0

−1.5

−1.0

−0.5

0.0

(b) ln e and ln lossboundary, d=100

Figure 3: The total error and error at the boundary during the training process. Thex-axis represents the iteration steps. The blue curves show the relative error of u. The redcurves show the relative error on the boundary.

Also shown in Figure 3(b) is the training process for the problem:

−∆u = −200 x ∈ (0, 1)d

u(x) =∑k

x2k x ∈ ∂(0, 1)d (16)

with d = 100 with a similar network structure (stack 3 blocks of size m=100). The solutionof this problem is u(x) =

∑k x

2k. After 50000 iterations, the relative error is reduced to

about 2.2%.

7

Page 8: The Deep Ritz method: A deep learning-based numerical ... · 2. A numerical quadrature rule for the functional. 3. An algorithm for solving the nal optimization problem. 2.1 Building

3.3 An example with the Neumann boundary condition

Consider:−∆u+ π2u = 2π2

∑k

cos(πxk) x ∈ [0, 1]d

∂u

∂n|∂[0,1]d = 0 x ∈ ∂[0, 1]d

(17)

The exact solution is u(x) =∑

k cos(πxk)

0 10000 20000 30000 40000 50000

−4

−3

−2

−1

0

(a) ln e, d=5

0 10000 20000 30000 40000 50000

−4

−3

−2

−1

0

(b) ln e, d=10

Figure 4: The error during the training process (d = 5 and d = 10).

In this case, we can simply use

I(u) =

∫Ω

(1

2

(|∇u(x)|2 + π2u(x)2

)− f(x)u(x)

)dx

without any penalty function for the boundary.With a similar network structure the relative L2 error reaches 1.3% for d = 5 and 1.9%

for d = 10. The training process is shown in Figure 4.

3.4 Transfer learning

An important component of the training process is the initialization. Here we inves-tigate the benefit of transferring weights in the network when the forcing function f ischanged.

8

Page 9: The Deep Ritz method: A deep learning-based numerical ... · 2. A numerical quadrature rule for the functional. 3. An algorithm for solving the nal optimization problem. 2.1 Building

Consider the problem:

−∆u(x) = 6(1 + x1)(1− x1)x2 + 2(1 + x2)(1− x2)x2 x ∈ Ω

u(x) = r12 sin

θ

2+ (1 + x1)(1− x1)(1 + x2)(1− x2)x2 x ∈ ∂Ω

(18)

where Ω = (−1, 1)× (−1, 1)\[0, 1)×0. Here we used a mixture of rectangular and polarcoordinates. The exact solution is

u(x) = r12 sin

θ

2+ (1 + x1)(1− x1)(1 + x2)(1− x2)x2

.The network consists of a stack of 3 blocks with m=10, that is, six fully-connected

layers and three residual connections and a final linear transformation layer to obtain u.We show how the error and the weights in the layers change during the training period inFigure 5.

We also transfer the weights from the problem:

−∆u(x) = 0, x ∈ Ω

u(x) = r12 sin

θ

2x ∈ ∂Ω

(19)

where Ω = (−1, 1)× (−1, 1)\[0, 1)× 0.The error and the weights during the training period are also shown in Figure 5. We

see that transferring weights speeds up the training process considerably during the initialstage of the training. This suggests that transferring weights is a particularly effectiveprocedure if the accuracy requirement is not very strigent.

3.5 Eigenvalue problems

Consider the following problem:

−∆u+ v · u = λu, x ∈ Ω

u|∂Ω = 0(20)

Problems of this kind occur often in quantum mechanics where v is the potential function.There is a well-known variational principle for the smallest eigenvalue:

min

∫Ω|∇u|2dx+

∫Ωvu2dx∫

Ωu2dx

s.t. u|∂Ω = 0

(21)

The functional we minimize here is called the Rayleigh quotient.

9

Page 10: The Deep Ritz method: A deep learning-based numerical ... · 2. A numerical quadrature rule for the functional. 3. An algorithm for solving the nal optimization problem. 2.1 Building

0 10000 20000 30000 40000 50000

−4

−3

−2

−1

0

1

(a) ln err

0 10000 20000 30000 40000 50000

−10

−8

−6

−4

−2

0

2

(b) ln ||∆W ||22

Figure 5: The red curves show the results of the training process with weight transfer.The blue curves show the results of the training process with random initialization. Theleft figure shows how the natural logarithm of the error changes during training. The rightfigure shows how the natural logarithm of ||∆W ||22 changes during training, where ∆W isthe change in W after 100 training steps, W is the weight matrix.

To avoid getting the trivial optimizer u = 0, instead of using the functional

L0(x) =

∫Ω|∇u|2dx+

∫Ωvu2dx∫

Ωu2dx

+ β

∫∂Ω

u(x)2dx

we use

min

∫Ω|∇u|2dx+

∫Ωvu2dx∫

Ωu2dx

s.t.

∫Ω

|∇u|2dx = 1

u|∂Ω = 0

(22)

In practice, we use

L(u(x; θ)) =

∫Ω|∇u|2dx+

∫Ωvu2dx∫

Ωu2dx

+ β

∫∂Ω

u(x)2dx+ γ

(∫Ω

u2dx− 1

)2

(23)

One might suggest that with the last penalty term, the denominator in the Rayleighquotient is no longer necessary. It turns out that we found in practice that this term stillhelps in two ways: (1) In the presence of this denominator, there is no need to choose a

10

Page 11: The Deep Ritz method: A deep learning-based numerical ... · 2. A numerical quadrature rule for the functional. 3. An algorithm for solving the nal optimization problem. 2.1 Building

large value of γ. For the harmonic oscillator in d = 5, we choose β = 2000, γ to be 100and this seems to be large enough. (2) This term helps to speed up the training process.

To solve this problem, we build a deep neural network much like the Densenet [12].There are skip connections between every pairwise layers, which help gradients flow throughthe whole network. The network structure is shown in Figure 6.

yi = φ(Wi−1xi−1 + bi−1) (24)

xi = [xi−1; yi] (25)

We use an activation function φ(x) = max(0, x)2. If we use the same activation function asbefore, we found that the gradients can become quite large and we may face the gradientexplosion problem.

input x

FC layer (size m)+ activation

output u

FC layer (size m)+ activation

FC layer (size 1)

Figure 6: Network structure used for the eigenvalue problem. There are skip connectionsbetween every pairwise layers. The triangles denote concatenation operations.

The remaining components of the algorithm are very much the same as before.

Example 1: Infinite potential wellConsider the potential function

v(x) =

0, x ∈ [0, 1]d

∞, x /∈ [0, 1]d(26)

11

Page 12: The Deep Ritz method: A deep learning-based numerical ... · 2. A numerical quadrature rule for the functional. 3. An algorithm for solving the nal optimization problem. 2.1 Building

The problem is then equivalent to solving:

−∆u = Eu, x ∈ [0, 1]d

u(x) = 0, x ∈ ∂[0, 1]d(27)

The smallest eigenvalue is λ0 = dπ2.The results of the Deep Ritz method in different dimensions are shown in Table 2.

Table 2: Error of deep Ritz method

Dimension d Exact λ0 Approximate Error1 9.87 9.85 0.20%5 49.35 49.29 0.11%10 98.70 92.35 6.43%

Example 2: The harmonic oscillatorThe potential function in Rd is v(x) = |x|2. For simplicity, we truncate the compu-

tational domain from Rd to [−3, 3]d. Obviously, there are better strategies, but we leaveimprovements to later work.

The results in different dimensions are shown in Table 3.

Table 3: Error of deep Ritz method

Dimension d Exact λ0 Approximate Error1 1 1.0016 0.16%5 5 5.0814 1.6%10 10 11.26 12.6%

The results deteriorate substantially as the dimension is increased. We believe thatthere is still a lot of room for improving the results. We will leave this to future work.

4 Discussion

We proposed a variational method based on representing the trial functions by deepneural networks. Our limited experience with this method suggests that it has the followingadvantages:

1. It is naturally adaptive.

12

Page 13: The Deep Ritz method: A deep learning-based numerical ... · 2. A numerical quadrature rule for the functional. 3. An algorithm for solving the nal optimization problem. 2.1 Building

2. It is less sensitive to the dimensionality of the problem and has the potential to workin rather high dimensions.

3. The method is reasonably simple and fits well with the stochastic gradient descentframework commonly used in deep learning.

We also see a number of disadvantages that need to be addressed in future work:

1. The variational problem that we obtain at the end is not convex even when the initialproblem is. The issue of local minima and saddle points is non-trivial.

2. At the present time, there is no consistent conclusion about the convergence rate.

3. The treatment of the essential boundary condition is not as simple as for the tradi-tional methods.

In addition, there are still interesting issues regarding the choice of the network struc-ture, the activation function and the minimization algorithm. The present paper is farfrom being the last word on the subject.

Acknowledgement: We are grateful to Professor Ruo Li and Dr. Zhanxing Zhu forvery helpful discussions. The work of E and Yu is supported in part by the National KeyBasic Research Program of China 2015CB856000, Major Program of NNSFC under grant91130005, DOE grant de-sc0009248, and ONR grant N00014-13-1-0338.

References

[1] I. Goodfellow, Y. Bengio and A. Courville, Deep Learning. MIT Press, 2016.

[2] W. E, “A proposal for machine learning via dynamical systems”, Communications inMathematics and Statistics, March 2017, Volume 5, Issue 1, pp 1-11.

[3] J. Q. Han, A. Jentzen and W. E, “Overcoming the curse of dimensionality: Solv-ing high-dimensional partial differential equations using deep learning”, submitted,arXiv:1707.02568.

[4] W. E, J. Q. Han and A. Jentzen, “Deep learning-based numerical methods for high-dimensional parabolic partial differential equations and backward stochastic differen-tial equations”, submitted, arXiv:1706.04702.

[5] C. Beck, W. E and Arnulf Jentzen, “Machine learning approximation algorithms forhigh-dimensional fully nonlinear partial differential equations and second-order back-ward stochastic differential equations”, submitted. arXiv:1709.05963.

13

Page 14: The Deep Ritz method: A deep learning-based numerical ... · 2. A numerical quadrature rule for the functional. 3. An algorithm for solving the nal optimization problem. 2.1 Building

[6] J. Q. Han, L. Zhang, R. Car and W. E, “Deep Potential: A general and “first-principle”representation of the potential energy”, submitted, arXiv:1707.01478.

[7] L. Zhang, J.Q. Han, H. Wang, R. Car and W.E, “Deep Potential Molecular Dy-namics: A scalable model with the accuracy of quantum mechanics”, submitted,arXiv:1707.09571.

[8] L. C. Evans, Partial Differential Equations, 2nd ed. American Mathematical Society,2010.

[9] K. M. He, X. Y. Zhang, S. Q. Ren, J. Sun, “Deep residual learning for image recogni-tion”, 2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR),vol. 00, no. , pp. 770-778, 2016, doi:10.1109/CVPR.2016.90

[10] D. P. Kingma, and J. Ba. “Adam: A method for stochastic optimization.” arXivpreprint arXiv:1412.6980, 2014.

[11] G. Strang and G. Fix, An Analysis of the Finite Element Method. Prentice-Hall, 1973.

[12] G. Huang, Z. Liu, K. Q. Weinberger, V. D. M. Laurens, “Densely connected convolu-tional networks.”, arXiv preprint arXiv:1608.06993, 2016.

14


Recommended