Commit 9988d3bd authored by zhiguo shang's avatar zhiguo shang
Browse files

Dr.Sammer_train_test

parent 6803dce3
%% data loading
fnTr_Orientation = '/project/BICF/BICF_Core/shared/bicf_helpdesk/issue1078_Sammer/data/Train Set for Orientation.xlsx';
fnTe_Orientation = '/project/BICF/BICF_Core/shared/bicf_helpdesk/issue1078_Sammer/data/Test Set for Orientation.xlsx';
fnTr_Screw = '/project/BICF/BICF_Core/shared/bicf_helpdesk/issue1078_Sammer/data/Train Set for Plate Screw.xlsx';
fnTe_Screw = '/project/BICF/BICF_Core/shared/bicf_helpdesk/issue1078_Sammer/data/Test Set for Plate Screw.xlsx';
trO= readtable(fnTr_Orientation,'readvariablenames',false);
teO= readtable(fnTe_Orientation,'readvariablenames',false);
trS= readtable(fnTr_Screw,'readvariablenames',false);
teS= readtable(fnTe_Screw,'readvariablenames',false);
Target_O = [trO.Var2; teO.Var2];
filelist_O = [trO.Var1; teO.Var1];
Target_S = [trS.Var2; teS.Var2];
filelist_S = [trS.Var1; teS.Var1];
fnmrcS = 'datatSn.mrc';
fnmrcO = 'datatOn.mrc';
addpath(genpath('~/code_shang'))
chan = 1;
%data = ReadMRC('datafS0.mrc');
fname = struct();
fname(1).fn = 'datatO.mrc';
fname(2).fn = 'datatSn_centerR.mrc';
%ind_remove = [240 246 258 318 319 320 330 1268 1269]; %for HW data
%ind_remove = [240 246 258 318 319 320 330]; %for Orientation data
%data(:,:,ind_remove)=[];
%Target_O(ind_remove)=[];
for fn_i =2%1:2
fnmr = fname(fn_i).fn;
data = ReadMRC(fnmr);
ns = size(data);
datan = zeros(ns(1),ns(2),chan,size(Target_S,1));
%datatOn_centerR.mrc; (segmented-datafS0.mrc; datafS1.mrc),datatSn_centerR.mrc
% raw data (datatS.mrc, datatO.mrc)
%% apply a mask to dataset
ns = size(data);
mask = fuzzymask(ns(1),2,0.48*ns(1),1);
for i = 1:ns(3)
datan(:,:,1,i) = mask.*(data(:,:,i));
end
data = datan;
clear datan;
%imresize(imr,[28 28])
%% ML
%indr = [240 282];
box = 300;
YTrain = categorical([trS.Var2; teS.Var2]);%
XTrain = zeros(box,box,chan,size(YTrain,1));
XTrain(:,:,:,:) = data(150-box/2+1:150+box/2,150-box/2+1:150+box/2,:,1:size(YTrain,1));
%XTrain is a 28-by-28-by-1-by-5000 array
%YTrain is a categorical vector containing the labels for each observation.
ind1 = find(YTrain=='0');
ind2 = find(YTrain=='1');
ind1x = randperm(length(ind1),floor(0.2*length(ind1)));
ind2x = randperm(length(ind2),floor(0.2*length(ind2)));
ind3 = find(YTrain=='2');
ind3x = randperm(length(ind3),floor(0.2*length(ind3)));
idx = [ind1(ind1x)' ind2(ind2x)' ind3(ind3x)'];
%save('idx_O.mat','idx');
load('idx_S.mat');
%Xtest = XTrain(:,:,:,idx);
%XTrain(:,:,:,idx) = [];
%Ytest = YTrain(idx);
%YTrain(idx) = [];
%use Dr. Sammers test data set
Xtest = XTrain(:,:,:,1+size(trS,1):end);
XTrain(:,:,:,1+size(trS,1):end) = [];
Ytest = YTrain(1+size(trS,1):end);
YTrain(1+size(trS,1):end) = [];
ind1 = find(YTrain=='0');
ind2 = find(YTrain=='1');
ind1x = randperm(length(ind1),floor(0.2*length(ind1)));
ind2x = randperm(length(ind2),floor(0.2*length(ind2)));
ind3 = find(YTrain=='2');
ind3x = randperm(length(ind3),floor(0.2*length(ind3)));
idx = [ind1(ind1x)' ind2(ind2x)' ind3(ind3x)'];
XValidation = XTrain(:,:,:,idx);
XTrain(:,:,:,idx) = [];
YValidation = YTrain(idx);
YTrain(idx) = [];
%Create an imageDataAugmenter object that specifies preprocessing options
%for image augmentation, such as resizing, rotation, translation, and reflection.
imageAugmenter = imageDataAugmenter( ...
'RandRotation',[-22,22], ...
'RandXReflection', true, ...
'RandYReflection', true, ...
'RandXTranslation',[0 0], ...
'RandYTranslation',[0 0]);
imageSize = [box box chan];
augimds = augmentedImageDatastore(imageSize,XTrain,YTrain,'DataAugmentation',...
imageAugmenter);
%Define Network Architecture
%Define the convolutional neural network architecture.
layers = [
imageInputLayer(imageSize)
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,16,'Padding','same')
batchNormalizationLayer
reluLayer
maxPooling2dLayer(2,'Stride',2)
convolution2dLayer(3,32,'Padding','same')
batchNormalizationLayer
reluLayer
fullyConnectedLayer(2)
softmaxLayer
classificationLayer];
%
opts = trainingOptions('sgdm', ...
'InitialLearnRate',0.01, ...
'MaxEpochs',20, ...
'MiniBatchSize',128, ...
'Plots','training-progress', ...
'Shuffle','every-epoch', ...
'ValidationFrequency',8, ...
'ValidationPatience',50,...
'Verbose',false, ...
'ValidationData',{XValidation,YValidation});
%% add fearure parameter to CNN net
%% start training
infot=struct();
for i=1%:20
disp(i)
[net, info] = trainNetwork(augimds,layers,opts);
%infot(i).info = info;
end
%save([fnmr(1:end-4),'_info_100.mat'],'net','infot','-v7.3');
end
cgt = double(Ytest);
[clabel,scores] = classify(net,Xtest);
%diffscore = scores(:,2)-max(scores(:,1),scores(:,3));
%[X, Y, T, AUC1, OPTROCPT, SUBY, SUBYNAMES] = perfcurve(cgt,diffscore,2);
[X, Y, T, AUC1, OPTROCPT, SUBY, SUBYNAMES] = perfcurve(cgt,scores(:,1),1);
cgt = double(YTrain);
[clabel,scores] = classify(net,XTrain);
%diffscore = scores(:,2)-max(scores(:,1),scores(:,3));
%[Xt, Yt, T, AUC2, OPTROCPT, SUBY, SUBYNAMES] = perfcurve(cgt,diffscore,2);
[Xt, Yt, T, AUC2, OPTROCPT, SUBY, SUBYNAMES] = perfcurve(cgt,scores(:,1),1);
cgt = double(YValidation);
[clabel,scores] = classify(net,XValidation);
%diffscore = scores(:,2)-max(scores(:,1),scores(:,3));
%[Xv, Yv, T, AUC3, OPTROCPT, SUBY, SUBYNAMES] = perfcurve(cgt,diffscore,2);
[Xv, Yv, T, AUC3, OPTROCPT, SUBY, SUBYNAMES] = perfcurve(cgt,scores(:,1),1);
figure(1);
plot(Xt,Yt,'r-','linewidth',2);
hold on;
plot(X,Y,'b-','linewidth',2);
hold on;
plot(Xv,Yv,'k-','linewidth',2);
grid
xlabel('False positive rate')
ylabel('True positive rate')
title('ROC for Classification of HW of Wrist X-ray');
set(gca,'fontweight','bold'); set(gca,'fontsize',14,'linewidth',3);
set(gcf,'color','w');
legend('Train','test','Validation')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment