Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhiguo shang
issue351_Sammer_Wrist X-ray iamge CNN prediction
Commits
9988d3bd
Commit
9988d3bd
authored
Sep 17, 2020
by
zhiguo shang
Browse files
Dr.Sammer_train_test
parent
6803dce3
Changes
1
Hide whitespace changes
Inline
Side-by-side
CNNWrist_20200917.m
0 → 100644
View file @
9988d3bd
%% data loading
fnTr_Orientation
=
'/project/BICF/BICF_Core/shared/bicf_helpdesk/issue1078_Sammer/data/Train Set for Orientation.xlsx'
;
fnTe_Orientation
=
'/project/BICF/BICF_Core/shared/bicf_helpdesk/issue1078_Sammer/data/Test Set for Orientation.xlsx'
;
fnTr_Screw
=
'/project/BICF/BICF_Core/shared/bicf_helpdesk/issue1078_Sammer/data/Train Set for Plate Screw.xlsx'
;
fnTe_Screw
=
'/project/BICF/BICF_Core/shared/bicf_helpdesk/issue1078_Sammer/data/Test Set for Plate Screw.xlsx'
;
trO
=
readtable
(
fnTr_Orientation
,
'readvariablenames'
,
false
);
teO
=
readtable
(
fnTe_Orientation
,
'readvariablenames'
,
false
);
trS
=
readtable
(
fnTr_Screw
,
'readvariablenames'
,
false
);
teS
=
readtable
(
fnTe_Screw
,
'readvariablenames'
,
false
);
Target_O
=
[
trO
.
Var2
;
teO
.
Var2
];
filelist_O
=
[
trO
.
Var1
;
teO
.
Var1
];
Target_S
=
[
trS
.
Var2
;
teS
.
Var2
];
filelist_S
=
[
trS
.
Var1
;
teS
.
Var1
];
fnmrcS
=
'datatSn.mrc'
;
fnmrcO
=
'datatOn.mrc'
;
addpath
(
genpath
(
'~/code_shang'
))
chan
=
1
;
%data = ReadMRC('datafS0.mrc');
fname
=
struct
();
fname
(
1
)
.
fn
=
'datatO.mrc'
;
fname
(
2
)
.
fn
=
'datatSn_centerR.mrc'
;
%ind_remove = [240 246 258 318 319 320 330 1268 1269]; %for HW data
%ind_remove = [240 246 258 318 319 320 330]; %for Orientation data
%data(:,:,ind_remove)=[];
%Target_O(ind_remove)=[];
for
fn_i
=
2
%1:2
fnmr
=
fname
(
fn_i
)
.
fn
;
data
=
ReadMRC
(
fnmr
);
ns
=
size
(
data
);
datan
=
zeros
(
ns
(
1
),
ns
(
2
),
chan
,
size
(
Target_S
,
1
));
%datatOn_centerR.mrc; (segmented-datafS0.mrc; datafS1.mrc),datatSn_centerR.mrc
% raw data (datatS.mrc, datatO.mrc)
%% apply a mask to dataset
ns
=
size
(
data
);
mask
=
fuzzymask
(
ns
(
1
),
2
,
0.48
*
ns
(
1
),
1
);
for
i
=
1
:
ns
(
3
)
datan
(:,:,
1
,
i
)
=
mask
.*
(
data
(:,:,
i
));
end
data
=
datan
;
clear
datan
;
%imresize(imr,[28 28])
%% ML
%indr = [240 282];
box
=
300
;
YTrain
=
categorical
([
trS
.
Var2
;
teS
.
Var2
]);
%
XTrain
=
zeros
(
box
,
box
,
chan
,
size
(
YTrain
,
1
));
XTrain
(:,:,:,:)
=
data
(
150
-
box
/
2
+
1
:
150
+
box
/
2
,
150
-
box
/
2
+
1
:
150
+
box
/
2
,:,
1
:
size
(
YTrain
,
1
));
%XTrain is a 28-by-28-by-1-by-5000 array
%YTrain is a categorical vector containing the labels for each observation.
ind1
=
find
(
YTrain
==
'0'
);
ind2
=
find
(
YTrain
==
'1'
);
ind1x
=
randperm
(
length
(
ind1
),
floor
(
0.2
*
length
(
ind1
)));
ind2x
=
randperm
(
length
(
ind2
),
floor
(
0.2
*
length
(
ind2
)));
ind3
=
find
(
YTrain
==
'2'
);
ind3x
=
randperm
(
length
(
ind3
),
floor
(
0.2
*
length
(
ind3
)));
idx
=
[
ind1
(
ind1x
)
' ind2(ind2x)'
ind3
(
ind3x
)
'
];
%save('idx_O.mat','idx');
load
(
'idx_S.mat'
);
%Xtest = XTrain(:,:,:,idx);
%XTrain(:,:,:,idx) = [];
%Ytest = YTrain(idx);
%YTrain(idx) = [];
%use Dr. Sammers test data set
Xtest
=
XTrain
(:,:,:,
1
+
size
(
trS
,
1
):
end
);
XTrain
(:,:,:,
1
+
size
(
trS
,
1
):
end
)
=
[];
Ytest
=
YTrain
(
1
+
size
(
trS
,
1
):
end
);
YTrain
(
1
+
size
(
trS
,
1
):
end
)
=
[];
ind1
=
find
(
YTrain
==
'0'
);
ind2
=
find
(
YTrain
==
'1'
);
ind1x
=
randperm
(
length
(
ind1
),
floor
(
0.2
*
length
(
ind1
)));
ind2x
=
randperm
(
length
(
ind2
),
floor
(
0.2
*
length
(
ind2
)));
ind3
=
find
(
YTrain
==
'2'
);
ind3x
=
randperm
(
length
(
ind3
),
floor
(
0.2
*
length
(
ind3
)));
idx
=
[
ind1
(
ind1x
)
' ind2(ind2x)'
ind3
(
ind3x
)
'
];
XValidation
=
XTrain
(:,:,:,
idx
);
XTrain
(:,:,:,
idx
)
=
[];
YValidation
=
YTrain
(
idx
);
YTrain
(
idx
)
=
[];
%Create an imageDataAugmenter object that specifies preprocessing options
%for image augmentation, such as resizing, rotation, translation, and reflection.
imageAugmenter
=
imageDataAugmenter
(
...
'RandRotation'
,[
-
22
,
22
],
...
'RandXReflection'
,
true
,
...
'RandYReflection'
,
true
,
...
'RandXTranslation'
,[
0
0
],
...
'RandYTranslation'
,[
0
0
]);
imageSize
=
[
box
box
chan
];
augimds
=
augmentedImageDatastore
(
imageSize
,
XTrain
,
YTrain
,
'DataAugmentation'
,
...
imageAugmenter
);
%Define Network Architecture
%Define the convolutional neural network architecture.
layers
=
[
imageInputLayer
(
imageSize
)
convolution2dLayer
(
3
,
16
,
'Padding'
,
'same'
)
batchNormalizationLayer
reluLayer
maxPooling2dLayer
(
2
,
'Stride'
,
2
)
convolution2dLayer
(
3
,
16
,
'Padding'
,
'same'
)
batchNormalizationLayer
reluLayer
maxPooling2dLayer
(
2
,
'Stride'
,
2
)
convolution2dLayer
(
3
,
16
,
'Padding'
,
'same'
)
batchNormalizationLayer
reluLayer
maxPooling2dLayer
(
2
,
'Stride'
,
2
)
convolution2dLayer
(
3
,
16
,
'Padding'
,
'same'
)
batchNormalizationLayer
reluLayer
maxPooling2dLayer
(
2
,
'Stride'
,
2
)
convolution2dLayer
(
3
,
32
,
'Padding'
,
'same'
)
batchNormalizationLayer
reluLayer
fullyConnectedLayer
(
2
)
softmaxLayer
classificationLayer
];
%
opts
=
trainingOptions
(
'sgdm'
,
...
'InitialLearnRate'
,
0.01
,
...
'MaxEpochs'
,
20
,
...
'MiniBatchSize'
,
128
,
...
'Plots'
,
'training-progress'
,
...
'Shuffle'
,
'every-epoch'
,
...
'ValidationFrequency'
,
8
,
...
'ValidationPatience'
,
50
,
...
'Verbose'
,
false
,
...
'ValidationData'
,{
XValidation
,
YValidation
});
%% add fearure parameter to CNN net
%% start training
infot
=
struct
();
for
i
=
1
%:20
disp
(
i
)
[
net
,
info
]
=
trainNetwork
(
augimds
,
layers
,
opts
);
%infot(i).info = info;
end
%save([fnmr(1:end-4),'_info_100.mat'],'net','infot','-v7.3');
end
cgt
=
double
(
Ytest
);
[
clabel
,
scores
]
=
classify
(
net
,
Xtest
);
%diffscore = scores(:,2)-max(scores(:,1),scores(:,3));
%[X, Y, T, AUC1, OPTROCPT, SUBY, SUBYNAMES] = perfcurve(cgt,diffscore,2);
[
X
,
Y
,
T
,
AUC1
,
OPTROCPT
,
SUBY
,
SUBYNAMES
]
=
perfcurve
(
cgt
,
scores
(:,
1
),
1
);
cgt
=
double
(
YTrain
);
[
clabel
,
scores
]
=
classify
(
net
,
XTrain
);
%diffscore = scores(:,2)-max(scores(:,1),scores(:,3));
%[Xt, Yt, T, AUC2, OPTROCPT, SUBY, SUBYNAMES] = perfcurve(cgt,diffscore,2);
[
Xt
,
Yt
,
T
,
AUC2
,
OPTROCPT
,
SUBY
,
SUBYNAMES
]
=
perfcurve
(
cgt
,
scores
(:,
1
),
1
);
cgt
=
double
(
YValidation
);
[
clabel
,
scores
]
=
classify
(
net
,
XValidation
);
%diffscore = scores(:,2)-max(scores(:,1),scores(:,3));
%[Xv, Yv, T, AUC3, OPTROCPT, SUBY, SUBYNAMES] = perfcurve(cgt,diffscore,2);
[
Xv
,
Yv
,
T
,
AUC3
,
OPTROCPT
,
SUBY
,
SUBYNAMES
]
=
perfcurve
(
cgt
,
scores
(:,
1
),
1
);
figure
(
1
);
plot
(
Xt
,
Yt
,
'r-'
,
'linewidth'
,
2
);
hold
on
;
plot
(
X
,
Y
,
'b-'
,
'linewidth'
,
2
);
hold
on
;
plot
(
Xv
,
Yv
,
'k-'
,
'linewidth'
,
2
);
grid
xlabel
(
'False positive rate'
)
ylabel
(
'True positive rate'
)
title
(
'ROC for Classification of HW of Wrist X-ray'
);
set
(
gca
,
'fontweight'
,
'bold'
);
set
(
gca
,
'fontsize'
,
14
,
'linewidth'
,
3
);
set
(
gcf
,
'color'
,
'w'
);
legend
(
'Train'
,
'test'
,
'Validation'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment