From 8dc363bdcaa599ae2f16b8d503947715a01284ae Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.irit.fr> Date: Thu, 5 May 2022 13:56:59 +0200 Subject: [PATCH] starting train --- ...nts.out.tfevents.1651476773.montana.6860.1 | Bin 370 -> 0 bytes ...nts.out.tfevents.1651476773.montana.6860.2 | Bin 370 -> 0 bytes ...nts.out.tfevents.1651474623.montana.6860.0 | Bin 565 -> 0 bytes ...ts.out.tfevents.1651488612.montana.35996.1 | Bin 172 -> 0 bytes ...ts.out.tfevents.1651488612.montana.35996.2 | Bin 172 -> 0 bytes ...ts.out.tfevents.1651484213.montana.35996.0 | Bin 250 -> 0 bytes ...ts.out.tfevents.1651239629.montana.55449.1 | Bin 172 -> 0 bytes ...ts.out.tfevents.1651239629.montana.55449.2 | Bin 172 -> 0 bytes SuperTagger/EncoderDecoder.py | 104 ------------------ SuperTagger/Linker/Linker.py | 17 ++- .../__pycache__/EncoderDecoder.cpython-38.pyc | Bin 5035 -> 0 bytes SuperTagger/eval.py | 41 +------ SuperTagger/utils.py | 30 +---- train.py | 6 +- 14 files changed, 18 insertions(+), 180 deletions(-) delete mode 100644 Output/Tranning_02-05_08-57/logs/Training vs. Validation Loss_Training/events.out.tfevents.1651476773.montana.6860.1 delete mode 100644 Output/Tranning_02-05_08-57/logs/Training vs. Validation Loss_Validation/events.out.tfevents.1651476773.montana.6860.2 delete mode 100644 Output/Tranning_02-05_08-57/logs/events.out.tfevents.1651474623.montana.6860.0 delete mode 100644 Output/Tranning_02-05_11-36/logs/Training vs. Validation Loss_Training/events.out.tfevents.1651488612.montana.35996.1 delete mode 100644 Output/Tranning_02-05_11-36/logs/Training vs. Validation Loss_Validation/events.out.tfevents.1651488612.montana.35996.2 delete mode 100644 Output/Tranning_02-05_11-36/logs/events.out.tfevents.1651484213.montana.35996.0 delete mode 100644 Output/Tranning_29-04_14-57/logs/Training vs. Validation Loss_Training/events.out.tfevents.1651239629.montana.55449.1 delete mode 100644 Output/Tranning_29-04_14-57/logs/Training vs. Validation Loss_Validation/events.out.tfevents.1651239629.montana.55449.2 delete mode 100644 SuperTagger/EncoderDecoder.py delete mode 100644 SuperTagger/__pycache__/EncoderDecoder.cpython-38.pyc diff --git a/Output/Tranning_02-05_08-57/logs/Training vs. Validation Loss_Training/events.out.tfevents.1651476773.montana.6860.1 b/Output/Tranning_02-05_08-57/logs/Training vs. Validation Loss_Training/events.out.tfevents.1651476773.montana.6860.1 deleted file mode 100644 index c5715fc44ef387ce7abb6b3341d37769ed78cd7c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 370 zcmb1OfPlsI-b$QJAG19l&A#C%#hX-=n3<>NT9%quVr8_!_ho_+L>YsMWY`<d_m3g! z1Q@kcxs<tNLO`lB^U@W{iuDx25_2+B5=%1k^AvpYi;G2VgE}31r_KySF@ZH;0mIYT zHyj0+h%n)fO1Wbr|CBu_Cd|9?sO&k&1ZE;k(AZn(*p%v2gkpk~O~9d7AQM=KFySvp fzGL8T%>yVVxcfB<z6Y7WN`whgI}06I16&#b9h`Kj diff --git a/Output/Tranning_02-05_08-57/logs/Training vs. Validation Loss_Validation/events.out.tfevents.1651476773.montana.6860.2 b/Output/Tranning_02-05_08-57/logs/Training vs. Validation Loss_Validation/events.out.tfevents.1651476773.montana.6860.2 deleted file mode 100644 index e5f81db34af3cd706eea16ef9d3bdc047ce0cae0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 370 zcmb1OfPlsI-b$Qizq36b&A#C%#hX-=n3<>NT9%quVrA4m@#cLah%yEh$*?z^?;k_d z2{3A@aw&7kgn(3M=A|o?73(R4CFW$NB$j06=PCH)7Z;1(kecbZU2Xmn6cbnj7BD=W zeZx_Ji3k&3ew**;qV;GZiV5?sJSuw*GJ%;06QmMnIa*vP)k87C$|m5@E075+M40gZ h!hFXaH%o<3OmO#W6nqadft3go;#=l83Wy(40s#MHcU%Af diff --git a/Output/Tranning_02-05_08-57/logs/events.out.tfevents.1651474623.montana.6860.0 b/Output/Tranning_02-05_08-57/logs/events.out.tfevents.1651474623.montana.6860.0 deleted file mode 100644 index efbfc7d8a55d2cb0f5ed2b1178f473f0e8e93be7..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 565 zcmb1OfPlsI-b$QHdB5sEn0>=hiZ`h!F*8rkwJbHS#LDOd$M+2?5M>Nb+6}&(`FFED zAI-kuD8Q&C&Lzsl=a`&aT9lYvsb8F$S5ho`f3c6%a+4WPRH1ra?J@hz`RpECj|7(( z7e7KzWo}Y_j;NdPS-Tsb&)K0^WZ}Dj;VH->CY%;|S|*)Av54Ovt_Q0{O*Uujs^SdS zqgW(x{!!U;kVVWmEjqZ|>n@5#MHk?Duv%oh_q5$PjTm1Pi|jx7A9@9{hy|xbQVkRK zp;*-L8?FbdMN2N9u{+}P`WlKw>~|Xk--9e-#c5IWvUmj)i(Ws2>%nT#j6KKfjvan5 F8vq}ZyH5ZB diff --git a/Output/Tranning_02-05_11-36/logs/Training vs. Validation Loss_Training/events.out.tfevents.1651488612.montana.35996.1 b/Output/Tranning_02-05_11-36/logs/Training vs. Validation Loss_Training/events.out.tfevents.1651488612.montana.35996.1 deleted file mode 100644 index 5cfc526c1a26ffcabcf14d00f42cc75bf018a1c4..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 172 zcmb1OfPlsI-b$Qu6D4oHoqfYmiZ`h!F*8rkwJbHS#LB2VL~5oHL>YsMWY`<dc5Vr< zIsry4RW4;NnGlfb%)E4kvSK}ju*96ql*E$E{5%Dp{NiHK^Znh9UAO)kp_p*o+q(GU U>>G{(OhlOA*4*rv`KbLe0C09UdjJ3c diff --git a/Output/Tranning_02-05_11-36/logs/Training vs. Validation Loss_Validation/events.out.tfevents.1651488612.montana.35996.2 b/Output/Tranning_02-05_11-36/logs/Training vs. Validation Loss_Validation/events.out.tfevents.1651488612.montana.35996.2 deleted file mode 100644 index 138bd57e18ec66a2f43191cb7b7af73fe3d56277..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 172 zcmb1OfPlsI-b$QW=Ou5xoqfYmiZ`h!F*8rkwJbHS#LDQp@^&2~h%yEh$*?z^?c5Sz zbpni9s$9xkG9e(<nR)37WyN|5VTn1JDTyVS`FRRH`NhSe!W@$vPa9ubjbg%aZ|mZZ Vvu`*GFcDz_*Pi)~42LgX2LNvcHnRW# diff --git a/Output/Tranning_02-05_11-36/logs/events.out.tfevents.1651484213.montana.35996.0 b/Output/Tranning_02-05_11-36/logs/events.out.tfevents.1651484213.montana.35996.0 deleted file mode 100644 index 9883309c9bb3dda787beef437a6a4df6ae931b3f..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 250 zcmb1OfPlsI-b$Q)Rn@&OXWwv?;!P?_%*@ksElbTSu`)V7f0>gCL>YsVc7reH3Py>W zZ)e|d6kya6=Mv@Ob4*SyElNzT)GtoWD=8Mezu3p>y8xquDpb#_J!YRd+d1KSB)G)5 z_z`+4bCdFOM4MEL?4CKMKSr_0#mc()Bgi5qoEDwB;h}(H(NY_@9;_BU4n1Y(?kh7N E04}9Yi~s-t diff --git a/Output/Tranning_29-04_14-57/logs/Training vs. Validation Loss_Training/events.out.tfevents.1651239629.montana.55449.1 b/Output/Tranning_29-04_14-57/logs/Training vs. Validation Loss_Training/events.out.tfevents.1651239629.montana.55449.1 deleted file mode 100644 index ae9e1c819bafc59a81b9616301304b4eb3485cb8..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 172 zcmb1OfPlsI-b$Pw{PQ>ensvibiZ`h!F*8rkwJbHS#L9^Kh~82oh%yEh$*?z^Tbv;3 z1Q@kcxs<tNLO`lB^U@W{iuDx25_2+B5=%1k^AvpYi;G2HI(0ji-oNXJV#0S(b+&)A TZa4}s5n)2GNt2`F0nTIq7zj4P diff --git a/Output/Tranning_29-04_14-57/logs/Training vs. Validation Loss_Validation/events.out.tfevents.1651239629.montana.55449.2 b/Output/Tranning_29-04_14-57/logs/Training vs. Validation Loss_Validation/events.out.tfevents.1651239629.montana.55449.2 deleted file mode 100644 index 95b5a29a10dc72589e9a0cd1bee8aac5c20574f2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 172 zcmb1OfPlsI-b$SAarv8n&AQ<z#hX-=n3<>NT9%quVr3-Qk|1jYQO2Mm8TN*AixWhh z0Hc;Fmok@32uO8iUb;eAv7SO$VoqjCVo7Fxo`O$)ak1#{M-v@=i<Wt!nDAXxo$cSO T8;$}@M3_+9G1GB{Swjy1aw0aK diff --git a/SuperTagger/EncoderDecoder.py b/SuperTagger/EncoderDecoder.py deleted file mode 100644 index 36311d5..0000000 --- a/SuperTagger/EncoderDecoder.py +++ /dev/null @@ -1,104 +0,0 @@ -import torch -from torch.nn import Dropout -from torch.nn import Module - -from Configuration import Configuration -from SuperTagger.Decoder.RNNDecoderLayer import RNNDecoderLayer -from SuperTagger.Encoder.EncoderLayer import EncoderLayer -from SuperTagger.eval import measure_supertagging_accuracy - - -class EncoderDecoder(Module): - """ - A standard Encoder-Decoder architecture. Base for this and many - other models. - - decoder : instance of Decoder - """ - - def __init__(self, BASE_TOKENIZER, BASE_MODEL, symbols_map): - super(EncoderDecoder, self).__init__() - - self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence']) - self.max_symbols_in_sentence = int(Configuration.datasetConfig['max_symbols_in_sentence']) - self.dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder']) - - self.symbols_map = symbols_map - self.sents_padding_id = BASE_TOKENIZER.pad_token_id - self.sents_space_id = BASE_TOKENIZER.bos_token_id - - self.encoder = EncoderLayer(BASE_MODEL) - self.decoder = RNNDecoderLayer(self.symbols_map) - - self.dropout = Dropout(0.1) - - def forward(self, sents_tokenized_batch, sents_mask_batch, symbols_tokenized_batch): - r"""Training the translation from sentence to symbols - - Args: - sents_tokenized_batch: [batch_size, max_len_sentence] the tokenized sentences - sents_mask_batch : mask output from the encoder tokenizer - symbols_tokenized_batch: [batch_size, max_symbols_in_sentence] the true symbols for each sentence. - """ - last_hidden_state, pooler_output = self.encoder([sents_tokenized_batch, sents_mask_batch]) - last_hidden_state = self.dropout(last_hidden_state) - return self.decoder(symbols_tokenized_batch, last_hidden_state, pooler_output) - - def decode_greedy_rnn(self, sents_tokenized_batch, sents_mask_batch): - r"""Predicts the symbols for each sentence in sents_tokenized_batch. - - Args: - sents_tokenized_batch: [batch_size, max_len_sentence] the tokenized sentences - sents_mask_batch : mask output from the encoder tokenizer - """ - last_hidden_state, pooler_output = self.encoder([sents_tokenized_batch, sents_mask_batch]) - last_hidden_state = self.dropout(last_hidden_state) - - predictions = self.decoder.predict_rnn(last_hidden_state, pooler_output) - - return predictions - - def eval_batch(self, batch, cross_entropy_loss): - r"""Calls the evaluating methods after predicting the symbols from the sentence contained in batch - - Args: - batch: contains the tokenized sentences, their masks and the true symbols. - """ - b_sents_tokenized = batch[0].to("cuda" if torch.cuda.is_available() else "cpu") - b_sents_mask = batch[1].to("cuda" if torch.cuda.is_available() else "cpu") - b_symbols_tokenized = batch[2].to("cuda" if torch.cuda.is_available() else "cpu") - - type_predictions = self.decode_greedy_rnn(b_sents_tokenized, b_sents_mask) - - pred = torch.argmax(type_predictions, dim=2) - - predict_trad = [{v: k for k, v in self.symbols_map.items()}[int(i)] for i in pred[0]] - true_trad = [{v: k for k, v in self.symbols_map.items()}[int(i)] for i in b_symbols_tokenized[0]] - l = len([i for i in true_trad if i != '[PAD]']) - print("\nsub true (", l, ") : ", - [token for token in true_trad if token != '[PAD]']) - print("\nsub predict (", len([i for i in predict_trad if i != '[PAD]']), ") : ", - [token for token in predict_trad if token != '[PAD]']) - - return measure_supertagging_accuracy(pred, b_symbols_tokenized, - ignore_idx=self.symbols_map["[PAD]"]), float( - cross_entropy_loss(type_predictions, b_symbols_tokenized)) - - def eval_epoch(self, dataloader, cross_entropy_loss): - r"""Average the evaluation of all the batch. - - Args: - dataloader: contains all the batch which contain the tokenized sentences, their masks and the true symbols - """ - s_total, s_correct, w_total, w_correct = (0.1,) * 4 - - for step, batch in enumerate(dataloader): - batch = batch - batch_output, loss = self.eval_batch(batch, cross_entropy_loss) - ((bs_correct, bs_total), (bw_correct, bw_total)) = batch_output - s_total += bs_total - s_correct += bs_correct - w_total += bw_total - w_correct += bw_correct - - return s_correct / s_total, w_correct / w_total, loss diff --git a/SuperTagger/Linker/Linker.py b/SuperTagger/Linker/Linker.py index 6a39ae7..1f43920 100644 --- a/SuperTagger/Linker/Linker.py +++ b/SuperTagger/Linker/Linker.py @@ -44,7 +44,7 @@ class Linker(Module): LayerNorm(self.dim_polarity_transfo, eps=1e-12) ) - def make_decoder_mask(self, atoms_batch) : + def make_decoder_mask(self, atoms_batch): decoder_attn_mask = torch.ones_like(atoms_batch, dtype=torch.float64) decoder_attn_mask[atoms_batch.eq(self.padding_id)] = 0.0 return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_batch.shape[1], 1) @@ -83,10 +83,12 @@ class Linker(Module): # to do select with list of list pos_encoding = pad_sequence( [atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence)) - for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=self.max_atoms_in_sentence, padding_value=0) + for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=self.max_atoms_in_sentence, + padding_value=0) neg_encoding = pad_sequence( [atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence)) - for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=self.max_atoms_in_sentence, padding_value=0) + for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=self.max_atoms_in_sentence, + padding_value=0) # pos_encoding = self.pos_transformation(pos_encoding) # neg_encoding = self.neg_transformation(neg_encoding) @@ -95,3 +97,12 @@ class Linker(Module): link_weights.append(sinkhorn(weights, iters=3)) return link_weights + + def predict_axiom_links(self): + return None + + def eval_batch(self): + return None + + def eval_epoch(self): + return None diff --git a/SuperTagger/__pycache__/EncoderDecoder.cpython-38.pyc b/SuperTagger/__pycache__/EncoderDecoder.cpython-38.pyc deleted file mode 100644 index 7745cbbcf4e476115694b33969f29399433509b5..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5035 zcmWIL<>g{vU|?8Nla(YP#K7<v#6iX^3=9ko3=9m#DGUq@DGVu$ISf%Cnkk1dmnn*g z5yWQBVa#QYVg|EWa#(U%qFBIe)*RMcwkS3*n=OYumm`V;%x2Hw%;k#WVq|b<Na0A~ zY+*>@Ol59nj^a*X3}(>edI_@MPm}Q$yGv1iL4Ik8CgUwO-~5!)oK#K5TfENsd1;yH zrA3J)nfZB|jJNoM{QO)}lk-zji+mC*Q;ReiZ}GV1q3~sMQxl6zi&Eo@OAAtqN)pr4 zGxO5p6O)rmixQJ7lR@?&V`eC)807I(hA74qhA5^K#wg}?hBU?$rWEEDjwqHCmK4?& zhA7q)wiNajhA1|OPomgUxKg-V7@|1Z8CV#iID;89d2aDRodb1cRXvx2f`WphLUBoA zUP@w7iUL$#7pg!Zu_!qsvm`aSq_ikiPr)g%I8`Anzeu4ZBePf`F)u|SH!-hL0c=Qq zNk(dsLT-LaYEH2p7g#6-5*7+p3YmEz8<JBM^3xQcMuSx(gCZ5`M+OiZ<Y5plu3%tb zC}AjJtYK(oTEMiBfsvtvxr8N)HHA@<p_#FUA)XD&XM*zCp?qd2pQDDkh9RD_hB1q) zh9RE2hB1q$h9RD}gs+APte?L`poSrf6U-JY5lUeJvxHL^f*CYf{cZ{5CRW7fq~^sJ zr{<NU<|U`z5(kMCSLP<==M=|hB1v$kWah?0{rY^?Oc2;<uPJhi6&%M!91IK$w>aYC zGxIV_;^S{IXXcd@fxLQ42pm|@;DDs1B2ESdh9XctDB=cb<A&Lnn^<s*rywyUz9c_8 zH7`ChB}xF~u;TcF#FP|J63R@u#gmj@jF9Jp$QKtRCa1<{rWCP(%wtc5B*!9FFc0eJ zA|?nAn#Ngfu@tA~q}}3katwBj5ApYQ_4AB!4Z6hz=KK1)xcU?cGB7Z#WW2>0AD^6) zSX>+*zmnmXqkcwyZmNEAN@_uVX;F#3Q)y9ZVyS*gYHC4xQDSCJe2|}?elRHAha{$_ zrxxi$6B{%H^a?7A_!t-%gh6>+1e7Qk`IrP41sDaG*cdq&IheQ@d6=ri(KNxr6(uEs z{0j1+Gsq|{1_p*2h6M~Yj5SQ@ObZ!fm}*&SS!);;FxIdvWUOVWVOYRa!(78s!@7_; zm_d^<F@lkSA*3iVGcPkQU7;i+RiUIPF|RlWoXZu`it=+6KtT&i6AC5y3eXsWq<C;{ zDoQW50;g^e4}?L=p)r$Lm6{Ttlvt9SVWkiarsIn<t5S6oK#2=Exy3?k2E{f=58TRP zEGFkB7H2~YRj^V3aTM}POA1O$AdUt302=KI;M9{@m0E;q3N%r}WO2I?>RIIE4|Ql! zX(}uzz$HX#VsZxD(Ry&}G?|M)QClPiO6Q^s3=EoVMdF~;CyF(YZb4Hy$Wss>Mv24R zfacm;f;oxBCGi=VDJh^_Tas9kdW*LpKR+k6C>|0VMKT~u!GtWxO7^t;qVmL|6lG8< z0VNp*4k1P^#wsziBnDNM3@VpE$r6M?%0ZbJoWk@lQy4QiT`|@$FJ!D`u3^YxTEGlS zV+)yTS!-B9sjTWpKv8N+W^zd}IHIAR!VEfv%sg-~Vx%8v84F5ZjzcQNX|kXtfhg_* zNKnKV<>f)sLy;UPuH-?40*FuqMHoz7W`15VdSnVh%Ch+MqSVxsN{~tVplAgZz6|U_ zj9g57OjWSj4OZY_XhDf$Pyq~zW>B;j^Dr_n)G&f7p9KshjG*XT$WSDf!j#990*PYA z6c$MaBsn9v97_#j3M-17A6$;LgJA&+G-0rnFlMpVu%)o3u=O(4vV&`Q_7sj5jv985 zy)F#RjI|tKQBJTZxGDsTa+a{AaHVjxFw}5@_|1&9TqW!&+%;S&JSn_L;w+2|CG07D zC2T4DU{MZ;>J*+70VF<16-Nzw4QmHO7H5i}Btwc&FLMc34O0zU4Qn$~Gh-T4FoULW z)mi7noE%8%O)X2zDNQT^m6y4xB^miC#R`dOC8<RU(1-;|fQlh_DhCy!U;$|5s*s$Y zSCW{Smzn}f@8EI~Ux5l%U<K0zaRa<K0B09a>0GP>%37I4;35&yTmjixf>cx?3Z^1P zPzqy7E=@_i#hhGF`Vz!{391-18E-MCWacJ=1VC9IghAq<3I<$PX)rJ_q%+ho#46M> zg4(4m3>^$v3=0?+GJwLjgsFzHnNgBqB2yttFoPx&iaDUP1u8~@L1`6KCT4=n;frCa zWvpeYVX9%A4AzB6>?@gXG3goHV$8n9ScVWS0+(l8HlXq(IX}0+&KqPNDDko}urV+( z{IWnR(m)XgDNprma`KZCbBgWsY;rP-OTfDI5C*Me23PgBSTjpfbBl}kLDfD}kv1r- z7&BKgg3E4*JD^U08v=F$$ZsGAROyoB0F*Equ&^!ywHzRxgI5|wpt9ZvHFS$WB~6h5 zsF*SW5hfr4p~IAcfdQ-o78n@514)2WfCMZh3Dh#yFs3kMGZ(ScFhWw!O2%KT(E*Mw zu@D!5Z30^YCP1NI1hUB&;Wj1)4u&djuHw=p1#p_zsAADnuu>q;`XV!!uU10ipB0?r zAg)1mgFl)Z_`q&}mIVqLlscdYm)`?W?WNcr3z(N7o`i>SAgVn@=Ad>AV-#0rdR~4J zsIgt4$$g8dB>xs`Nq$jsMv*EgyYOTd$0wF0X67U&<)jv=fP~l*i_&uwD{e98q~_gX zEhqx@4A?<-f?C(NSkrRy6H7F?!FlW!E4a+NC6rv0UtAoYnpaYkUr-sJlV4nXOE3xC zNkM8M+~R==ff`P?gp=Z-B`>6V0BTT136xY8q{c%d1yrNnVgYe(@qoz8<dXQ3qQsP2 zoS>8gVKU}Gf&rYoL1A%=3sH>Nf`SB;Ul`bA7-bk`n0T1@n8X;x7=;))7<vA)Fo}Uk zW;RBS|16A5j2!=&{#QxjEE7=bVNjU^Dr><_trAe*pD~7|mKoHDuVpJ?0F|yT46!~j zOttK_95o!Z9N?xsM-9_LrWmGL&RVV-&KMAjyOt-0sg}2vuZFpX$DE-S)a2(jXQ*YV zVXfgaXQ*YZVXNUaXQ*W>6scjUVU}h9)pybiHC$<+ZUIx(cE_^RqQvx6P^F>(@;Ru+ z%g+OKX%lmDz@p$d!rL%QNi0ds$xlp4EkdfkP_!$QXJjU4!0I-zvEYmbiZoKHHbj$~ zrHBWVs89(_rdyn;d8N6jMTsS;MP8s3%nwOvTnIN7IWaIWM6rWXR!L$`6lZaKa(+=! zYH~>wdwG0GK3J?ACU%RZxFofp$PtvXdB7n7ZFk>d0VUHYt|WwZjwG0wTuE>PqBxSE zra}@LyikEAwA6z9<cx4oLIZUn7}$821Q=OB{X<aikdcFtiwS#@0_)S{0`<Z35_41I z<8N`r$LE4aapL1|@x;d$mL}#vW!U56Q}UDJ<H0@AA|+6`fwC$%fQmrn%Pn4Lmp3^z zCnr80$)!>t#h`ov?(u_KkHw%M=U`-E=MmuG;o#$t0?TOf6j_79hyz?W=;h^syWd5i zrT}<cAW9jjkE{o4tD+4pfjjU~3dky<xk(R7gGZnsy?y~?WuTN>1TJMbKn?~6GJ*iP e{}zW0By{XRO`c*<pPYk{hmnVg2ShURFarSA-4X!+ diff --git a/SuperTagger/eval.py b/SuperTagger/eval.py index 372e68c..7a14ac5 100644 --- a/SuperTagger/eval.py +++ b/SuperTagger/eval.py @@ -3,45 +3,6 @@ from torch import Tensor from torch.nn import Module from torch.nn.functional import nll_loss, cross_entropy -# Another from Kokos function to calculate the accuracy of our predictions vs labels -def measure_supertagging_accuracy(pred, truth, ignore_idx=0): - r"""Evaluation of the decoder's output and the true symbols without considering the padding in the true targets. - - Args: - pred: [batch_size, max_symbols_in_sentence] prediction of symbols for each symbols. - truth: [batch_size, max_symbols_in_sentence] true values of symbols - """ - correct_symbols = torch.ones(pred.size()) - correct_symbols[pred != truth] = 0 - correct_symbols[truth == ignore_idx] = 1 - num_correct_symbols = correct_symbols.sum().item() - num_masked_symbols = len(truth[truth == ignore_idx]) - - correct_sentences = correct_symbols.prod(dim=1) - num_correct_sentences = correct_sentences.sum().item() - - return (num_correct_sentences, pred.shape[0]), \ - (num_correct_symbols - num_masked_symbols, pred.shape[0] * pred.shape[1] - num_masked_symbols) - - -def count_sep(xs, sep_id, dim=-1): - return xs.eq(sep_id).sum(dim=dim) - - -class NormCrossEntropy(Module): - r"""Loss based on the cross entropy, it considers the number of words and ignore the padding. - """ - - def __init__(self, ignore_index, sep_id, weights=None): - super(NormCrossEntropy, self).__init__() - self.ignore_index = ignore_index - self.sep_id = sep_id - self.weights = weights - - def forward(self, predictions, truths): - return cross_entropy(predictions.flatten(0, -2), truths.flatten(), weight=self.weights, - reduction='sum', ignore_index=self.ignore_index) / count_sep(truths.flatten(), self.sep_id) - class SinkhornLoss(Module): def __init__(self): @@ -49,4 +10,4 @@ class SinkhornLoss(Module): def forward(self, predictions, truths): return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean') - for link, perm in zip(predictions, truths)) \ No newline at end of file + for link, perm in zip(predictions, truths)) diff --git a/SuperTagger/utils.py b/SuperTagger/utils.py index cfacf25..fc1511e 100644 --- a/SuperTagger/utils.py +++ b/SuperTagger/utils.py @@ -55,32 +55,4 @@ def format_time(elapsed): elapsed_rounded = int(round(elapsed)) # Format as hh:mm:ss - return str(datetime.timedelta(seconds=elapsed_rounded)) - - -def checkpoint_save(model, opt, epoch, dir, loss): - torch.save({ - 'epoch': epoch, - 'model_state_dict': model.state_dict(), - 'optimizer_state_dict': opt.state_dict(), - 'loss': loss, - }, dir + '/model_check.pt') - - -def checkpoint_load(model, opt, path): - epoch = 0 - loss = 0 - print("#" * 15) - - try: - checkpoint = torch.load(path) - model.load_state_dict(checkpoint['model_state_dict']) - opt.load_state_dict(checkpoint['optimizer_state_dict']) - epoch = checkpoint['epoch'] - loss = checkpoint['loss'] - print("\n The loading checkpoint was successful ! \n") - print("#" * 10) - except Exception as e: - print("\nCan't load checkpoint model because : " + str(e) + "\n\nUse default model \n") - print("#" * 15) - return model, opt, epoch, loss + return str(datetime.timedelta(seconds=elapsed_rounded)) \ No newline at end of file diff --git a/train.py b/train.py index 25154db..9287436 100644 --- a/train.py +++ b/train.py @@ -1,7 +1,5 @@ -import os import pickle import time -from datetime import datetime import numpy as np import torch @@ -13,8 +11,8 @@ from transformers import get_cosine_schedule_with_warmup from Configuration import Configuration from SuperTagger.Linker.Linker import Linker from SuperTagger.Linker.atom_map import atom_map -from SuperTagger.eval import NormCrossEntropy, SinkhornLoss -from SuperTagger.utils import format_time, read_csv_pgbar, checkpoint_save, checkpoint_load +from SuperTagger.eval import SinkhornLoss +from SuperTagger.utils import format_time, read_csv_pgbar from torch.utils.tensorboard import SummaryWriter -- GitLab