From 85a4adabea03e42e68e4eb57c1bac1c298653460 Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.irit.fr> Date: Thu, 12 May 2022 17:05:49 +0200 Subject: [PATCH] adding mha --- Configuration/config.ini | 2 +- SuperTagger/Linker/Linker.py | 44 ++++----- SuperTagger/Linker/MHA.py | 92 ++++++++++++++++++ .../Linker/__pycache__/Linker.cpython-38.pyc | Bin 5231 -> 5632 bytes .../Linker/__pycache__/MHA.cpython-38.pyc | Bin 0 -> 4389 bytes .../Linker/__pycache__/utils.cpython-38.pyc | Bin 7210 -> 8757 bytes SuperTagger/Linker/utils.py | 25 ++++- SuperTagger/__pycache__/eval.cpython-38.pyc | Bin 1770 -> 1899 bytes 8 files changed, 135 insertions(+), 28 deletions(-) create mode 100644 SuperTagger/Linker/MHA.py create mode 100644 SuperTagger/Linker/__pycache__/MHA.cpython-38.pyc diff --git a/Configuration/config.ini b/Configuration/config.ini index dafdae6..8e6c08c 100644 --- a/Configuration/config.ini +++ b/Configuration/config.ini @@ -18,7 +18,7 @@ dropout=0.1 teacher_forcing=0.05 [MODEL_LINKER] -nhead=8 +nhead=1 dim_feedforward=246 dim_embedding_atoms=8 dim_polarity_transfo=128 diff --git a/SuperTagger/Linker/Linker.py b/SuperTagger/Linker/Linker.py index ef03e0e..93028fd 100644 --- a/SuperTagger/Linker/Linker.py +++ b/SuperTagger/Linker/Linker.py @@ -1,36 +1,21 @@ from itertools import chain import torch -from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU +from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU, MultiheadAttention from torch.nn import Module import torch.nn.functional as F from Configuration import Configuration from SuperTagger.Linker.AtomEmbedding import AtomEmbedding from SuperTagger.Linker.AtomTokenizer import AtomTokenizer +from SuperTagger.Linker.MHA import AttentionDecoderLayer from SuperTagger.Linker.atom_map import atom_map from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn -from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch +from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch, FFN from SuperTagger.eval import mesure_accuracy from SuperTagger.utils import pad_sequence -class FFN(Module): - "Implements FFN equation." - - def __init__(self, d_model, d_ff, dropout=0.1): - super(FFN, self).__init__() - self.ffn = Sequential( - Linear(d_model, d_ff, bias=False), - GELU(), - Dropout(dropout), - Linear(d_ff, d_model, bias=False) - ) - - def forward(self, x): - return self.ffn(x) - - class Linker(Module): def __init__(self): super(Linker, self).__init__() @@ -39,6 +24,8 @@ class Linker(Module): self.dim_polarity_transfo = int(Configuration.modelLinkerConfig['dim_polarity_transfo']) self.dim_embedding_atoms = int(Configuration.modelLinkerConfig['dim_embedding_atoms']) self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters']) + self.nhead = int(Configuration.modelLinkerConfig['nhead']) + self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence']) self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence']) self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type']) self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size']) @@ -50,7 +37,7 @@ class Linker(Module): self.atom_embedding = AtomEmbedding(self.dim_embedding_atoms, self.atom_vocab_size, self.padding_id) # to do : definit un encoding - # self.linker_encoder = + self.linker_encoder = AttentionDecoderLayer() self.pos_transformation = Sequential( FFN(self.dim_embedding_atoms, self.dim_polarity_transfo, 0.1), @@ -61,23 +48,32 @@ class Linker(Module): LayerNorm(self.dim_embedding_atoms, eps=1e-12) ) - def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding): + def make_decoder_mask(self, atoms_token) : + decoder_attn_mask = torch.ones_like(atoms_token, dtype=torch.float64) + decoder_attn_mask[atoms_token.eq(self.padding_id)] = 0.0 + return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_token.shape[1], 1).repeat(self.nhead, 1, 1) + + def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding, sents_mask=None): r''' Parameters : atoms_batch_tokenized : (batch_size, max_atoms_in_one_sentence) flattened categories atoms_polarity_batch : (batch_size, max_atoms_in_one_sentence) flattened categories polarities - sents_embedding : output of BERT for context + sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context + sents_mask Returns : link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat) ''' # atoms embedding atoms_embedding = self.atom_embedding(atoms_batch_tokenized) + print(atoms_embedding.shape) # MHA ou LSTM avec sortie de BERT - # decoder_mask = self.make_decoder_mask(atoms_batch) - # atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask, decoder_mask) - atoms_encoding = atoms_embedding + sents_embedding = torch.randn(32, self.max_len_sentence, self.dim_encoder) + batch_size, len_sentence, sents_embedding_dim = sents_embedding.shape + sents_mask = torch.randn(batch_size * self.nhead, self.max_atoms_in_sentence, self.max_len_sentence) + atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask, self.make_decoder_mask(atoms_batch_tokenized)) + #atoms_encoding = atoms_embedding link_weights = [] for atom_type in list(self.atom_map.keys())[:-1]: diff --git a/SuperTagger/Linker/MHA.py b/SuperTagger/Linker/MHA.py new file mode 100644 index 0000000..d85d5e0 --- /dev/null +++ b/SuperTagger/Linker/MHA.py @@ -0,0 +1,92 @@ +import copy +import torch +import torch.nn.functional as F +import torch.optim as optim +from Configuration import Configuration +from torch import Tensor, LongTensor +from torch.nn import (GELU, LSTM, Dropout, LayerNorm, Linear, Module, MultiheadAttention, + ModuleList, Sequential) + +from SuperTagger.Linker.utils import FFN + + +class AttentionDecoderLayer(Module): + r"""TransformerDecoderLayer is made up of self-attn, multi-head-attn and feedforward network. + This standard decoder layer is based on the paper "Attention Is All You Need". + Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez, + Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in + Neural Information Processing Systems, pages 6000-6010. Users may modify or implement + in a different way during application. + + Args: + dim_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + activation: the activation function of the intermediate layer, can be a string + ("relu" or "gelu") or a unary callable. Default: relu + layer_norm_eps: the eps value in layer normalization components (default=1e-5). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False``. + norm_first: if ``True``, layer norm is done prior to self attention, multihead + attention and feedforward operations, respectivaly. Otherwise it's done after. + Default: ``False`` (after). + """ + __constants__ = ['batch_first', 'norm_first'] + + def __init__(self) -> None: + super(AttentionDecoderLayer, self).__init__() + + # init params + dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder']) + dim_decoder = int(Configuration.modelDecoderConfig['dim_decoder']) + max_len_sentence = int(Configuration.datasetConfig['max_len_sentence']) + atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size']) + nhead = int(Configuration.modelLinkerConfig['nhead']) + dropout = float(Configuration.modelLinkerConfig['dropout']) + dim_feedforward = int(Configuration.modelLinkerConfig['dim_feedforward']) + layer_norm_eps = float(Configuration.modelLinkerConfig['layer_norm_eps']) + + # layers + self.dropout = Dropout(dropout) + self.self_attn = MultiheadAttention(dim_decoder, nhead, dropout=dropout, batch_first=True, + kdim=dim_decoder, vdim=dim_decoder) + self.norm1 = LayerNorm(dim_decoder, eps=layer_norm_eps) + self.multihead_attn = MultiheadAttention(dim_decoder, nhead, dropout=dropout, + kdim=dim_encoder, vdim=dim_encoder, + batch_first=True) + self.norm2 = LayerNorm(dim_decoder, eps=layer_norm_eps) + self.ffn = FFN(d_model=dim_decoder, d_ff=dim_feedforward, dropout=dropout) + self.norm3 = LayerNorm(dim_decoder, eps=layer_norm_eps) + + def forward(self, atoms_embedding: Tensor, sents_embedding: Tensor, encoder_mask: Tensor, + decoder_mask: Tensor) -> Tensor: + r"""Pass the inputs through the decoder layer. + + Args: + atoms_embedding: the sequence to the decoder layer (required). + sents_embedding: the sequence from the last layer of the encoder (required) + encoder_mask + decoder_mask + """ + x = atoms_embedding + x = self.norm1(x + self._mask_mha_block(atoms_embedding, decoder_mask)) + x = self.norm2(x + self._mha_block(x, sents_embedding, encoder_mask)) + x = self.norm3(x + self._ff_block(x)) + + return x + + # self-attention block + def _mask_mha_block(self, x: Tensor, decoder_mask: Tensor) -> Tensor: + x = self.self_attn(x, x, x, attn_mask=decoder_mask)[0] + return x + + # multihead attention block + def _mha_block(self, x: Tensor, sents_embs: Tensor, encoder_mask: Tensor) -> Tensor: + x = self.multihead_attn(x, sents_embs, sents_embs, attn_mask=encoder_mask)[0] + return x + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.ffn.forward(x) + return x diff --git a/SuperTagger/Linker/__pycache__/Linker.cpython-38.pyc b/SuperTagger/Linker/__pycache__/Linker.cpython-38.pyc index e732b7dd6a2e215b497d136ba944e5e832aeacd6..facf9eafa213710664a3d7f18c9150693e6c0a2c 100644 GIT binary patch delta 3067 zcmaE_(V)W@%FD~ez`($;@?T9-xYR^G8OB=^wGHcYIiolk8B!QhSaO(ixuUqhY^EIU zTplI{MuuG8C_b<VYYu-de-uBM&6XpOD;OmRX0zuA<qAg$gV`K8BDtbbqF^>>j##dE zlsK5pl_QZW86^c~bLU9s%0$WJ%0|iN%0<aBI)JR^$&t@hh*AJ6;LTCYRf<w#WN>Fl z;Y;CfVMyUm<!WY*QchtCX3!LP3G#v_<1N<YjKs`5KTWpD;*5GCw}gC4b4oHZQWH}g zOG;AnN;32FCc84)ifc075`~Mnq$cO5q!#%kR!(kURApqIyo^ze^%k?6o8RO!j3&O! zAfp)=7>ex}7#LC+q8L*cqL|tl(il@1Q<z#fqL@>dQ&?IUqF7RRQrJ@1TNtBQQ#evM zTNt9)Qn*sMTNt9)+Zk9GqBw#XG<hfAW>I3)nf#9>te!oERg$5Zv4$a@yM(!fC5tD8 z4a(<*^4Xz$J}939%IAmjIiY-k62TNMut`EC!Zi%>!Zi$8!YSO63^fe#A|;|Au`Ir3 zrW%HLu@dnbhAhq+hAe@Y8ip*9X671(c!?6pW{^^;66q2ds5D0kS2J^oEQpuFlfv7} z1lBG`P=$O7LokCT-(+?+FFw{hP{7|3$W5$>&q<ve%O*KFp3R^>eb!76*lAzI8Xe&1 z5}PjgO1E+5^g>%r##_v(1;v^Qw^)ly3sQ@U#26SDZgIrNXXa&=#K+%a&de(*0;Sbk zLb>@VsX4BB;COS+&r8cpFA@hS76eQ9Waee37D2@&7#J9eBte7}h>(r}spCyaEJ-X* zErF_$0m;aM2sscT4<gt=>ey3?@(c1yON#hGJg$Pol$6Z8^!Ut_A|8+cUt&pqZhT38 zc4}T`RccWY_vAo!5fNUnR1Vk?@u?6;PEKLBWV^*uoSKt1S%NuEUx|T%K?Ibj3_z)i zk&j7$5gl?du`%*6@-d1qb1+T5&8|?-1X2pcphU^az`)=PGQ*95fuV#Ui?N2WgdvM5 zg|U~ZmMMj)hN*^e0doyQ7Ry3L8-^OD1*|E|3mF%%r7+hpX0fL*OERP|FJxZ8R>P3R zk-`jO2Qz50_<=G}N=aowYC6;^hL?;C3=Eo_^|x3{@{5u)ZgJ-4r54BMWM-${Vo%G- zPb@Jrxy6)PSfl~+H)m;Habam{YE|kjwxZO6)Wnioti>6L1*t{SAXAu&^cWZzqPRf; zR}2ZcD8Up+28>TEDanh^O)SnXG6cl{m@onvDwvy?omw9cQwLIF4DvWAju^lJS0w?7 zbUlzt6or1tpm2iu1r#?N3=9mQ@GI6}WMJrES-{Z2xR9}ywS=*THH)d4k&&T<c>zlb zV+TtXYX?gf+d}4A<{IWKrurD>TDDsDS`LtO4O<6G7P|;T2TK+QNZ~@}T9yu$1)Mdk zHOw_278l6Sg)E&+=?t};CEOh>Sv(7P7cw+6rZA;2w=zjGxG=;n>SQS4YiDR@Y-eg` zZf8kjO<_r4ZQ&^4FA?ZqZf0y|bOEUc*(8|4E>+Kv!qLlI%Y~+e6OR^%4(<+?1wu7k z3mI#<I#?D6*KjXntmQ6YULaD#Rl~hNG=)8dYa!D@rdpmF&IO<Xf+vMLg{PNkAtNJ0 zVO0t90*M;Vg^a-r^_skiM;I9xxD=ouAh9ShH?<_Ss93=YE(nR(q{Ncs3`q25R;8vW zSSe^ggo-n(QgswSB?3fYW?p=LUTS=CYF<fdUUI6YLRwBDsH9F!QAkcKNlni$%1kYW z7@nx0pa9WSke`!Slvz?44>prR!xf;WWu_KSPT-VqMK>-dH4pAQ9fg$4Tv(RYRLCza zDJU&b$WK#nat#VmNXsu$NY2kINv$ZET)`<nc?PFk7|6l#W%<d8N%0V`gPp7kVPgbC za$*UVz)4Ol(S!tnf`Y<J1qKF&$qzWsi-NNUAE@SLzQv?xaEmeX7GuTaV_Z(spezQ- zLg3uPWs{RxT#}rhTVNM3S()3Jwa64?MlQFMGpJl#$ynq7GKICs1(fStLHQG0uoXjm znwghg1S(;`Id>&vkvm8g$lxLm1_p+J&8N5(8SCMWDY6GigNh}D3&0A)Kp7TPdb2UG zF$yt?G4lK^@&u{%0ug>7!XHEgfC#wpU|paL%f`a=M}?W^Z;=TD1H(&D_5~GGnV{6C zDOr@E1u7<33yMHhgf++nkQZ;U7A5AT<Q0L7cToKRE`y4|1-&*%kqt-}3n+9T1w2c3 zYGrYeC`c5fH9hqfb53er5x7#&0cp^kypUI|-Wybs@q)vn2ozl4;ECb~CF0`v)ZC<0 za4iwV1rY$1R#E(*gn%S@iwm3(z=<M?2RV&J2}88O&52LR%#Gp!RVVS~shQ~+C6l@N z)b+pt4h}F-Slwbz%P%TVEK0Eeg)&HlfrpEci&=<~hf#u2jCpei9}lA`DCmj;LH-B< z`GD~jb4q5eCf6;dl6-KIC`tzToF}t5KCvt@Gbb@AC$%VP@>G6B4p3F4$yKB`c^`kL zG}vk|0kY^8S87>e4kRn4Y)%qLVpIj$qsdg{1G1?YWD^%StZs3EYTcat#FW&cqTI<Z z1r4=wKuQWhL=lJp8ww_nOinGxPtGWw>>{Mc!sw^TIa!27rM?JMe-?qFuP6zm5oAkI z8OZg#@$tzyiN(dqsX00E@sKnKHc=m>4rDC2Ss@GxCs1nSVC0eKVCB-`;0KFqs@&oP zC90DA{G8$<P>l-?nOhv-s!uO3uLzX%CVvo;WGdpE%qFbKl?5^vWKz`R1YrqoS6K8H zaZipBR$~LT6QaZ>CkhL5`Fc1Of$E=I(je~S*}|M`pk_rBgu7RmlT83*R{Z2M!gAb# zAf^zA5T5)&Sj-3%1GfZ&LG7ZD#PsykBE8a*%p6E!1-l%{t6={cO|})0l2BmeVH9HG LVMIU?5StkQ+O`S= delta 2603 zcmZqBd9T42%FD~ez`(#@wWd00wfICn8OCc9wGHbT8B!QhSaO(iIioniY^EHpTy7=? zMuuFTC|<A#YYty7Ulbph&6dNTD-b0BX0zuA<_bj#f!Q26!nq<*B49RWj%cn}lsK5p zl_QZW86}x36(t4cbLU9s%0$V4**rP2xi)f9a*Pb_3@N-Rd@T$qe5stx%u(|03@Q96 z0xb+F0${#E3R5tHrr=AE^EDZ7u_k9EX6E^6vQCy})SDd3Xghf`qcS7Y<m-%TlbM)I z>RCYQ7#J9ewHO!}QW>HcQy8L{QW&F{Q`p-X(il^iQ&?I!qgYZ{Q`lM<qFCD*SQw(% zf*CY9ZZW&L`Bh1H<`(3n=BDPA6f3y7`6;9pmL`^D=I7}>pEVN%cG@SyYyepVVskPu zFffB`v|(Uis4rnCVXR?jW?I0skb#k*gt>&JhOvezg)xPxm$`(snW2QOhPj!sgr$b5 zh7ls#%u>S;&z`~%%%I8a=cdVcizO*Du~?Ji7He^7L26NvFara_Esps3%)HE!`1m4_ zRYj}}3=Bo=Ac74f%bb>$r^#}Qr8re8C+!w{N_=j9N@`9NOG<nih+C9jkY8Fdc^7M3 zy*L8{gDA+cd?4pD@-Z?o@-VV6vN7`fsp4^S^V5UblgtEC0L36S$X;iVPEbOsVOYRW z!?=(!m_d`#Pm{4of`NfSlc`7&<QT?^A~}%Dz=S+Vggq_4s64SKMUFw9fuR^=7z1M! zH^?@qgeGH=2*?EX$$NxVvTm`*$EV~c$Hy0eTwA0FGJ-cgJ~=0`xHvgACnr80$pRUW zRUliy;UoldASg&U7+IM3IKg7cP*1=qkWWCIVw1@ZObU|Fl#s%n!qLJI#Rg6h?7<A0 zoRgcGlo(kiFJlU;=T2djWN2orVTk7eMQIjy3LBKq3+1yz`Fth(DI8#x0wsbq4Do_B z3|WFHoRSPR4DmuG!XU9Mo@S;RhIo+@(He#<&Kia+z8Z!sp=Rb9hIp|O@e&EBC`SrM zGjoX~h-Z?*mBQW21lA^nT?vv3X>eNNnS7nuYqAxK9V<9&CpWT47|0Zffr3jNM1aEf z7IS7^Nf9R~fQ7(G#x*ZFKP9!uIX^EgGrdS-@&y+0dQFgOP;M&H1}WrCNi0b$PA!3| z(E-Wof(Shj0SfyfX%LH-fq@~4s~|BYB{MHQJ~O3=8>E{ru_Qk?z9c_8H7~O&wWtV` zm~KG?Q*)D2!HOYC9F**fSQr>4cd}}$fm2oyIIiKz3Y<EWK@kpe90MaC(_~#XaYok3 z?raM1VgMA+Apdc|GU+u228Irn1q>aG3mI!!YFIj%(iv)5OBg#?vX~YyFJx#2=gL+l zNd^~&*hQTTB`oa>?TqbA?ab{gX{_MlgQJACgsp?QnX#GC1*D#}gC&bSg-wbfg}s-# zmJOt%9!U#F3QG%qE$kgE3pi@n7Bbedb+9bptYKfsSj%3*wSc>Zt%iL8PYPQK=R&51 zOtl;}tP6Nc_-Z&Hk-d<Sk)cqdglhqR4J$Y+@>IR!QUHU=A2`G_k`qf()ANfe<C79g zk~0*n6d-hdnnH1ARjPspm>Umb>nIeb7M7;wC8x&cq~@iUWN0ebD&&`z6qJ@I<fkd5 zq=Hig)NBQXVo>oI509_O@?7$hZMozPzy|6<3<Sq>S$=Y264VAgJv~jh34SmC|NsAg z@&}%MN<}O+OcR+3S%MiLxtfQ8fng={EhasKTa1}S3JeSklc#YzNrQ?OhyqY%zr|&f zlUZDnoS$1@7c%)Xw=?@K){^|9<c!JIJW|e}^sth#$O>c<YmqG|HQIqvBRHKE$ESi4 zd1hXEkv&KTWbjJHA_ou~WOR`u0|P_I<{3PSjP>BO269l5B}f`nz$4rMRu~0Jf1pIn z#=yoX#3;td^S8(eq!v_E7I}bJo*=>tM8J&)>jEW7HWsEoD$G28UotW<FuVj6OrR<* zvm~{sSW}?L+!T~$SU^540%fOLEZM1*#YLb(?iOoNVqSXcE#{omydrS+GXvRR4kBD8 zFW^%%cLf!woREwI@;%tMQT%9e7sU^XsA8nZwBR>U0{a{6Zm^RT85kHq;RCKy8MwJv zS(rA@;^$#B^#fVw4I=zOwlUsfPRY#G<hsRFl3(Nk5(atn7Efkzd}3K*W=>*KPHK_j z<l6#@9D*SATt$+TSp++!!PbEZkaf4XQp*x^Ac;0(^9sQvMpckKnoLD*AS1J&Izd6h z1uBbj@)J{1i;7Yv>kAudfeQPg43N4^5CJw6Ody$@T9BWdQ9ij#Sg)R=NC~72REiX- zfLNeX5nSUGf$S;D1&OJF2(Tew0$d7#j4J|_qagjops3_v<dNmj;NS;~YAW911jS-W zetu4I5vaxkhsP}ra3a;q%Y&40lM_TFCfACnu_c16;GDctM2(Sa@&yq!M()WXqKb^X zlg&gG*+5lwQN-j}Q8`(DkU9a7*9C$>jf0TH^z_ssJy3uZ2~M6XDlQL79N=mX><&n3 e1v?u|$WJ~eDkUMq$ipba#KQ=K0w5N%fB*oxT0&O< diff --git a/SuperTagger/Linker/__pycache__/MHA.cpython-38.pyc b/SuperTagger/Linker/__pycache__/MHA.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..679c41a8ef96c82cce084ed1d859b6f0933c12f5 GIT binary patch literal 4389 zcmWIL<>g{vU|`t!rzUBUAOpi=5C<8vFfcGUFfcF_Ph((UNMT4}%wdRv(2P-xU_MhW zGZO<NLoQ1cD<edPEtfrt9n5CR;mGBP;sCRma~N|uqd0T9qPW0(mK^R}o+zGN-YDK& zz9_z2{wV%jfhd7o!6?C8p(vqT;V9u;kth+cTGkxVT(KxIMh16=6t)!h7KRk|RKaHE zDDf1=U<OT&mmnAWX)@m8b<WRA%S<mVN-W9D&(mbO#TJs9SDasTi_0fJFCD_x<h;e= z?&=eIi^V56#P=4vOHqD7erd@qPM^fe)FQw9qTE|-KACx`iAA^AeDhOEb5d^!`IhFC zWM-r$rZ|?Aq~?LlzQqNR^T{kOxy2QnT38AaPt4I|yv6M1=9dg|5i(|ja*C@M7#LC+ zKw%ffl)@OroWhjCoX!-*lERY0I)|~HA&n92L!KzsR5ln);h4jk!kNO=!V<*}5li8o z!<@pC!rQ_U#Q_mXWlQ0k!;-?EBGAGT#R(Ei5uC%6B9tQB!V<;R&cMPD#U0F`DRN5` z?gy9D<ouM>BCx*`i&z;LLW&aeiqrCoa#M?t<rFfD6><|(QWZ)I6!Oy)ic@pabQ4QT z@^ln(LE)kc3K+0}LSkNuLRxBSN?LwVd16tDLSAY~d45s09+!fGf<j0}X0bwXNn&0K zNGb(lxk3&o)Il~RB^IZqDCFlUlw_nT6eJd;7AYvfovPqjtl*fFqY#;2s^FKJnxX_T z(y=%rvp7Q`EU~yeF)vd`!7o2CS0Okfu_`sSNJqghGdoiuAh9Snu}DY3D=|AiNg=ej zDm%X@HM2xV!6zp(KTp9cKQFadN5L^OB{5IIPr*GuH?>L!>|CGH?8M?K1@FYnVvu1V zuX*O=WF{&E<mcoUmu6>V=IJRI85o-DDIlDgS*(zllcP|XU#gIonwp}g;F+i3m{OLQ zmz-LxkeLTI*e|uTC^1LDGY=Hf;Do9WP?VpXT3no&m#z?8SzMBuTdbo{keHrYtYBtf zV4!PeU}&JH5L%pC1d51Ch1~p<%(O~{{33<S+=86c+|;}hu*sQu3W*9SnQ3XMMX7lu z3gwBF3Mr*UAX5_y3UV@&!ItT9fz>(|r59U)S)jN{$;^$<%}+_qu>uEwUTJPpY7r>1 zQ!5HmlS@)T(VAFNS_Jk5SV?AHL1~GCMp0^EX=YJsiY6pNKxXBE(gV6VATdw~fc1it z4M-#rmYqOeO3f+8Y97dKNQnw!a!O`yYF;rY2J_RvW+JJ8B|C_r3K}V?X^EvdCALNe zCKgEUOo3#4sCf{MLRn%?X(~dCfu13fj>P1W%rbBwLNp=x3TdTz$)Iup6lS0VoS9dW zT9libl9^bN3Q6NS3dxCi3Q4I7i3-Ifph$rSJ}A&Ml!{VwN|iuys+105X@Z!E3Z;38 zMU@K4i8(omNja%{3N8@8SSf&%!%YBZqWHZ0qTKk@f?|l{QVWWqzRApkxIh7<ATcMi z3hV`i<ow)%{Jhk>l468^4O4YZkvyD~SdyF(pO#rvTw<l*nWm7C5K>f{nvkFa3UWwJ zg`@*e?#M4K0VRnNaEV*2kXV$eP*9X#mYI^80(U6LtBJ)58eoHU6pB*|brjN4!6{1< z$#V$_ZizX?sR;>4sS4yds7o@DT>|zy*lVCtBqcu&WNl`CkwQs6xTsQqCo*Vx1xk3R z4uYpbkT;MMQhq@yq$(-aQ7B3+24#iJvc#N9Jq7=gjMSp?%;Hpq%o6ot1+cY=X(g#e z2p5C=3Xd+hj}$b(%AkRKiyIsgkRZOr1&SaD8&sRI6{VJx7Ud-~LhB9?#mT_Hz`?-4 zzznLBKQS;clrWSq)-W_PEnr&6z{pU-T*8vYn!+f_(9BrN2<EY+FoAhYU><u4GnmH= z<*|TyEMOi-3M-h$TEbbvlEsn22IjFr<=DYIcBmW&n8#7VRm0ZISi=y{UBXktRKr@s zmcm)XRKt|Q)yr1H5YJn}SHo1pk;2`}RKpO@4^vaaSi_jYnZg5BBY>htutcbasfN9V zErqw2xrQNL7)4ejg&~+hlh5xKH>glZ%>&n85Dv6Hxh0UBSP`F-nipT3npcvVmz;Ww zKd~e~H@+-CIWZ}|II}AC7Av@fy2TDH0B-Sv3`2_cTYTU&1<y|*nk+>U3=9lK0t^fc zw^*`MGIMXSl%-_mf-3vef?}AoG)RpMhyW?nWWL3o0x49ZSW@EC(!geDir!)^E-gqc zD&hwj#t|Q%nU`4-AAgHEGq0qG2_z;2E{I(ro`W=KZb8JLl`>R}Hzlzou{gB^$`u4_ z^U2K1PA!58v8LtZCzcekgPa8R05`~dPEZ9N530b6Kq=`KD=5_&-r|Fo%3wK&uu%~o zNDFgXS{_8e7}R_M8Bin)GF1dbh=K?)5FrjCz&1&OSW*lO3@aILamL3d=Oh*v$H%W^ z_~ogek)NBYpPZ6fkY8F<qVH5%l$uzopOTtdkY1FSnG+x6=NDg|S&|W-1NMr3FetP_ z64TRDi}WGB)A#jo)GMegQea?UPzB`!Jy6bI<YVSzM1jm)j9iRN3{3wynD`huKrBY4 ze;mwwi~@{Ra_~AE8mQn_kRB{LG+80ZrZ_$|Hz_qGB{MJm7C$IOAq((86JUI9VsZ8@ z9%u>#^NNHR7#NaSLGFTLPy-MY+|D2edoeID)G*ev)-Wt!r~#3THH<YZ3z@_j7D7Zp zS&3l*a}8?^Qw{S%X1F*DST$=6>q17T7;7+tCS#&FBLhP~VsSC3Tvh-V$slG?erb9J zSP1GiNId{;0f5RUN0e3oNCcEwkUR;gONznm+vHSGnF^}iP>jP`X@iVH3okUI(u(qP z!EVb*EG~gK3|4tU!xhP3l=2!X4Gw4&38=GBB{bQJG(c&XA1oZ7n~@lwl#`#FU8DsP z<w6qB0SR!%r=>v@XtEXQf>N;&h%g2bCLqETq>8bk$N<Cv6NVrWc4(W|4ivxOJ{1=W zqZp$EqX<)#EcV0()uG8;WC=10WH2P%fC_hLvH=(LptJ){H=qnv!;r-QDnM%(YnW0P zds#rmdo5!PV=%)?W>Brc3Ch`EzZQWi4o$`)bp{3oO=gH^5H^9G4kkeEFEV3bU<d%Y z9n|AvU=(AlQpN6Su+fN6(_{g={uUQJwu_6v;RMbBV8ftZ0kN_9gb9?!QGHTm4RSJL zktT>@L2|h*D4@W`gIx|Lz|J>kU|<MEaXz9|j+(>a2|$wx>~ctugGyegI~hQ1P!xcJ z98{Epf_wpE4dX&aNI+{c`e`y1X@lHi1#%%1k_*930ux|&Suij#L?PV8!&oJcBUs_d zj^`Ffe0*MFZfbn|Ev|S_4N{r|V)Mku7nUaGKxNqD<5TjJ<Ku7f#>Xe;=YjfFCB^aa z;6kwo<oY6z)4<VM1PYuYkW0XdK}B(LYEDjkJW?746`@677lZ00UXY7HWjhBWizo+^ zkR%(Z^!(2wB*ZGhS_D#|DRGM>IlrLt7Hdg<QF6vDVKAkam!}772`A><V#>?A#f8LY zbi2h3QI}s(l9_vp6-<F^J&2P*N#qs>L`hyAs7uOF1WGcHDhyP77IA_?9hAR{K$Te$ z4~PqL?kx$Vib4-kQRtPHWabout1z&mk^Bpfp<5g_x!_*D9VoXHgQA;*k%Li)k%y6o LQGij0nTHtw*0ILR literal 0 HcmV?d00001 diff --git a/SuperTagger/Linker/__pycache__/utils.cpython-38.pyc b/SuperTagger/Linker/__pycache__/utils.cpython-38.pyc index 0e04fe2ee2a37736df214e8faee8e94eaa55ae64..c4eef1e07886db024496a876b846bd69646a7538 100644 GIT binary patch delta 4190 zcmZ2wvDHO8l$V!_fq{X+@^?*=k`4pIV-N=!voJ6)I503U6d#zVty`bV9>t!^5yg?q z8O52)6~&dy9mUPakiw9{oWq*S6U75&v*hsR@<s81*{nJIxdKrFU^ZKhV6IS<5F>*- zLkfEeM+-v=M=DP<bChrjXOu_^R|<CvLzHL=PYQ1fLzGwwUkZN<LzH-mK#E`sLzF~{ zP>OI%Jwudaib#rR3qzDtidc$x3qzE23S%&Xro>B-oBcG|ZgB;t7M7;wm1HL7+~V{} ztV}KP%P-2k#paWlmzr2~i`}ItzaYP~<Q9v&t54`HA>Y!RlFW?M#1zMp5|FO^JWa-1 zY`*y^r8%jZjJJ3lOY(C=^0QO(GOJRHG$)>su1^Npgp5HU1mYKKF)%QsGDI<^Fhnt> zFh((_u(va$F{Uu5u(WVSv81r3u(dEmv9>d?FhsEhGiY+$Vs>-$tCH}{EyzjDP0cGQ zR&aCkQ%Ef=O)SaG&(nK8YbFTnv`>cF0J5bX#NlLMU|<GCiVXt;LkU9(V+}(y(*mZ2 z42%pV%q1)}j5SOtj44dL%q6VN3?*ze%*~7?EHz9uj1bXgmKuh5_7sL-22Ex^H%-P{ zEJ>M(#hM(qSc^*wQj3as7#J9Cam2@G=4F<|#}_d%FfbH>tSw?=V6b2)Vg*Svr={g- zvfN@RPR&WX#hwzMo1c=JbBiS<J`Kc$1oBG8Tb%Ln$vKI|#qsfzZ!!7T7YQ&hFz|z1 z&Ij@gBOfCZBM&1BBO4>npDG?VH$Od?t;rx;L25u4#0L4-8Ke`GUTYW@Fw`(EWDI6l z$>^uaSR}~6z@W)gBm{C9V?~h!$i-kn5+uT&mS0q!Sd=1B57G}ZjDfL=8)O?)LX)wG z7i0oE$k%y^xvBB-x47a#ek#oYv3cU-3riDopfc?7@hSPq@$q2K6v=>W;DvfKIW;FI zJ|4+_F_23@jsOR$2q<KW*%%lYI2c)&_&C9W6Q8HmGo~=LGD$MHFvOO0GPE<aGqy9O zF@sYMM+r*_YX@U9V>6=*Lo;J13tW~pg{1{umbrwjow=PQjVXl<Q?`a7g<Xmvg(-!j zm${auhIs*d4a-7CMutL#$!8e3>p2q>7#SG26ciK`0uqZ7b5l!Fi;5Mjzyi6674eBB z`MJgMnR)T~d8zRwl?AB^Rtm}ad1;yHrA6Qr3swnHRgj;PSd>{(sbHm`k(5}HoDpA~ zS(U1zfMQN@YF<fdUUI4?SYtgzU1CLMer|kDW?ptNy5^kJym+`Ch2+GN)b#wK%+z9q z^30M9g_4X^h4@T`@{G*n428s;oc!`)1&}eBdFcvZM}ch#N-ZfZ%7b|eVr5BDX=;2< zJvJ*r8sp3IlM|D$c>oj;$%!SJpkOFT%*-nWJ0deLCAC5!KMl;uOHEHK$t+7na)RH> z|NsC0Pi6!qY$yh?C%<Rh&RQhRz`!thC6l2zBn^R5&MhvRoXp~q<ow(MJHyFl%mM14 zf@URikuE50>46A+1_p+ej70_@Hb_g6Ap--0;pDZ<H^P-cVyYlQ6GZ5M2zbr_t1|&* z3JwMa1~vvZ1|CM4|3yY1MaCe)6hxST2y+ktH``?L9hUW;N}xQ$c#9S6<06pxnoPF@ za^hjJAD^3<S6U2l7OI_=AUi=t3dqhPqsdoT70p5JxW$r_pO;<)vR0F+2&C#3KazO` zIi<x#s4lUbtiu+}$ThivO^lBlloi<$3kp*6Qi`l6FJV(*)SrBk%|8WXxh5-GfPsxK zvV-~S7JqtbNjy^WQvtaVL@;o2FtRX8F!3>Q{bOSWlPpY(JWO2wSeV!tnEtY{@-P*7 zPp)B~SI-VA*MCg_CpC;r1WmD(U|A?9GqpIrBp;Hu6s#1AQj<%Iib3fG6y;!rs43Pb zv$zD5wm=E9SRpg7I5Q;`oY27LN#qw*DrDv<6cnYVWmYKU7p0^Yfeo3gz~OJ92J)&p zD2!NhL3wf|Q;`-(2$YUiGT!0<8<AgBiAZ|j^x-vmIY)zYkrYTNC?KKQZZRh(mOzw( zM2p~QBM203pfZk)fr){G!B3O32vlSjxlT^!lw%B-+{S6o7(DqHr;2+x$SpiDYvYqk zq2_^1fS4D{z`)=HG7nUaGH{47axwBTvN3WnvM~uT@-cES3NRMMPB!F<tY67kwIA$N zxN+djjb4UeFMt%%auPxHe`*SHv4SliQDXy9xcTBP-H10_lc@;g;UbU8Gq}{v;z0o( z3?jhc3nsvU45|eqK!F0TL)lmuLCF&o6g-SY&Xa4{WhV=9FJ<KnX3*rCJVDq)*%xGn zKZpnf5#Z28m<B2;N+v7wn6QO_WJ4#%@&wnzBfBUJBnVawCcq8=X{iG_06CVqp|K27 z4{}{mB!~q{Zbbp0cxFv6$}cS_iUNs7gNPWAs31JLV7Z_uj)6ge0g|x5P6b<;$-uyn z1F{n2G6s&xe!P;5nUf8<qWnPZF-@i-kb)u)khv)!0_<2Y0n!f5+u#ZeoGe;EMqvrQ ztjz*^i<lV?Pu?u(#HEo^q>x{vkdmW0`K#bAB7$-L<V!-fj3JW+guSdl?u2*(R|u{E zxd9P^MW9NcD4l^}vc9;BHfo%K8d5o_X(jQX><%#%<SlUaTs!%Lh`$bodak0(^bD*T zHiI;v`z8+_XmEpfV^ddvsqWz9BvE%|!u~xq`HW~F5%!)Z(%!4s)M4?*U68tH)Bq~N z;tx>Qrx;QvF)_dkEs%e~$r+rsz^NCMoQn`iSZeYEAt@G3)}r9aZ^Tu!5<ohVKm;iH zX)+gqvv590pb$iW!?%cmfnoY&D~W8zqRGo7`szP|i!Q8%6P{KUI4M)3gsVrcD{vKY zB$jaM&{hDbfrhOl4MA!FP5#5dHJMjR&KMMA;PkBtt(&yLEt@DVP<0p&_Agi(Qagak zp<k2Zq^#9IJycEhB2chG3IS_SP=k6$MW9AOksFBXK6$^CY(2Qal*Yio05ud`lY%4Z z789&OD+VbCMbs_fw9LGe_=5c6_`KBg_{@~liqzsapo|Nykl0~mFCQZxqY$GMBNGGD z|DrNbvq+{0T=anw%q`Zu(%gc|TTFQcMb@Co_7+D;eo=CUUS1xgZw2bwKss2Ut`(@4 z#Zbf!astQXXlb>2PEbK99SrK4g(Rk@rxxk?Waee37U_YU9-o_704{lN2_mZ~Ey>I& zhWH)S=YS+vZ;+Yb8U|E_K#I&PkQg}JK+c8)E<6^%!EAbq!zLHpA+`e*QN^HE9|t1@ aDlqauAhUn~3kQcBhZu(lhXjW(M<@VHh~|g@ delta 2536 zcmdn$vdTg`l$V!_fq{Wx<NoTTg{lk;k3k${%*4RJ;K0DZP~0+6TQ`Z3A%!7@Ifp%$ zBZ>peX362q<%;5BWJqC+;!a^pVQ*oG;z{91;cQ`u;!WX7;cj7w;!EL4;ca1v;!ojA z;csDx5=aq95o}?I5=;?F5pH3K5=vnVX3!LQ39{C2vJ0c`WDb@zW`-2T$%QP!kxVJf zEgU5*C9EBc&5X^AE)30#oh<DP?TqbAY0N1sSY(+?*xH%fS<;wNSTSX57*g1z7*ZHh z*eBm*k!IwW9LIK!`4(fvWJ7kt$;+Aid2eyq<YX3?B<JTA*ojZBWDij1Wnf@f$y_AJ zz`#%>1R{h%0*pl>AhsBY5NBXu5TE>k{YE$&h|2*YctC^zh!6!4aCMTPzyzsaV_;+8 zVU+n_Bmt5K2^UF&STZ0&7DOQQOqS<d&(F%hz@W)^i#0PZCAFf6XYyT62_>)_P%V%H zS-`@;zyPwKNMf=$m!i280|Ub?mYn>&bg&(oOhxh_5&oR`g8ZDsqRf)Y_=23$;v!UQ z<tFEI1*<ZHRI+5JRu<o4ElSKwPrb#QlbTlqvf&n6VnIP_UP_VT<mX%}jKY(7xcyT= zuF+&g3ofvIMJg~~-QrJAEs0O8$jr};&&kZoE(V2bF-V$$k%N(iQG$t&k?TJj6CV@T zKQ<O7Mjj@ve=JOF3`~F7Sb3O=3@7jBp2x-z%%I8cH(5Z!$#o^uEru$V_?Q&!mCQw) zAX`AG`W9<$Vo7qwN~R)E6oOr}lJOQta$-qpdVW!55y+)Q@Icm>T+Q3yu#(XaEDqIl zi#a*51S|j{;HnKkp$v*`HU=gJ4hBumB2YLLX-~G}lVdcU9Ls0VXfb&%pNhLJDEN3_ zcE%@{LX9&Axe{!gGXn#II><N>kAYE)k&BU!k&Tgqk&Q`!k&lssQGl_?c`^roBwvv= zC<!nXfpiq<P1Y4qH**CUV*w(-z5x>;bBZ(>7#Pezrt*RE6(b8cTtQ~>FcxV}-p?&N zdAYz+a2Rre!q8dS2xN!}hyZCT0=o}k94LvqP2MAD!e$9lVm0}jU~me^IgoI%0SSUt zg9)$;Kw1JoE<g$<ZfG!pl-q#_dk_JNtRhoTK(VG5<(C!|fl8brN05}$<aI*gwk{wh z*krJ?z{dMAFff38UkpmzuwWAc2b>Uc(D{La&L0+ZAYDazAmcni1lUY40WuSsEKBlp zi{q0(Nhb{Cd@KRzzj>ALB4);&lmD@ZO->fqnLJZ`7ZCxNGr3H{meF$Z76~t9!eLiB zSy57Ca*&|AHfq$YWW2?hlbTi%U!0ke3JHFY<BLGyT{ZcHjK2<sdak0(^bD*T8bBJ* zogNGi8o0q7*wlq$s_UD~Chbmy=cZ3CkPalm-Z@0ty9k>)EdE#pQWu6AKt)*mu>lnG zpfV1Wzu~zX<X>>I1*afzDg`B5SW?ZERbsT5JV91P%MGN`14MulmnL%&IE#jW1VCvD zQtE^=FfgP{z9*Z_7(Urcu8;8$C^^-G6JC6AWo}Y_PO*ZOf(BF|vno|bp*Xd$G&L_d zH9jXbFTEr~Q=udyRROF(AwNw4s)>tBK|w(wD7B=tD6d$-3d}<>B{#7GW<h*rUVL$C zUJ1xhO@*ZVeBGSX#5{$P)V$*SB8AMnl+47E%)E4kg8X8I#Jm)Ryi|~2YDH=>*x1Q} zLSmC2$jcdnya`T=n$Xe~lxtQpMsa~EQ%Gh4OGA7PN~$L&8!A}qs)0&0_99SFK(eJG zD2$Xq1Ss^2bU-X!kRrBQqLb_8Me8ATur~t(1Jv+DP<{mm@GT}-VI2X|21>lQgwrzf zQsN8pi{tZB)8jK!Ku+HYid>K^0|%^J<YVMx6k?QOWMW|YUla|pPq+wNEP*2V7HeK< zZb9WOro4h8M#ag|ifZ+scF`^A;L?KBqL9S&^wc6fpUk}M)FM5QYvOYg3%~{4EkR@z zr6rj;#o&q)l3)x$js%zGpu)NcTr~NE#K55eas(vE;9(06yjvVLx%pZtr8%i~j39r2 f+6x?vV93P7$iv9P#KOTL%E8aU&mqXc#}NturLXV3 diff --git a/SuperTagger/Linker/utils.py b/SuperTagger/Linker/utils.py index d951926..abd6814 100644 --- a/SuperTagger/Linker/utils.py +++ b/SuperTagger/Linker/utils.py @@ -2,11 +2,29 @@ import re import regex import numpy as np import torch - +from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU, MultiheadAttention +from torch.nn import Module from SuperTagger.Linker.AtomTokenizer import AtomTokenizer from SuperTagger.Linker.atom_map import atom_map from SuperTagger.utils import pad_sequence + +class FFN(Module): + "Implements FFN equation." + + def __init__(self, d_model, d_ff, dropout=0.1): + super(FFN, self).__init__() + self.ffn = Sequential( + Linear(d_model, d_ff, bias=False), + GELU(), + Dropout(dropout), + Linear(d_ff, d_model, bias=False) + ) + + def forward(self, x): + return self.ffn(x) + + regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)' @@ -29,9 +47,10 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links): for atom_type in list(atom_map.keys())[:-1]: # filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity l_polarity_plus = [[x for i, x in enumerate(atoms_batch[s_idx]) if atoms_polarity[s_idx, i] - and bool(re.search(atom_type+"_", atoms_batch[s_idx][i]))] for s_idx in range(len(atoms_batch))] + and bool(re.search(atom_type + "_", atoms_batch[s_idx][i]))] for s_idx in + range(len(atoms_batch))] l_polarity_minus = [[x for i, x in enumerate(atoms_batch[s_idx]) if not atoms_polarity[s_idx, i] - and bool(re.search(atom_type+"_", atoms_batch[s_idx][i]))] for s_idx in + and bool(re.search(atom_type + "_", atoms_batch[s_idx][i]))] for s_idx in range(len(atoms_batch))] linking_plus_to_minus = pad_sequence( diff --git a/SuperTagger/__pycache__/eval.cpython-38.pyc b/SuperTagger/__pycache__/eval.cpython-38.pyc index fce253cccfad91c0b80dd6cd68cc5e05e3437104..f5cec815f7dbc90d4ab075b8e48682a5c6e119fd 100644 GIT binary patch delta 240 zcmaFG`<jn0l$V!_fq{Xc@@#cd7VAd7!z}f0xfB!>6p|84k~88<ib_-Cb29U?ixsRC zG$3NdnN_Je3W+88x$$NB$%#qv5P{spig=J<aeQW8e12YPd~#xmrb2RlUP)qRUa>++ yMyf()UP@|(LVg;Ulb4#FSdv+m3O1=2Y-eIcW`6EuHdbkWkU>;5eX}L&Mn(X5CRLCC delta 84 zcmaFO_llP<l$V!_fq{X6e|2@zOO}m%hgk%3xD*r=6p|84k~88HD>CzQCqHBnW6{&o o)12(asw$D2SP`FClAl`~Us73+8efo~lUS5lQmMIl7wbkw04H)AasU7T -- GitLab