diff --git a/DRPCA_Net/admm_DRPCANet_v8.py b/DRPCA_Net/admm_DRPCANet_v8.py new file mode 100644 index 0000000000000000000000000000000000000000..39087c1d7b66239e759f0ad9a86161571428846e --- /dev/null +++ b/DRPCA_Net/admm_DRPCANet_v8.py @@ -0,0 +1,622 @@ +import numpy as np +import torch +import torch.utils.data as data +from torch import nn +from torch.optim import Adam +import argparse +from scipy.io import loadmat +from scipy import io +import os + +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +plt.ioff() + + +def parse_args(): + parser = argparse.ArgumentParser("Network's parameters") + parser.add_argument( + "--layers", + type=int, + default=20, + help="Number of layers/iterations", + ) + parser.add_argument( + "--epochs", + type=int, + default=100, + help="Number of epochs", + ) + parser.add_argument( + "--learning-rate", + type = float, + default=0.004, + help = "Learning rate", + ) + parser.add_argument( + "--kernel-size", + type = int, + default=9, + help = "Kernel size", + ) + + parser.add_argument( + "--coefS", + type = float, + default=-2, + help = "coef S (lambda)", + ) + parser.add_argument( + "--coefT", + type = float, + default=5.0, + help = "coef T (mu)", + ) + parser.add_argument( + "--pLossT", + type = float, + default=0.25, + help = "Percentage of T in Loss function", + ) + # parser.add_argument("--output_path", type=str, + # default="./Results", + # help='folder to save output data') + + return parser.parse_args() + + +DEBUG = False +ARGS = parse_args() +torch.autograd.set_detect_anomaly(True) + +# Model parameters +admm_iterations = ARGS.layers # Number of ADMM iterations during forward +learning_rate = ARGS.learning_rate # Adam Learning rate +num_epochs = ARGS.epochs #200 # Epochs to train +Kernel_Size = ARGS.kernel_size # Kernel size +CoefS = ARGS.coefS # initial lambda +CoefT = ARGS.coefT # initial mu +pLossT = ARGS.pLossT # Percentage of T in Loss function + +####################################################################################### +# Data Loading & Utility functions # +####################################################################################### + +def preprocess(L,S,D): + A=max(np.max(np.abs(L)),np.max(np.abs(S)),np.max(np.abs(D))) + if A==0: + A=1.0 + L=L/A + S=S/A + D=D/A + + return L,S,D + +class ImageDataset(data.Dataset): + + DATA_DIR='...' + + def __init__(self, NumInstances, shape, train, transform=None, data_dir=None): + data_dir = self.DATA_DIR if data_dir is None else data_dir + self.shape=shape + + # dummy image loader + images_L = torch.zeros(tuple([NumInstances])+self.shape) + images_S = torch.zeros(tuple([NumInstances])+self.shape) + images_D = torch.zeros(tuple([NumInstances])+self.shape) + + # -- TRAIN -- + if train == 0: + for n in range(NumInstances): + if np.mod(n, 100) == 0: print('loading train set %s' % (n)) + L=loadmat(data_dir + '/fista/train/L_fista%s.mat' % (n))['patch_180'].astype('float32') + S=loadmat(data_dir + '/fista/train/S_fista%s.mat' % (n))['patch_180'].astype('float32') + D=loadmat(data_dir + '/D_data/train/D%s.mat' % (n))['patch_180'].astype('float32') + L,S,D=preprocess(L,S,D) + + images_L[n] = torch.from_numpy(L.reshape(self.shape)) + images_S[n] = torch.from_numpy(S.reshape(self.shape)) + images_D[n] = torch.from_numpy(D.reshape(self.shape)) + + # -- VALIDATION -- + if train == 1: + IndParam = 300 + for n in range(IndParam, IndParam + NumInstances): + if np.mod(n - IndParam, 50) == 0: print('loading validation set %s' % (n - IndParam)) + L=loadmat(data_dir + '/fista/val/L_fista%s.mat' % (n))['patch_180'].astype('float32') + S=loadmat(data_dir + '/fista/val/S_fista%s.mat' % (n))['patch_180'].astype('float32') + D=loadmat(data_dir + '/D_data/val/D%s.mat' % (n))['patch_180'].astype('float32') + L,S,D=preprocess(L,S,D) + + images_L[n-IndParam] = torch.from_numpy(L.reshape(self.shape)) + images_S[n-IndParam] = torch.from_numpy(S.reshape(self.shape)) + images_D[n-IndParam] = torch.from_numpy(D.reshape(self.shape)) + + + self.transform = transform + self.images_L = images_L + self.images_S = images_S + self.images_D = images_D + + def __getitem__(self, index): + + L = self.images_L[index] + S = self.images_S[index] + D = self.images_D[index] + + + return L, S, D + + def __len__(self): + + return len(self.images_L) + + + + +####################################################################################### +# Model Implementation # +####################################################################################### + +class Conv3dC(nn.Module): + def __init__(self,kernel): + super(Conv3dC,self).__init__() + + pad0=int((kernel[0]-1)/2) + pad1=int((kernel[1]-1)/2) + self.convR=torch.nn.Conv3d(1,1,(kernel[0],kernel[0],kernel[1]),(1,1,1),(pad0,pad0,pad1)) + #torch.nn.init.dirac_(self.convR.weight.data) + #if self.convR.bias is not None: + # torch.nn.init.constant_(self.convR.bias.data, 0.0) + + def forward(self,x): + device = "cuda" if torch.cuda.is_available() else "cpu" + + xR = (x[None, None].clone()).to(device) + + xR = self.convR(xR) + + x=xR.squeeze() + + return x + +class Conv3dCS(nn.Module): + def __init__(self,kernel): + super(Conv3dCS,self).__init__() + + kernel0 = int(kernel[0])-4 + + pad0=int((kernel0-1)/2) + pad1=int((kernel[1]-1)/2) + self.convR=torch.nn.Conv3d(1,1,(kernel0,kernel0,kernel[1]),(1,1,1),(pad0,pad0,pad1)) + torch.nn.init.dirac_(self.convR.weight.data) + if self.convR.bias is not None: + torch.nn.init.constant_(self.convR.bias.data, 0.0) + + def forward(self,x): + device = "cuda" if torch.cuda.is_available() else "cpu" + + xR = (x[None, None].clone()).to(device) + + xR = self.convR(xR) + + x=xR.squeeze() + + return x + + + +class DRPCACell(nn.Module): + def __init__(self,kernel,coef_T,coef_S,coef_Rho): + super(DRPCACell,self).__init__() + #self.Rho = 1.0; + self.conv1=Conv3dC(kernel) + self.conv2=Conv3dCS(kernel) + self.conv3=Conv3dCS(kernel) + + self.coef_T=nn.Parameter(coef_T) + self.coef_S=nn.Parameter(coef_S) + self.coef_Rho=nn.Parameter(coef_Rho) + self.register_parameter("coef_T", self.coef_T) + self.register_parameter("coef_S", self.coef_S) + self.register_parameter("coef_Rho", self.coef_Rho) + + self.sig=nn.Sigmoid() + + def forward(self,data): + + device = "cuda" if torch.cuda.is_available() else "cpu" + + + x0=data[0].size(0) + x1=data[0].size(1) + x2=data[0].size(2) + + + t=(data[1]).to(device) + T=(data[3]).to(device) + v=(data[5]).to(device) + + + t = (T - v + data[0] - self.conv1(data[2]))/2 + + data[2] = self.conv2(data[4] - data[6]) + self.conv3(data[0] - t) + + + t=t.view(x0*x1,x2) + v=v.view(x0*x1,x2) + T=T.view(x0*x1,x2) + + T = self.SVT(t+v,self.coef_T) + + t=t.view(x0,x1,x2) + v=v.view(x0,x1,x2) + T=T.view(x0,x1,x2) + + data[4] = self.Softthresh(data[2]+data[6],self.coef_S) + + + v = v + self.coef_Rho*(t - T) + data[6] = data[6] + self.coef_Rho*(data[2] - data[4]) + + data[1]=t + data[3]=T + data[5]=v + + # data[1]=t + # data[2]=x + # data[3]=T + # data[4]=Z + # data[5]=v gamma2 + # data[6]=w gamma1 + + return data + + def SVT(self,x,mu): + + try: + U,S,V=torch.svd(x) + except: # torch.svd may have convergence issues for GPU and CPU. + U,S,V = torch.svd(x + 1e-2*x.mean()*torch.rand_like(x)) + + device = "cuda" if torch.cuda.is_available() else "cpu" + U=U.to(device) + S=S.to(device) + V=V.to(device) + + S=torch.sign(S) * torch.max(torch.zeros_like(S), torch.abs(S) - mu) + + x=(U*S)@V.t() + + return x + + def Softthresh(self,x,lamda): + return torch.sign(x) * torch.max(torch.zeros_like(x), torch.abs(x) - self.sig(lamda)) + + + +class RPCAnet(nn.Module): + def __init__(self, kernel, admm_iterations=5,coefSini=-2.0,coefTini=10.0): + super(RPCAnet, self).__init__() + device = "cuda" if torch.cuda.is_available() else "cpu" + self.admm_iterations = admm_iterations + self.kernel = kernel + + self.coef_S=torch.zeros(self.admm_iterations).to(device)+coefSini + self.coef_T=torch.zeros(self.admm_iterations).to(device)+coefTini + self.coef_Rho=torch.zeros(self.admm_iterations).to(device)+0.5 + + self.filter=self.makelayers() + + def makelayers(self): + filt=[] + for i in range(self.admm_iterations): + filt.append(DRPCACell(self.kernel[i],self.coef_T[i],self.coef_S[i],self.coef_Rho[i])) + + return nn.Sequential(*filt) + + def forward(self,y): + ## INIT ## + device = "cuda" if torch.cuda.is_available() else "cpu" + + data=(torch.zeros([7]+list(y.shape))).to(device) + data[0]=y + data=self.filter(data) + L=data[1] + S=data[2] + + return L,S + + + + def getS(self): + exp_S=self.coef_S + if torch.cuda.is_available(): + exp_S=exp_S.cpu().detach().numpy() + else: + exp_S=exp_S.detach().numpy() + + return exp_S + + def getT(self): + exp_T=self.coef_T + if torch.cuda.is_available(): + exp_T=exp_T.cpu().detach().numpy() + else: + exp_T=exp_T.detach().numpy() + + return exp_T + + def getRho(self): + exp_Rho=self.coef_Rho + if torch.cuda.is_available(): + exp_Rho=exp_Rho.cpu().detach().numpy() + else: + exp_Rho=exp_Rho.detach().numpy() + + return exp_Rho + + def getIter(self): + return self.admm_iterations + + def getKernelSize(self): + return self.kernel[0][0] + +####################################################################################### +# Training Functions # +####################################################################################### + + +def train_epoch( + model, + optimizer, + criterion, + train_loader, + device="cpu", + pLossT=0.2 +): + device = "cuda" if torch.cuda.is_available() else "cpu" + avg_train_loss = 0.0 + avg_train_mse = 0.0 + outputs_S = torch.zeros([128,128,50]).to(device) + outputs_L = torch.zeros([128,128,50]).to(device) + + model.train() + + for _,(L,S,D) in enumerate(train_loader): + # set the gradients to zero at the beginning of each epoch + optimizer.zero_grad() + for ii in range(25): # BatchSize + inputs=D[ii].to(device) # "ii"th picture + targets_L=L[ii].to(device) + targets_S=S[ii].to(device) + + # Forward + backward + loss + outputs_L,outputs_S=model(inputs) # Forward + # Current loss + loss=pLossT*criterion(outputs_L,targets_L)+(1-pLossT)*criterion(outputs_S,targets_S) + avg_train_loss+=loss.item() + loss.backward() + + optimizer.step() + avg_train_loss=avg_train_loss/300.0 #TrainInstances + + + return avg_train_loss, avg_train_mse + + +def val_epoch(model, criterion, val_loader, device="cpu", pLossT=0.2): + avg_val_mse = 0 + + model.eval() + + with torch.no_grad(): + for _,(Lv,Sv,Dv) in enumerate(val_loader): + for jj in range(25): #ValBatchSize + inputsv=Dv[jj].to(device) # "jj"th picture + targets_Lv=Lv[jj].to(device) + targets_Sv=Sv[jj].to(device) + + outputs_Lv,outputs_Sv=model(inputsv) # Forward + + # Current loss + loss_val=pLossT*criterion(outputs_Lv,targets_Lv)+(1-pLossT)*criterion(outputs_Sv,targets_Sv) + avg_val_mse+=loss_val.item() + avg_val_mse=avg_val_mse/100.0 # ValInstances + + return avg_val_mse + + + +def val_epoch2(model, criterion, val_loader, device="cpu", output_path ="./Results"): + avg_val_mse = 0 + model.eval() + numberi = 0 + print("Val") + with torch.no_grad(): + for _,(Lv,Sv,Dv) in enumerate(val_loader): + for jj in range(25): #ValBatchSize + inputsv=Dv[jj].to(device) # "jj"th picture + + outputs_Lv,outputs_Sv=model(inputsv) # Forward + + pred={'predS':outputs_Sv.detach().cpu().numpy(), 'predL':outputs_Lv.detach().cpu().numpy() } + io.savemat(os.path.join(output_path, "Predicted_" + str(numberi) + ".mat"),pred) + numberi=numberi+1; + + avg_val_mse= 0.0; + + return avg_val_mse + + +def train( + model, + optimizer, + criterion, + train_loader, + val_loader, + epochs, + device="cpu", + output_path ="./Results" +): + + KernelSize=model.getKernelSize() + Nb_iter = model.getIter() + trainloss=np.empty(epochs,dtype=float) + valloss=np.empty(epochs,dtype=float) + Slist=np.empty(epochs*Nb_iter,dtype=float) + Tlist=np.empty(epochs*Nb_iter,dtype=float) + + for e in range(epochs): + avg_train_loss, avg_train_mse = train_epoch( + model, optimizer, criterion, train_loader, device=device,pLossT=pLossT + ) + avg_val_mse = val_epoch(model, criterion, val_loader, device=device,pLossT=pLossT) + trainloss[e] = avg_train_loss; + valloss[e] = avg_val_mse; + S=model.getS() + T=model.getT() + Slist[e*Nb_iter:(e+1)*Nb_iter]=S; + Tlist[e*Nb_iter:(e+1)*Nb_iter]=T; + + # print in each epoch the average train+test MSEs and the generalization error + print("Epoch: {:d}, Average Train loss: {:.10e}, Average Test MSE: {:.10e}".format(e,avg_train_loss,avg_val_mse)) + print("--------------------------------------") + print("Current lambda = ") + print(S) + print("--------------------------------------") + print("Current mu = ") + print(T) + print("--------------------------------------") + + + torch.save(model.state_dict(), os.path.join(output_path, "TrainedModels", "DRPCAnet_%02dth_over_%depochs.pkl" % (e+1, epochs))) #ResultsRPCAnet200epochs + lossss={'trainloss':trainloss, 'valloss':valloss} + io.savemat(os.path.join(output_path, "Loss.mat"),lossss) + + # PSF1=np.zeros((KernelSize,KernelSize,Nb_iter),dtype=np.float32) + # for counter in range(Nb_iter): + # bb=list(model.filter[counter].conv1.state_dict().items()) + # PSF1[:,:,counter]=bb[0][1].cpu().detach().numpy().squeeze() + bb[1][1].cpu().detach().numpy() + + # temppsf1={'PSF1':PSF1} + # io.savemat(os.path.join(output_path,"EstimatedPSFs", "EstimatedPSF1_%d_epoch.mat" % (e+1)),temppsf1) + + # PSF2=np.zeros((KernelSize-4,KernelSize-4,Nb_iter),dtype=np.float32) + # for counter in range(Nb_iter): + # bb=list(model.filter[counter].conv2.state_dict().items()) + # PSF2[:,:,counter]=bb[0][1].cpu().detach().numpy().squeeze() + bb[1][1].cpu().detach().numpy() + + # temppsf2={'PSF2':PSF2} + # io.savemat(os.path.join(output_path,"EstimatedPSFs", "EstimatedPSF2_%d_epoch.mat" % (e+1)),temppsf2) + + # PSF3=np.zeros((KernelSize-4,KernelSize-4,Nb_iter),dtype=np.float32) + # for counter in range(Nb_iter): + # bb=list(model.filter[counter].conv3.state_dict().items()) + # PSF3[:,:,counter]=bb[0][1].cpu().detach().numpy().squeeze() + bb[1][1].cpu().detach().numpy() + + # temppsf3={'PSF3':PSF3} + # io.savemat(os.path.join(output_path,"EstimatedPSFs", "EstimatedPSF3_%d_epoch.mat" % (e+1)),temppsf3) + + val_epoch2(model, criterion, val_loader, device=device,output_path=output_path) + + #save some parameters + parameters_out={'trainloss':trainloss, 'valloss':valloss,'Slist':Slist,'Tlist':Tlist} + io.savemat(os.path.join(output_path, "parameters_out.mat"),parameters_out) + + #Plot final results + xx = np.arange(1,epochs+1) + _ = plt.figure() + plt.plot(xx,trainloss, "-b", label="trainning") + plt.plot(xx,valloss, "-r", label="validation") + plt.legend(loc="upper right") + plt.xlabel('epoch') + plt.ylabel('Loss') + plt.title("Loss functions as function of epochs") + plt.savefig(output_path + "/Loss_functions_epoch%s.png" % (epochs)) + plt.close() + + #Plot final results S + _ = plt.figure() + for i in range(Nb_iter): + plt.plot(xx,Slist[i::Nb_iter],label='%dlayer'%(i+1)) + plt.legend(bbox_to_anchor=(1.0,0.5) ,loc="center left") + plt.xlabel('epoch') + plt.ylabel('S') + plt.title("lambda as function of epochs") + plt.savefig(output_path + "/lambda_functions_epoch%s.png" % (epochs)) + plt.close() + + #Plot final results T + _ = plt.figure() + for i in range(Nb_iter): + plt.plot(xx,Tlist[i::Nb_iter],label='%dlayer'%(i+1)) + plt.legend(bbox_to_anchor=(1.0,0.5) ,loc="center left") + plt.xlabel('epoch') + plt.ylabel('T') + plt.title("mu as function of epochs") + plt.savefig(output_path + "/mu_functions_epoch%s.png" % (epochs)) + plt.close() + + +####################################################################################### +# Main # +####################################################################################### + + +if __name__ == "__main__": + device = "cuda" if torch.cuda.is_available() else "cpu" + + TrainInstances = 300 # Size of training dataset + ValInstances = 100 + BatchSize = 25 + ValBatchSize = 25 + #num_epochs = 200 + shape_dset = (128,128,50) + Kernel=[(Kernel_Size,1)]*admm_iterations + + #train + train_dataset=ImageDataset(round(TrainInstances),shape_dset,train=0) + train_loader=data.DataLoader(train_dataset,batch_size=BatchSize,shuffle=True) + #validation + val_dataset=ImageDataset(round(ValInstances),shape_dset,train=1) + val_loader=data.DataLoader(val_dataset,batch_size=ValBatchSize,shuffle=False) + + + model = RPCAnet(kernel = Kernel, admm_iterations=admm_iterations,coefSini=CoefS,coefTini=CoefT).to(device) + + optimizer = Adam(model.parameters(), lr=learning_rate) + criterion = nn.MSELoss() + + epochs = num_epochs + # Save output + output_path ="./Results_eps%s_admm%s_Ker%s_CoefS%.1f_CoefT%.1f_pLossT%.1f_lr%.2e" % (epochs,admm_iterations,Kernel_Size,CoefS,CoefT,pLossT,learning_rate) + if not os.path.exists(output_path): + os.makedirs(output_path) + os.makedirs(os.path.join(output_path,"TrainedModels")) + os.makedirs(os.path.join(output_path,"EstimatedPSFs")) + print("----------------------------------------------------------------") + print('Somme information:') + print('- Out folder:', output_path) + print('- Learning rate: %s' % (learning_rate)) + print('- N° epoch: %s' % (epochs)) + print('- N° Layer: %s' % (admm_iterations)) + print('- Initial value of \lambda (with S): %s' % (CoefS)) + print('- Initial value of \mu (with T): %s' % (CoefT)) + print('- Kernel: %s' % (Kernel)) + print('- pLossT: %s' % (pLossT)) + print("----------------------------------------------------------------") + print('- Model:') + print(model) + print("----------------------------------------------------------------") + train( + model, + optimizer, + criterion, + train_loader, + val_loader, + epochs, + device=device, + output_path=output_path + )