From a2a36c49d8a10a13fdcbfcdf9e7f10e5bf266a31 Mon Sep 17 00:00:00 2001
From: CarodePourtales <49019500+CarodePourtales@users.noreply.github.com>
Date: Thu, 7 Jul 2022 19:03:15 +0200
Subject: [PATCH] adding predict, comments and auto config

---
 Configuration/Configuration.py |   1 -
 Configuration/config.ini       |  38 +++++------
 Linker/Linker.py               |  73 ++++++++++++++------
 Linker/README.md               |  50 ++++++++++++++
 Linker/utils_linker.py         |  32 ++-------
 README.md                      |  14 +---
 find_config.py                 | 118 ++++++++++++++++-----------------
 postprocessing.py              |  47 +++++++++----
 train.py                       |  15 +++--
 utils.py                       |  17 ++++-
 10 files changed, 248 insertions(+), 157 deletions(-)
 create mode 100644 Linker/README.md

diff --git a/Configuration/Configuration.py b/Configuration/Configuration.py
index 12a4b5f..9a120c7 100644
--- a/Configuration/Configuration.py
+++ b/Configuration/Configuration.py
@@ -10,7 +10,6 @@ config.read(path_config_file)
 # region Get section
 
 version = config["VERSION"]
-
 datasetConfig = config["DATASET_PARAMS"]
 modelEncoderConfig = config["MODEL_ENCODER"]
 modelLinkerConfig = config["MODEL_LINKER"]
diff --git a/Configuration/config.ini b/Configuration/config.ini
index 4de3f49..61314f0 100644
--- a/Configuration/config.ini
+++ b/Configuration/config.ini
@@ -2,29 +2,29 @@
 transformers = 4.16.2
 
 [DATASET_PARAMS]
-symbols_vocab_size=26
-atom_vocab_size=18
-max_len_sentence=290
-max_atoms_in_sentence=875
-max_atoms_in_one_type=324
+symbols_vocab_size = 26
+atom_vocab_size = 18
+max_len_sentence = 83
+max_atoms_in_sentence = 238
+max_atoms_in_one_type = 102
 
 [MODEL_ENCODER]
 dim_encoder = 768
 
 [MODEL_LINKER]
-nhead=8
-dim_emb_atom = 512
-dim_feedforward_transformer = 768
-num_layers=3
-dim_cat_inter=768
-dim_cat_out=512
-dim_intermediate_FFN=256
-dim_pre_sinkhorn_transfo=32
-dropout=0.1
-sinkhorn_iters=5
+nhead = 8
+dim_emb_atom = 256
+dim_feedforward_transformer = 512
+num_layers = 3
+dim_cat_out = 512
+dim_intermediate_ffn = 256
+dim_pre_sinkhorn_transfo = 32
+dropout = 0.1
+sinkhorn_iters = 5
 
 [MODEL_TRAINING]
-batch_size=32
-epoch=30
-seed_val=42
-learning_rate=2e-3
\ No newline at end of file
+batch_size = 32
+epoch = 30
+seed_val = 42
+learning_rate = 2e-3
+
diff --git a/Linker/Linker.py b/Linker/Linker.py
index 498a828..2f10844 100644
--- a/Linker/Linker.py
+++ b/Linker/Linker.py
@@ -1,9 +1,7 @@
+import datetime
 import math
 import os
-import re
 import sys
-import datetime
-
 import time
 
 import torch
@@ -17,17 +15,16 @@ from torch.utils.tensorboard import SummaryWriter
 from tqdm import tqdm
 
 from Configuration import Configuration
+from Linker.AtomTokenizer import AtomTokenizer
 from Linker.PositionalEncoding import PositionalEncoding
 from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
-from Linker.AtomTokenizer import AtomTokenizer
 from Linker.atom_map import atom_map, atom_map_redux
 from Linker.eval import mesure_accuracy, SinkhornLoss
-from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx, get_num_atoms_batch, get_neg_idx
-from Supertagger import SuperTagger
+from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx, get_neg_idx, get_atoms_batch, \
+    find_pos_neg_idexes, get_num_atoms_batch
+from SuperTagger import SuperTagger
 from utils import pad_sequence
 
-import torch
-
 
 def format_time(elapsed):
     '''
@@ -73,7 +70,6 @@ class Linker(Module):
         self.num_layers = int(Configuration.modelLinkerConfig['num_layers'])
         # torch cat
         dropout = float(Configuration.modelLinkerConfig['dropout'])
-        self.dim_cat_inter = int(Configuration.modelLinkerConfig['dim_cat_inter'])
         self.dim_cat_out = int(Configuration.modelLinkerConfig['dim_cat_out'])
         dim_intermediate_FFN = int(Configuration.modelLinkerConfig['dim_intermediate_FFN'])
         dim_pre_sinkhorn_transfo = int(Configuration.modelLinkerConfig['dim_pre_sinkhorn_transfo'])
@@ -87,7 +83,7 @@ class Linker(Module):
         self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
         # endregion
 
-        # Supertagger for categories
+        # SuperTagger for categories
         supertagger = SuperTagger()
         supertagger.load_weights(supertagger_path_model)
         self.Supertagger = supertagger
@@ -145,11 +141,14 @@ class Linker(Module):
         sentences_batch = df_axiom_links["X"].str.strip().tolist()
         sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
 
-        atoms_batch, atoms_polarity_batch, num_atoms_per_word = get_GOAL(self.max_len_sentence, self.max_atoms_in_sentence, df_axiom_links)
+        atoms_batch, polarities, num_atoms_per_word = get_GOAL(self.max_len_sentence, df_axiom_links)
+        atoms_polarity_batch = pad_sequence(
+            [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
+            max_len=self.max_atoms_in_sentence, padding_value=0)
         atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
 
-        pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence)
-        neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type, self.max_atoms_in_sentence)
+        pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type)
+        neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type)
 
         truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch,
                                             df_axiom_links["Y"])
@@ -203,8 +202,8 @@ class Linker(Module):
         atoms_encoding = self.linker_encoder(atoms_sentences_encoding)
 
         # linking per atom type
-        batch_size, atom_vocan_size, _ = batch_pos_idx.shape
-        link_weights = torch.zeros(atom_vocan_size, batch_size, self.max_atoms_in_one_type // 2,
+        batch_size, atom_vocab_size, _ = batch_pos_idx.shape
+        link_weights = torch.zeros(atom_vocab_size, batch_size, self.max_atoms_in_one_type // 2,
                                    self.max_atoms_in_one_type // 2, device=self.device)
         for atom_type in list(atom_map_redux.keys()):
             pos_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_pos_idx, atom_type)
@@ -252,7 +251,7 @@ class Linker(Module):
 
             if checkpoint:
                 self.__checkpoint_save(
-                    path=os.path.join("Output", 'linker' + datetime.today().strftime('%d-%m_%H-%M') + '.pt'))
+                    path=os.path.join("Output", 'linker' + datetime.datetime.today().strftime('%d-%m_%H-%M') + '.pt'))
 
             if tensorboard:
                 writer.add_scalars(f'Accuracy', {
@@ -319,7 +318,6 @@ class Linker(Module):
                 accuracy_train += mesure_accuracy(batch_true_links, pred_axiom_links, self.max_atoms_in_one_type)
 
         self.scheduler.step()
-        print("learning rate ", self.scheduler.get_last_lr())
 
         # Measure how long this epoch took.
         training_time = format_time(time.time() - t0)
@@ -370,19 +368,51 @@ class Linker(Module):
 
         return loss_average / len(dataloader), accuracy_average / len(dataloader)
 
+    def predict(self, sentence, categories):
+        r""" Predict the links from a sentence and its categories
+
+        Args :
+            sentence : list of words composing the sentence
+            categories : list of categories (tags) of each word
+        """
+        self.eval()
+        with torch.no_grad():
+            sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors([sentence])
+            sentences_tokens = sentences_tokens.to(self.device)
+            nb_sentence, len_sentence = sentences_tokens.shape
+            sentences_mask = sentences_mask.to(self.device)
+
+            atoms = get_atoms_batch([categories])
+            atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms).to(self.device)
+
+            polarities = find_pos_neg_idexes([categories])
+            polarities = pad_sequence(
+                [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
+                max_len=self.max_atoms_in_sentence, padding_value=0).to(self.device)
+
+            num_atoms_per_word = get_num_atoms_batch([categories], len_sentence).to(self.device)
+
+            pos_idx = get_pos_idx(atoms, polarities, self.max_atoms_in_one_type).to(self.device)
+            neg_idx = get_neg_idx(atoms, polarities, self.max_atoms_in_one_type).to(self.device)
+
+            output = self.Supertagger.forward(sentences_tokens, sentences_mask)
+
+            logits_predictions = self(num_atoms_per_word, atoms_tokenized, pos_idx, neg_idx, output['word_embeding'])
+            axiom_links_pred = torch.argmax(logits_predictions, dim=3)
+
+        return axiom_links_pred
+
     def load_weights(self, model_file):
         print("#" * 15)
         try:
             params = torch.load(model_file, map_location=self.device)
-            args = params['args']
-            self.max_atoms_in_sentence = args['max_atoms_in_sentence']
-            self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
             self.atom_encoder.load_state_dict(params['atom_encoder'])
             self.position_encoder.load_state_dict(params['position_encoder'])
             self.transformer.load_state_dict(params['transformer'])
             self.linker_encoder.load_state_dict(params['linker_encoder'])
             self.pos_transformation.load_state_dict(params['pos_transformation'])
             self.neg_transformation.load_state_dict(params['neg_transformation'])
+            self.cross_entropy_loss.load_state_dict(params['cross_entropy_loss'])
             self.optimizer.load_state_dict(params['optimizer'])
             print("\n The loading checkpoint was successful ! \n")
         except Exception as e:
@@ -399,10 +429,11 @@ class Linker(Module):
         torch.save({
             'atom_encoder': self.atom_encoder.state_dict(),
             'position_encoder': self.position_encoder,
-            'transformer': self.transformer,
+            'transformer': self.transformer.state_dict(),
             'linker_encoder': self.linker_encoder.state_dict(),
             'pos_transformation': self.pos_transformation.state_dict(),
             'neg_transformation': self.neg_transformation.state_dict(),
+            'cross_entropy_loss': self.cross_entropy_loss,
             'optimizer': self.optimizer,
         }, path)
         self.to(self.device)
diff --git a/Linker/README.md b/Linker/README.md
new file mode 100644
index 0000000..d6903fa
--- /dev/null
+++ b/Linker/README.md
@@ -0,0 +1,50 @@
+# DeepGrail Linker
+
+This repository contains a Python implementation of a Neural Proof Net using TLGbank data.
+
+This code was designed to work with the [DeepGrail Tagger](https://gitlab.irit.fr/pnria/global-helper/deepgrail_tagger). 
+In this repository we only use the embedding of the word from the tagger and the tags from the dataset, but next step is to use the prediction of the tagger for the linking step.
+ 
+## Usage
+
+### Installation
+Python 3.9.10 **(Warning don't use Python 3.10**+**)**
+Clone the project locally.
+
+### Libraries installation
+
+In a clean python venv do `pip install -r requirements.txt`
+
+### Dataset format
+
+The sentences should be in a column "X", the links with '_x' postfix should be in a column "Y" and the categories in a column "Z".
+For the links each atom_x goes with the one and only other atom_x in the sentence.
+
+## Training
+
+Launch train.py, if you look at it you can give another dataset file and another tagging model.
+
+In train, if you use `checkpoint=True`, the model is automatically saved in a folder: Training_XX-XX_XX-XX. It saves
+after each epoch. Use `tensorboard=True` for log in same folder. (`tensorboard --logdir=logs` for see logs)
+
+## Predicting
+
+For predict on your data you need to load a model (save with this code).
+
+```
+df = read_csv_pgbar(file_path,20)
+texts = df['X'].tolist()
+categories = df['Z'].tolist()
+
+linker = Linker(tagging_model)
+linker.load_weights("your/linker/path")
+
+links = linker.predict(texts[7], categories[7])
+print(links)
+```
+
+The file ```postprocessing.py``` will allow you to draw the prediction. (limited sentence length otherwise it will be confusing) 
+
+## Authors
+
+[de Pourtales Caroline](https://www.linkedin.com/in/caroline-de-pourtales/), [Rabault Julien](https://www.linkedin.com/in/julienrabault)
\ No newline at end of file
diff --git a/Linker/utils_linker.py b/Linker/utils_linker.py
index 3b38a77..15b37f3 100644
--- a/Linker/utils_linker.py
+++ b/Linker/utils_linker.py
@@ -318,7 +318,7 @@ print(" test for get polarities for atoms in categories on ['dr(0,np,n)', 'n', '
 
 # region get atoms and polarities with GOAL
 
-def get_GOAL(max_len_sentence, max_atoms_in_sentence, df_axiom_links):
+def get_GOAL(max_len_sentence, df_axiom_links):
     categories_batch = df_axiom_links["Z"]
     categories_with_goal = df_axiom_links["Y"]
     polarities = find_pos_neg_idexes(categories_batch)
@@ -334,8 +334,7 @@ def get_GOAL(max_len_sentence, max_atoms_in_sentence, df_axiom_links):
         polarities[s_idx] = polarities_goal + polarities[s_idx]  # + False
         num_atoms_batch[s_idx][0] += len(atoms)  # +1
 
-    return atoms_batch, pad_sequence([torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
-                                     max_len=max_atoms_in_sentence, padding_value=0), num_atoms_batch
+    return atoms_batch, polarities, num_atoms_batch
 
 
 df_axiom_links = pd.DataFrame({"Z": [['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)',
@@ -343,14 +342,14 @@ df_axiom_links = pd.DataFrame({"Z": [['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)',
                                "Y": [['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6',
                                       'dl(0,n_6,n_5)', 'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9',
                                       'GOAL:np_7']]})
-print(" test for get GOAL ", get_GOAL(10, 30, df_axiom_links))
+print(" test for get GOAL ", get_GOAL(10, df_axiom_links))
 
 
 # endregion
 
 # region get idx for pos and neg
 
-def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence):
+def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type):
     pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if
                                               bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and
                                               atoms_polarity_batch[s_idx][i]])
@@ -361,7 +360,7 @@ def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_at
     return torch.stack(pos_idx).permute(1, 0, 2)
 
 
-def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type, max_atoms_in_sentence):
+def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type):
     pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if
                                               bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and
                                               not atoms_polarity_batch[s_idx][i]])
@@ -378,25 +377,6 @@ print(" test for cut into pos neg on ['dr(0,s,np)', 's']",
                       [[True, True, False, False,
                         True, False, False, False,
                         False, False,
-                        False, False]]), 10, 50))
-
-# endregion
-
-
-# region style Output
-
-def get_output(links_pred, atoms_batch, atoms_polarity):
-    r"""
-    Parameters:
-        links_pred : atom_vocab_size, batch_size, max atoms in one type
-        atoms_batch : batch_size, max atoms in sentence
-        atoms_polarity : batch_size, max atoms in sentence
-    """
-    sentences_with_links = []
-    for s_idx in range(len(atoms_batch)) :
-        atoms = atoms_batch[s_idx]
-        polarities = atoms_polarity[s_idx]
-
-
+                        False, False]]), 10))
 
 # endregion
\ No newline at end of file
diff --git a/README.md b/README.md
index 5994f14..15a8616 100644
--- a/README.md
+++ b/README.md
@@ -1,15 +1,7 @@
 # DeepGrail
 
-## Usage
+This repository contains a Python implementation of a Neural Proof Net using TLGbank data.
 
-### Installation
-Python 3.9.10 **(Warning don't use Python 3.10**+**)**
-
-Clone the project locally. In a clean python venv do `pip install -r requirements.txt`
-
-## How To use
-
-TODO ...
-
-tensorboard --logdir=logs
+## Authors
 
+[de Pourtales Caroline](https://www.linkedin.com/in/caroline-de-pourtales/), [Rabault Julien](https://www.linkedin.com/in/julienrabault)
\ No newline at end of file
diff --git a/find_config.py b/find_config.py
index 58d95bd..5372528 100644
--- a/find_config.py
+++ b/find_config.py
@@ -1,63 +1,61 @@
-import numpy as np
-import torch
-from Configuration import Configuration
-from Linker import *
-from Linker.atom_map import atom_map_redux
-from Linker.utils_linker import get_atoms_batch, get_GOAL, get_atoms_links_batch, get_axiom_links
-from Supertagger.SuperTagger.SuperTagger import SuperTagger
-from utils import read_csv_pgbar
+import configparser
 import re
 
+import torch
 
-torch.cuda.empty_cache()
-batch_size = int(Configuration.modelTrainingConfig['batch_size'])
-nb_sentences = batch_size * 800
-epochs = int(Configuration.modelTrainingConfig['epoch'])
-file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv'
-df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
-
-atoms_batch, atoms_polarity_batch, num_batch = get_GOAL(290, 875, df_axiom_links)
-
-truth_links_batch = get_axiom_links(324, atoms_polarity_batch, df_axiom_links["Y"])
-print("max idx for link", torch.max(truth_links_batch))
-
-neg_idx = [[[i for i, x in enumerate(sentence) if
-             bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i]))
-             and not atoms_polarity_batch[s_idx][i]]
-            for s_idx, sentence in enumerate(atoms_batch)]
-           for atom_type in list(atom_map_redux.keys())]
-max_atoms_in_on_type = 0
-for atoms_type_batch in neg_idx:
-    for sentence in atoms_type_batch:
-        if len(sentence) > max_atoms_in_on_type:
-            max_atoms_in_on_type = len(sentence)
-print("max atoms of one type in one sentence", max_atoms_in_on_type)
-
-atoms_links_batch = get_atoms_links_batch(df_axiom_links["Y"])
-max_atoms_in_links = 0
-sentence_max = ""
-for sentence in atoms_links_batch:
-    if len(sentence) > max_atoms_in_links:
-        max_atoms_in_links = len(sentence)
-        sentence_max = sentence
-print("max atoms in links", max_atoms_in_links)
-
-max_atoms_in_sentence = 0
-sentence_max = ""
-for sentence in atoms_batch:
-    if len(sentence) > max_atoms_in_sentence:
-        max_atoms_in_sentence = len(sentence)
-        sentence_max = sentence
-print("max atoms in categories", max_atoms_in_sentence)
-
-supertagger = SuperTagger()
-supertagger.load_weights("models/flaubert_super_98_V2_50e.pt")
-sentences_batch = df_axiom_links["X"].str.strip().tolist()
-sentences_tokens, sentences_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
-max_len_sentence = 0
-sentence_max = ""
-for sentence in sentences_tokens:
-    if len(sentence) > max_len_sentence:
-        max_len_sentence = len(sentence)
-        sentence_max = sentence
-print(" max len sentence", max_len_sentence)
+from Linker.atom_map import atom_map_redux
+from Linker.utils_linker import get_GOAL, get_atoms_links_batch, get_atoms_batch
+from SuperTagger.SuperTagger.SuperTagger import SuperTagger
+from utils import read_csv_pgbar, pad_sequence
+
+
+def configurate(dataset, model_tagger, nb_sentences=1000000000):
+    print("#" * 20)
+    print("#" * 20)
+    print("Configuration with dataset\n")
+    config = configparser.ConfigParser()
+    config.read('Configuration/config.ini')
+
+    file_path_axiom_links = dataset
+    df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
+
+    supertagger = SuperTagger()
+    supertagger.load_weights(model_tagger)
+    sentences_batch = df_axiom_links["X"].str.strip().tolist()
+    sentences_tokens, sentences_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
+    max_len_sentence = 0
+    for sentence in sentences_tokens:
+        if len(sentence) > max_len_sentence:
+            max_len_sentence = len(sentence)
+    print("Configure parameter max len sentence to ", max_len_sentence)
+    config.set('DATASET_PARAMS', 'max_len_sentence', str(max_len_sentence))
+
+    atoms_batch, polarities, num_batch = get_GOAL(max_len_sentence, df_axiom_links)
+    max_atoms_in_sentence = 0
+    for sentence in atoms_batch:
+        if len(sentence) > max_atoms_in_sentence:
+            max_atoms_in_sentence = len(sentence)
+    print("Configure parameter max atoms in categories to", max_atoms_in_sentence)
+    config.set('DATASET_PARAMS', 'max_atoms_in_sentence', str(max_atoms_in_sentence))
+
+    atoms_polarity_batch = pad_sequence([torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
+                 max_len=max_atoms_in_sentence, padding_value=0)
+    pos_idx = [[torch.as_tensor([i for i, x in enumerate(sentence) if
+                 bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i]))
+                 and atoms_polarity_batch[s_idx][i]])
+                for s_idx, sentence in enumerate(atoms_batch)]
+               for atom_type in list(atom_map_redux.keys())]
+    max_atoms_in_on_type = 0
+    for atoms_type_batch in pos_idx:
+        for sentence in atoms_type_batch:
+            length = sentence.size(0)
+            if length > max_atoms_in_on_type:
+                max_atoms_in_on_type = length
+    print("Configure parameter max atoms of one type in one sentence to", max_atoms_in_on_type)
+    config.set('DATASET_PARAMS', 'max_atoms_in_one_type', str(max_atoms_in_on_type * 2+2))
+
+    with open('Configuration/config.ini', 'w') as configfile:  # save
+        config.write(configfile)
+
+    print("#" * 20)
+    print("#" * 20)
\ No newline at end of file
diff --git a/postprocessing.py b/postprocessing.py
index 4dbf200..d2d43f0 100644
--- a/postprocessing.py
+++ b/postprocessing.py
@@ -10,6 +10,19 @@ regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
 
 def recursive_linking(links, dot, category, parent_id, word_idx, depth,
                       polarity, compt_plus, compt_neg):
+    r"""
+    recursive linking between atoms inside a category
+    :param links:
+    :param dot:
+    :param category:
+    :param parent_id:
+    :param word_idx:
+    :param depth:
+    :param polarity:
+    :param compt_plus:
+    :param compt_neg:
+    :return:
+    """
     res = [(category == atom_type) for atom_type in atom_map.keys()]
     if True in res:
         polarity = not polarity
@@ -54,43 +67,53 @@ def recursive_linking(links, dot, category, parent_id, word_idx, depth,
             polarities_inside = []
 
         for cat_id in range(len(categories_inside)):
-            recursive_linking(links, dot, categories_inside[cat_id], parent_id, word_idx, depth + 1, polarities_inside[cat_id], compt_plus,
-                                  compt_neg)
+            recursive_linking(links, dot, categories_inside[cat_id], parent_id, word_idx, depth + 1,
+                              polarities_inside[cat_id], compt_plus,
+                              compt_neg)
 
 
 def draw_sentence_output(sentence, categories, links):
+    r"""
+    Drawing the prediction of a sentence when given categories and links predictions
+    :param sentence: list of words
+    :param categories: list of categories
+    :param links: links predicted
+    :return: dot source
+    """
     dot = graphviz.Graph('linking', comment='Axiom linking')
     dot.graph_attr['rankdir'] = 'BT'
-    dot.attr('edge', tailport='n')
-    dot.attr('edge', headport='s')
+    dot.graph_attr['splines'] = 'ortho'
+    dot.graph_attr['ordering'] = 'in'
 
     compt_plus = {'cl_r': 0, 'pp': 0, 'n': 0, 'np': 0, 'cl_y': 0, 'txt': 0, 's': 0}
     compt_neg = {'cl_r': 0, 'pp': 0, 'n': 0, 'np': 0, 'cl_y': 0, 'txt': 0, 's': 0}
+    last_word_id = ""
     for word_idx in range(len(sentence)):
         word = sentence[word_idx]
         word_id = word + "_" + str(word_idx)
         dot.node(word_id, word)
+        if word_idx > 0:
+            dot.edge(last_word_id, word_id, constraint="false", style="invis")
 
         category = categories[word_idx]
         polarity = True
         parent_id = word_id
         recursive_linking(links, dot, category, parent_id, word_idx, 0, polarity, compt_plus, compt_neg)
+        last_word_id = word_id
 
     dot.attr('edge', color='red')
     dot.attr('edge', style='dashed')
-    dot.attr('edge', tailport='n')
-    dot.attr('edge', headport='n')
     for atom_type in list(atom_map_redux.keys()):
         for id in range(compt_plus[atom_type]):
-            atom_plus = atom_type+"_"+str(True)+"_"+str(id)
-            atom_moins = atom_type+"_"+str(False)+"_"+str(id)
+            atom_plus = atom_type + "_" + str(True) + "_" + str(id)
+            atom_moins = atom_type + "_" + str(False) + "_" + str(id)
             dot.edge(atom_plus, atom_moins, constraint="false")
 
     dot.render(format="svg", view=True)
     return dot.source
 
 
-sentence = ["Le", "chat","est","noir","bleu"]
-categories = ["dr(0,s,n)", "dl(0,s,n)","dr(0,dl(0,n,np),n)", "dl(0,np,n)","n"]
-links = np.array([[0,0,0,0],[0,0,0,0],[1,0,2,0],[0,0,0,0],[0,0,0,0],[0,0,0,0],[0,0,0,0]])
-draw_sentence_output(sentence, categories, links)
\ No newline at end of file
+sentence = ["Le", "chat", "est", "noir", "bleu"]
+categories = ["dr(0,s,n)", "dl(0,s,n)", "dr(0,dl(0,n,np),n)", "dl(0,np,n)", "n"]
+links = np.array([[0, 0, 0, 0], [0, 0, 0, 0], [1, 0, 2, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])
+draw_sentence_output(sentence, categories, links)
diff --git a/train.py b/train.py
index fdf3936..0f1d17c 100644
--- a/train.py
+++ b/train.py
@@ -2,16 +2,21 @@ import torch
 from Configuration import Configuration
 from Linker import *
 from utils import read_csv_pgbar
+from find_config import configurate
 
 torch.cuda.empty_cache()
 batch_size = int(Configuration.modelTrainingConfig['batch_size'])
-nb_sentences = batch_size * 800
-epochs = int(Configuration.modelTrainingConfig['epoch'])
-
+nb_sentences = batch_size * 4
 file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv'
+model_tagger = "models/flaubert_super_98_V2_50e.pt"
+configurate(file_path_axiom_links, model_tagger, nb_sentences=nb_sentences)
+
+epochs = int(Configuration.modelTrainingConfig['epoch'])
 df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
 
 print("Linker")
-linker = Linker("models/flaubert_super_98_V2_50e.pt")
+# Load the Linker with trained tagger
+linker = Linker(model_tagger)
 print("\nLinker Training\n")
-linker.train_linker(df_axiom_links, validation_rate=0.1, epochs=epochs, batch_size=batch_size, checkpoint=False, tensorboard=True)
\ No newline at end of file
+linker.train_linker(df_axiom_links, validation_rate=0.05, epochs=1, batch_size=batch_size,
+                    checkpoint=True, tensorboard=True)
diff --git a/utils.py b/utils.py
index 0433510..c4fae14 100644
--- a/utils.py
+++ b/utils.py
@@ -6,6 +6,14 @@ from tqdm import tqdm
 
 
 def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400):
+    r"""
+    Padding sequence for preparation to tensorDataset
+    :param sequences: data to pad
+    :param batch_first: boolean indicating whether the batch are in first dimension
+    :param padding_value: the value for pad
+    :param max_len: the maximum length
+    :return: padding sequences
+    """
     max_size = sequences[0].size()
     trailing_dims = max_size[1:]
     if batch_first:
@@ -26,7 +34,13 @@ def pad_sequence(sequences, batch_first=True, padding_value=0, max_len=400):
 
 
 def read_csv_pgbar(csv_path, nrows=float('inf'), chunksize=500):
-    print("\n" + "#" * 20)
+    r"""
+    Preparing csv dataset
+    :param csv_path:
+    :param nrows:
+    :param chunksize:
+    :return:
+    """
     print("Loading csv...")
 
     rows = sum(1 for _ in open(csv_path, 'r', encoding="utf8")) - 1  # minus the header
@@ -42,7 +56,6 @@ def read_csv_pgbar(csv_path, nrows=float('inf'), chunksize=500):
             bar.update(len(chunk))
 
     df = pd.concat((f for f in chunk_list), axis=0)
-    print("#" * 20)
 
     return df
 
-- 
GitLab