From 167f5455af447f1aa66cba5b9c1d0a1577267d18 Mon Sep 17 00:00:00 2001 From: Elquintas <sebastiao.frazao@gmail.com> Date: Fri, 6 Jan 2023 11:32:43 +0100 Subject: [PATCH] changements 06/01/2023 --- configs/parameters.yaml | 24 +++++++-- data/TRAINING.txt | 8 +++ dataloader/.embedding_extract.py.swp | Bin 0 -> 12288 bytes dataloader/__pycache__/load.cpython-36.pyc | Bin 0 -> 1097 bytes dataloader/load.py | 25 +++++++++ models/__pycache__/model.cpython-36.pyc | Bin 0 -> 1594 bytes models/model.py | 56 +++++++++++++++++++++ train.py | 32 ++++++++++++ 8 files changed, 142 insertions(+), 3 deletions(-) create mode 100644 data/TRAINING.txt create mode 100644 dataloader/.embedding_extract.py.swp create mode 100644 dataloader/__pycache__/load.cpython-36.pyc create mode 100644 dataloader/load.py create mode 100644 models/__pycache__/model.cpython-36.pyc create mode 100644 models/model.py create mode 100644 train.py diff --git a/configs/parameters.yaml b/configs/parameters.yaml index 8ad5eac3..322c2c07 100644 --- a/configs/parameters.yaml +++ b/configs/parameters.yaml @@ -1,13 +1,31 @@ # PATHS wav_path: '../data/wavs/' +data_path: './data/' embedding_path: '../data/embeddings/' +model_path: '../models/model' -# PARAMETERS + +# TRAINING PARAMETERS sampling_rate: 16000 batch_size: 16 learning_rate: 0.001 +epochs: 50 +dropout: 0.2 + +training_set_file: 'TRAINING.txt' +test_set_file: 'TEST.txt' + + + +# MODEL PARAMETERS + +first_layer: 512 +second_layer: 128 +third_layer: 62 + + # Types of embeddings supported: 'ecapa_tdnn' or 'x-vector' -# ecapa_tdnn: dim = 192 -# x-vector: dim = 512 +# ecapa_tdnn: dim = 192 (change first_layer dim) +# x-vector: dim = 512 (change first_layer dim) embedding_type: x-vector diff --git a/data/TRAINING.txt b/data/TRAINING.txt new file mode 100644 index 00000000..1aceb9e7 --- /dev/null +++ b/data/TRAINING.txt @@ -0,0 +1,8 @@ +ANC150_LEC_seg_1_ecapa_emb.pickle,7.0,2.3,1.2,4.5,2.0,2.1 +ANC150_LEC_seg_2_ecapa_emb.pickle,5.0,3.3,2.2,4.1,1.2,8.3 +ANC150_LEC_seg_3_ecapa_emb.pickle,4.0,2.5,4.2,3.5,3.0,8.3 +ANC150_LEC_seg_4_ecapa_emb.pickle,3.0,2.3,1.2,4.1,3.4,5.4 +ANC150_LEC_seg_5_ecapa_emb.pickle,1.2,2.5,6.2,4.1,1.0,2.1 +ANC150_LEC_seg_6_ecapa_emb.pickle,7.3,2.4,1.2,3.5,2.5,2.4 +ANC150_LEC_seg_7_ecapa_emb.pickle,4.0,2.3,3.2,4.3,1.0,3.6 +ANC150_LEC_seg_8_ecapa_emb.pickle,2.0,2.3,2.2,4.2,1.6,2.4 diff --git a/dataloader/.embedding_extract.py.swp b/dataloader/.embedding_extract.py.swp new file mode 100644 index 0000000000000000000000000000000000000000..023dba0bd56f739313356d0c781b1304162d2384 GIT binary patch literal 12288 zcmYc?2=nw+u+TGLU|?VnU|`^Xygm6g_aCNR9t;e{sY!{&C7Fr&AUQlZEHNiDFTV(n z3W!u4+(7-*+@#c$l+3*J_{_YL)SR5m^vtBpoXnC+{glL##GL%Zl++?5MX41fMTyBJ zdIgmblSawW5Eu;s(nElk!Pv+UoYj?;6cvPpLP<AbRL5utjE2By2#kinXb6mkz-S1J zhQMeDjE2CF4S|vZMuvI@1_mam4}zgIBN`3mj#8r`Fd71*Aut*OqaiRF0;3@?8Umvs zFd71*Aut*OqaiRF0z)ta5>pr$Zm}>hII%(I|6%?Acl-<t5BV7wuJJQ4oZ)9+IKa=q z(9O@lP|DB1P{Pl^P|VN35X8^GV8YM9pv}*~@Q;sy;R+uE!#+L+hQ)jg3?+OF48?p5 z3_*Mh3`%?q3_N@c40m}M7#8p{FeLIaFsSh|FmUrSFx=o_V3^Irz!1a3z#z-Rz`)4E z!0?@$f#Ezi1H&F}28P|-3=Ffl85m}AGcZJRGcXu)GcfRTGca&*GcY{lVqn<F#lWzE zi-BP_7X!mAE(V5~Tnr2|xEL6wb1^W~aWOE|axpOYaWODxa4|5bb1^V{<YZv@z{$XH zl9Pd98Yct8L{0{VP)-I0T}}oD9!>@ZZcYY<vm6WzlQ<X{k~kO`^f(w8ezG$#Tw-Tn zn8MD$kipKt5YEoPpvlg_@P&<m;W--v!#XwwhJH2%hFUfThB!6`23Ix)1{XF424^-% zxH^Kub_k}8Q74avz-S1Jh5*43P(TC8`FSOod8MfgXyOV5Nr^>zr3ER8C8-*Qnkdq# zImM|~C?X)ucr<{dOY)17GxUlR%ThJqE6Q{fiWAE~E6WNJOER?cQp@AhP;J7n1gu8E zRw1c0E!rSfTRmP~JGG)jTRlECH%YG`GdVjaRUJc9QfZoktwMgW9>^NK;)0yalGKV4 zjkL_1)M8ZoKn}&w4iW@mh<Y7`<ebFf;>@(n)FMpHaIY!YDnM7l<risaB<JU)Wu`}~ zmnW9RgB+nAtDvnAgJM%=ngZAYz0|yv;_}Rr3=MU?^29QAO+-ki<rhKJC}ie=ypxkz zT#}MmgvE5ICXg`-C50)u8YTH9i8;18sd*YuS7qks#e?nB)Ko|;2FEtUqJpB#yb=wi zm^@d=LcYwrbcN!A)Wq!6A_chbiuLsLlpwxWR&aOqa}9C~aSc}R_j6Tn^>uP}aq;wX zSBN&$G1PJK^o>>UbW?Ei^l?@23|8>;Q}Ffo^9Oq!S!+nRKSFD;LbMTfy}==_L9YG` zVC|qd$uHK+P0UVB$t)_?KuSvxr)8!o<mH!uLJAb-dZ`te#U-f9AYN8ha1L?>dmuQ# z)!Ea{(;49uh2T&p7tbJ&7d%50T>M>w{nSGgTq8V#L%@yz6G%2I*ebvx8y;HFD2Jp} zeMtI5BqAM9LIEWg_&Upy%7Ro4O^7GaqYILd@x(kt8^jF?whEy2neh-FG@XE50|_9A zI9NczRzY1)Pah&utY45=l$e`Zl3G-(SDBcbqt3vPnWhjQpO=`M8Xs?Ks}vufo0ypw zAFpHuE^$CnQIuLzT9k)KBM^tc0|68;@QkCEnwOlPk{X|sSdyHfQJk5cmzV=_5`?Xz zkXEc<3obGeOH(rQ^+4XxC{HW{C0U3<hLqGaP{@H6o8}k6<WO>_6@#*ZvyWqNu&0}+ zYf!L)zf*{#r=P2ff?JTkuYyOYySt~KyPKo4D}%CvYeYzpqjN~GLWqZ}LU4eqqql1i zQYvtA40d%<@b?2thl1D$gCQP;dLJ!bQ}c>5^Yb7oK!H?TSqw^tB^nS<fP^3!PCYa) zJ1@T+Df|>b$v_>V4z<Khttf$*3yQSV<ivtRNSXqf2E$4PMX4o4iJ5t+De<}aDXBTd z`o#s=MXAZUU=1ZHd3m~J`4!2jIjKp|oD9<nqLqpZQd5&Nl0YV)XhdiN7v7+pXNxGv z5WZKiRd5Bzagj43*3*jebK^6>X{%VHIKQ+gITfloH77F-OU3~QL9BwUt%5q(+3_HE zt3z@v$eAc1QBjtfOiDn2)R!Rq0|{x6sW6N!5YRMHH4q2|L4|HvYH~?_k-8Nm@uL+O zpyUoJMuS5G0{nwQTwN4GA_H85trQ}3!(5$1{DXoO9Q|ArT%8>Q9OFY={QMY{6@pwt zLWBGusUA^sLo)+XISWmsu0=)pMG8f!iJ<l13MCn-&^QBEOR%(Bk(^pkqTmXqGV}95 z)eBN{0hUj|H6pn5gatRG>`|~)0F^j;#ffRD@!)Ka$nzkT;5se8AT<xUIMY#3FH#3r zWFW<m^ifh&2?-~VA=rvQP<Z+HJGulbIQ#p#dAf%NIfi)p`zh!}I{Ny6D=Y>E22hd( zSD^}-xdr(}C15s06e)}7A#xi`9Y*SBfEobu9gLk{3}Y2n!kOS=4<-)gfXhGz0Boc} Ac>n+a literal 0 HcmV?d00001 diff --git a/dataloader/__pycache__/load.cpython-36.pyc b/dataloader/__pycache__/load.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8202e853402b9db713c54fca80f0f6ea89ca1d0c GIT binary patch literal 1097 zcmXr!<>k8leS2~t69dCz1|-13z`)?Zz`#&!!oa|g!jQt4!w|&?rkSFcQW#R0a+q^j zqF8cSqgWXk+!<1sQ&?IUQdm-1nwg{6QW%37G+AGQjPuiEy2b92Sdv(rT5^jE#PG>a zOi3+D21&v&6U;Id1_p*yutiZ!DU2yhEeui25Tm15Qdm>iS{R~OQ`mzUG&yc@=Hw@) z#HS>dBqlRM%>lU?gq=a|Ud6z`P{UZm5YJG<RKpO@Si+RWTm$w&FH<dZ4RaPt3PTD* zD^m(n8j~bL4MRLDk{B~oj15VQ1uDjlB*qFA<3JK)gNkt?iLpb)xRAs+pkmx948aVV zoPL^&x7adLL19zG$iTqx5=?+X<|W8~FF_uA$qJHSV_;y=<hjL>mYI{9mzbM+iyagS zDVasLm<m#&IEqpeQ{t10%WkoiCFYc-7T;n4DZIt(=@$~k9PAo)i!tmLW6&+efGDN_ z7fqI1EXAogX+=B?3=BoQ3=9mnm?|n)G8Bn3Ffjab)X&JzP0h(qPR!9SPEATIF3C*H z*H6t&N=-@0%uA2Y%qvOF$;nL5Ov=p3EUDB7xdRj}sYUuAO0S^u7Ds%1W?p7Vd_2h2 z#Uh}XVq#-tW8`BJVMM?xA&`1-Z0f-@B!dC~q#i_rFgRL4#?&xmfr5<DPm?i<GcUe4 zF}ENmwOEs}NEGA=kjuf&FOmRxhdn+%Cp9lV9%OYfC^#7ys|1m41Is3ZWf&M3Kr{%0 zEdvD#$g&znNrp5~ng&@}BmfG3Ca`6<m@`u<AeMox1>49SAD^CDl39|P8y^o6h1kke zC5mh-T!o(|8z_)K5g8wUiz_}pH$SB`CpA9)7EgS9VQFFxM5ag*<Toi0Ap;^n2BTSA z1WF>H6jBTdBo0OnJ`QHEm?rxzwt~dGl*Hm9K2VUbmgE;DXWSA1Q+lN(nK{LJpr|VX fdAtae+Q8b7j0f8W_8f-|#C>+47=c;DD8K{&A|drQ literal 0 HcmV?d00001 diff --git a/dataloader/load.py b/dataloader/load.py new file mode 100644 index 00000000..7d5b1afa --- /dev/null +++ b/dataloader/load.py @@ -0,0 +1,25 @@ +import pandas as pd +import torch +from torch.utils.data import Dataset,DataLoader + +class load_data(Dataset): + def __init__(self, filename, datadir): + + self.filename = filename + self.datadir = datadir + xy = pd.read_csv(filename,header=None) + + self.file = xy.values[:,0] + self.INT = xy.values[:,1] + self.SEV = xy.values[:,2] + self.V = xy.values[:,3] + self.R = xy.values[:,4] + self.P = xy.values[:,5] + self.PD = xy.values[:,6] + + def __len__(self): + return self.n_samples + + def __getitem__(self, idx): + + return (self.file[idx], ) diff --git a/models/__pycache__/model.cpython-36.pyc b/models/__pycache__/model.cpython-36.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80e701b9c286ef9cfe645dcdc7a78d707d370d5e GIT binary patch literal 1594 zcmXr!<>h+udV8`O8w0~*1|-1Ez`)?Zz`#%(!@$6h!jQt4!w?0b8KW4%e5NesC}uE= zC5t(WHHs~jJ%uTSxtXb7Ac`ZEA&WDWqnR;^E1V&PF_0mMA%el3A%!J{wS^&tHI+Gw zyO}wPCxtPXL6hwz$XdTDNj?4K{JgZx^kV&j#G=I9)RNSqV!g`5+?-pCMadvFFw6{P z6ss^WFr+eo+!Mvr&XC5K!kEI;!V$%s!kog=!VtyM&cMPD#Tv|@$$CpTH$NpcCq6Yd zDK#Y}GcP^9I4>`m5o!>KW?^7p02y5Th=GBjgrS77hM}1$i>aA`k)eb+i=~;Nh9RD{ zggJ|?ggu2(lA)Qgh9RD#ggJ||ggu1`%IAXenW21cD4zw&=P6;%;)SR$VNU_6Vya<? z=YxteOEQ$Or-0OgMERkjERqZ<tj!=10SpmA3=ts=5n&7wkrW0{u>0L&PbtbT$S*C4 z;!evfDlUo7NvupQisC6wP0r6tf%3RZGBS&xoR^?*(iFYLT3lL?T2urIty>)N@tJv< zCGqjMnDX*&u>__1gx+E)O3f*~#pwbu-6-W2b8=dG5h#FfaXTfJBxm^L7v&nJ+~Q6G zapLnpJY<d$GROE9n@?t5YGTnX=Cou(5MgwSJuNxDEI+g27F$|!d{JsKSfHRNzxWnw zT5^0r3RpZduSApa7E5tzPMRj;EzbD(<ebFf;`sQL48QF3GxBp&bMliDbM%W-lM;(d zG86Ol5osecuOu}mCo?@WDKjUtq*5Q8c8VdCUO{CMI|Bm)D1#JBff5iS9}^d&0FwZt z0HXkt7^4`I5F;2eRmp%gq9!^$Sb!xnfiyEPFmNz1Ft9Q(FgSy<+ZhH1h8l(}h8o6Z z##*KthAhS!CMb&u&SHkMSl}#HI18IPHn=Q1n8jSfki}fX+{{=DVsU`QSioW|&5X4m z7AIJY6)eWu%vcLzae>9yz+!C8jI|&ZH&~1vEXLl<Sj%3+Uc+3&Qo~xqmIlgtjDDIt zMSKhl3`Ig90+b4hxIip^5FrdA1VDrch!6!4Vjw~sL`X0&FlcfVNir}nL~&*26_l2M z!m{`lW5q4TvLaCU-C``b#Z-`Ti`mmJq(~a12TaH?FfiO=Ps=YVPb^AN0EMdtD6pBh z7||fQ2#C*AC5<iKpr&cEfzo_lVs2`D{4K8d_}u)I(i{+*CqBNgG%*Jv15P4EN+4(O zLQ`IHYEDjkJi?Khj76a2Qv`A=I4^-hr5F@G9E?1oU=dB;TP)xLkhLVgC^_R62bj{! z%PRt<nJ8fxS1+wJFS#T$KQA%o78eqm(d`yXenDzp6nkZEPFiM8>MfR>{KS+Z9+11( oe8DaUTZC{GgavZbEe;!q-|Rryz8Iv3gOP)gi;;(khf#zX0RC2sxBvhE literal 0 HcmV?d00001 diff --git a/models/model.py b/models/model.py new file mode 100644 index 00000000..107e7a04 --- /dev/null +++ b/models/model.py @@ -0,0 +1,56 @@ +import yaml +import torch +import torch.nn as nn +import torch.nn.functional as F + +with open("./configs/parameters.yaml", "r") as ymlfile: + cfg = yaml.load(ymlfile) + + +class model_embedding_snn(nn.Module): + def __init__(self): + super(model_embedding_snn,self).__init__() + + self.relu = nn.ReLU() + self.dropout = nn.Dropout2d(cfg['dropout']) + + self.batch_norm1 = nn.BatchNorm1d(cfg['first_layer']) + self.batch_norm2 = nn.BatchNorm1d(cfg['second_layer']) + self.batch_norm3 = nn.BatchNorm1d(cfg['third_layer']) + + self.fc1 = nn.Linear(cfg['first_layer'],cfg['second_layer']) + self.fc2 = nn.Linear(cfg['second_layer'],cfg['third_layer']) + + self.fc_voix = nn.Linear(cfg['third_layer'],1) + self.fc_res = nn.Linear(cfg['third_layer'],1) + self.fc_pros = nn.Linear(cfg['third_layer'],1) + self.fc_pd = nn.Linear(cfg['third_layer'],1) + + self.fc_int = nn.Linear(cfg['third_layer'],1) + + + def forward(self, input_embs): + + x = self.batch_norm1(input_embs) + x = self.fc1(x) + x = self.dropout(x) + x = self.relu(x) + x = self.batch_norm2(x) + x = self.fc2(x) + x = self.dropout(x) + x = self.relu(x) + x = self.batch_norm3(x) + + v = self.fc_voix(x) + v = self.relu(v) + r = self.fc_res(x) + r = self.relu(r) + p = self.fc_pros(x) + p = self.relu(p) + pd = self.fc_pd(x) + pd = self.relu(pd) + + INT = self.fc_int(x) + INT = self.relu(INT) + + return INT, v, r, p, pd diff --git a/train.py b/train.py new file mode 100644 index 00000000..b8c300e9 --- /dev/null +++ b/train.py @@ -0,0 +1,32 @@ +import sys +import torch +import torch.nn as nn +import yaml +import models.model +import dataloader.load + +def load_config(config_path): + try: + with open(config_path, 'r') as file: + config = yaml.safe_load(file) + return config + except Exception as e: + print('Error reading the config file') + + +if __name__ == "__main__": + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + config_path = './configs/parameters.yaml' + cfg = load_config(config_path) + + model_snn = models.model.model_embedding_snn().cuda() + criterion = nn.MSELoss() + optimizer = torch.optim.Adam(model_snn.parameters()) + + train_filename = cfg['data_path']+cfg['traininig_set_file'] + print(train_filename) + + for ep in range(cfg['epochs']): + print(ep) -- GitLab