From 5ebfe402fc00013ce07ce6fad058c1f00febbca4 Mon Sep 17 00:00:00 2001
From: Caroline DE POURTALES <cdepourt@montana.irit.fr>
Date: Tue, 5 Jul 2022 17:18:05 +0200
Subject: [PATCH] best score 81%

---
 Configuration/config.ini |  4 +-
 Linker/Linker.py         |  4 +-
 Linker/utils_linker.py   | 25 +++++++++--
 command_line.txt         |  4 --
 find_config.py           | 63 ++++++++++++++++++++++++++
 postprocessing.py        | 96 ++++++++++++++++++++++++++++++++++++++++
 6 files changed, 183 insertions(+), 13 deletions(-)
 delete mode 100644 command_line.txt
 create mode 100644 find_config.py
 create mode 100644 postprocessing.py

diff --git a/Configuration/config.ini b/Configuration/config.ini
index 0cf354c..4de3f49 100644
--- a/Configuration/config.ini
+++ b/Configuration/config.ini
@@ -13,8 +13,8 @@ dim_encoder = 768
 
 [MODEL_LINKER]
 nhead=8
-dim_emb_atom = 256
-dim_feedforward_transformer = 512
+dim_emb_atom = 512
+dim_feedforward_transformer = 768
 num_layers=3
 dim_cat_inter=768
 dim_cat_out=512
diff --git a/Linker/Linker.py b/Linker/Linker.py
index 370b25b..498a828 100644
--- a/Linker/Linker.py
+++ b/Linker/Linker.py
@@ -128,7 +128,7 @@ class Linker(Module):
         self.cross_entropy_loss = SinkhornLoss()
         self.optimizer = AdamW(self.parameters(),
                                lr=learning_rate)
-        self.scheduler = StepLR(self.optimizer, step_size=3, gamma=0.5)
+        self.scheduler = StepLR(self.optimizer, step_size=2, gamma=0.5)
 
         self.to(self.device)
 
@@ -257,8 +257,6 @@ class Linker(Module):
             if tensorboard:
                 writer.add_scalars(f'Accuracy', {
                     'Train': avg_accuracy_train}, epoch_i)
-                writer.add_scalars(f'Learning rate', {
-                    'learning_rate': self.scheduler.get_last_lr()}, epoch_i)
                 writer.add_scalars(f'Loss', {
                     'Train': avg_train_loss}, epoch_i)
                 if validation_rate > 0.0:
diff --git a/Linker/utils_linker.py b/Linker/utils_linker.py
index 92b41d5..3b38a77 100644
--- a/Linker/utils_linker.py
+++ b/Linker/utils_linker.py
@@ -352,8 +352,7 @@ print(" test for get GOAL ", get_GOAL(10, 30, df_axiom_links))
 
 def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence):
     pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if
-                                              bool(
-                                                  re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and
+                                              bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and
                                               atoms_polarity_batch[s_idx][i]])
                              for s_idx, sentence in enumerate(atoms_batch)],
                             max_len=max_atoms_in_one_type // 2, padding_value=-1)
@@ -364,8 +363,7 @@ def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_at
 
 def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence):
     pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if
-                                              bool(
-                                                  re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and
+                                              bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and
                                               not atoms_polarity_batch[s_idx][i]])
                              for s_idx, sentence in enumerate(atoms_batch)],
                             max_len=max_atoms_in_one_type // 2, padding_value=-1)
@@ -383,3 +381,22 @@ print(" test for cut into pos neg on ['dr(0,s,np)', 's']",
                         False, False]]), 10, 50))
 
 # endregion
+
+
+# region style Output
+
+def get_output(links_pred, atoms_batch, atoms_polarity):
+    r"""
+    Parameters:
+        links_pred : atom_vocab_size, batch_size, max atoms in one type
+        atoms_batch : batch_size, max atoms in sentence
+        atoms_polarity : batch_size, max atoms in sentence
+    """
+    sentences_with_links = []
+    for s_idx in range(len(atoms_batch)) :
+        atoms = atoms_batch[s_idx]
+        polarities = atoms_polarity[s_idx]
+
+
+
+# endregion
\ No newline at end of file
diff --git a/command_line.txt b/command_line.txt
deleted file mode 100644
index 31b0fc4..0000000
--- a/command_line.txt
+++ /dev/null
@@ -1,4 +0,0 @@
-scp -r cdepourt@osirim-slurm.irit.fr:projets/deepgrailGPU1/deepgrail_RNN_with_linker/TensorBoard/ /home/cdepourt/Bureau/deepgrail_RNN_with_linker/TensorBoard
-
-rsync -av -e ssh --exclude="__pycache__" --exclude="venv" --exclude=".git" --exclude=".idea"  -r /home/cdepourt/Bureau/deepgrail_RNN_with_linker cdepourt@osirim-slurm.irit.fr:projets/deepgrail2
-
diff --git a/find_config.py b/find_config.py
new file mode 100644
index 0000000..58d95bd
--- /dev/null
+++ b/find_config.py
@@ -0,0 +1,63 @@
+import numpy as np
+import torch
+from Configuration import Configuration
+from Linker import *
+from Linker.atom_map import atom_map_redux
+from Linker.utils_linker import get_atoms_batch, get_GOAL, get_atoms_links_batch, get_axiom_links
+from Supertagger.SuperTagger.SuperTagger import SuperTagger
+from utils import read_csv_pgbar
+import re
+
+
+torch.cuda.empty_cache()
+batch_size = int(Configuration.modelTrainingConfig['batch_size'])
+nb_sentences = batch_size * 800
+epochs = int(Configuration.modelTrainingConfig['epoch'])
+file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv'
+df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
+
+atoms_batch, atoms_polarity_batch, num_batch = get_GOAL(290, 875, df_axiom_links)
+
+truth_links_batch = get_axiom_links(324, atoms_polarity_batch, df_axiom_links["Y"])
+print("max idx for link", torch.max(truth_links_batch))
+
+neg_idx = [[[i for i, x in enumerate(sentence) if
+             bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i]))
+             and not atoms_polarity_batch[s_idx][i]]
+            for s_idx, sentence in enumerate(atoms_batch)]
+           for atom_type in list(atom_map_redux.keys())]
+max_atoms_in_on_type = 0
+for atoms_type_batch in neg_idx:
+    for sentence in atoms_type_batch:
+        if len(sentence) > max_atoms_in_on_type:
+            max_atoms_in_on_type = len(sentence)
+print("max atoms of one type in one sentence", max_atoms_in_on_type)
+
+atoms_links_batch = get_atoms_links_batch(df_axiom_links["Y"])
+max_atoms_in_links = 0
+sentence_max = ""
+for sentence in atoms_links_batch:
+    if len(sentence) > max_atoms_in_links:
+        max_atoms_in_links = len(sentence)
+        sentence_max = sentence
+print("max atoms in links", max_atoms_in_links)
+
+max_atoms_in_sentence = 0
+sentence_max = ""
+for sentence in atoms_batch:
+    if len(sentence) > max_atoms_in_sentence:
+        max_atoms_in_sentence = len(sentence)
+        sentence_max = sentence
+print("max atoms in categories", max_atoms_in_sentence)
+
+supertagger = SuperTagger()
+supertagger.load_weights("models/flaubert_super_98_V2_50e.pt")
+sentences_batch = df_axiom_links["X"].str.strip().tolist()
+sentences_tokens, sentences_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
+max_len_sentence = 0
+sentence_max = ""
+for sentence in sentences_tokens:
+    if len(sentence) > max_len_sentence:
+        max_len_sentence = len(sentence)
+        sentence_max = sentence
+print(" max len sentence", max_len_sentence)
diff --git a/postprocessing.py b/postprocessing.py
new file mode 100644
index 0000000..4dbf200
--- /dev/null
+++ b/postprocessing.py
@@ -0,0 +1,96 @@
+import re
+
+import graphviz
+import numpy as np
+import regex
+from Linker.atom_map import atom_map, atom_map_redux
+
+regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
+
+
+def recursive_linking(links, dot, category, parent_id, word_idx, depth,
+                      polarity, compt_plus, compt_neg):
+    res = [(category == atom_type) for atom_type in atom_map.keys()]
+    if True in res:
+        polarity = not polarity
+        if polarity:
+            atoms_idx = compt_plus[category]
+            compt_plus[category] += 1
+        else:
+            idx_neg = compt_neg[category]
+            compt_neg[category] += 1
+            atoms_idx = np.where(links[atom_map_redux[category]] == idx_neg)[0][0]
+        atom_id = category + "_" + str(polarity) + "_" + str(atoms_idx)
+        dot.node(atom_id, category + " " + str("+" if polarity else "-"))
+        dot.edge(parent_id, atom_id)
+    else:
+        category_id = category + "_" + str(word_idx) + "_" + str(depth)
+        dot.node(category_id, category + " " + str("+" if polarity else "-"))
+        dot.edge(parent_id, category_id)
+        parent_id = category_id
+
+        if category.startswith("dr"):
+            categories_inside = regex.match(regex_categories, category).groups()
+            categories_inside = [cat for cat in categories_inside if cat is not None]
+            categories_inside = [categories_inside[0], categories_inside[1]]
+            polarities_inside = [polarity, not polarity]
+
+        # dl / p
+        elif category.startswith("dl") or category.startswith("p"):
+            categories_inside = regex.match(regex_categories, category).groups()
+            categories_inside = [cat for cat in categories_inside if cat is not None]
+            categories_inside = [categories_inside[0], categories_inside[1]]
+            polarities_inside = [not polarity, polarity]
+
+        # box / dia
+        elif category.startswith("box") or category.startswith("dia"):
+            categories_inside = regex.match(regex_categories, category).groups()
+            categories_inside = [cat for cat in categories_inside if cat is not None]
+            categories_inside = [categories_inside[0]]
+            polarities_inside = [polarity]
+
+        else:
+            categories_inside = []
+            polarities_inside = []
+
+        for cat_id in range(len(categories_inside)):
+            recursive_linking(links, dot, categories_inside[cat_id], parent_id, word_idx, depth + 1, polarities_inside[cat_id], compt_plus,
+                                  compt_neg)
+
+
+def draw_sentence_output(sentence, categories, links):
+    dot = graphviz.Graph('linking', comment='Axiom linking')
+    dot.graph_attr['rankdir'] = 'BT'
+    dot.attr('edge', tailport='n')
+    dot.attr('edge', headport='s')
+
+    compt_plus = {'cl_r': 0, 'pp': 0, 'n': 0, 'np': 0, 'cl_y': 0, 'txt': 0, 's': 0}
+    compt_neg = {'cl_r': 0, 'pp': 0, 'n': 0, 'np': 0, 'cl_y': 0, 'txt': 0, 's': 0}
+    for word_idx in range(len(sentence)):
+        word = sentence[word_idx]
+        word_id = word + "_" + str(word_idx)
+        dot.node(word_id, word)
+
+        category = categories[word_idx]
+        polarity = True
+        parent_id = word_id
+        recursive_linking(links, dot, category, parent_id, word_idx, 0, polarity, compt_plus, compt_neg)
+
+    dot.attr('edge', color='red')
+    dot.attr('edge', style='dashed')
+    dot.attr('edge', tailport='n')
+    dot.attr('edge', headport='n')
+    for atom_type in list(atom_map_redux.keys()):
+        for id in range(compt_plus[atom_type]):
+            atom_plus = atom_type+"_"+str(True)+"_"+str(id)
+            atom_moins = atom_type+"_"+str(False)+"_"+str(id)
+            dot.edge(atom_plus, atom_moins, constraint="false")
+
+    dot.render(format="svg", view=True)
+    return dot.source
+
+
+sentence = ["Le", "chat","est","noir","bleu"]
+categories = ["dr(0,s,n)", "dl(0,s,n)","dr(0,dl(0,n,np),n)", "dl(0,np,n)","n"]
+links = np.array([[0,0,0,0],[0,0,0,0],[1,0,2,0],[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]])
+draw_sentence_output(sentence, categories, links)
\ No newline at end of file
-- 
GitLab