From 9f49c6d7646db371b2d0127579b14c584c4cc7d3 Mon Sep 17 00:00:00 2001
From: Caroline DE POURTALES <cdepourt@montana.irit.fr>
Date: Wed, 11 May 2022 14:29:49 +0200
Subject: [PATCH] it runs, some corrections needed next

---
 SuperTagger/Linker/Linker.py                  |   5 -
 .../Linker/__pycache__/utils.cpython-38.pyc   | Bin 5906 -> 7029 bytes
 SuperTagger/Linker/utils.py                   |  95 ++++++++++++------
 SuperTagger/eval.py                           |   8 +-
 train.py                                      |   2 +
 5 files changed, 70 insertions(+), 40 deletions(-)

diff --git a/SuperTagger/Linker/Linker.py b/SuperTagger/Linker/Linker.py
index b8dbf8c..281a7ab 100644
--- a/SuperTagger/Linker/Linker.py
+++ b/SuperTagger/Linker/Linker.py
@@ -61,11 +61,6 @@ class Linker(Module):
             LayerNorm(self.dim_embedding_atoms, eps=1e-12)
         )
 
-    def make_decoder_mask(self, atoms_batch):
-        decoder_attn_mask = torch.ones_like(atoms_batch, dtype=torch.float64)
-        decoder_attn_mask[atoms_batch.eq(self.padding_id)] = 0.0
-        return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_batch.shape[1], 1)
-
     def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding):
         '''
         Parameters :
diff --git a/SuperTagger/Linker/__pycache__/utils.cpython-38.pyc b/SuperTagger/Linker/__pycache__/utils.cpython-38.pyc
index 7f42ae68119786d9252d2c8eac5eb9c7a7083377..b13bf57006d2d3eb78625e40b86218f03e27798c 100644
GIT binary patch
delta 1561
zcmbQF_ti{0l$V!_fq{WRZDw_nwF(2nV-N=!GchnQI503U6t_&&Ze~ki3}(<2+4!iI
ziBWuV0&^Lo_~h%%-xwt)pJOp*l$^}OT4kokz`&5sP{R<b7Q<A_Sj$wy)WOifD9Mn`
zR-^`DEnui&TF98kl)^BPsgNa@L6fnFb@BpMOGdfLw^&m}<UqFZFfcH%F){rUVC4GG
z!pO4On~jf=QD|}$yF5P=*FP2}HU_4@Y^*#?Mf#I>amh|z%Kn~FWAapvd5pS~O*oe`
zs!zVdd6dz3av`ty<egmq7@a5Y;9kn;IXR9ejnQfHQJ&R|{*w!NO_|b|C$Hc&k9MhH
zTELvbwvdsLVFGiWTnEDfrWE#tj1w4(E|st>U@c+mU|7hsfPEoDElUbV3R4S14T}py
zGh;1F4ND5ABts2L3YR1U6GJCsCsPe$3UfBo1jZt*cE&Wu6z<6meB!2iIv5u)rLZhy
zEQYXH7cv%|K-ihbSHo1o(ZQI)+{|RoP|MWLl*X9CGueSp%x(f>F<%E0L^GVt3U-<l
z*Z^jbqoD>crLaK^$Xf!ooEKz-%>>5cB$)khHpGZL3?mp}Mu6SG2Qr|30%P$OBsW0V
z5CaZ?4FCrm+zn7qOkm8L01in0g^UxJi;^$`Qh=y{WGZ6c2lk90$SpP#n2H17o`JDJ
zZec2l0lNh`JWxF&1Tw&G0#orKxMyH&hym-s24M9JG(d$xMu6P#4ap4<HpmF(A|?zY
zP$N|YVgPfo2iy%XHpGAsumQ-PfJbT#Q!s<3=w<_c1xCiy$$<jytZuiMQi>+`3Pv*K
zO#UX2$-a{D7H3Xs+T?UWrOAAPfs89BCkV>$l!NjU3n)J^Fiqwa3{oiaVPIe=VgnH(
zAVQRZfuSBvk!X=0ND5U^>*QL&Ovb*+&jsVL8ZsHp5Cv?8%*0TH#W@Sn6k&DF^2y0U
z35-#bj|f@U`-5By_DvuI14ArGdmN}3V`pI$W0Ya!W9DHLVB}*IV&Y-~@p+i|n8X;B
z7<rfkSRgcvFNC5Jqyi>`u9A<b2xM0>sD$(e*#*L&00UuWkPEWFMPVmH4WkP~tW7Ob
zIzugU4U-E)tXwTi38>UyYzFc8YFTTT7ckbag0ga<(&o9sDU6K!Co_pUF$PX{6;)zn
zo17tP#C?k?u>=xpL6cXB%BpPuSs@0pl!JwlgAs)JKwe~&Vq{`q`d<{vz`&p>JXu6+
s7AMG^VE*KjV#*3K3=9k$j9|#b!^p$P!^FbDA<DtS!Oy|V!N(B{0FO&eD*ylh

delta 1028
zcmexrHc3xAl$V!_fq{WxL334-fiwfdV-N=!voSC*I503U6mOcS-E5L7lp>NMnj)6Y
zl+M`96eXN0k}BNH$jFc?lFE}Jk}3?BV+M()Fa|Sdif`Of%fu))*@(G}QEc)`=5LG=
zlNYcUGfGUp!cwIo$H2gl&QQY;D;C34%UH|Q!O+1d$*_Q-hG`+=M5aQPV1^>L$t|pw
zjIxvWv8M9Kf?UGFz`(%9$h6s<jgOI0aIzb_JU1iPe-=hI2Bv>(ESn411sECCCpU7;
zW7L_fz`2}JZSo1uql`wI`?y3I8677t;a<w<KG};WjnQH9MxND-zLWd7#3z^V$}zPt
zPM*SRp6FJ?w1ByWk%wU+BO}8E#yqPIh6PM1YzrBS9+a>wU@c+mU|7hsfPEoDElUb}
z3R4S14T}pyGh;1F4ND4#Bts2L3a2Ck6GJCsCsPem2}cJb$gE~2bB0=`cBVAO6t2k&
zc*W}vbucYpN?}>ZSPWsYE@Uh^2eBDsZ=L{HFB3v<-U6^YxEC@`U@U6Da0w57m$Wmc
zF{bdQFtu=$Ff3qP$T0Z=AHR1<2P4D{6BvsjY*w)82^j7`@&mfNYM6o<H2F6F6Hs7e
zOr9(+<j!)7DWz!gK@PLYA4HTHvnC%E%4F1>yi!<ZvX?OL<ZU9cd{v-m1f?uC1||lM
z$rpu#7&RyB3rjQBO};0Vz<6u2y|{I~4+8^3kqC(JXJBB625Hd)rA&4fMlnViMge9n
zMgc}XMj>VaW)Po;Nr;&TDk8?H#K^-Wz%0O21k#(#1kwk^?4Z=cz`(!^GB6LE{5lzG
z7+n}*9cr1<8ETnpm|PfQm1^Z#QW$GlOBiZcYFL{=LVUGsHOvbbYuG^1Q>eDtMIwcf
zan|IOl1_~Nlix@xF|tqQmogG43SeMhxW$xMQUvl~QQ%}BDOukxkYy4e7jUpJaxg+L
zAISBLQjAOtO#h2Q7#J8{g0f;2UrLdNp-x^wyrE8UyrE_lUrG)#UsH5)ztk*gkiWp1
zqBs*v@^g#h3-WVru@)3%=9Nq~l2%qzVqjq4U<5-Z9!4HU9wruc4p9ys4n7WE4n7Wj
K4q*XC4n_bVv)(EI

diff --git a/SuperTagger/Linker/utils.py b/SuperTagger/Linker/utils.py
index 898a921..3f8e892 100644
--- a/SuperTagger/Linker/utils.py
+++ b/SuperTagger/Linker/utils.py
@@ -9,6 +9,7 @@ from SuperTagger.utils import pad_sequence
 
 regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
 
+
 #########################################################################################
 ################################ Liste des atoms avc _i########################################
 #########################################################################################
@@ -26,17 +27,17 @@ def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
                             range(len(atoms_batch))]
 
         linking_plus_to_minus = pad_sequence(
-            [torch.as_tensor([l_polarity_minus[s_idx].index(x) for i, x in enumerate(l_polarity_plus[s_idx])], dtype=torch.long)
-             for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type//2, padding_value=-1)
+            [torch.as_tensor([l_polarity_minus[s_idx].index(x) if x in l_polarity_minus[s_idx] else -1 for i, x in
+                              enumerate(l_polarity_plus[s_idx])], dtype=torch.long)
+             for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type // 2, padding_value=-1)
 
         linking_plus_to_minus_all_types.append(linking_plus_to_minus)
 
     return torch.stack(linking_plus_to_minus_all_types)
 
 
-
 def category_to_atoms_axiom_links(category, categories_to_atoms):
-    res = [bool(re.match(r''+atom_type+"_\d+", category)) for atom_type in atom_map.keys()]
+    res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
     if True in res:
         return [category]
     else:
@@ -56,13 +57,14 @@ def get_atoms_links_batch(category_batch):
         batch.append(categories_to_atoms)
     return batch
 
+
 #########################################################################################
 ################################ Liste des atoms ########################################
 #########################################################################################
 
 
 def category_to_atoms(category, categories_to_atoms):
-    res = [bool(re.match(r''+atom_type+"_\d+", category)) for atom_type in atom_map.keys()]
+    res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
     if True in res:
         category = re.match(r'([a-zA-Z|_]+)_\d+', category).group(1)
         return [category]
@@ -96,9 +98,9 @@ def category_to_atoms_polarity(category, polarity):
     Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes
     '''
     category_to_polarity = []
-    res = [bool(re.match(r''+atom_type+"_\d+", category)) for atom_type in atom_map.keys()]
+    res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
     if True in res or category.startswith("dia") or category.startswith("box"):
-        category_to_polarity.append(not polarity)
+        category_to_polarity.append(False)
     else:
         # dr = /
         if category.startswith("dr"):
@@ -106,15 +108,33 @@ def category_to_atoms_polarity(category, polarity):
             category_cut = [cat for cat in category_cut if cat is not None]
             left_side, right_side = category_cut[0], category_cut[1]
 
-            # for the left side
-            category_to_polarity += category_to_atoms_polarity(left_side, not polarity)
-
-            # for the right side
-            res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()]
-            if True in res or right_side.startswith("dia") or right_side.startswith("box"):
-                category_to_polarity.append(polarity)
-            else :
-                category_to_polarity += category_to_atoms_polarity(right_side, not polarity)
+            if polarity == True:
+                # for the left side : normal
+                res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()]
+                if True in res or left_side.startswith("dia") or left_side.startswith("box"):
+                    category_to_polarity.append(False)
+                else:
+                    category_to_polarity += category_to_atoms_polarity(left_side, True)
+                # for the right side : change polarity for next right formula
+                res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()]
+                if True in res or right_side.startswith("dia") or right_side.startswith("box"):
+                    category_to_polarity.append(True)
+                else:
+                    category_to_polarity += category_to_atoms_polarity(right_side, False)
+
+            else:
+                # for the left side
+                res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()]
+                if True in res or left_side.startswith("dia") or left_side.startswith("box"):
+                    category_to_polarity.append(True)
+                else:
+                    category_to_polarity += category_to_atoms_polarity(left_side, False)
+                # for the right side : change polarity for next right formula
+                res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()]
+                if True in res or right_side.startswith("dia") or right_side.startswith("box"):
+                    category_to_polarity.append(False)
+                else:
+                    category_to_polarity += category_to_atoms_polarity(right_side, True)
 
         # dl = \
         elif category.startswith("dl"):
@@ -122,15 +142,33 @@ def category_to_atoms_polarity(category, polarity):
             category_cut = [cat for cat in category_cut if cat is not None]
             left_side, right_side = category_cut[0], category_cut[1]
 
-            # for the left side
-            res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()]
-            if True in res or left_side.startswith("dia") or left_side.startswith("box"):
-                category_to_polarity.append(polarity)
-            else :
-                category_to_polarity += category_to_atoms_polarity(left_side, not polarity)
-
-            # for the right side
-            category_to_polarity += category_to_atoms_polarity(right_side, not polarity)
+            if polarity == True:
+                # for the left side : change polarity
+                res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()]
+                if True in res or left_side.startswith("dia") or left_side.startswith("box"):
+                    category_to_polarity.append(True)
+                else:
+                    category_to_polarity += category_to_atoms_polarity(left_side, False)
+                # for the right side : normal
+                res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()]
+                if True in res or right_side.startswith("dia") or right_side.startswith("box"):
+                    category_to_polarity.append(False)
+                else:
+                    category_to_polarity += category_to_atoms_polarity(right_side, True)
+
+            else:
+                # for the left side
+                res = [bool(re.match(r'' + atom_type + "_\d+", left_side)) for atom_type in atom_map.keys()]
+                if True in res or left_side.startswith("dia") or left_side.startswith("box"):
+                    category_to_polarity.append(False)
+                else:
+                    category_to_polarity += category_to_atoms_polarity(left_side, True)
+                # for the right side
+                res = [bool(re.match(r'' + atom_type + "_\d+", right_side)) for atom_type in atom_map.keys()]
+                if True in res or right_side.startswith("dia") or right_side.startswith("box"):
+                    category_to_polarity.append(True)
+                else:
+                    category_to_polarity += category_to_atoms_polarity(right_side, False)
 
     return category_to_polarity
 
@@ -147,13 +185,8 @@ def find_pos_neg_idexes(max_atoms_in_sentence, atoms_batch):
     for sentence in atoms_batch:
         list_atoms = []
         for category in sentence:
-            polarity = True
-            for at in category_to_atoms_polarity(category, polarity):
+            for at in category_to_atoms_polarity(category, True):
                 list_atoms.append(at)
         list_batch.append(torch.as_tensor(list_atoms))
     return pad_sequence([list_batch[i] for i in range(len(list_batch))],
                         max_len=max_atoms_in_sentence, padding_value=0)
-
-atoms_pol = find_pos_neg_idexes(10, [['dr(1,np_1,s_1)', 'dl(1,np_1,s_1)']])
-print(atoms_pol)
-print(get_axiom_links(10, atoms_pol, [['dr(1,np_1,s_1)', 'dl(1,np_1,s_1)']]))
diff --git a/SuperTagger/eval.py b/SuperTagger/eval.py
index 9bfdc85..b287d4b 100644
--- a/SuperTagger/eval.py
+++ b/SuperTagger/eval.py
@@ -18,16 +18,16 @@ class SinkhornLoss(Module):
                    for link, perm in zip(predictions, truths))
 
 
-def mesure_accuracy(linking_plus_to_minus, axiom_links_pred):
+def mesure_accuracy(batch_true_links, axiom_links_pred):
     r"""
     batch_axiom_links : (batch_size, ...)
     axiom_links_pred : (batch_size, max_atoms_type_polarity)
     """
     correct_links = torch.ones(axiom_links_pred.size())
-    correct_links[axiom_links_pred != linking_plus_to_minus] = 0
-    correct_links[linking_plus_to_minus == -1] = 1
+    correct_links[axiom_links_pred != batch_true_links] = 0
+    correct_links[batch_true_links == -1] = 1
     num_correct_links = correct_links.sum().item()
-    num_masked_atoms = len(linking_plus_to_minus[linking_plus_to_minus == -1])
+    num_masked_atoms = len(batch_true_links[batch_true_links == -1])
 
     # diviser par nombre de links
     return (num_correct_links - num_masked_atoms)/(axiom_links_pred.size()[0] * axiom_links_pred.size()[1] * axiom_links_pred.size()[2] - num_masked_atoms)
diff --git a/train.py b/train.py
index fced23b..05d223f 100644
--- a/train.py
+++ b/train.py
@@ -53,9 +53,11 @@ print("atoms_polarity_batch", atoms_polarity_batch.shape)
 
 truth_links_batch = get_axiom_links(max_atoms_in_one_type, atoms_polarity_batch, df_axiom_links["sub_tree"])
 print("truth_links_batch", truth_links_batch.permute(1, 0, 2).shape)
+print(" truth_links_batch example on first sentence class cl_r", truth_links_batch[0][0])
 
 sentences_batch = df_axiom_links["Sentences"]
 
+# Construction tensor dataset
 dataset = TensorDataset(atoms_batch_tokenized, atoms_polarity_batch, truth_links_batch.permute(1, 0, 2))
 
 # Calculate the number of samples to include in each set.
-- 
GitLab