From 85e492e2dee7654dc35baee0872ef0ffe932fbd1 Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.irit.fr> Date: Thu, 5 May 2022 13:53:04 +0200 Subject: [PATCH] starting train --- SuperTagger/Decoder/RNNDecoderLayer.py | 180 -------------- .../RNNDecoderLayer.cpython-38.pyc | Bin 5324 -> 0 bytes SuperTagger/Encoder/EncoderInput.py | 18 -- SuperTagger/Encoder/EncoderLayer.py | 67 ----- .../__pycache__/EncoderInput.cpython-38.pyc | Bin 1182 -> 0 bytes .../__pycache__/EncoderLayer.cpython-38.pyc | Bin 2675 -> 0 bytes SuperTagger/Linker/AtomTokenizer.py | 23 +- SuperTagger/Linker/Linker.py | 41 ++-- SuperTagger/Linker/Sinkhorn.py | 1 - .../__pycache__/Sinkhorn.cpython-38.pyc | Bin 687 -> 687 bytes SuperTagger/Linker/utils.py | 1 - SuperTagger/Symbol/SymbolEmbedding.py | 12 - SuperTagger/Symbol/SymbolTokenizer.py | 53 ---- .../SymbolEmbedding.cpython-38.pyc | Bin 858 -> 0 bytes .../SymbolTokenizer.cpython-38.pyc | Bin 2862 -> 0 bytes .../__pycache__/symbol_map.cpython-38.pyc | Bin 563 -> 0 bytes SuperTagger/Symbol/symbol_map.py | 28 --- SuperTagger/eval.py | 12 +- SuperTagger/utils.py | 20 ++ test.py | 59 +++-- train.py | 232 ++++-------------- 21 files changed, 150 insertions(+), 597 deletions(-) delete mode 100644 SuperTagger/Decoder/RNNDecoderLayer.py delete mode 100644 SuperTagger/Decoder/__pycache__/RNNDecoderLayer.cpython-38.pyc delete mode 100644 SuperTagger/Encoder/EncoderInput.py delete mode 100644 SuperTagger/Encoder/EncoderLayer.py delete mode 100644 SuperTagger/Encoder/__pycache__/EncoderInput.cpython-38.pyc delete mode 100644 SuperTagger/Encoder/__pycache__/EncoderLayer.cpython-38.pyc delete mode 100644 SuperTagger/Symbol/SymbolEmbedding.py delete mode 100644 SuperTagger/Symbol/SymbolTokenizer.py delete mode 100644 SuperTagger/Symbol/__pycache__/SymbolEmbedding.cpython-38.pyc delete mode 100644 SuperTagger/Symbol/__pycache__/SymbolTokenizer.cpython-38.pyc delete mode 100644 SuperTagger/Symbol/__pycache__/symbol_map.cpython-38.pyc delete mode 100644 SuperTagger/Symbol/symbol_map.py diff --git a/SuperTagger/Decoder/RNNDecoderLayer.py b/SuperTagger/Decoder/RNNDecoderLayer.py deleted file mode 100644 index 93e96a6..0000000 --- a/SuperTagger/Decoder/RNNDecoderLayer.py +++ /dev/null @@ -1,180 +0,0 @@ -import random - -import torch -import torch.nn.functional as F -from torch.nn import (Module, Dropout, Linear, LSTM) - -from Configuration import Configuration -from SuperTagger.Symbol.SymbolEmbedding import SymbolEmbedding - - -class RNNDecoderLayer(Module): - def __init__(self, symbols_map): - super(RNNDecoderLayer, self).__init__() - - # init params - self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder']) - self.dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder']) - dropout = float(Configuration.modelDecoderConfig['dropout']) - self.num_rnn_layers = int(Configuration.modelDecoderConfig['num_rnn_layers']) - self.teacher_forcing = float(Configuration.modelDecoderConfig['teacher_forcing']) - self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence']) - self.symbols_vocab_size = int(Configuration.datasetConfig['symbols_vocab_size']) - - self.bidirectional = False - self.use_attention = True - self.symbols_map = symbols_map - self.symbols_padding_id = self.symbols_map["[PAD]"] - self.symbols_sep_id = self.symbols_map["[SEP]"] - self.symbols_start_id = self.symbols_map["[START]"] - self.symbols_sos_id = self.symbols_map["[SOS]"] - - # Different layers - # Symbols Embedding - self.symbols_embedder = SymbolEmbedding(self.dim_decoder, self.symbols_vocab_size, - padding_idx=self.symbols_padding_id) - # For hidden_state - self.dropout = Dropout(dropout) - # rnn Layer - if self.use_attention: - self.rnn = LSTM(input_size=self.dim_encoder, hidden_size=self.dim_encoder, num_layers=self.num_rnn_layers, - dropout=dropout, - bidirectional=self.bidirectional, batch_first=True) - else: - self.rnn = LSTM(input_size=self.dim_decoder, hidden_size=self.dim_encoder, num_layers=self.num_rnn_layers, - dropout=dropout, - bidirectional=self.bidirectional, batch_first=True) - - # Projection on vocab_size - if self.bidirectional: - self.proj = Linear(self.dim_encoder * 2, self.symbols_vocab_size) - else: - self.proj = Linear(self.dim_encoder, self.symbols_vocab_size) - - self.attn = Linear(self.dim_decoder + self.dim_encoder, self.max_len_sentence) - self.attn_combine = Linear(self.dim_decoder + self.dim_encoder, self.dim_encoder) - - def sos_mask(self, y): - return torch.eq(y, self.symbols_sos_id) - - def forward(self, symbols_tokenized_batch, last_hidden_state, pooler_output): - r"""Training the translation from encoded sentences to symbols - - Args: - symbols_tokenized_batch: [batch_size, max_len_sentence] the true symbols for each sentence. - last_hidden_state: [batch_size, max_len_sentence, dim_encoder] Sequence of hidden-states at the output of the last layer of the model. - pooler_output: [batch_size, dim_encoder] Last layer hidden-state of the first token of the sequence (classification token) after further processing through the layers used for the auxiliary pretraining task - """ - batch_size, sequence_length, hidden_size = last_hidden_state.shape - - # y_hat[batch_size, max_len_sentence, vocab_size] init with probability pad =1 - y_hat = torch.zeros(batch_size, self.max_len_sentence, self.symbols_vocab_size, - dtype=torch.float, device="cuda" if torch.cuda.is_available() else "cpu") - y_hat[:, :, self.symbols_padding_id] = 1 - - decoded_i = torch.ones(batch_size, 1, dtype=torch.long, - device="cuda" if torch.cuda.is_available() else "cpu") * self.symbols_start_id - - sos_mask = torch.zeros(batch_size, dtype=torch.bool, device="cuda" if torch.cuda.is_available() else "cpu") - - # hidden_state goes through multiple linear layers - hidden_state = pooler_output.unsqueeze(0).repeat(self.num_rnn_layers * (1 + self.bidirectional), 1, 1) - - c_state = torch.zeros(self.num_rnn_layers * (1 + self.bidirectional), batch_size, hidden_size, - dtype=torch.float, device="cuda" if torch.cuda.is_available() else "cpu") - - use_teacher_forcing = True if random.random() < self.teacher_forcing else False - - # for each symbol - for i in range(self.max_len_sentence): - # teacher-forcing training : we pass the target value in the embedding, not a created vector - symbols_embedding = self.symbols_embedder(decoded_i) - symbols_embedding = self.dropout(symbols_embedding) - - output = symbols_embedding - if self.use_attention: - attn_weights = F.softmax( - self.attn(torch.cat((symbols_embedding, hidden_state[0].unsqueeze(1)), 2)), dim=2) - attn_applied = torch.bmm(attn_weights, last_hidden_state) - - output = torch.cat((symbols_embedding, attn_applied), 2) - output = self.attn_combine(output) - output = F.relu(output) - - # rnn layer - output, (hidden_state, c_state) = self.rnn(output, (hidden_state, c_state)) - - # Projection of the output of the rnn omitting the last probability (which is pad) so we dont predict PAD - proj = self.proj(output)[:, :, :-2] - - if use_teacher_forcing: - decoded_i = symbols_tokenized_batch[:, i].unsqueeze(1) - else: - decoded_i = torch.argmax(F.softmax(proj, dim=2), dim=2) - - # Calculate sos and pad - sos_mask_i = self.sos_mask(torch.argmax(F.softmax(proj, dim=2), dim=2)[:, -1]) - y_hat[~sos_mask, i, self.symbols_padding_id] = 0 - y_hat[~sos_mask, i, :-2] = proj[~sos_mask, -1, :] - sos_mask = sos_mask_i | sos_mask - - # Stop if every sentence says padding or if we are full - if not torch.any(~sos_mask): - break - - return y_hat - - def predict_rnn(self, last_hidden_state, pooler_output): - r"""Predicts the symbols from the output of the encoder. - - Args: - last_hidden_state: [batch_size, max_len_sentence, dim_encoder] the output of the encoder - pooler_output: [batch_size, dim_encoder] Last layer hidden-state of the first token of the sequence (classification token) after further processing through the layers used for the auxiliary pretraining task - """ - batch_size, sequence_length, hidden_size = last_hidden_state.shape - - # contains the predictions - y_hat = torch.zeros(batch_size, self.max_len_sentence, self.symbols_vocab_size, - dtype=torch.float, device="cuda" if torch.cuda.is_available() else "cpu") - y_hat[:, :, self.symbols_padding_id] = 1 - # input of the embedder, a created vector that replace the true value - decoded_i = torch.ones(batch_size, 1, dtype=torch.long, - device="cuda" if torch.cuda.is_available() else "cpu") * self.symbols_start_id - - sos_mask = torch.zeros(batch_size, dtype=torch.bool, device="cuda" if torch.cuda.is_available() else "cpu") - - hidden_state = pooler_output.unsqueeze(0).repeat(self.num_rnn_layers * (1 + self.bidirectional), 1, 1) - - c_state = torch.zeros(self.num_rnn_layers * (1 + self.bidirectional), batch_size, hidden_size, - dtype=torch.float, device="cuda" if torch.cuda.is_available() else "cpu") - - for i in range(self.max_len_sentence): - symbols_embedding = self.symbols_embedder(decoded_i) - symbols_embedding = self.dropout(symbols_embedding) - - output = symbols_embedding - if self.use_attention: - attn_weights = F.softmax( - self.attn(torch.cat((symbols_embedding, hidden_state[0].unsqueeze(1)), 2)), dim=2) - attn_applied = torch.bmm(attn_weights, last_hidden_state) - - output = torch.cat((symbols_embedding, attn_applied), 2) - output = self.attn_combine(output) - output = F.relu(output) - - output, (hidden_state, c_state) = self.rnn(output, (hidden_state, c_state)) - - proj_softmax = F.softmax(self.proj(output)[:, :, :-2], dim=2) - decoded_i = torch.argmax(proj_softmax, dim=2) - - # Set sos and pad - sos_mask_i = self.sos_mask(decoded_i[:, -1]) - y_hat[~sos_mask, i, self.symbols_padding_id] = 0 - y_hat[~sos_mask, i, :-2] = proj_softmax[~sos_mask, -1, :] - sos_mask = sos_mask_i | sos_mask - - # Stop if every sentence says padding or if we are full - if not torch.any(~sos_mask): - break - - return y_hat diff --git a/SuperTagger/Decoder/__pycache__/RNNDecoderLayer.cpython-38.pyc b/SuperTagger/Decoder/__pycache__/RNNDecoderLayer.cpython-38.pyc deleted file mode 100644 index cd9f43c2b2c435f545d90c5f77a9e9e4ce5480f2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5324 zcmWIL<>g{vU|?AIAUi2bnStRkh=Yt-7#J8F7#J9e^B5QyQW#Pga~Pr^G-DJan9r2U z%*4RRkjoOq$_SES%3;f8k7Cc|h~miQjN;7YisH)Uj^fVciQ)mvGw1N;@<#E3*(^DH zx%^T5j12A!DXb}MEet7asT|GBQ35H9!3>)0FF`Kx(`38F?oyOrkY8GIi_JGbr8Fn? z78iu}$t*6p#paWlmzr2~i^V56#P=4ndr+t*<1Jq2{JgZx^wOfllFa-(O~zaN!Iim5 z`8lq+NvSC*nR)5SAhVG%Gn7+o!oa|g$^i0R6jM7x8e<A$3R4S56mtr53QG$^6br<q zQLHKKDI6^fQEcrDEDTZX!3>(5xA=qn{9IC#^HWlbd=e{Di;|h3hJq-N1sn_v3?RCA z5+egc2}22E4MQ{20;Yuwj0`2rB`jI2DU6Z~&5Sh+@oZ2&6O_*m<ugP198f+Bl+Ov} zvqJe?CEQ^1*g*VRCa5esh*!f9&jXd^fbw}$IBOW<`BJ!Q7~=VBK=upNFk}g&a7!}O zFvJUj**svj5SYyiW($Mad|<XniD(T&7JCgt7FP{JmS75hFLMn;yjY2N4HMWu5;Y82 z{1X_9#7ZP<7_!))n5TxRh9QeDg)2p%m#u~&UMf!xEDtgVqy|GCq<R8lQBa9A$b=L@ z5eATr%}g~6@iKV=5K)8(m=88f9HfpDtWFk9lnpE@m%<Rtpef{ci#sJVH$F8F9ElK4 z3M58viRUI(#215;QE_}`UVL$CUP)?Ra_TJskVH-@l86vYSy_H^Vp4o@W>xAf_7rF$ z=F2P1jW5c}i_Zb2k>XqYC8>$Y8L36_Y57ITnR)4MAyur=0gf)QRjkp$t^u)C?9ss? zjzJ+{zJGA6CgUycf<$mCj?YY~&}6&Cm6=yiS^}|#J0mkC1>_cxom?QRAl4KKFfcIO z;!Vm-$t+4u2IY#xoLk&Ui6zMy@oAYw#U(F6iB(hn7He^7L26MEs9d<k5g(tKmst`Y ze~URYucQc+d~XTm=BK3QK$AHnn-y_0FfbJHfCyd?!3QGvLH4nx<>V)p6bXU^gg}Hi zNR+p<I5j@8q$D*D<c?e1u+Yj)EVu;^I)oRa_+Sc(Qw!oVQ*H^s`6Y=(B_I)miu~gE z%#<QdkTq~+so-3nT2ut`A~+z5*gy)Hi}Lb{SV1h7f};GaTP%qsC3&}aKxBM!er{4` zUaBVZEtcZcoU|eZkn2HVxRUV}XMB8ePGWI!eEdp=Utand`MIh3$tkG?`K3iA`c9=q zsfnffDXFOi=|zc|Iq^Y$e)_?nFbzpePfsn<hX$QKT4AAAP+26+z`y_s#$pdpiNVOn zEWikZd>|GR8zTtvFoGcq7bDAm4rVSU4n{skCPo$prhi;4Fct?Z2csC10FxMFl?1v? zuvkSYpFm*)D(;;@#XTqnOBfa~g35v{riDzw44RC7n#{LYOY)17Gj1`Z78WTpFfeE` z6={NkkFl~y2NXVFLKh^$QJh~KpPN{mZ3Qw8WHtj9hrkpigDeJl8H7a`7#KjQ5A2XK zCI*HYrYwdS=316o)>^g_#w?~9mKugE<{E}9mJ-G+))K}nwgv198H#jL81tA?m{OQ~ zS!>x-SZdf)SW{SAnIL2hLl#FGvkgNDV-{x(OA1IQ7hGo=Tqnq6ZV`rB4v0M^j9ENz zRXuQ3%(a{~%nNu^*cLJ_-~$zGDIk3U;tV1TDJ&^03z=&{W`l}0euPX7OAW~N2>0HB zo5xzqT_U(Za3QFW63S*M`j*0;$CScR%Tod>Vwsy6T^J@X#%9&>)-Ws(so_}2Sj$(# zut2ni56t4L<*#AL5}UwS)KVgzB~in$K(d5!fm97&4Oa@AB*OyU6qbdIX-p}c3z-%& zrf{Zk^)l58lrS!kt`VqVTF6u@2vsczQ3En<flLkmLdIJD67dDH5LOMt0=XLg8m=1d zG^T}2F-*1mF-)~wwcIrf3*<p6AfXH=xm%gim?ar%g=%;vFc#gaVT8!k@TM`rTvW@E zCsD$fr2uxb5X4;&vX-NUVSyrcIfywaJRoxxGS&*Gu+^~FaPl+M@WNuaMhL`D;en_G zxdNslg*652E*pj#;TldShFZ=N#sx|s)e9LXFcz~gGE879ESkVr$WjBU&v_FKm>3vB ziV`#PGV{_EN-|OvN{SNmigUni3Wc<y{9FY{MUkQas|$)1O7az;r4tvI0u(qFr59Vl zc_2ZklK7JR?9{x>s??Nta7k>X5DlimMY)aws9u4oM5=INp>8Qng;}MLmS3a*szu-q z(L=E=C$YFB9#%1wB$lLNu?V>W(@{XG1!ENyf>R4iK~<eXewqTr7~Nu!F~tgrC1A(o zmzETimVl%|!345S0bK3CM8I_hs+$V(^K(**;vqWG+=F7fPhxQi)HGzfVMc>1F9mQw z!i0)p&eZ@H)0t_R$&fe$tI||ROe;w(Qb;Q;D#=JKQYa|OPfjf^hQx7Eerb9J*h^sd z7b}z&r>1~|7v!<T(u&NS%*3Kfg@U5g5_rl;EY5}p_bry>(v-wo%*h3%E17Pwrj%3` zq~2mnNiEAvPJPM9z`)>F1R_J+iX<5r7*;agVou4-efjVI|Nk%l|NsAAQ~4HaaYkZ6 zYLOnO7-g+WEy^!00@tdb3bjZMRLmNIq<Aum;}gpgGjkG?a#D*FK|(C~d8x&>SaS07 z(u<V90!jJ#Ikz}V^NI^gQ&X!_Z?P4n7No)|WN=-~R+N~RlAl`ys-17K7A5ATrxvMz zECJPZMe-n17~O8M7w4yy<R(@Wse=TVlM_pBF(>8b7HNREEJdj~rA4YBO==)MTVhc< zNUb4=!<?8`sVM@fCd6Th85|FYWPeK#lwwd4=q+B9^l^&|kuPrX!{RqSCp9m<B%??a zWD#p+d`4o)Elx;tEhRp)2y8zOB+bNw)9WqvWC$Zl7}NwoYk1va%)BKCb0Vb8keQc$ ziw%-Qq3we5)XemZl42+?v7jI)Gc~0M>~AhubrzqAULCTh<rkGF7NyvNYQr#4Eyuyb z%*V*WAn>1!NrH)oQQ$udBg=mlW(8&iCNYROvlz1g6Bi>F6AL2;BMYM(1Ji#tCLSgM zMlMD!W(j5mCN@S1CJAOOMj=KOMh<2^MlPl*arCMlYHu>A*8#E?gh5dauVA%66)YpT zf(2Kx;EE4axq>jLnnYAs$Q3MG3QG+eq=<!(u;Le7!7>x9VA&z|;H+RdYMAg=u$*B1 z3|Yt(EOQOWtso3(5Fl5uTqU4h46G_O0asi?HS7x+Yk6uI76{kyfLT1CDmF`G0%K8L ziD;Hs4Z{NQ62=7*H9R$(C{-ytxGH5&;pk<m<tt%aAX&p#!?=*CmLIB`KaDAceIXOb zv;|T%ybBp?c}qkWNJCgP3=3pxcxyOoxS&-jZwyl{XDt`hMp=+5NSMM&PDphM>T*IV z?<~0*0Vt0hTm{1#J6zC;Q80zAhOLH!pP_~u7Be*hAbtuL#8hy_2vw27ngX`LhM`8V zhQkTm{8^v?QoWF|_#L=fR4aT}!xqe-$(`84$iNU#l$w&6Tv7}!D4{t9QXzs$Nl1Qy z6~>S%vPchKb%>Gu(dq+GMTn<<!f7eIR0h?#C<O|d1w**PDFW37MaH1g7c_KHWDc)m zz%@*f1xUsUL|B7b?21J;AU3FWDFT&7*y@%dJCG7kl~QC6Vu410z!gf7BZ%t+B0#P4 zB5=PBRE-q5g2YhkjUqRYm^+97)fz>hK24D)hzqJUio8H9c%@O~4dQ}I>>?kKGljr4 zMv)&#)E`6yfQUd40V?f_f<UZb5CPT`0%C=N2vFf#6b@oVfQU#C5yim35XA%P;>E-2 zv1pJuxVr)-z<n0(0!Xq3b<P_=1vSV{26h|;H4~zs1{cZLifJYZCM`xjW-(|H&BMsS zR3(jGL?hX%$pPwe=OyN*#>d~{ijU6)4|m1K-{OgnFDy;WfyjWniQw9;C<f$>SWxxH z3+<C8r{?6u$0OB0S|BAL|AD*opq_d$s2|V4$fL)>%;5|c)f6mp0VxBeid(|qKAv7) zo?cpM9&Dt5DKGC97lIFt+gltkm5`1#sE#QD_2j^PYH%03h#h1s$i65gq&~17WHb&+ p!3X8Q-FC2pklc@m9}XKxc-Vm&P{p8rG6xd}qY$G46AvQ~GXR!;l0^Uj diff --git a/SuperTagger/Encoder/EncoderInput.py b/SuperTagger/Encoder/EncoderInput.py deleted file mode 100644 index e19da7d..0000000 --- a/SuperTagger/Encoder/EncoderInput.py +++ /dev/null @@ -1,18 +0,0 @@ -import torch - - -class EncoderInput(): - - def __init__(self, tokenizer): - """@params tokenizer (PretrainedTokenizer): Tokenizer that tokenizes text """ - self.tokenizer = tokenizer - - def fit_transform(self, sents): - return self.tokenizer(sents, padding=True,) - - def fit_transform_tensors(self, sents): - temp = self.tokenizer(sents, padding=True, return_tensors='pt', ) - return temp['input_ids'], temp['attention_mask'] - - def convert_ids_to_tokens(self, inputs_ids, skip_special_tokens=False): - return self.tokenizer.batch_decode(inputs_ids, skip_special_tokens=skip_special_tokens) diff --git a/SuperTagger/Encoder/EncoderLayer.py b/SuperTagger/Encoder/EncoderLayer.py deleted file mode 100644 index c954584..0000000 --- a/SuperTagger/Encoder/EncoderLayer.py +++ /dev/null @@ -1,67 +0,0 @@ -import sys - -import torch -from torch import nn - -from Configuration import Configuration - - -class EncoderLayer(nn.Module): - """Encoder class, imput of supertagger""" - - def __init__(self, model): - super(EncoderLayer, self).__init__() - self.name = "Encoder" - - self.bert = model - - self.hidden_size = self.bert.config.hidden_size - - def forward(self, batch): - r""" - Args : - batch: list[str,mask], list of sentences (NOTE: untokenized, continuous sentences) - Returns : - last_hidden_state: [batch_size, max_len_sentence, dim_encoder] Sequence of hidden-states at the output of the last layer of the model. - pooler_output: [batch_size, dim_encoder] Last layer hidden-state of the first token of the sequence (classification token) after further processing through the layers used for the auxiliary pretraining task - """ - b_input_ids = batch[0] - b_input_mask = batch[1] - - outputs = self.bert( - input_ids=b_input_ids, attention_mask=b_input_mask) - - return outputs[0], outputs[1] - - @staticmethod - def load(model_path: str): - r""" Load the model from a file. - Args : - model_path (str): path to model - Returns : - model (nn.Module): model with saved parameters - """ - params = torch.load( - model_path, map_location=lambda storage, loc: storage) - args = params['args'] - model = EncoderLayer(**args) - model.load_state_dict(params['state_dict']) - - return model - - def save(self, path: str): - r""" Save the model to a file. - Args : - path (str): path to the model - """ - print('save model parameters to [%s]' % path, file=sys.stderr) - - params = { - 'args': dict(bert_config=self.bert.config, dropout_rate=self.dropout_rate), - 'state_dict': self.state_dict() - } - - torch.save(params, path) - - def to_dict(self): - return {} diff --git a/SuperTagger/Encoder/__pycache__/EncoderInput.cpython-38.pyc b/SuperTagger/Encoder/__pycache__/EncoderInput.cpython-38.pyc deleted file mode 100644 index 03717155496997dc3ef4713b269d7115fda65fd7..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1182 zcmWIL<>g{vU|?9;=bKc)#K7<v#6iZ)3=9ko3=9m#QVa|XDGVu$ISf$@?hGkRDa<Vl zDa_4GQH&{!!3>%#FG0Hel0hUghU?d3U|>jPh+<4(h+;}%Okrwah+<A*PGM<bh+;`$ zO<`+ch+<9QPT@%5Y+;OIOW_J;(B!$r<C>S8pORYSnO9I+lFWo;Fo+Gp&LFpeyimhf z!w}B^b_-LLyF)=@QDSbfLP>shYF=hlYLS9QKv8N*QDSCZYDx%9P}52Q&Q>VNNGyS? zFIFf?tte6OTgiBf6K;kk(=C?b)SR>;7Ep*W6tOWdF#K}Y&&bbB)lW`IEyyn|D$#c; zElN!+)lW%HEl4j)%*=@o^7GRVE-gqc3Q0^)Pc71idI?5@y{1=Cd5a@HJ~J<~Bt9Nw zY%$1h3`|v`P}LAsFclaf0df#HB={H@7-|@@7-|?nVcg3U%%I8a7oy2{i@hK*B_%U2 zy$BQ`noLEk3=9mnSc_BhN{Ye$fDr5;QQox7lK7IM#Ju9P{Gwb?*g{;xSS5wx8Wd$H zP6F8hvH<KP1r#TxF!wUoGSx7put+l0Fr~0cGNgeV$>tYwi>aVQlc|Uk<OIH=)RNMo zy!evTyyE<#;#-`Vpac+~nNoa<FR`Q~HLoNyKQBHvu{ayzA7-$Bia>6?#Zr=*TLAG6 zC?t!(eiB7-EX-V9P~riF9Rs5PW0ewa-@!DvC4;;MONJmeNH^G<0#I*eG1f5FFr_ep z!=A}clkt{tadu`wd~rc)a%N&qJUBHKYchco(Jh{&#FFHU_>@#oGS`H7BZ>>`+hUM+ zi-Z^$APEiPBZy~`^YhA5i%LKW;!E<OMhb#F15P&RUcsu|Pm|G2lO2@m@)C1X<Ku5} z#mDF7r<CS^*gWy^g{6r(5SbzlkQ=x`1jvpe5fBR;v<L#^t|E}-pztULY2aYuVBuiq iU;&G2GTmY=$uCOI0P`STLDhVV!v<oz9V5&i9Lxat0Ui+m diff --git a/SuperTagger/Encoder/__pycache__/EncoderLayer.cpython-38.pyc b/SuperTagger/Encoder/__pycache__/EncoderLayer.cpython-38.pyc deleted file mode 100644 index f685148ba6d4ea28a6f741836205060a3d60123d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2675 zcmWIL<>g{vU|=xO$V~dm&cN^(#6iX^3=9ko3=9m#b_@&*DGVu$ISf${nlXwog&~D0 zhcTBaiir^<#+<{P%N)hb$l%V9!ji(;!jQt6%9O>@%pAp<!Whh;$@UUtvY#g7EvCFY zO~zZi&iQ$1ndzlPi6xo&dC4G2WDK$z#4iqDU|>jP09hBsl)@Or+|H23n8K98+`<vX z0<kEHHI*%uJ%w!!V+wl;M+-|cV-!a!dkW_q#uTm;?iQ9P&J>;$-WG-^u670%hA8e} z22H+OJg#}k`6;PIK8cm7MODgBo<ed?VsWvKLS}A3X^BF9nnH1DL26M+VtRUNQ8E+A zJx~l{GczzSfN-%E0|P?|LkVLILo?F?riBcQ3@MB?4Drl0Ad;nqA&Vu8HH)o=A)YOT zA(%mv$?q0B)ILr2TdZI!ia`Fo#StH$nU`4-AAgG_FEKau7E4lUQOPZ~WN=8`;?BrS zNlDF%FV3t=)nvNGQk<HTc8fJPKP5G1CF3p5_;|46<KtH{{BqaN$j?pHPfkfK$S*A_ z(RV5>N=+=)Pf1NJNH0pv%!v>3^V1Io`7H$OH+`skU^F;X^a?7A*cliYct9zD1LSc= zK1MbsHl`|3sB(xhn4?hwh?RkX0purVkY+sw28J5O6h=vgTBaJt6edZATIL#tEQT7U z8s-${UZz@>8Wxaj4GTyvjVYKxlO>UXiGhJj0SX+8(u)<W;9QVsQesJRhLu82W^qY$ zaY>O*ZenqEtPYqDj=9vllGME9)M5n<KmQO{D}~a$lKkw{yv(Z96di@+{JfIPywd#A zVufOuLQS~wL8&FBMS18Jf-KEREG~(M#Y0JANvf4XG}s~FSkY0)O{|E|0mTv2P#uMo z%-r}?NLq_kPzX*fECsm=<U)usy2T)4iWL$|6iPBu74l0<ph*wJ16ik#14^PG4ImM4 zbnC%g3UU(&7v$&Xq!z_PbfdWm#RWcz#U)S^k?n^WpO#rvT%rIDD40+&%*h(y^p%;G znGDXd5LKEAiD@ONMG9%9MI{-jMG6H)`N^rp#hH2O3MCmu`K9R@U~hr_P^?f|oSLGL zmR|%GODwI(%*jkFs#GW_N-Zf$%*+GpN-WNXd#{L*fq~&Ah|pxZ#hIB`P+Ah7nNoa< zFR`Q~H4jv9#Dk(;ld*`Cfq_AjwFs0}Z?S@7B#JvJ9<C#b2gU)Zzr_v-sNy041_lNQ zA;`eMaEm=Hzo<O1C`AO6YWYA(o`ZvtgNcQaiGk@q8*`NiO2UUK)?~cJ1&)^Zg2a-H zWRPn?Izbqe13(y@3(Ua9L>6NWLkeRGQwsw)<uleWrm%qWMF~?HLk&|OV=Z$H^8)4? z#uQdbhJ}ob3^mN4!n10vf=_;83OL-rb}6J4<>x9SDx_uRq#}|7mTUr64R)D=MsZ1z zrj-K7l?o;K3Sb#{P-9Q?U^NOFd3k!i`6;D2shU<0Z3^X?B^e6EiDjuN3I&NpiMgpI zsYS(b!%&I>P+18IlweSxf<mK)A(&w$qaQfjHJNU)7nkH0C8np|V$R7=1`9$6aEx%+ z<Rs=Mr6k(vGB7ZF28AiSu!fXAIr)hxdNw)v$%#3|c6u<?nvA!2auW;ULE@lfa*HLg zD82X=7dW%Ur(`CVXtLa5Ey*uR&bY+_G6GTn^Mjc2$cl<2K<R=F?8o9Fc~ElU1rewT zr$~u`fk6f2Kae;B2Pna?{byt1VB})tW2zEFNj4xmG#PKPfZ_oqpg@%gD15-F1ytph zfGQG3F@_SxET$BuUM5C{5>UCxlEqrXki`Z{KQ#;s*cURSu*_kqWh&uVz*)mo!?+Ms zA+v&0O^iZtVp%Fuswv4QCe46chMrpBg%3QEt7Jj(4^0o?00JihP|}K4Esh0w5tKSL znWDHs^;kT#8skYR$}h+-Er~BmEJ@X5f`q980|SGfCf6<2f}+g4l3UEhmBqK%ic3;b zi;8ZsfT9bW!HYo6l_GUe_=7YQDT3mL1>_1%W^mRoQU%F@A_9_xK(SE-N=HRHpmq)@ zf<Q(vunI8?FtYt)V?~P^kWnZZ7?f^77#t;_xXEM)W+;*b*#a_KlM!Mqhz+q3WZ5nD zl6-I)a{^fiO8yLtXhjE9fhG^A&Ii@?@$t8~;^T9{nJhm37EgS9VQFFxRE9l1J|#an zJ|0{h7m0&>%>zoLnaR1SB^miCx0s7dioj_Q?9U=lUb)2!Z4e}<=H$f3LtF=TC#YP5 z1PaJQpn|g)B*ekU!o<VDA;82az|Y4A&JUVwMVcV5g2K57<c}gy61~L+DGR_#kjw=e a4)#BX4a6gMpr9)TH61vZI2d`Dc$fk7oUlOv diff --git a/SuperTagger/Linker/AtomTokenizer.py b/SuperTagger/Linker/AtomTokenizer.py index e400d4e..6df55ed 100644 --- a/SuperTagger/Linker/AtomTokenizer.py +++ b/SuperTagger/Linker/AtomTokenizer.py @@ -1,5 +1,7 @@ import torch +from SuperTagger.utils import pad_sequence + class AtomTokenizer(object): def __init__(self, atom_map, max_atoms_in_sentence): @@ -28,24 +30,3 @@ class AtomTokenizer(object): def convert_ids_to_atoms(self, ids): return [self.inverse_atom_map[int(i)] for i in ids] - - -def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400): - max_size = sequences[0].size() - trailing_dims = max_size[1:] - if batch_first: - out_dims = (len(sequences), max_len) + trailing_dims - else: - out_dims = (max_len, len(sequences)) + trailing_dims - - out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value) - for i, tensor in enumerate(sequences): - length = tensor.size(0) - # use index notation to prevent duplicate references to the tensor - if batch_first: - out_tensor[i, :length, ...] = tensor - else: - out_tensor[:length, i, ...] = tensor - - return out_tensor - diff --git a/SuperTagger/Linker/Linker.py b/SuperTagger/Linker/Linker.py index 6b5c6f1..6a39ae7 100644 --- a/SuperTagger/Linker/Linker.py +++ b/SuperTagger/Linker/Linker.py @@ -2,6 +2,7 @@ from itertools import chain import torch from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU +from torch.nn import Module from Configuration import Configuration from SuperTagger.Linker.AtomEmbedding import AtomEmbedding @@ -10,11 +11,12 @@ from SuperTagger.Linker.atom_map import atom_map from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch from SuperTagger.Linker.AttentionLayer import FFN, AttentionLayer +from SuperTagger.utils import pad_sequence - -class Linker: +class Linker(Module): def __init__(self): + super(Linker, self).__init__() self.__init__() self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder']) @@ -71,20 +73,25 @@ class Linker: atoms_polarity = find_pos_neg_idexes(category_batch) link_weights = [] - for sentence_idx in range(len(atoms_polarity)): - for atom_type in self.atom_map.keys(): - pos_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if - x and atoms_batch[sentence_idx][i] == atom_type] - neg_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if - not x and atoms_batch[sentence_idx][i] == atom_type] - - pos_encoding = atoms_encoding[sentence_idx, pos_idx_per_atom_type, :] - neg_encoding = atoms_encoding[sentence_idx, neg_idx_per_atom_type, :] - - pos_encoding = self.pos_transformation(pos_encoding) - neg_encoding = self.neg_transformation(neg_encoding) - - weights = torch.bmm(pos_encoding.unsqueeze(0), neg_encoding.transpose(1, 0).unsqueeze(0)) - link_weights.append(sinkhorn(weights, iters=self.sinkhorn_iters)) + for atom_type in self.atom_map.keys(): + pos_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if + x and atoms_batch[s_idx][i] == atom_type] for s_idx in range(len(atoms_polarity))] + neg_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if + not x and atoms_batch[s_idx][i] == atom_type] for s_idx in + range(len(atoms_polarity))] + + # to do select with list of list + pos_encoding = pad_sequence( + [atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence)) + for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=self.max_atoms_in_sentence, padding_value=0) + neg_encoding = pad_sequence( + [atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence)) + for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=self.max_atoms_in_sentence, padding_value=0) + + # pos_encoding = self.pos_transformation(pos_encoding) + # neg_encoding = self.neg_transformation(neg_encoding) + + weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1)) + link_weights.append(sinkhorn(weights, iters=3)) return link_weights diff --git a/SuperTagger/Linker/Sinkhorn.py b/SuperTagger/Linker/Sinkhorn.py index 912abb4..9cf9b45 100644 --- a/SuperTagger/Linker/Sinkhorn.py +++ b/SuperTagger/Linker/Sinkhorn.py @@ -1,4 +1,3 @@ - from torch import logsumexp diff --git a/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc b/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc index afbe4b3ab22416e9fb3fa5b7e422587b47fe3c95..26d18c0959a2b1114a5564ff8b1ec4304f81acf3 100644 GIT binary patch delta 56 zcmZ3_x}KFcl$V!_fq{X+bw*K=%0^y(Mn;y&vW$(49FylWE@tGJ?93F)&&a^QP|U`_ Mz`(=I!NS1;07MT7ng9R* delta 56 zcmZ3_x}KFcl$V!_fq{Wx#@~V@)s4LTjEt<4Wf>b8IVaC&T+GNj*_kPppNWBip_q+< Mfq{pagN1_y0Aj}qRR910 diff --git a/SuperTagger/Linker/utils.py b/SuperTagger/Linker/utils.py index ddb8cb5..f2e72e1 100644 --- a/SuperTagger/Linker/utils.py +++ b/SuperTagger/Linker/utils.py @@ -92,4 +92,3 @@ def find_pos_neg_idexes(batch_symbols): list_symbols.append(cut_category_in_symbols(category)) list_batch.append(list_symbols) return list_batch - diff --git a/SuperTagger/Symbol/SymbolEmbedding.py b/SuperTagger/Symbol/SymbolEmbedding.py deleted file mode 100644 index b982ef0..0000000 --- a/SuperTagger/Symbol/SymbolEmbedding.py +++ /dev/null @@ -1,12 +0,0 @@ -import torch -from torch.nn import Module, Embedding - - -class SymbolEmbedding(Module): - def __init__(self, dim_decoder, atom_vocab_size, padding_idx): - super(SymbolEmbedding, self).__init__() - self.emb = Embedding(num_embeddings=atom_vocab_size, embedding_dim=dim_decoder, padding_idx=padding_idx, - scale_grad_by_freq=True) - - def forward(self, x): - return self.emb(x) diff --git a/SuperTagger/Symbol/SymbolTokenizer.py b/SuperTagger/Symbol/SymbolTokenizer.py deleted file mode 100644 index cded840..0000000 --- a/SuperTagger/Symbol/SymbolTokenizer.py +++ /dev/null @@ -1,53 +0,0 @@ - -import torch - - -class SymbolTokenizer(object): - def __init__(self, symbol_map, max_symbols_in_sentence, max_len_sentence): - self.symbol_map = symbol_map - self.max_symbols_in_sentence = max_symbols_in_sentence - self.max_len_sentence = max_len_sentence - self.inverse_symbol_map = {v: k for k, v in self.symbol_map.items()} - self.sep_token = '[SEP]' - self.pad_token = '[PAD]' - self.sos_token = '[SOS]' - self.sep_token_id = self.symbol_map[self.sep_token] - self.pad_token_id = self.symbol_map[self.pad_token] - self.sos_token_id = self.symbol_map[self.sos_token] - - def __len__(self): - return len(self.symbol_map) - - def convert_symbols_to_ids(self, symbol): - return self.symbol_map[str(symbol)] - - def convert_sents_to_ids(self, sentences): - return torch.as_tensor([self.convert_symbols_to_ids(symbol) for symbol in sentences]) - - def convert_batchs_to_ids(self, batchs_sentences): - return torch.as_tensor(pad_sequence([self.convert_sents_to_ids(sents) for sents in batchs_sentences], - max_len=self.max_symbols_in_sentence, padding_value=self.pad_token_id)) - - def convert_ids_to_symbols(self, ids): - return [self.inverse_symbol_map[int(i)] for i in ids] - - -def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400): - max_size = sequences[0].size() - trailing_dims = max_size[1:] - if batch_first: - out_dims = (len(sequences), max_len) + trailing_dims - else: - out_dims = (max_len, len(sequences)) + trailing_dims - - out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value) - for i, tensor in enumerate(sequences): - length = tensor.size(0) - # use index notation to prevent duplicate references to the tensor - if batch_first: - out_tensor[i, :length, ...] = tensor - else: - out_tensor[:length, i, ...] = tensor - - return out_tensor - diff --git a/SuperTagger/Symbol/__pycache__/SymbolEmbedding.cpython-38.pyc b/SuperTagger/Symbol/__pycache__/SymbolEmbedding.cpython-38.pyc deleted file mode 100644 index 030ce696540244b363763b0b592d2a84a8c19fd9..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 858 zcmWIL<>g{vU|^UZ>zlNck%8ech=Yt-7#J8F7#J9ebr={JQW#Pga~PsPG*b>^E>jd! zE^`z!BZE6b3Udle3quM^DpNCa6iW(YFoP!ROOQE!noPIYeDhOEb5d_{y5=UOrle%% zr6+@=kTElqQ>?<kz>vxi#hAhn#njG_#+bsG!qmbM#SF0^ilv=_g&~R+Wc)4u;L6;j z{2YX#Ah$xzVF0n&7#J8p27?VPVJKm&VQ6Mrz_gHok)edShN*_Jh8bjhFG~$WJWC2g zFoPzuUx+5lExx?c-1t<OTZ?b;!rAdDnYp*P3lhPeh|f%^xFuAaoS2gupI(%h5}#BV zpH`GwsL6VZwYan(wWtW>h+7=-@tJv<CGqh^Ah+CNPR&iyWVyvsoSKt%i#sJVH$Ejb zIX@+}D2hL^BtJL4EI&ChDZV(fDz%86fq`Kq<1Nnk_~e|#;^O%Dl?=Z;^)vEwQ}vTm zQVa4+i%RsJN{dnxOZ8JyQw!3I5;Jq+gZ%vTgF&{3B&MgQ7U@F*2};4kO0S@@2o&+4 zAS&hn1sWqCBL^eX|0)TH282qO^OKoC%Agp;26@LBlr%sF)i5kzs9{_Pig-ppO~xW_ z1_lOArXn5&28LUV6-5FJ3=9xL5G2B$mS0q!Sd;=%RSYtVfw4*)-D0R3O*T*v<|XE) z#>d~{ijU6)C#LxLTRidcg{6r(5E-y<iiAOi^FqTbIW;FIJ|1E#*nU2cBS7|pBZd{^ zbWjX&F!Hd1ML-c(l3$dZaf<^?>E-1WfgA^pIIs#N!@wqj47<f)196faBf>-hW&k<w B(5(Of diff --git a/SuperTagger/Symbol/__pycache__/SymbolTokenizer.cpython-38.pyc b/SuperTagger/Symbol/__pycache__/SymbolTokenizer.cpython-38.pyc deleted file mode 100644 index 7c631aa22d6bb2cee2970ecccb0a089a347c9f95..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2862 zcmWIL<>g{vU|`r0>zj0voq^#oh=Yt-7#J8F7#J9e)fgBUQW#Pga~Pr++!<1sQkYv9 zQkYX2o0+4SQaDoBQrKG<qnJ||gBdhAUxG~VO9qk1m<eW<4Fdy1Dnk@w3PTiA3S$aW z3qurh3Udle3qur33Tp~m3qurZ3VRAi3qurJ3TFye3qurp3U>-m3quq~3U4rjCf_ao z;L6;j{G5>d?9{x>s??%nkV9dXf!NFp3=Ga7hi5P_Fw`*CFvK&|Fx4=`GuANIFvLTx ztzpPwSiroH!G)n2q=F>{B*&V<Qo|6>mcm-Y5YG-$QNxhMD#=j85YGYSv4MG<P#!y& z$Cbhm%%I8Um&^!r1p@;E$PpmN1cMwR#=yXk$xy=(%NN5`%UH`)!&JjKnW>N^m|-PD z5lE#b^DQPlgIkQ*w;0R70$}2or+!9$ZmNEAN@_uVX;F#3Q)y9ZVyS*gYHC4xQDSCJ ze2|}?esF0)YEejHdU|S+J|t+M6g-6V3My}L*`#D9mn7%s7TAHJy_k)Gfq{*Ije&`w zN(Z6~p*%i5GcU6wK3>lzCqFqcr`S#pp}C4RI@mQJwu&`6z|jRn1^WlbYI5J=Dh9_& zd~RaFE%DsMig*aWI6gBkzBo0nBsDKN^_BogA}19|gf+7yHMjVdP-b3PYEf}2R2SG3 z&f?U9_!3Yo-{LGtOo1??IE(X(A<SDma0T(1DYtmwYT`3fqIlrS;xkh;S#Gfur{<&; zaWXJ46oI^7#0_GBLLMB@MeGa=47WI7{sx6ju?Q$K8QH*43{0{y@-gx;R!QIrE|mBO zr3Fx`2PZ_3l_d-{3|WlLjKK^m8T~YwZZYSi=7AlwlCek#i)+~9<3S-650WbenZ&?Y zC5~w|SWPlm9^_yK1_lrt<TJ3%e4s4Aki}5KSi{)ND9I4aPz3UrCKK4|Tg=5JMVd@t z`){#9Vhj=>AOj$wB9@#FiklK>)PMp%J~O3Q6qK$(RxvPEsi8XyyA}*rvO@Dou>=DH zLkYtI#&(7@MsTUXQNswzm<vH=3NvcP0@(@DC=7A{C}X9AGgd8Q2g3q}8pef;6Tuk^ z><vxEA}Iz222G|SP*yCG1O+ol3s@53JuaJ^%wkBs07(~v6oPYw1x`<;=9R!g0V#*T zjnQPf#afbIl$>#kGqE_nBsH%%9~Q<@oUp7=3~>V3G2p}}0uM7J<3Zs9i+dGf+=CIs zAhSVn2`<DzL0rOw9>^UGS<D>_Su82cz0BYMXQ3#-<q!dWixuSE;v%pwL9PG?d66^& z1A`1G_TXV{iOX+Ei6zMy1Ol5KTBzLOEl5mB$;?ZSFH6iRP1R&A0$EZd3(DHy!l5{| zuoP4bK}rBYh&w^qD@p)r56HP76&SH23X2`6!jgPY9sqfym;;pU7{wS_7+C&SsS@Q8 zj93DN4G0&5T!S7z!3>)CvH&Q7gR_7TC<`ERI&wB(D&hr&0&`|w2{bR<V$8%AEea^n zVui~$nJM5nffN|U*fIe)dBaK;=FAj~5QP^jsAkE7901DV_=*)&Eq<DuprSi3F*h|n z{uWn!d~SY9X%2|Z6CYn#nwSHTDFPK3MdF|$02E|JpmbWK4C1PT2vG2Y9StEs`KAcu z4p323%)`LIz`@AD!@&%OEMS?CmlGHn7*JXtpjrh~d4P-FdElZqg;A1W0mDLuT9z7? z6s8o0RwhY?TGkrIBE1sE8c^GVIgP1?Rh*%gEl;3^86pM~sbK)=&SJ`9u3_tASin-l zw2-luy@VB{&xIk@CWfh&qn5LVa{)sN<3h$-t{TRowi?bF_8N{9h8nI`rWEEhW*dgW z7?2o5q=o}7#vaU|$>LW8s;6Fp2t7@<TP($(1{J7izr~W0Sdw^)IWM*R7He8&PEPzS z&eXip+|;7Pl2lFZTb!`WRD6pYoG0SbGK-2!iWEU%2}(t`I6&DFWZo^_5>Q(Olrd8> zbBk|r<d>GjgIQc47NmeIQUxVYHVEeyTTW_TdPxQ(D!?%Vjt?aU1_n@!f$IwfMjl2U zMixdPCIKcfMgc}HCKg5!CK0A0kSa~)5U|gRv_ZL?1zZx`V#`m;N=+^SM;=%elKF5; UZgJQ^(vKY|_QBNw2eW`200G;je*gdg diff --git a/SuperTagger/Symbol/__pycache__/symbol_map.cpython-38.pyc b/SuperTagger/Symbol/__pycache__/symbol_map.cpython-38.pyc deleted file mode 100644 index 1e1195f2ba885473f1d89b05684939ab6a2385b1..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 563 zcmWIL<>g{vU|`rJ8j^I7k%8ech=Yu!85kHG7#J9e?HCvsQW#PgQ<zeiQ&>`1Q`l12 zQ#evMQ@B#NQ+QH%Q}|N&Qv^~3Q-o54Q$$ikQ^ZolQzTL(Q>5ldMKPpE2Qz5Oyabv2 zl97Rd;UyD@U<MH^Ac7S{uz?765WxW=I6(x+buYO=EFKWS3nKVH1V4xn01<*9LI^|% zg9s52AqpbIK!iAmkN^>q3=9mKQc*0)Iq^lm7-McR=G|f}h+;2}FDNKVExyH^l9?FA zS{z@VQ5eOXlUnkNQ6I$3%u9=6D~>NnEG~{>DJY0fjAAP&h%ZPiiefD&h)+ocn^$>@ zIVr#57IR5O$t}iWh?$8+B~i@9@rAdT@(OM-6%<slMhCkF#DXaQ;Mgkm=-?2?pb!v0 zz|kevPm}Q$S8-)-QhrW+Zeqboh9Wfv1_<%XML#1yH&s75CAA>Gw5UYiskA6Hu~a`L zHMJnUC^0i9KFH5cKe)6YwJ0PpJw3HZKNxI|KEgP?g34PQHo5sJr8%i~pr9`{V_;xl LVk8+pXZa5R2VIM` diff --git a/SuperTagger/Symbol/symbol_map.py b/SuperTagger/Symbol/symbol_map.py deleted file mode 100644 index c16b8fd..0000000 --- a/SuperTagger/Symbol/symbol_map.py +++ /dev/null @@ -1,28 +0,0 @@ -symbol_map = \ - {'cl_r': 0, - '\\': 1, - 'n': 2, - 'p': 3, - 's_ppres': 4, - 'dia': 5, - 's_whq': 6, - 'let': 7, - '/': 8, - 's_inf': 9, - 's_pass': 10, - 'pp_a': 11, - 'pp_par': 12, - 'pp_de': 13, - 'cl_y': 14, - 'box': 15, - 'txt': 16, - 's': 17, - 's_ppart': 18, - 's_q': 19, - 'np': 20, - 'pp': 21, - '[SEP]': 22, - '[SOS]': 23, - '[START]': 24, - '[PAD]': 25 - } diff --git a/SuperTagger/eval.py b/SuperTagger/eval.py index 3017c7f..372e68c 100644 --- a/SuperTagger/eval.py +++ b/SuperTagger/eval.py @@ -1,8 +1,7 @@ import torch from torch import Tensor from torch.nn import Module -from torch.nn.functional import cross_entropy - +from torch.nn.functional import nll_loss, cross_entropy # Another from Kokos function to calculate the accuracy of our predictions vs labels def measure_supertagging_accuracy(pred, truth, ignore_idx=0): @@ -42,3 +41,12 @@ class NormCrossEntropy(Module): def forward(self, predictions, truths): return cross_entropy(predictions.flatten(0, -2), truths.flatten(), weight=self.weights, reduction='sum', ignore_index=self.ignore_index) / count_sep(truths.flatten(), self.sep_id) + + +class SinkhornLoss(Module): + def __init__(self): + super(SinkhornLoss, self).__init__() + + def forward(self, predictions, truths): + return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean') + for link, perm in zip(predictions, truths)) \ No newline at end of file diff --git a/SuperTagger/utils.py b/SuperTagger/utils.py index 8712cca..cfacf25 100644 --- a/SuperTagger/utils.py +++ b/SuperTagger/utils.py @@ -5,6 +5,26 @@ import torch from tqdm import tqdm +def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400): + max_size = sequences[0].size() + trailing_dims = max_size[1:] + if batch_first: + out_dims = (len(sequences), max_len) + trailing_dims + else: + out_dims = (max_len, len(sequences)) + trailing_dims + + out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value) + for i, tensor in enumerate(sequences): + length = tensor.size(0) + # use index notation to prevent duplicate references to the tensor + if batch_first: + out_tensor[i, :length, ...] = tensor + else: + out_tensor[:length, i, ...] = tensor + + return out_tensor + + def read_csv_pgbar(csv_path, nrows=float('inf'), chunksize=500): print("\n" + "#" * 20) print("Loading csv...") diff --git a/test.py b/test.py index f208027..9e14d08 100644 --- a/test.py +++ b/test.py @@ -1,27 +1,54 @@ from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn import torch -atoms_batch = [["np", "v", "np", "v","np", "v", "np", "v"], - ["np", "np", "v", "v","np", "np", "v", "v"]] -atoms_polarity = [[False, True, True, False,False, True, True, False], - [True, False, True, False,True, False, True, False]] +def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400): + max_size = sequences[0].size() + trailing_dims = max_size[1:] + if batch_first: + out_dims = (len(sequences), max_len) + trailing_dims + else: + out_dims = (max_len, len(sequences)) + trailing_dims -atoms_encoding = torch.randn((2, 8, 24)) + out_tensor = sequences[0].data.new(*out_dims).fill_(padding_value) + for i, tensor in enumerate(sequences): + length = tensor.size(0) + # use index notation to prevent duplicate references to the tensor + if batch_first: + out_tensor[i, :length, ...] = tensor + else: + out_tensor[:length, i, ...] = tensor + + return out_tensor -matches = [] -for sentence_idx in range(len(atoms_polarity)): - for atom_type in ["np", "v"]: - pos_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if - x and atoms_batch[sentence_idx][i] == atom_type] - neg_idx_per_atom_type = [i for i, x in enumerate(atoms_polarity[sentence_idx]) if - not x and atoms_batch[sentence_idx][i] == atom_type] +atoms_batch = [["np", "v", "np", "v", "np", "v", "np", "v"], + ["np", "np", "v", "v"]] - pos_encoding = atoms_encoding[sentence_idx, pos_idx_per_atom_type, :] - neg_encoding = atoms_encoding[sentence_idx, neg_idx_per_atom_type, :] +atoms_polarity = [[False, True, True, False, False, True, True, False], + [True, False, True, False]] - weights = torch.bmm(pos_encoding.unsqueeze(0), neg_encoding.transpose(1, 0).unsqueeze(0)) - matches.append(sinkhorn(weights, iters=3)) +atoms_encoding = torch.randn((2, 8, 24)) + +matches = [] +for atom_type in ["np", "v"]: + pos_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if + x and atoms_batch[s_idx][i] == atom_type] for s_idx in range(len(atoms_polarity))] + neg_idx_per_atom_type = [[i for i, x in enumerate(atoms_polarity[s_idx]) if + not x and atoms_batch[s_idx][i] == atom_type] for s_idx in range(len(atoms_polarity))] + + # to do select with list of list + pos_encoding = pad_sequence([atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence)) + for i, sentence in enumerate(pos_idx_per_atom_type)], max_len=3, padding_value=0) + neg_encoding = pad_sequence([atoms_encoding.select(0, index=i).index_select(0, index=torch.as_tensor(sentence)) + for i, sentence in enumerate(neg_idx_per_atom_type)], max_len=3, padding_value=0) + + print(neg_encoding.shape) + + weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1)) + print(weights.shape) + print("sinkhorn") + print(sinkhorn(weights, iters=3).shape) + matches.append(sinkhorn(weights, iters=3)) print(matches) diff --git a/train.py b/train.py index 58ebe45..25154db 100644 --- a/train.py +++ b/train.py @@ -1,134 +1,51 @@ import os +import pickle import time from datetime import datetime import numpy as np import torch import torch.nn.functional as F -import transformers from torch.optim import SGD, Adam, AdamW from torch.utils.data import Dataset, TensorDataset, random_split -from transformers import (AutoTokenizer, get_cosine_schedule_with_warmup) -from transformers import (CamembertModel) +from transformers import get_cosine_schedule_with_warmup from Configuration import Configuration -from SuperTagger.Encoder.EncoderInput import EncoderInput -from SuperTagger.EncoderDecoder import EncoderDecoder -from SuperTagger.Symbol.SymbolTokenizer import SymbolTokenizer -from SuperTagger.Symbol.symbol_map import symbol_map -from SuperTagger.eval import NormCrossEntropy +from SuperTagger.Linker.Linker import Linker +from SuperTagger.Linker.atom_map import atom_map +from SuperTagger.eval import NormCrossEntropy, SinkhornLoss from SuperTagger.utils import format_time, read_csv_pgbar, checkpoint_save, checkpoint_load from torch.utils.tensorboard import SummaryWriter -transformers.TOKENIZERS_PARALLELISM = True torch.cuda.empty_cache() # region ParamsModel -max_len_sentence = int(Configuration.modelDecoderConfig['max_len_sentence']) -symbol_vocab_size = int(Configuration.modelDecoderConfig['symbols_vocab_size']) -num_gru_layers = int(Configuration.modelDecoderConfig['num_rnn_layers']) +max_len_sentence = int(Configuration.datasetConfig['max_len_sentence']) +atom_vocab_size = int(Configuration.datasetConfig['atoms_vocab_size']) # endregion ParamsModel # region ParamsTraining -file_path = 'Datasets/m2_dataset.csv' batch_size = int(Configuration.modelTrainingConfig['batch_size']) nb_sentences = batch_size * 40 epochs = int(Configuration.modelTrainingConfig['epoch']) seed_val = int(Configuration.modelTrainingConfig['seed_val']) learning_rate = float(Configuration.modelTrainingConfig['learning_rate']) -loss_scaled_by_freq = True # endregion ParamsTraining -# region OutputTraining - -outpout_path = str(Configuration.modelTrainingConfig['output_path']) - -training_dir = os.path.join(outpout_path, 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M')) -logs_dir = os.path.join(training_dir, 'logs') - -checkpoint_dir = training_dir -writer = SummaryWriter(log_dir=logs_dir) - -use_checkpoint_SAVE = bool(Configuration.modelTrainingConfig.getboolean('use_checkpoint_SAVE')) - -# endregion OutputTraining - -# region InputTraining - -input_path = str(Configuration.modelTrainingConfig['input_path']) -model_to_load = str(Configuration.modelTrainingConfig['model_to_load']) -model_to_load_path = os.path.join(input_path, model_to_load) -use_checkpoint_LOAD = bool(Configuration.modelTrainingConfig.getboolean('use_checkpoint_LOAD')) - -# endregion InputTraining - -# region Print config - -print("##" * 15 + "\nConfiguration : \n") - -print("ParamsModel\n") - -print("\tsymbol_vocab_size :", symbol_vocab_size) -print("\tbidirectional : ", False) -print("\tnum_gru_layers : ", num_gru_layers) - -print("\n ParamsTraining\n") - -print("\tDataset :", file_path) -print("\tb_sentences :", nb_sentences) -print("\tbatch_size :", batch_size) -print("\tepochs :", epochs) -print("\tseed_val :", seed_val) - -print("\n Output\n") -print("\tuse checkpoint save :", use_checkpoint_SAVE) -print("\tcheckpoint_dir :", checkpoint_dir) -print("\tlogs_dir :", logs_dir) - -print("\n Input\n") -print("\tModel to load :", model_to_load_path) -print("\tLoad checkpoint :", use_checkpoint_LOAD) - -print("\nLoss and optimizer : ") - -print("\tlearning_rate :", learning_rate) -print("\twith loss scaled by freq :", loss_scaled_by_freq) - -print("\n Device\n") -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -print("\t", device) - -print() -print("##" * 15) - -# endregion Print config - -# region Model +# region Data loader file_path = 'Datasets/m2_dataset.csv' -BASE_TOKENIZER = AutoTokenizer.from_pretrained( - 'camembert-base', - do_lower_case=True) -BASE_MODEL = CamembertModel.from_pretrained("camembert-base") -symbols_tokenizer = SymbolTokenizer(symbol_map, max_len_sentence, max_len_sentence) -sents_tokenizer = EncoderInput(BASE_TOKENIZER) -model = EncoderDecoder(BASE_TOKENIZER, BASE_MODEL, symbol_map) -model = model.to("cuda" if torch.cuda.is_available() else "cpu") +file_path_axiom_links = 'Datasets/axiom_links.csv' -# endregion Model - -# region Data loader df = read_csv_pgbar(file_path, nb_sentences) +df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences) -symbols_tokenized = symbols_tokenizer.convert_batchs_to_ids(df['sub_tree']) -sents_tokenized, sents_mask = sents_tokenizer.fit_transform_tensors(df['Sentences'].tolist()) - -dataset = TensorDataset(sents_tokenized, sents_mask, symbols_tokenized) +dataset = TensorDataset(df, df, df_axiom_links) # Calculate the number of samples to include in each set. train_size = int(0.9 * len(dataset)) @@ -137,46 +54,34 @@ val_size = len(dataset) - train_size # Divide the dataset by randomly selecting samples. train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) -print('{:>5,} training samples'.format(train_size)) -print('{:>5,} validation samples'.format(val_size)) - training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True) validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=True) # endregion Data loader + +# region Models + +supertagger_path = "" +supertagger = pickle.load(supertagger_path) +linker = Linker() + +# endregion Models + # region Fit tunning # Optimizer -optimizer_encoder = AdamW(model.encoder.parameters(), - weight_decay=1e-5, - lr=5e-5) -optimizer_decoder = AdamW(model.decoder.parameters(), - weight_decay=1e-5, - lr=learning_rate) - -# Total number of training steps is [number of batches] x [number of epochs]. -# (Note that this is not the same as the number of training samples). -total_steps = len(training_dataloader) * epochs +optimizer_linker = AdamW(linker.parameters(), + weight_decay=1e-5, + lr=learning_rate) # Create the learning rate scheduler. -scheduler_encoder = get_cosine_schedule_with_warmup(optimizer_encoder, - num_warmup_steps=0, - num_training_steps=5) -scheduler_decoder = get_cosine_schedule_with_warmup(optimizer_decoder, - num_warmup_steps=0, - num_training_steps=total_steps) +scheduler_linker = get_cosine_schedule_with_warmup(optimizer_linker, + num_warmup_steps=0, + num_training_steps=100) # Loss -if loss_scaled_by_freq: - weights = torch.as_tensor( - [6.9952, 1.0763, 1.0317, 43.274, 16.5276, 11.8821, 28.2416, 2.7548, 1.0728, 3.1847, 8.4521, 6.77, 11.1887, - 6.6692, 23.1277, 11.8821, 4.4338, 1.2303, 5.0238, 8.4376, 1.0656, 4.6886, 1.028, 4.273, 4.273, 0], - device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) - cross_entropy_loss = NormCrossEntropy(symbols_tokenizer.pad_token_id, symbols_tokenizer.sep_token_id, - weights=weights) -else: - cross_entropy_loss = NormCrossEntropy(symbols_tokenizer.pad_token_id, symbols_tokenizer.sep_token_id) +cross_entropy_loss = SinkhornLoss() np.random.seed(seed_val) torch.manual_seed(seed_val) @@ -192,10 +97,6 @@ total_t0 = time.time() validate = True -if use_checkpoint_LOAD: - model, optimizer_decoder, last_epoch, loss = checkpoint_load(model, optimizer_decoder, model_to_load_path) - epochs = epochs - last_epoch - def run_epochs(epochs): # For each epoch... @@ -216,60 +117,38 @@ def run_epochs(epochs): # Reset the total loss for this epoch. total_train_loss = 0 - model.train() + linker.train() # For each batch of training data... for step, batch in enumerate(training_dataloader): + # Unpack this training batch from our dataloader. + batch_categories = batch[0].to("cuda" if torch.cuda.is_available() else "cpu") + batch_sentences = batch[1].to("cuda" if torch.cuda.is_available() else "cpu") + batch_axiom_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu") + + optimizer_linker.zero_grad() + + # Find the prediction of categories to feed the linker and the sentences embedding + category_logits_pred, sents_embedding, sents_mask = supertagger(batch_categories, batch_sentences) + + # Predict the categories from prediction with argmax and softmax + category_batch = torch.argmax(torch.nn.functional.softmax(category_logits_pred, dim=2), dim=2) - # if epoch_i == 0 and step == 0: - # writer.add_graph(model, input_to_model=batch[0], verbose=False) - - # Progress update every 40 batches. - if step % 40 == 0 and not step == 0: - # Calculate elapsed time in minutes. - elapsed = format_time(time.time() - t0) - # Report progress. - print(' Batch {:>5,} of {:>5,}. Elapsed: {:}.'.format(step, len(training_dataloader), elapsed)) - - # Unpack this training batch from our dataloader. - b_sents_tokenized = batch[0].to("cuda" if torch.cuda.is_available() else "cpu") - b_sents_mask = batch[1].to("cuda" if torch.cuda.is_available() else "cpu") - b_symbols_tokenized = batch[2].to("cuda" if torch.cuda.is_available() else "cpu") - - optimizer_encoder.zero_grad() - optimizer_decoder.zero_grad() - - logits_predictions = model(b_sents_tokenized, b_sents_mask, b_symbols_tokenized) - - predict_trad = [{v: k for k, v in symbol_map.items()}[int(i)] for i in - torch.argmax(F.softmax(logits_predictions, dim=2), dim=2)[0]] - true_trad = [{v: k for k, v in symbol_map.items()}[int(i)] for i in b_symbols_tokenized[0]] - l = len([i for i in true_trad if i != '[PAD]']) - if step % 40 == 0 and not step == 0: - writer.add_text("Sample", "\ntrain true (" + str(l) + ") : " + str( - [token for token in true_trad if token != '[PAD]']) + "\ntrain predict (" + str( - len([i for i in predict_trad if i != '[PAD]'])) + ") : " + str( - [token for token in predict_trad[:l] if token != '[PAD]'])) - - loss = cross_entropy_loss(logits_predictions, b_symbols_tokenized) + # Run the kinker on the categories predictions + logits_predictions = linker(category_batch, sents_embedding, sents_mask) + + linker_loss = cross_entropy_loss(logits_predictions, batch_axiom_links) # Perform a backward pass to calculate the gradients. - total_train_loss += float(loss) - loss.backward() + total_train_loss += float(linker_loss) + linker_loss.backward() # This is to help prevent the "exploding gradients" problem. # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2) # Update parameters and take a step using the computed gradient. - optimizer_encoder.step() - optimizer_decoder.step() - - scheduler_encoder.step() - scheduler_decoder.step() - - # checkpoint + optimizer_linker.step() - if use_checkpoint_SAVE: - checkpoint_save(model, optimizer_decoder, epoch_i, checkpoint_dir, loss) + scheduler_linker.step() avg_train_loss = total_train_loss / len(training_dataloader) @@ -277,27 +156,18 @@ def run_epochs(epochs): training_time = format_time(time.time() - t0) if validate: - model.eval() + linker.eval() with torch.no_grad(): print("Start eval") - accuracy_sents, accuracy_symbol, v_loss = model.eval_epoch(validation_dataloader, cross_entropy_loss) + accuracy_sents, accuracy_atom, v_loss = linker.eval_epoch(validation_dataloader, cross_entropy_loss) print("") print(" Average accuracy sents on epoch: {0:.2f}".format(accuracy_sents)) - print(" Average accuracy symbol on epoch: {0:.2f}".format(accuracy_symbol)) - writer.add_scalar('Accuracy/sents', accuracy_sents, epoch_i + 1) - writer.add_scalar('Accuracy/symbol', accuracy_symbol, epoch_i + 1) + print(" Average accuracy atom on epoch: {0:.2f}".format(accuracy_atom)) print("") print(" Average training loss: {0:.2f}".format(avg_train_loss)) print(" Training epcoh took: {:}".format(training_time)) - # writer.add_scalar('Loss/train', total_train_loss, epoch_i+1) - - writer.add_scalars('Training vs. Validation Loss', - {'Training': avg_train_loss, 'Validation': v_loss}, - epoch_i + 1) - writer.flush() - run_epochs(epochs) # endregion Train -- GitLab