Skip to content
Snippets Groups Projects
Commit f0419fe7 authored by vpustova's avatar vpustova
Browse files

Upload New File

parent ebb2002f
No related branches found
No related tags found
No related merge requests found
import numpy as np
import torch
import torch.utils.data as data
from torch import nn
from torch.optim import Adam
import argparse
from scipy.io import loadmat
from scipy import io
import os
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
plt.ioff()
def parse_args():
parser = argparse.ArgumentParser("Network's parameters")
parser.add_argument(
"--layers",
type=int,
default=20,
help="Number of layers/iterations",
)
parser.add_argument(
"--epochs",
type=int,
default=100,
help="Number of epochs",
)
parser.add_argument(
"--learning-rate",
type = float,
default=0.004,
help = "Learning rate",
)
parser.add_argument(
"--kernel-size",
type = int,
default=9,
help = "Kernel size",
)
parser.add_argument(
"--coefS",
type = float,
default=-2,
help = "coef S (lambda)",
)
parser.add_argument(
"--coefT",
type = float,
default=5.0,
help = "coef T (mu)",
)
parser.add_argument(
"--pLossT",
type = float,
default=0.25,
help = "Percentage of T in Loss function",
)
# parser.add_argument("--output_path", type=str,
# default="./Results",
# help='folder to save output data')
return parser.parse_args()
DEBUG = False
ARGS = parse_args()
torch.autograd.set_detect_anomaly(True)
# Model parameters
admm_iterations = ARGS.layers # Number of ADMM iterations during forward
learning_rate = ARGS.learning_rate # Adam Learning rate
num_epochs = ARGS.epochs #200 # Epochs to train
Kernel_Size = ARGS.kernel_size # Kernel size
CoefS = ARGS.coefS # initial lambda
CoefT = ARGS.coefT # initial mu
pLossT = ARGS.pLossT # Percentage of T in Loss function
#######################################################################################
# Data Loading & Utility functions #
#######################################################################################
def preprocess(L,S,D):
A=max(np.max(np.abs(L)),np.max(np.abs(S)),np.max(np.abs(D)))
if A==0:
A=1.0
L=L/A
S=S/A
D=D/A
return L,S,D
class ImageDataset(data.Dataset):
DATA_DIR='...'
def __init__(self, NumInstances, shape, train, transform=None, data_dir=None):
data_dir = self.DATA_DIR if data_dir is None else data_dir
self.shape=shape
# dummy image loader
images_L = torch.zeros(tuple([NumInstances])+self.shape)
images_S = torch.zeros(tuple([NumInstances])+self.shape)
images_D = torch.zeros(tuple([NumInstances])+self.shape)
# -- TRAIN --
if train == 0:
for n in range(NumInstances):
if np.mod(n, 100) == 0: print('loading train set %s' % (n))
L=loadmat(data_dir + '/fista/train/L_fista%s.mat' % (n))['patch_180'].astype('float32')
S=loadmat(data_dir + '/fista/train/S_fista%s.mat' % (n))['patch_180'].astype('float32')
D=loadmat(data_dir + '/D_data/train/D%s.mat' % (n))['patch_180'].astype('float32')
L,S,D=preprocess(L,S,D)
images_L[n] = torch.from_numpy(L.reshape(self.shape))
images_S[n] = torch.from_numpy(S.reshape(self.shape))
images_D[n] = torch.from_numpy(D.reshape(self.shape))
# -- VALIDATION --
if train == 1:
IndParam = 300
for n in range(IndParam, IndParam + NumInstances):
if np.mod(n - IndParam, 50) == 0: print('loading validation set %s' % (n - IndParam))
L=loadmat(data_dir + '/fista/val/L_fista%s.mat' % (n))['patch_180'].astype('float32')
S=loadmat(data_dir + '/fista/val/S_fista%s.mat' % (n))['patch_180'].astype('float32')
D=loadmat(data_dir + '/D_data/val/D%s.mat' % (n))['patch_180'].astype('float32')
L,S,D=preprocess(L,S,D)
images_L[n-IndParam] = torch.from_numpy(L.reshape(self.shape))
images_S[n-IndParam] = torch.from_numpy(S.reshape(self.shape))
images_D[n-IndParam] = torch.from_numpy(D.reshape(self.shape))
self.transform = transform
self.images_L = images_L
self.images_S = images_S
self.images_D = images_D
def __getitem__(self, index):
L = self.images_L[index]
S = self.images_S[index]
D = self.images_D[index]
return L, S, D
def __len__(self):
return len(self.images_L)
#######################################################################################
# Model Implementation #
#######################################################################################
class Conv3dC(nn.Module):
def __init__(self,kernel):
super(Conv3dC,self).__init__()
pad0=int((kernel[0]-1)/2)
pad1=int((kernel[1]-1)/2)
self.convR=torch.nn.Conv3d(1,1,(kernel[0],kernel[0],kernel[1]),(1,1,1),(pad0,pad0,pad1))
#torch.nn.init.dirac_(self.convR.weight.data)
#if self.convR.bias is not None:
# torch.nn.init.constant_(self.convR.bias.data, 0.0)
def forward(self,x):
device = "cuda" if torch.cuda.is_available() else "cpu"
xR = (x[None, None].clone()).to(device)
xR = self.convR(xR)
x=xR.squeeze()
return x
class Conv3dCS(nn.Module):
def __init__(self,kernel):
super(Conv3dCS,self).__init__()
kernel0 = int(kernel[0])-4
pad0=int((kernel0-1)/2)
pad1=int((kernel[1]-1)/2)
self.convR=torch.nn.Conv3d(1,1,(kernel0,kernel0,kernel[1]),(1,1,1),(pad0,pad0,pad1))
torch.nn.init.dirac_(self.convR.weight.data)
if self.convR.bias is not None:
torch.nn.init.constant_(self.convR.bias.data, 0.0)
def forward(self,x):
device = "cuda" if torch.cuda.is_available() else "cpu"
xR = (x[None, None].clone()).to(device)
xR = self.convR(xR)
x=xR.squeeze()
return x
class DRPCACell(nn.Module):
def __init__(self,kernel,coef_T,coef_S,coef_Rho):
super(DRPCACell,self).__init__()
#self.Rho = 1.0;
self.conv1=Conv3dC(kernel)
self.conv2=Conv3dCS(kernel)
self.conv3=Conv3dCS(kernel)
self.coef_T=nn.Parameter(coef_T)
self.coef_S=nn.Parameter(coef_S)
self.coef_Rho=nn.Parameter(coef_Rho)
self.register_parameter("coef_T", self.coef_T)
self.register_parameter("coef_S", self.coef_S)
self.register_parameter("coef_Rho", self.coef_Rho)
self.sig=nn.Sigmoid()
def forward(self,data):
device = "cuda" if torch.cuda.is_available() else "cpu"
x0=data[0].size(0)
x1=data[0].size(1)
x2=data[0].size(2)
t=(data[1]).to(device)
T=(data[3]).to(device)
v=(data[5]).to(device)
t = (T - v + data[0] - self.conv1(data[2]))/2
data[2] = self.conv2(data[4] - data[6]) + self.conv3(data[0] - t)
t=t.view(x0*x1,x2)
v=v.view(x0*x1,x2)
T=T.view(x0*x1,x2)
T = self.SVT(t+v,self.coef_T)
t=t.view(x0,x1,x2)
v=v.view(x0,x1,x2)
T=T.view(x0,x1,x2)
data[4] = self.Softthresh(data[2]+data[6],self.coef_S)
v = v + self.coef_Rho*(t - T)
data[6] = data[6] + self.coef_Rho*(data[2] - data[4])
data[1]=t
data[3]=T
data[5]=v
# data[1]=t
# data[2]=x
# data[3]=T
# data[4]=Z
# data[5]=v gamma2
# data[6]=w gamma1
return data
def SVT(self,x,mu):
try:
U,S,V=torch.svd(x)
except: # torch.svd may have convergence issues for GPU and CPU.
U,S,V = torch.svd(x + 1e-2*x.mean()*torch.rand_like(x))
device = "cuda" if torch.cuda.is_available() else "cpu"
U=U.to(device)
S=S.to(device)
V=V.to(device)
S=torch.sign(S) * torch.max(torch.zeros_like(S), torch.abs(S) - mu)
x=(U*S)@V.t()
return x
def Softthresh(self,x,lamda):
return torch.sign(x) * torch.max(torch.zeros_like(x), torch.abs(x) - self.sig(lamda))
class RPCAnet(nn.Module):
def __init__(self, kernel, admm_iterations=5,coefSini=-2.0,coefTini=10.0):
super(RPCAnet, self).__init__()
device = "cuda" if torch.cuda.is_available() else "cpu"
self.admm_iterations = admm_iterations
self.kernel = kernel
self.coef_S=torch.zeros(self.admm_iterations).to(device)+coefSini
self.coef_T=torch.zeros(self.admm_iterations).to(device)+coefTini
self.coef_Rho=torch.zeros(self.admm_iterations).to(device)+0.5
self.filter=self.makelayers()
def makelayers(self):
filt=[]
for i in range(self.admm_iterations):
filt.append(DRPCACell(self.kernel[i],self.coef_T[i],self.coef_S[i],self.coef_Rho[i]))
return nn.Sequential(*filt)
def forward(self,y):
## INIT ##
device = "cuda" if torch.cuda.is_available() else "cpu"
data=(torch.zeros([7]+list(y.shape))).to(device)
data[0]=y
data=self.filter(data)
L=data[1]
S=data[2]
return L,S
def getS(self):
exp_S=self.coef_S
if torch.cuda.is_available():
exp_S=exp_S.cpu().detach().numpy()
else:
exp_S=exp_S.detach().numpy()
return exp_S
def getT(self):
exp_T=self.coef_T
if torch.cuda.is_available():
exp_T=exp_T.cpu().detach().numpy()
else:
exp_T=exp_T.detach().numpy()
return exp_T
def getRho(self):
exp_Rho=self.coef_Rho
if torch.cuda.is_available():
exp_Rho=exp_Rho.cpu().detach().numpy()
else:
exp_Rho=exp_Rho.detach().numpy()
return exp_Rho
def getIter(self):
return self.admm_iterations
def getKernelSize(self):
return self.kernel[0][0]
#######################################################################################
# Training Functions #
#######################################################################################
def train_epoch(
model,
optimizer,
criterion,
train_loader,
device="cpu",
pLossT=0.2
):
device = "cuda" if torch.cuda.is_available() else "cpu"
avg_train_loss = 0.0
avg_train_mse = 0.0
outputs_S = torch.zeros([128,128,50]).to(device)
outputs_L = torch.zeros([128,128,50]).to(device)
model.train()
for _,(L,S,D) in enumerate(train_loader):
# set the gradients to zero at the beginning of each epoch
optimizer.zero_grad()
for ii in range(25): # BatchSize
inputs=D[ii].to(device) # "ii"th picture
targets_L=L[ii].to(device)
targets_S=S[ii].to(device)
# Forward + backward + loss
outputs_L,outputs_S=model(inputs) # Forward
# Current loss
loss=pLossT*criterion(outputs_L,targets_L)+(1-pLossT)*criterion(outputs_S,targets_S)
avg_train_loss+=loss.item()
loss.backward()
optimizer.step()
avg_train_loss=avg_train_loss/300.0 #TrainInstances
return avg_train_loss, avg_train_mse
def val_epoch(model, criterion, val_loader, device="cpu", pLossT=0.2):
avg_val_mse = 0
model.eval()
with torch.no_grad():
for _,(Lv,Sv,Dv) in enumerate(val_loader):
for jj in range(25): #ValBatchSize
inputsv=Dv[jj].to(device) # "jj"th picture
targets_Lv=Lv[jj].to(device)
targets_Sv=Sv[jj].to(device)
outputs_Lv,outputs_Sv=model(inputsv) # Forward
# Current loss
loss_val=pLossT*criterion(outputs_Lv,targets_Lv)+(1-pLossT)*criterion(outputs_Sv,targets_Sv)
avg_val_mse+=loss_val.item()
avg_val_mse=avg_val_mse/100.0 # ValInstances
return avg_val_mse
def val_epoch2(model, criterion, val_loader, device="cpu", output_path ="./Results"):
avg_val_mse = 0
model.eval()
numberi = 0
print("Val")
with torch.no_grad():
for _,(Lv,Sv,Dv) in enumerate(val_loader):
for jj in range(25): #ValBatchSize
inputsv=Dv[jj].to(device) # "jj"th picture
outputs_Lv,outputs_Sv=model(inputsv) # Forward
pred={'predS':outputs_Sv.detach().cpu().numpy(), 'predL':outputs_Lv.detach().cpu().numpy() }
io.savemat(os.path.join(output_path, "Predicted_" + str(numberi) + ".mat"),pred)
numberi=numberi+1;
avg_val_mse= 0.0;
return avg_val_mse
def train(
model,
optimizer,
criterion,
train_loader,
val_loader,
epochs,
device="cpu",
output_path ="./Results"
):
KernelSize=model.getKernelSize()
Nb_iter = model.getIter()
trainloss=np.empty(epochs,dtype=float)
valloss=np.empty(epochs,dtype=float)
Slist=np.empty(epochs*Nb_iter,dtype=float)
Tlist=np.empty(epochs*Nb_iter,dtype=float)
for e in range(epochs):
avg_train_loss, avg_train_mse = train_epoch(
model, optimizer, criterion, train_loader, device=device,pLossT=pLossT
)
avg_val_mse = val_epoch(model, criterion, val_loader, device=device,pLossT=pLossT)
trainloss[e] = avg_train_loss;
valloss[e] = avg_val_mse;
S=model.getS()
T=model.getT()
Slist[e*Nb_iter:(e+1)*Nb_iter]=S;
Tlist[e*Nb_iter:(e+1)*Nb_iter]=T;
# print in each epoch the average train+test MSEs and the generalization error
print("Epoch: {:d}, Average Train loss: {:.10e}, Average Test MSE: {:.10e}".format(e,avg_train_loss,avg_val_mse))
print("--------------------------------------")
print("Current lambda = ")
print(S)
print("--------------------------------------")
print("Current mu = ")
print(T)
print("--------------------------------------")
torch.save(model.state_dict(), os.path.join(output_path, "TrainedModels", "DRPCAnet_%02dth_over_%depochs.pkl" % (e+1, epochs))) #ResultsRPCAnet200epochs
lossss={'trainloss':trainloss, 'valloss':valloss}
io.savemat(os.path.join(output_path, "Loss.mat"),lossss)
# PSF1=np.zeros((KernelSize,KernelSize,Nb_iter),dtype=np.float32)
# for counter in range(Nb_iter):
# bb=list(model.filter[counter].conv1.state_dict().items())
# PSF1[:,:,counter]=bb[0][1].cpu().detach().numpy().squeeze() + bb[1][1].cpu().detach().numpy()
# temppsf1={'PSF1':PSF1}
# io.savemat(os.path.join(output_path,"EstimatedPSFs", "EstimatedPSF1_%d_epoch.mat" % (e+1)),temppsf1)
# PSF2=np.zeros((KernelSize-4,KernelSize-4,Nb_iter),dtype=np.float32)
# for counter in range(Nb_iter):
# bb=list(model.filter[counter].conv2.state_dict().items())
# PSF2[:,:,counter]=bb[0][1].cpu().detach().numpy().squeeze() + bb[1][1].cpu().detach().numpy()
# temppsf2={'PSF2':PSF2}
# io.savemat(os.path.join(output_path,"EstimatedPSFs", "EstimatedPSF2_%d_epoch.mat" % (e+1)),temppsf2)
# PSF3=np.zeros((KernelSize-4,KernelSize-4,Nb_iter),dtype=np.float32)
# for counter in range(Nb_iter):
# bb=list(model.filter[counter].conv3.state_dict().items())
# PSF3[:,:,counter]=bb[0][1].cpu().detach().numpy().squeeze() + bb[1][1].cpu().detach().numpy()
# temppsf3={'PSF3':PSF3}
# io.savemat(os.path.join(output_path,"EstimatedPSFs", "EstimatedPSF3_%d_epoch.mat" % (e+1)),temppsf3)
val_epoch2(model, criterion, val_loader, device=device,output_path=output_path)
#save some parameters
parameters_out={'trainloss':trainloss, 'valloss':valloss,'Slist':Slist,'Tlist':Tlist}
io.savemat(os.path.join(output_path, "parameters_out.mat"),parameters_out)
#Plot final results
xx = np.arange(1,epochs+1)
_ = plt.figure()
plt.plot(xx,trainloss, "-b", label="trainning")
plt.plot(xx,valloss, "-r", label="validation")
plt.legend(loc="upper right")
plt.xlabel('epoch')
plt.ylabel('Loss')
plt.title("Loss functions as function of epochs")
plt.savefig(output_path + "/Loss_functions_epoch%s.png" % (epochs))
plt.close()
#Plot final results S
_ = plt.figure()
for i in range(Nb_iter):
plt.plot(xx,Slist[i::Nb_iter],label='%dlayer'%(i+1))
plt.legend(bbox_to_anchor=(1.0,0.5) ,loc="center left")
plt.xlabel('epoch')
plt.ylabel('S')
plt.title("lambda as function of epochs")
plt.savefig(output_path + "/lambda_functions_epoch%s.png" % (epochs))
plt.close()
#Plot final results T
_ = plt.figure()
for i in range(Nb_iter):
plt.plot(xx,Tlist[i::Nb_iter],label='%dlayer'%(i+1))
plt.legend(bbox_to_anchor=(1.0,0.5) ,loc="center left")
plt.xlabel('epoch')
plt.ylabel('T')
plt.title("mu as function of epochs")
plt.savefig(output_path + "/mu_functions_epoch%s.png" % (epochs))
plt.close()
#######################################################################################
# Main #
#######################################################################################
if __name__ == "__main__":
device = "cuda" if torch.cuda.is_available() else "cpu"
TrainInstances = 300 # Size of training dataset
ValInstances = 100
BatchSize = 25
ValBatchSize = 25
#num_epochs = 200
shape_dset = (128,128,50)
Kernel=[(Kernel_Size,1)]*admm_iterations
#train
train_dataset=ImageDataset(round(TrainInstances),shape_dset,train=0)
train_loader=data.DataLoader(train_dataset,batch_size=BatchSize,shuffle=True)
#validation
val_dataset=ImageDataset(round(ValInstances),shape_dset,train=1)
val_loader=data.DataLoader(val_dataset,batch_size=ValBatchSize,shuffle=False)
model = RPCAnet(kernel = Kernel, admm_iterations=admm_iterations,coefSini=CoefS,coefTini=CoefT).to(device)
optimizer = Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
epochs = num_epochs
# Save output
output_path ="./Results_eps%s_admm%s_Ker%s_CoefS%.1f_CoefT%.1f_pLossT%.1f_lr%.2e" % (epochs,admm_iterations,Kernel_Size,CoefS,CoefT,pLossT,learning_rate)
if not os.path.exists(output_path):
os.makedirs(output_path)
os.makedirs(os.path.join(output_path,"TrainedModels"))
os.makedirs(os.path.join(output_path,"EstimatedPSFs"))
print("----------------------------------------------------------------")
print('Somme information:')
print('- Out folder:', output_path)
print('- Learning rate: %s' % (learning_rate))
print('- N° epoch: %s' % (epochs))
print('- N° Layer: %s' % (admm_iterations))
print('- Initial value of \lambda (with S): %s' % (CoefS))
print('- Initial value of \mu (with T): %s' % (CoefT))
print('- Kernel: %s' % (Kernel))
print('- pLossT: %s' % (pLossT))
print("----------------------------------------------------------------")
print('- Model:')
print(model)
print("----------------------------------------------------------------")
train(
model,
optimizer,
criterion,
train_loader,
val_loader,
epochs,
device=device,
output_path=output_path
)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment