From a874cd6dc34c2e29c05b7a5a513af62f4e6e1264 Mon Sep 17 00:00:00 2001
From: Caroline DE POURTALES <cdepourt@montana.irit.fr>
Date: Tue, 10 May 2022 17:24:25 +0200
Subject: [PATCH] it runs, some corrections needed next

---
 SuperTagger/Linker/Linker.py                  |   6 +++---
 .../Linker/__pycache__/Linker.cpython-38.pyc  | Bin 5697 -> 5536 bytes
 .../Linker/__pycache__/utils.cpython-38.pyc   | Bin 6078 -> 5906 bytes
 SuperTagger/Linker/utils.py                   |  18 ++++++++++--------
 SuperTagger/__pycache__/eval.cpython-38.pyc   | Bin 1686 -> 1775 bytes
 SuperTagger/eval.py                           |   6 ++++--
 6 files changed, 17 insertions(+), 13 deletions(-)

diff --git a/SuperTagger/Linker/Linker.py b/SuperTagger/Linker/Linker.py
index f2f7192..b8dbf8c 100644
--- a/SuperTagger/Linker/Linker.py
+++ b/SuperTagger/Linker/Linker.py
@@ -90,13 +90,13 @@ class Linker(Module):
                              if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
                                  atoms_batch_tokenized[s_idx][i] == self.atom_map[atom_type] and
                                  atoms_polarity_batch[s_idx][i])] + [torch.zeros(self.dim_embedding_atoms)])
-                            for s_idx in range(len(atoms_polarity_batch))], padding_value=19, max_len=self.max_atoms_in_one_type//2)
+                            for s_idx in range(len(atoms_polarity_batch))], padding_value=0, max_len=self.max_atoms_in_one_type//2)
 
             neg_encoding = pad_sequence([torch.stack([x  for i, x in enumerate(atoms_encoding[s_idx])
                              if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
                                  atoms_batch_tokenized[s_idx][i] == self.atom_map[atom_type] and
                                  not atoms_polarity_batch[s_idx][i])] + [torch.zeros(self.dim_embedding_atoms)])
-                            for s_idx in range(len(atoms_polarity_batch))], padding_value=19, max_len=self.max_atoms_in_one_type//2)
+                            for s_idx in range(len(atoms_polarity_batch))], padding_value=0, max_len=self.max_atoms_in_one_type//2)
 
             pos_encoding = self.pos_transformation(pos_encoding)
             neg_encoding = self.neg_transformation(neg_encoding)
@@ -104,7 +104,7 @@ class Linker(Module):
             weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
             link_weights.append(sinkhorn(weights, iters=3))
 
-        return torch.cat([link_weights[i].unsqueeze(0) for i in range(len(link_weights))])
+        return torch.stack(link_weights)
 
     def eval_batch(self, batch, cross_entropy_loss):
         batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
diff --git a/SuperTagger/Linker/__pycache__/Linker.cpython-38.pyc b/SuperTagger/Linker/__pycache__/Linker.cpython-38.pyc
index db616de9a03c6a9f098f7797b6db73b69aea4751..d796c7a4452db32c4eb9d9237b80fc4bab334428 100644
GIT binary patch
delta 358
zcmX@8vp}0Sl$V!_fq{WxMMYK8u8q8Eoa&qm3=A9$3=GT+3=GBB7#J8jSQapJFfL@Q
zWvOB5WJ+hKWi4UsV3{1nsmRDOxrS3xk|m2hg-wbfg}s-#maT-PouQqvovEF<on`Vm
zPRUSYS@sT=1spYO3mI$KI#?EP*03*RtYt6ZTEJbyR>Qu4CxtDAb0O10rdo~~)&;yJ
zd^H>?Tq)eWObZzq845K@xEAo&ur6c_X3*r>Y{a#Oi7{&OeV%xxBF4=Qyrztdp_9w_
z3@5Ad$Fmm&FfcG^a!&5%H|EL7%*&21Pt8ovC@G%2hhK}SC}cCUfG(qK7y|=CQ6xwq
z<1OZt%v?>bTTCVSMb;qUl+C_^eoT@jAX%;=Nd^XnTU;rLC5buti7BZ?MR}9A2pe+e
kf>aiPh~mi~gmc+Is*B1eCyUHu%%037I-4<R@;uSS0GTCQkpKVy

delta 542
zcmZ3WeNcxtl$V!_fq{V`Govc$-9}zDPIV3j1_lsjW?*0_e#5}P(80QZp@VTDV=YS!
zOD9u0LnliKV+ZTxC{9I2*2y)Tl9H@h>?v$g3@Plr%(bi~EbR>KjO|SA%<Zg`*KtbL
zBg?XNurA=JVO_{r%i6)ZfU|~eA!99D3D*Md8rB-N1w1KiDVz(L7Bbbccd#trE#a$S
zPvJ`8?qyoY$jDHrQNp!=znvwGF@-0ErG=w}wS=vMrJ1pr(S@OzaUo+cgC=j)J1zw%
z2uLhS%uOvxE!u3rHG_%q*5s2s@k}p;H*4{lGBSouPT(`FPi6#p9EzD37#LU?7#M^>
zzLsKOU`S`EVTcu~W$a+6VU%Q8z>vbQkZ~eYAxkiWCSwsurzT^OJp%)SCR33c0|P^m
zJBZDAizg>DFFU?GH8VY<q_{{9BqI+Zd>I%RvO#9EFfcH%F)%S`PM*XU$9{`BIk7~O
zWAb-C<H;8M8nT*fMV=t#ULYbAM0kUUSP&t(If-AFQ8og^j)vIAoRXQV$#sjVB)`ZS
zB%HB%jesAMWCci;t4NZ8f#DWcN@7W3PJUuaYEe<qWMN@L?n02tG7wQd*<U!94WzoL
Wdh#LRd5rm!6Gdh-rcVAVvKRoye|&KO

diff --git a/SuperTagger/Linker/__pycache__/utils.cpython-38.pyc b/SuperTagger/Linker/__pycache__/utils.cpython-38.pyc
index 634e9ecc6fb1dfb6f18f0cb63c6e4db4b3607ca5..7f42ae68119786d9252d2c8eac5eb9c7a7083377 100644
GIT binary patch
delta 1420
zcmdm|KS@tJl$V!_fq{WxL334-fiwfdV-N=!voSC*I503U6mOcS-J&U!B9bDSB9_jS
z&e+TpC7ddfD%{M-$dD?M%9A3JDh!unp17}CmX(2l0i>N7q<sMc149Wz4O25?Cu2H8
zE%W3s#&UVa6!u=`T9z8-1?)8}3mF+13KdE?7I4-uFJug6(B#<6#N^E~IfzYaGBdj;
z2WxRjVsf@7>*OGIgUQF)6ec&a%QCV~p3g2dc^|toqyA<l4jx8E^~wC4tx~s`lM_pd
zK$b%Y5e5c^B2fkg2A#?4IK_B185kIDu@<GKr&bghPrk@$&uBJTkW0td5~K{Q%M`={
z>#}8FU{C}3KmZghjAD!&jBHE-jC_n7i~@{Bwv!9EqIk_27#K8}ia<(>bS7WtQa7^)
z={ExrVB^39*f@0t1_mRLaeQFoSQt4NL0Ev1hp|X~vI+NXWdo2dkdh)35DV-Qgnp1G
z948y`G;@P2h4{{5@&O)k3wMyIAcbrUObi?hnw&+TfG-01wTKgBk}=rV=|%aa1w~dM
z0c#LpGnt=PT+$B21Zge;+Xr?k$Uuk5_Po-J?jT<Zfh^QyDgp@?>44NbO+L=6Z0-xP
z7mIH^7#J9mnLxgTVo-WuVPIeYIjgv3vLc^_v|A0+0_GY<9tKe4PhiZm>R?#Fl)|=<
zvFO3%B0dWy9)`)Q_>}DrbucYpN?}>ZSPWsYE@Uh^2hj@Bm?u!fRKn501Tv_Z$(*5<
zDQ^MTAnt{X6BvscCiC$7ONMkXLd=-JSPWsaf@KpXm+-4F@i1(j!Oz7=MBpb+J}S_J
z9{6Hepg;hHA0+UL7(wCVI=NoZjT<RGCtna$U@V;cQP3UT0LH4x(L#YJYPSo?Ti1b%
z1%(;NSmYSJ#Z_FASX5G6o>`Isi41Uzf}*$x8m)rDN{nWcO@vikLHW1H5#%C}3pJUG
zzzGu+Z$&;JQE+VgGcYhjgRIj7MHf2@qZp$MqX07(BPgv4F$+wdD=f((z*OWvd6%%E
zMh%k-L#$FQOA2ExYY9UQOATu?NQAGJt%i94V+|W9S_;()MJ5Z2{9z24%qS!}Syfb>
zQ<JHP2V^7fWIs^_##xi|MV;h97DK!s50X*<5g;!YX@OYUlP`!Wsk$&QFhp_XWEPji
zCzj;r7DIdx07}qIu#6r!SzJsuqzmK)36K{!K=}%S`55^ag&3t6nHZS<7lkk|FuVj6
zdR2TWMH+@Wc?I!?I>qsZnpJ!$ImmoXQAobAntV;H)Ce4^`XDEPvpFaqL$U%Wofm;U
z0Wvv?6YRkFg8ZDxMdDg&N(>AP9E@Pd#KXwL$iu|K&LPUd!@<YF%fZK?&mk<p$iWBz
DBL4xG

delta 1556
zcmbQFw@+U?l$V!_fq{WRDzz%fSDJz0F^GeVSs54@92giFif2sJZs8YB6-p6F5ls<G
zXG&*mW}0}TT9$=@fq{*Ifq@yMXd43qLkU9-Q!`^HV>&}8>*NB)a#Qvcj$Y<k<_^{c
z>@~~_85tQ06-qc3aJI9iF{W^)Ft>1&u#~WNur@O`GrBM|GcIHdX3*r?%*N!+!c@dK
z*^f;&I++p4NlXk33?NqtgDjL{U|>jRs9}f|s%7k8s9}_3Siq3Nu#j;gQz1(*gC=7U
zNT(*_EzZ)s;=<C@)T-2#Ohr5l3=BoQ3=9mKjJGDgVN;MP0(qcF1SBlVz`&pcG8y78
zCWgt`Y@Te)$%!SJER*-J8BFG7mtka^ti>+Hs5#kz-JQ{B@-8mf$xGSaGpbFV%F)VH
zWH9->pty=Q$PAFNYz#~c91NQ5MW6sG(qUj=xW!tOnx0xwWHvdR)1J|C@^nreNdu7G
zAmv3CAQspu_6!URs*~?<N;29{e#jHWYXx#KQxQnGNPluHm%5xINUbG^0BZ#kV69pV
z3=D>o3wgyEwI=W4x~*&sQez4t%s~X$4uooupB*Oe;BMvySpf+QYmhG(-6qHJh~<JD
zUBn5}Xa;tEdQpC9L6Hqez!pT<fkXw96H8Ll^NT9uOY-9rOY(Dzi$LiHVj9>&u**Cd
z7#Ki4C>8=mDWe!82cr-p8<PMdA0r2&0ArCS$g4sicW5#dffN_%gRF3w?8~dn=mYkW
z*W^xKS;o4_%Xuv%U22#XFxN1GQq=^;JhKjl1xzVy3mJ><O#a1dG1-JqS#)0q(*mXx
zmW7PP5Ekn~#-dY`bNFO==YdsoFJzp+SX4E68lS&pKnEj4?F7bR2%8lw8#DPIpUP$x
zelA8LLNaM`pg<FPNQz~mrU!7kEOMXB$RjdYUeJgeDS{@)2r4iZOs*AlM>mAAa`JP*
zz*Ho)pcKJ)i!&!Rtt7rUGbI%g+Tef&MQSa`QdmOP<U~q>T*W1cMJ2`MnI##J00ze<
zD0+(^Q9609kP@Ti<O4z~Nlu^$a0U^eB%{e(1Ws}OAOSxR0k$rXfq@|kWSt%;UfEd~
z#TaE6`IxyFK`|@D%*PDk^Dqf9^FT$!7?l`#m;{*ln2G`?y9pa=q%hX9)G)a)#LCsO
zmN3+?*03~#czm^NHOvbbYuG?>T&PwkF?qi5A4a#yb3|;IiZmua6j5NDIhje+Ngm_^
zh+pJEu2KLIAio#sf>?T!lSP$O-9WD7%E>G)fhN%)kQ~TAx0n)3AfXpLd6lSaeJ98-
z5+J{DurP8kvM_>3K1M!9Ax0@iCI+VeMPUpK3@<^sqRKC&NW(xUB?m-+SOo>~rbaq>
z1@WfFnh@c<f_M`{kchF7rlt;9d2zg{iB3vpqK2VPQhtSop$<rwsivl;W|d4zkuu0E
zkbxjmEOd(FO)WKHDnuZ8(Ppx=SgATVwv0f&0OhJ8P{xO38!wO;I7~neoqR}4OG%D_
gfq{b&44HTsc^G+^SlBs4ICwbtICweuIP^J00Fp#5<NyEw

diff --git a/SuperTagger/Linker/utils.py b/SuperTagger/Linker/utils.py
index 7e12ea2..898a921 100644
--- a/SuperTagger/Linker/utils.py
+++ b/SuperTagger/Linker/utils.py
@@ -27,11 +27,12 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
 
         linking_plus_to_minus = pad_sequence(
             [torch.as_tensor([l_polarity_minus[s_idx].index(x) for i, x in enumerate(l_polarity_plus[s_idx])], dtype=torch.long)
-             for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type//2, padding_value=0)
+             for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type//2, padding_value=-1)
 
         linking_plus_to_minus_all_types.append(linking_plus_to_minus)
 
-    return torch.cat([linking_plus_to_minus_all_types[i].unsqueeze(0) for i in range(len(linking_plus_to_minus_all_types))])
+    return torch.stack(linking_plus_to_minus_all_types)
+
 
 
 def category_to_atoms_axiom_links(category, categories_to_atoms):
@@ -97,7 +98,7 @@ def category_to_atoms_polarity(category, polarity):
     category_to_polarity = []
     res = [bool(re.match(r''+atom_type+"_\d+", category)) for atom_type in atom_map.keys()]
     if True in res or category.startswith("dia") or category.startswith("box"):
-        category_to_polarity.append(polarity)
+        category_to_polarity.append(not polarity)
     else:
         # dr = /
         if category.startswith("dr"):
@@ -106,7 +107,7 @@ def category_to_atoms_polarity(category, polarity):
             left_side, right_side = category_cut[0], category_cut[1]
 
             # for the left side
-            category_to_polarity += category_to_atoms_polarity(left_side, polarity)
+            category_to_polarity += category_to_atoms_polarity(left_side, not polarity)
 
             # for the right side
             res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()]
@@ -129,7 +130,7 @@ def category_to_atoms_polarity(category, polarity):
                 category_to_polarity += category_to_atoms_polarity(left_side, not polarity)
 
             # for the right side
-            category_to_polarity += category_to_atoms_polarity(right_side, polarity)
+            category_to_polarity += category_to_atoms_polarity(right_side, not polarity)
 
     return category_to_polarity
 
@@ -145,13 +146,14 @@ def find_pos_neg_idexes(max_atoms_in_sentence, atoms_batch):
     list_batch = []
     for sentence in atoms_batch:
         list_atoms = []
-        polarity = False
         for category in sentence:
+            polarity = True
             for at in category_to_atoms_polarity(category, polarity):
                 list_atoms.append(at)
         list_batch.append(torch.as_tensor(list_atoms))
     return pad_sequence([list_batch[i] for i in range(len(list_batch))],
                         max_len=max_atoms_in_sentence, padding_value=0)
 
-
-print(find_pos_neg_idexes(9, [['dr(0,dl(0,dr(0,pp_52,np_53),dl(0,np_41,np_32)),dr(0,s_54,dia(1,box(1,pp_55))))', 'dr(0,dl(0,np_58,s_59),pp_55)']]))
\ No newline at end of file
+atoms_pol = find_pos_neg_idexes(10, [['dr(1,np_1,s_1)', 'dl(1,np_1,s_1)']])
+print(atoms_pol)
+print(get_axiom_links(10, atoms_pol, [['dr(1,np_1,s_1)', 'dl(1,np_1,s_1)']]))
diff --git a/SuperTagger/__pycache__/eval.cpython-38.pyc b/SuperTagger/__pycache__/eval.cpython-38.pyc
index 4ed72f9719c0751839b70c3c027157a3fa1257c9..ec90d972078edc0b20a9275119656a5973b657f9 100644
GIT binary patch
delta 347
zcmbQn`<|CCl$V!_fq{YHT47ZZ6YEC4MkYaa1_p*=5Y}U0V5nh;RhztkNyZ_CrI#g)
zfsvt*DTOtdL6fbBk%58XC5X7ilAD^C_wxV$|Nk|aZgCc+rj#a^Waj7H;>k?U%P&fe
z&&*3nt=RmWiG`6-e)1b;uSiw~1_l-e1_o!4Ekz6r3?&Q;7;6|8FfC+Q$XLsi!dSyp
z!&t+R&1}Pv!c@al!;r$9&1Az+!?b{TA;SWeg$%XKC9E|J5OGO{X2x2U8s-`n35Ln}
zEVBG8k_;jYAf7ZsFoP!R<moIGhD9PE*D^6MFle&fVlBxpO3t{&lAo7ae2b+xvnsVn
zhJk_M7E5MHYVIxOoYXu`*2!k9-hxpAd8N7WxrxQusVVV^CHc9N8(7sOc|hUC%E!n9
Zh62n*Iv`^Oi&Q7?W}U~VGdYUQ1_0#2S5W`}

delta 258
zcmaFQJB^nwl$V!_fq{WxM`l%$JIhACMkYZv1_p*=5Y}N}V5nh;RhhhiNk+SuIgEjk
zp^zzsC73~zwTO{{f#D^HxW$s2nwY1_c#E?rHKjDUBr`v6^B*P_Mn>7mj4WP!EDQ_`
zAnXh>&~I`ii>z4<(*ov&3=3EmGSo8HFxN0HU|PtK!YIiAWid%Ih``v)Aht9^FoPz`
z<O?hn>MxlX7#K8JZ?Tr-7bR!hV#&`-ExyH4oLQAxB*nnMaEm3gBsEu)WpX8}_vCA=
is-hequd(tm@-Xr+7iog@3Kl6(7Gayms5yBvn+*Uc$vYST

diff --git a/SuperTagger/eval.py b/SuperTagger/eval.py
index b486e6a..9bfdc85 100644
--- a/SuperTagger/eval.py
+++ b/SuperTagger/eval.py
@@ -14,7 +14,7 @@ class SinkhornLoss(Module):
         super(SinkhornLoss, self).__init__()
 
     def forward(self, predictions, truths):
-        return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean')
+        return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean', ignore_index=-1)
                    for link, perm in zip(predictions, truths))
 
 
@@ -25,7 +25,9 @@ def mesure_accuracy(linking_plus_to_minus, axiom_links_pred):
     """
     correct_links = torch.ones(axiom_links_pred.size())
     correct_links[axiom_links_pred != linking_plus_to_minus] = 0
+    correct_links[linking_plus_to_minus == -1] = 1
     num_correct_links = correct_links.sum().item()
+    num_masked_atoms = len(linking_plus_to_minus[linking_plus_to_minus == -1])
 
     # diviser par nombre de links
-    return num_correct_links/(axiom_links_pred.size()[0] * axiom_links_pred.size()[1] * axiom_links_pred.size()[2])
+    return (num_correct_links - num_masked_atoms)/(axiom_links_pred.size()[0] * axiom_links_pred.size()[1] * axiom_links_pred.size()[2] - num_masked_atoms)
-- 
GitLab