diff --git a/DRPCA_Net/Test.py b/DRPCA_Net/Test.py new file mode 100644 index 0000000000000000000000000000000000000000..fba50affab94da5a1255d64fa524ded0d4f507e9 --- /dev/null +++ b/DRPCA_Net/Test.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Apr 23 17:44:09 2025 + +@author: vpustova +""" + +from admm_DRPCANet_v8 import RPCAnet +import torch +import numpy as np +from scipy.io import savemat +from scipy.io import loadmat + + + +device = 'cpu' +model=RPCAnet(kernel = [(9,1)]*20, admm_iterations=20) +state_dict=torch.load('DRPCAnet_100th_over_100epochs.pkl',map_location=device) +model.load_state_dict(state_dict) +model.eval() + +data=loadmat('D325.mat')['patch_180'] +data=data/np.max(np.abs(data)) +data = torch.tensor(data).to(device) +outputs_Lv,outputs_Sv=model(data) # Forward + +pred_test={'predS':outputs_Sv.cpu().detach().numpy(), 'predL':outputs_Lv.cpu().detach().numpy() } +savemat('D325processed.mat',pred_test) + +