AWS re:Invent 2016: Automatic Grading of Diabetic Retinopathy through Deep Learning (MAC403)

Post on 16-Apr-2017

260 views 0 download

transcript

© 2016, Amazon Web Services, Inc. or its Affiliates. All rights reserved.

Advisors: Robert Chang, Jeff Ullman, Andreas Paepcke

November 30, 2016

Automatic Grading of Diabetic

Retinopathy through Deep LearningApaar Sadhwani, Leo Tam, and Jason Su

MAC403

Problem, Data and Motivation Motivation:

Affects ~100M, many in developed, ~45% of diabetics Make process faster, assist ophthalmologist, self-help Widespread disease, enable early diagnosis/care

Given fundus image Rate severity of Diabetic Retinopathy 5 Classes: 0 (Normal), 1, 2, 3, 4 (Severe) Hard classification (may solve as ordinal though) Metric: quadratic weighted kappa, (pred – real)2 penalty

Data from Kaggle (California Healthcare Foundation, EyePACS) ~35,000 training images, ~54,000 test images High resolution: variable, more than 2560 x 1920

Problem, Data and Motivation Motivation:

Affects ~100M, many in developed, ~45% of diabetics Make process faster, assist ophthalmologist, self-help Widespread disease, enable early diagnosis/care

Given fundus image Rate severity of Diabetic Retinopathy 5 Classes: 0 (Normal), 1, 2, 3, 4 (Severe) Hard classification (may solve as ordinal though) Metric: quadratic weighted kappa, (pred – real)2 penalty

Data from Kaggle (California Healthcare Foundation, EyePACS) ~35,000 training images, ~54,000 test images High resolution: variable, more than 2560 x 1920

Example images

Class 0 (normal) Class 4 (severe)

Problem, Data and Motivation Motivation:

Affects ~100M, many in developed, ~45% of diabetics Make process faster, assist ophthalmologist, self-help Widespread disease, enable early diagnosis/care

Given fundus image Rate severity of Diabetic Retinopathy 5 Classes: 0 (Normal), 1, 2, 3, 4 (Severe) Hard classification (may solve as ordinal though) Metric: quadratic weighted kappa, (pred – real)2 penalty

Data from Kaggle (California Healthcare Foundation, EyePACS) ~35,000 training images, ~54,000 test images High resolution: variable, more than 2560 x 1920

Challenges High resolution images

Atypical in vision, GPU batch size issues

Discriminative features small Grading criteria:

not clear (EyePACS guidelines) learn from data

Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance

class 0 dominates

Too few training examples

Image size Batch Size224 x 224 1282K x 2K 2

Challenges High resolution images

Atypical in vision, GPU batch size issues

Discriminative features small Grading criteria:

not clear (EyePACS guidelines) learn from data

Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance

class 0 dominates

Too few training examples

Class 0 1

2 3

4

Challenges High resolution images

Atypical in vision, GPU batch size issues

Discriminative features small Grading criteria:

not clear (EyePACS guidelines) learn from data

Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance

class 0 dominates

Too few training examples Class 2

Challenges High resolution images

Atypical in vision, GPU batch size issues

Discriminative features small Grading criteria:

not clear (EyePACS guidelines) learn from data

Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance

class 0 dominates

Too few training examples

Challenges High resolution images

Atypical in vision, GPU batch size issues

Discriminative features small Grading criteria:

not clear (EyePACS guidelines) learn from data

Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance

class 0 dominates

Too few training examples

- Mentioned in problem statement- Confirmed with doctors

Challenges High resolution images

Atypical in vision, GPU batch size issues

Discriminative features small Grading criteria:

not clear (EyePACS guidelines) learn from data

Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance

class 0 dominates

Too few training examples

Challenges High resolution images

Atypical in vision, GPU batch size issues

Discriminative features small Grading criteria:

not clear (EyePACS guidelines) learn from data

Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance

class 0 dominates

Too few training examples

- Hard classification non-differentiable- Backprop difficult

0 1Truth

2 3 4

Penalty/Loss

Class

Challenges High resolution images

Atypical in vision, GPU batch size issues

Discriminative features small Grading criteria:

not clear (EyePACS guidelines) learn from data

Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance

class 0 dominates

Too few training examples

- Hard classification non-differentiable- Backprop difficult

0 1Truth

2 3 4

Predict1

Penalty/Loss

Class

Challenges High resolution images

Atypical in vision, GPU batch size issues

Discriminative features small Grading criteria:

not clear (EyePACS guidelines) learn from data

Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance

class 0 dominates

Too few training examples

- Hard classification non-differentiable- Backprop difficult

0 1Truth

2 3 4

Predict2

Penalty/Loss

Class

Challenges High resolution images

Atypical in vision, GPU batch size issues

Discriminative features small Grading criteria:

not clear (EyePACS guidelines) learn from data

Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance

class 0 dominates

Too few training examples

- Hard classification non-differentiable- Backprop difficult

0 1Truth

2 3 4

Predict3

Penalty/Loss

Class

Challenges High resolution images

Atypical in vision, GPU batch size issues

Discriminative features small Grading criteria:

not clear (EyePACS guidelines) learn from data

Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance

class 0 dominates

Too few training examples

- Hard classification non-differentiable- Backprop difficult

0 1Truth

2 3 4

Penalty/Loss

Class

Challenges High resolution images

Atypical in vision, GPU batch size issues

Discriminative features small Grading criteria:

not clear (EyePACS guidelines) learn from data

Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance

class 0 dominates

Too few training examples

- Squared error approximation?- Differentiable

0 1Truth

2 3 4

Penalty/Loss

Class2.5

Challenges High resolution images

Atypical in vision, GPU batch size issues

Discriminative features small Grading criteria:

not clear (EyePACS guidelines) learn from data

Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance

class 0 dominates

Too few training examples

- Naïve: 3 class problem, or all zeros!- Learn all classes separately: 1 vs All?- Balanced while training

- At test time?

Challenges High resolution images

Atypical in vision, GPU batch size issues

Discriminative features small Grading criteria:

not clear (EyePACS guidelines) learn from data

Incorrect labeling Artifacts in ~40% images Optimizing approach to QWK Severe class imbalance

class 0 dominates

Too few training examples

- Big learning models take more data!- Harness test set?

Conventional Approaches Literature survey:

Hand-designed features to pick each component

Clean images, small datasets Optic disk, exudate segmentation: fail

due to artifacts SVM: poor performance

Conventional Approaches Literature survey:

Hand-designed features to pick each component

Clean images, small datasets Optic disk, exudate segmentation: fail

due to artifacts SVM: poor performance

Our Approach

1. Registration, Pre-processing2. Convolutional Neural Nets (CNNs)3. Hybrid Architecture

Step 1: Pre-processing

Registration

Hough circles, remove outside portion

Downsize to common size (224 x 224, 1K x 1K)

Color correction Normalization (mean, variance)

Step 2: CNNs

3 Conv layers (depth 96)

MaxPool (stride2)

3 Conv layers (depth 384)

MaxPool (stride2)

3 Conv layers (depth 1024)

MaxPool (stride2)

AvgPool

Input Image

Class probabilities

3 Conv layers (depth 256)

MaxPool (stride2)

Network in Network architecture 7.5M parameters No FC layers, spatial average pooling instead

Transfer learning (ImageNet) Variable learning rates

Low for “ImageNet” layers Schedule

Combat lack of data, over-fitting Dropout, Early stopping Data augmentation (flips, rotation)

Step 2: CNNs

3 Conv layers (depth 96)

MaxPool (stride2)

3 Conv layers (depth 384)

MaxPool (stride2)

3 Conv layers (depth 1024)

MaxPool (stride2)

AvgPool

Input Image

Class probabilities

3 Conv layers (depth 256)

MaxPool (stride2)

Network in Network architecture 7.5M parameters No FC layers, spatial average pooling instead

Transfer learning (ImageNet) Variable learning rates

Low for “ImageNet” layers Schedule

Combat lack of data, over-fitting Dropout, Early stopping Data augmentation (flips, rotation)

Step 2: CNNs

3 Conv layers (depth 96)

MaxPool (stride2)

3 Conv layers (depth 384)

MaxPool (stride2)

3 Conv layers (depth 384, 64, 5)

MaxPool (stride2)

AvgPool

Input Image

Class probabilities

3 Conv layers (depth 256)

MaxPool (stride2)

Network in Network architecture 2.2M parameters No FC layers, spatial average pooling instead

Transfer learning (ImageNet) Variable learning rates

Low for “ImageNet” layers Schedule

Combat lack of data, over-fitting Dropout, Early stopping Data augmentation (flips, rotation)

Step 2: CNNs

3 Conv layers (depth 384, 64, 5)

MaxPool (stride2)

AvgPool

Input Image

Class probabilities

Network in Network architecture 2.2M parameters No FC layers, spatial average pooling instead

Transfer learning (ImageNet) Variable learning rates

Low for “ImageNet” layers Schedule

Combat lack of data, over-fitting Dropout, Early stopping Data augmentation (flips, rotation)

Step 2: CNNs

3 Conv layers (depth 384, 64, 5)

MaxPool (stride2)

AvgPool

Input Image

Class probabilities

Network in Network architecture 2.2M parameters No FC layers, spatial average pooling instead

Transfer learning (ImageNet) Variable learning rates

Low for “ImageNet” layers Schedule

Combat lack of data, over-fitting Dropout, Early stopping Data augmentation (flips, rotation)

Step 2: CNNs

Network in Network architecture 2.2M parameters No FC layers, spatial average pooling instead

Transfer learning (ImageNet) Variable learning rates

Low for “ImageNet” layers Schedule

Combat lack of data, over-fitting Dropout, Early stopping Data augmentation (flips, rotation)

3 Conv layers (depth 384, 64, 5)

MaxPool (stride2)

AvgPool

Input Image

Class probabilities

Step 2: CNN Experiments

What image size to use? Strategize using 224 x 224 -> extend to 1024 x 1024

What loss function? Mean squared error (MSE) Negative Log Likelihood (NLL) Linear Combination (annealing)

Class imbalance Even sampling -> true sampling

Step 2: CNN Experiments

3 Conv layers (depth 384, 64, 5)

MaxPool (stride2)

AvgPool

Input Image

Class probabilities

Nolearning

Loss Function Sampling Result

Image size: 224 x 224

Step 2: CNN Experiments

3 Conv layers (depth 384, 64, 5)

MaxPool (stride2)

AvgPool

Input Image

Class probabilities

Nolearning

Loss Function Sampling Result

MSE Fails to learn

Image size: 224 x 224

Step 2: CNN Experiments

Loss Function Sampling Result

MSE Fails to learn

MSE Fails to learn

Image size: 224 x 224

3 Conv layers (depth 384, 64, 5)

MaxPool (stride2)

AvgPool

Input Image

Class probabilities

Nolearning

Step 2: CNN Experiments

Loss Function Sampling Result

MSE Fails to learn

MSE Fails to learn

NLL Kappa < 0.1

Image size: 224 x 224

3 Conv layers (depth 384, 64, 5)

MaxPool (stride2)

AvgPool

Input Image

Class probabilities

Nolearning

Step 2: CNN Experiments

Loss Function Sampling Result

MSE Fails to learn

MSE Fails to learn

NLL Kappa < 0.1

NLL Kappa = 0.29

Image size: 224 x 224

3 Conv layers (depth 384, 64, 5)

MaxPool (stride2)

AvgPool

Input Image

Class probabilities

Nolearning

Step 2: CNN Experiments

3 Conv layers (depth 384, 64, 5)

MaxPool (stride2)

AvgPool

Input Image

Class probabilities

0.01x step size

Loss Function Sampling Result

NLL(top layers only)

Kappa = 0.29

Image size: 224 x 224

Step 2: CNN Experiments

Loss Function Sampling Result

NLL(top layers only)

Kappa = 0.29

NLL Kappa = 0.42

Image size: 224 x 224

3 Conv layers (depth 384, 64, 5)

MaxPool (stride2)

AvgPool

Input Image

Class probabilities

0.01x step size

Step 2: CNN Experiments

Loss Function Sampling Result

NLL(top layers only)

Kappa = 0.29

NLL Kappa = 0.42

NLL Kappa = 0.51

Image size: 224 x 224

3 Conv layers (depth 384, 64, 5)

MaxPool (stride2)

AvgPool

Input Image

Class probabilities

0.01x step size

Step 2: CNN Experiments

Loss Function Sampling Result

NLL(top layers only)

Kappa = 0.29

NLL Kappa = 0.42

NLL Kappa = 0.51

MSE Kappa = 0.56

Image size: 224 x 224

3 Conv layers (depth 384, 64, 5)

MaxPool (stride2)

AvgPool

Input Image

Class probabilities

0.01x step size

Step 2: CNN Results

Step 2: CNN Results

Computing Setup

Amazon EC2: GPU nodes, VPC, Amazon EBS-optimized Single GPU nodes for 224 x 224 (g2.2xlarge) Multi-GPU nodes for 1K x 1K (g2.8xlarge)

EBS, Amazon S3

Used Python for processing

Torch library (Lua) for training

Computing Setup

Data EBS (gp2)

Model Expt.

1 or 4 GPU node on EC2

Computing Setup

Data 1 Data 2EBS (gp2) EBS (gp2)

Snapshot (S3)

Model Expt.

GPU node on EC2

Computing Setup

Master

Data 1 Data 2Central Node

Model 2

Model 1

Model 10

EBS (gp2)

EBS-optimized

EBS (gp2)

Snapshot (S3)

VPC on EC2

Model Expt.

GPU node on EC2

Computing Setup

Master

Data 1 Data 2Central Node

Model 2

Model 1

Model 10

EBS (gp2)

EBS-optimized

EBS (gp2)

Snapshot (S3)

VPC on EC2

Model Expt.

GPU node on EC2~200 MB/s

Computing Setup

Master 2

Data 1 Data 2Central Node

Model 12

Model 11

Model 20

EBS (gp2)

EBS-optimized

EBS (gp2)

Snapshot (S3)

VPC on EC2

Master 1

Central Node

Model 2

Model 1

Model 10…

EBS-optimized VPC on EC2

Computing Setup

g2.2xlarge1 GPU node on EC2

4 GB GPU memoryBatch size: 128 images of 224 x 224

Computing Setup

g2.2xlarge1 GPU node on EC2

4 GB GPU memoryBatch size: 128 images of 224 x 224

!! Batch size: 8 images of 1024 x 1024 !!

Computing Setup

g2.2xlarge1 GPU node on EC2

4 GB GPU memoryBatch size: 128 images of 224 x 224

!! Batch size: 8 images of 1024 x 1024 !!

g2.8xlarge4 GPU node on EC2

16 GB GPU memoryData ParallelismBatch size: ~28 images of 1024 x 1024

Step 3: Hybrid Architecture

2048 1024

64 tiles of256 x 256

MainNetwork

Fuse

Class probabilities

LesionDetector

Lesion Detector

Web viewer and annotation tool Lesion annotation Extract image patches Train lesion classifier

Viewer and Lesion Annotation

Viewer and Lesion Annotation

Lesion Annotation

Extracted Image Patches

Train Lesion Detector

Only hemorrhages so far Positives: 1866 extracted patches from 216

images/subjects Negatives: ~25k class-0 images Pre-processing/augmentation

Crop random 256 x 256 image from input, flips Pre-trained Network in Network architecture Accuracy: 99% for Negatives, 76% for Positives

Train Lesion Detector

Only hemorrhages so far Positives: 1866 extracted patches from 216

images/subjects Negatives: ~25k class-0 images Pre-processing/augmentation

Crop random 256 x 256 image from input, flips Pre-trained Network in Network architecture Accuracy: 99% for Negatives, 76% for Positives

Train Lesion Detector

Only hemorrhages so far Positives: 1866 extracted patches from 216

images/subjects Negatives: ~25k class-0 images Pre-processing/augmentation

Crop random 256 x 256 image from input, flips Pre-trained Network in Network architecture Accuracy: 99% for Negatives, 76% for Positives

Train Lesion Detector

Only hemorrhages so far Positives: 1866 extracted patches from 216

images/subjects Negatives: ~25k class-0 images Pre-processing/augmentation

Crop random 256 x 256 image from input, flips Pre-trained Network in Network architecture Accuracy: 99% for Negatives, 76% for Positives

Train Lesion Detector

Only hemorrhages so far Positives: 1866 extracted patches from 216

images/subjects Negatives: ~25k class-0 images Pre-processing/augmentation

Crop random 256 x 256 image from input, flips Pre-trained Network in Network architecture Accuracy: 99% for Negatives, 76% for Positives

Hybrid Architecture

64 tiles of256 x 256

2048 1024

MainNetwork

Fuse

Class probabilities

LesionDetector

Hybrid Architecture

64 tiles of256 x 256

64 x 31 x 312 x 31 x 31

66 x 31 x 31

2048 1024

2 Conv layers

MainNetwork

Fuse

Class probabilities

LesionDetector

Hybrid Architecture

64 tiles of256 x 256

64 x 31 x 312 x 31 x 31

66 x 31 x 31

2048 1024

2 Conv layers

MainNetwork

Fuse

Class probabilities

LesionDetector

2 x 56 x56

Training Hybrid Architecture

Class probabilities

Training Hybrid Architecture

64 tiles of256 x 256

2048 1024

MainNetwork

Fuse

LesionDetector

Training Hybrid Architecture

64 tiles of256 x 256

Backprop

2048 1024

MainNetwork

Fuse

Class probabilities

LesionDetector

Training Hybrid Architecture

64 tiles of256 x 256

Backprop

2048 1024

MainNetwork

Fuse

Class probabilities

LesionDetector

Other Insights

Supervised-unsupervised learning Distillation Hard-negative mining Other lesion detectors Attention CNNs Both eyes Ensemble

Clinical Importance

3 class problem True “4” problem Combining imaging modalities (OCT) Longitudinal analysis

Many thanks to…

Amazon Web Services AWS Educate AWS Cloud Credits for Research

Robert Chang Jeff Ullman Andreas Paepcke

Thank you!

Remember to complete your evaluations!