From 86aa9a9c952f347634bfedeb13ddcc87a8900e47 Mon Sep 17 00:00:00 2001
From: Caroline DE POURTALES <cdepourt@montana.irit.fr>
Date: Wed, 13 Jul 2022 15:53:44 +0200
Subject: [PATCH] change init

---
 Linker/AtomTokenizer.py          |  50 ----
 Linker/Linker.py                 | 495 -------------------------------
 Linker/PositionalEncoding.py     |  25 --
 Linker/Sinkhorn.py               |  16 -
 Linker/__init__.py               |   5 -
 Linker/atom_map.py               |  30 --
 Linker/eval.py                   |  34 ---
 Linker/utils_linker.py           | 403 -------------------------
 NeuralProofNet/NeuralProofNet.py |   1 -
 README.md                        |  23 +-
 find_config.py                   |  61 ----
 init.sh                          |   1 +
 train.py                         |   2 -
 13 files changed, 12 insertions(+), 1134 deletions(-)
 delete mode 100644 Linker/AtomTokenizer.py
 delete mode 100644 Linker/Linker.py
 delete mode 100644 Linker/PositionalEncoding.py
 delete mode 100644 Linker/Sinkhorn.py
 delete mode 100644 Linker/__init__.py
 delete mode 100644 Linker/atom_map.py
 delete mode 100644 Linker/eval.py
 delete mode 100644 Linker/utils_linker.py
 delete mode 100644 find_config.py

diff --git a/Linker/AtomTokenizer.py b/Linker/AtomTokenizer.py
deleted file mode 100644
index 1f5c1a1..0000000
--- a/Linker/AtomTokenizer.py
+++ /dev/null
@@ -1,50 +0,0 @@
-import torch
-from utils import pad_sequence
-
-
-class AtomTokenizer(object):
-    r"""
-    Tokenizer for the atoms with padding
-    """
-    def __init__(self, atom_map, max_atoms_in_sentence):
-        self.atom_map = atom_map
-        self.max_atoms_in_sentence = max_atoms_in_sentence
-        self.inverse_atom_map = {v: k for k, v in self.atom_map.items()}
-        self.pad_token = '[PAD]'
-        self.pad_token_id = self.atom_map[self.pad_token]
-
-    def __len__(self):
-        return len(self.atom_map)
-
-    def convert_atoms_to_ids(self, atom):
-        r"""
-        Convert a atom to its id
-        :param atom: atom string
-        :return: atom id
-        """
-        return self.atom_map[str(atom)]
-
-    def convert_sents_to_ids(self, sentences):
-        r"""
-        Convert sentences to ids
-        :param sentences: List of atoms in a sentence
-        :return: List of atoms'ids
-        """
-        return torch.as_tensor([self.convert_atoms_to_ids(atom) for atom in sentences])
-
-    def convert_batchs_to_ids(self, batchs_sentences):
-        r"""
-        Convert a batch of sentences of atoms to the ids
-        :param batchs_sentences: batch of sentences atoms
-        :return: list of list of atoms'ids
-        """
-        return torch.as_tensor(pad_sequence([self.convert_sents_to_ids(sents) for sents in batchs_sentences],
-                                            max_len=self.max_atoms_in_sentence, padding_value=self.pad_token_id))
-
-    def convert_ids_to_atoms(self, ids):
-        r"""
-        Translate id to atom
-        :param ids: atom id
-        :return: atom string
-        """
-        return [self.inverse_atom_map[int(i)] for i in ids]
diff --git a/Linker/Linker.py b/Linker/Linker.py
deleted file mode 100644
index 1006b8f..0000000
--- a/Linker/Linker.py
+++ /dev/null
@@ -1,495 +0,0 @@
-import datetime
-import math
-import os
-import sys
-import time
-
-import torch
-import torch.nn.functional as F
-from torch.nn import Sequential, LayerNorm, Module, Linear, Dropout, TransformerEncoderLayer, TransformerEncoder, \
-    Embedding, GELU
-from torch.optim import AdamW
-from torch.optim.lr_scheduler import StepLR
-from torch.utils.data import TensorDataset, random_split
-from torch.utils.tensorboard import SummaryWriter
-from tqdm import tqdm
-
-from Configuration import Configuration
-from .AtomTokenizer import AtomTokenizer
-from .PositionalEncoding import PositionalEncoding
-from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
-from Linker.atom_map import atom_map, atom_map_redux
-from Linker.eval import measure_accuracy, SinkhornLoss
-from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx, get_neg_idx, get_atoms_batch, \
-    find_pos_neg_idexes, get_num_atoms_batch
-from SuperTagger import SuperTagger
-from utils import pad_sequence
-
-
-def format_time(elapsed):
-    '''
-    Takes a time in seconds and returns a string hh:mm:ss
-    '''
-    # Round to the nearest second.
-    elapsed_rounded = int(round(elapsed))
-
-    # Format as hh:mm:ss
-    return str(datetime.timedelta(seconds=elapsed_rounded))
-
-
-def output_create_dir():
-    """
-    Create le output dir for tensorboard and checkpoint
-    @return: output dir, tensorboard writter
-    """
-    from datetime import datetime
-    outpout_path = 'TensorBoard'
-    training_dir = os.path.join(outpout_path, 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M'))
-    logs_dir = os.path.join(training_dir, 'logs')
-    writer = SummaryWriter(log_dir=logs_dir)
-    return training_dir, writer
-
-
-def generate_square_subsequent_mask(sz):
-    """Generates an upper-triangular matrix of -inf, with zeros on diag."""
-    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)
-
-
-class Linker(Module):
-    def __init__(self, supertagger_path_model):
-        super(Linker, self).__init__()
-
-        config = Configuration.read_config()
-        datasetConfig = config["DATASET_PARAMS"]
-        modelEncoderConfig = config["MODEL_ENCODER"]
-        modelLinkerConfig = config["MODEL_LINKER"]
-        modelTrainingConfig = config["MODEL_TRAINING"]
-
-        # region parameters
-        dim_encoder = int(modelEncoderConfig['dim_encoder'])
-        # atom settings
-        atom_vocab_size = int(datasetConfig['atom_vocab_size'])
-        # Transformer
-        self.nhead = int(modelLinkerConfig['nhead'])
-        self.dim_emb_atom = int(modelLinkerConfig['dim_emb_atom'])
-        self.dim_feedforward_transformer = int(modelLinkerConfig['dim_feedforward_transformer'])
-        self.num_layers = int(modelLinkerConfig['num_layers'])
-        # torch cat
-        dropout = float(modelLinkerConfig['dropout'])
-        self.dim_cat_out = int(modelLinkerConfig['dim_cat_out'])
-        dim_intermediate_FFN = int(modelLinkerConfig['dim_intermediate_FFN'])
-        dim_pre_sinkhorn_transfo = int(modelLinkerConfig['dim_pre_sinkhorn_transfo'])
-        # sinkhorn
-        self.sinkhorn_iters = int(modelLinkerConfig['sinkhorn_iters'])
-        # settings
-        self.max_len_sentence = int(datasetConfig['max_len_sentence'])
-        self.max_atoms_in_sentence = int(datasetConfig['max_atoms_in_sentence'])
-        self.max_atoms_in_one_type = int(datasetConfig['max_atoms_in_one_type'])
-        learning_rate = float(modelTrainingConfig['learning_rate'])
-        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-        # endregion
-
-        # SuperTagger for categories
-        supertagger = SuperTagger()
-        supertagger.load_weights(supertagger_path_model)
-        self.Supertagger = supertagger
-        self.Supertagger.model.to(self.device)
-
-        # Atoms embedding
-        self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
-        self.atom_map_redux = atom_map_redux
-        self.padding_id = atom_map["[PAD]"]
-        self.sub_atoms_type_list = list(atom_map_redux.keys())
-        self.atom_encoder = Embedding(atom_vocab_size, self.dim_emb_atom, padding_idx=self.padding_id)
-        self.atom_encoder.weight.data.uniform_(-0.1, 0.1)
-        self.position_encoder = PositionalEncoding(self.dim_emb_atom, dropout, max_len=self.max_atoms_in_sentence)
-        encoder_layer = TransformerEncoderLayer(d_model=self.dim_emb_atom, nhead=self.nhead,
-                                                dim_feedforward=self.dim_feedforward_transformer, dropout=dropout)
-        self.transformer = TransformerEncoder(encoder_layer, num_layers=self.num_layers)
-
-        # Concatenation with word embedding
-        dim_cat = dim_encoder + self.dim_emb_atom
-        self.linker_encoder = Sequential(
-            Linear(dim_cat, self.dim_cat_out),
-            GELU(),
-            Dropout(dropout),
-            LayerNorm(self.dim_cat_out, eps=1e-8)
-        )
-
-        # Division into positive and negative
-        self.pos_transformation = Sequential(
-            FFN(self.dim_cat_out, dim_intermediate_FFN, dropout, d_out=dim_pre_sinkhorn_transfo),
-            LayerNorm(dim_pre_sinkhorn_transfo, eps=1e-8)
-        )
-        self.neg_transformation = Sequential(
-            FFN(self.dim_cat_out, dim_intermediate_FFN, dropout, d_out=dim_pre_sinkhorn_transfo),
-            LayerNorm(dim_pre_sinkhorn_transfo, eps=1e-8)
-        )
-
-        # Learning
-        self.cross_entropy_loss = SinkhornLoss()
-        self.optimizer = AdamW(self.parameters(),
-                               lr=learning_rate)
-        self.scheduler = StepLR(self.optimizer, step_size=2, gamma=0.5)
-
-        self.to(self.device)
-
-    def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1):
-        r"""
-        Args:
-            batch_size : int
-            df_axiom_links pandas DataFrame
-            validation_rate
-        Returns:
-            the training dataloader and the validation dataloader. They contains the list of atoms, their polarities, the axiom links, the sentences tokenized, sentence mask
-        """
-        print("Start preprocess Data")
-        sentences_batch = df_axiom_links["X"].str.strip().tolist()
-        sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
-
-        atoms_batch, polarities, num_atoms_per_word = get_GOAL(self.max_len_sentence, df_axiom_links)
-        atoms_polarity_batch = pad_sequence(
-            [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
-            max_len=self.max_atoms_in_sentence, padding_value=0)
-        atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
-
-        pos_idx = get_pos_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type)
-        neg_idx = get_neg_idx(atoms_batch, atoms_polarity_batch, self.max_atoms_in_one_type)
-
-        truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, atoms_polarity_batch,
-                                            df_axiom_links["Y"])
-        truth_links_batch = truth_links_batch.permute(1, 0, 2)
-
-        # Construction tensor dataset
-        dataset = TensorDataset(num_atoms_per_word, atoms_batch_tokenized, pos_idx, neg_idx, truth_links_batch,
-                                sentences_tokens, sentences_mask)
-
-        if validation_rate > 0.0:
-            train_size = int(0.9 * len(dataset))
-            val_size = len(dataset) - train_size
-            train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
-            validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
-        else:
-            validation_dataloader = None
-            train_dataset = dataset
-
-        training_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
-        print("End preprocess Data")
-        return training_dataloader, validation_dataloader
-
-    def forward(self, batch_num_atoms_per_word, batch_atoms, batch_pos_idx, batch_neg_idx, sents_embedding):
-        r"""
-        Args:
-            batch_num_atoms_per_word : (batch_size, len_sentence) flattened categories
-            batch_atoms : atoms tok
-            batch_pos_idx : (batch_size, atom_vocab_size, max atom in one cat) flattened categories polarities
-            batch_neg_idx : (batch_size, atom_vocab_size, max atom in one cat) flattened categories polarities
-            sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
-        Returns:
-            link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat) log probabilities
-        """
-        # repeat embedding word for each atom in word with a +1 for sep
-        sents_embedding_repeat = pad_sequence(
-            [torch.repeat_interleave(input=sents_embedding[i], repeats=batch_num_atoms_per_word[i], dim=0)
-             for i in range(len(sents_embedding))], max_len=self.max_atoms_in_sentence, padding_value=0)
-
-        # atoms emebedding
-        src_key_padding_mask = torch.eq(batch_atoms, self.padding_id)
-        src_mask = generate_square_subsequent_mask(self.max_atoms_in_sentence).to(self.device)
-        atoms_embedding = self.atom_encoder(batch_atoms) * math.sqrt(self.dim_emb_atom)
-        atoms_embedding = self.position_encoder(atoms_embedding)
-        atoms_embedding = atoms_embedding.permute(1, 0, 2)
-        atoms_embedding = self.transformer(atoms_embedding, src_mask,
-                                           src_key_padding_mask=src_key_padding_mask)
-        atoms_embedding = atoms_embedding.permute(1, 0, 2)
-
-        # cat
-        atoms_sentences_encoding = torch.cat([sents_embedding_repeat, atoms_embedding], dim=2)
-        atoms_encoding = self.linker_encoder(atoms_sentences_encoding)
-
-        # linking per atom type
-        batch_size, atom_vocab_size, _ = batch_pos_idx.shape
-        link_weights = torch.zeros(atom_vocab_size, batch_size, self.max_atoms_in_one_type // 2,
-                                   self.max_atoms_in_one_type // 2, device=self.device)
-        for atom_type in list(atom_map_redux.keys()):
-            pos_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_pos_idx, atom_type)
-            neg_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_neg_idx, atom_type)
-
-            pos_encoding = self.pos_transformation(pos_encoding)
-            neg_encoding = self.neg_transformation(neg_encoding)
-
-            weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
-            link_weights[self.atom_map_redux[atom_type]] = sinkhorn(weights, iters=self.sinkhorn_iters)
-
-        return F.log_softmax(link_weights, dim=3)
-
-    def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20,
-                     batch_size=32, checkpoint=True, tensorboard=False):
-        r"""
-        Args:
-            df_axiom_links : pandas dataFrame containing the atoms anoted with _i
-            validation_rate : float
-            epochs : int
-            batch_size : int
-            checkpoint : boolean
-            tensorboard : boolean
-        Returns:
-            Final accuracy and final loss
-        """
-        training_dataloader, validation_dataloader = self.__preprocess_data(batch_size, df_axiom_links,
-                                                                            validation_rate)
-        if checkpoint or tensorboard:
-            checkpoint_dir, writer = output_create_dir()
-
-        for epoch_i in range(epochs):
-            print("")
-            print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
-            print('Training...')
-            avg_train_loss, avg_accuracy_train, training_time = self.train_epoch(training_dataloader)
-
-            print("")
-            print(f'Epoch: {epoch_i + 1:02} | Epoch Time: {training_time}')
-            print(f'\tTrain Loss: {avg_train_loss:.3f} | Train Acc: {avg_accuracy_train * 100:.2f}%')
-
-            if validation_rate > 0.0:
-                loss_test, accuracy_test = self.eval_epoch(validation_dataloader)
-                print(f'\tVal Loss: {loss_test:.3f} | Val Acc: {accuracy_test * 100:.2f}%')
-
-            if checkpoint:
-                self.__checkpoint_save(
-                    path=os.path.join("Output", 'linker' + datetime.datetime.today().strftime('%d-%m_%H-%M') + '.pt'))
-
-            if tensorboard:
-                writer.add_scalars(f'Accuracy', {
-                    'Train': avg_accuracy_train}, epoch_i)
-                writer.add_scalars(f'Loss', {
-                    'Train': avg_train_loss}, epoch_i)
-                if validation_rate > 0.0:
-                    writer.add_scalars(f'Accuracy', {
-                        'Validation': accuracy_test}, epoch_i)
-                    writer.add_scalars(f'Loss', {
-                        'Validation': loss_test}, epoch_i)
-
-            print('\n')
-
-    def train_epoch(self, training_dataloader):
-        r""" Train epoch
-
-        Args:
-            training_dataloader : DataLoader from torch , contains atoms, polarities, axiom_links, sents_tokenized, sents_masks
-        Returns:
-             accuracy on validation set
-             loss on train set
-        """
-        self.train()
-
-        # Reset the total loss for this epoch.
-        epoch_loss = 0
-        accuracy_train = 0
-        t0 = time.time()
-
-        # For each batch of training data...
-        with tqdm(training_dataloader, unit="batch") as tepoch:
-            for batch in tepoch:
-                # Unpack this training batch from our dataloader
-                batch_num_atoms = batch[0].to(self.device)
-                batch_atoms_tok = batch[1].to(self.device)
-                batch_pos_idx = batch[2].to(self.device)
-                batch_neg_idx = batch[3].to(self.device)
-                batch_true_links = batch[4].to(self.device)
-                batch_sentences_tokens = batch[5].to(self.device)
-                batch_sentences_mask = batch[6].to(self.device)
-
-                self.optimizer.zero_grad()
-
-                # get sentence embedding from BERT which is already trained
-                output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
-
-                # Run the Linker on the atoms
-                logits_predictions = self(batch_num_atoms, batch_atoms_tok, batch_pos_idx, batch_neg_idx,
-                                          output['word_embeding'])
-
-                linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links, self.max_atoms_in_one_type)
-                # Perform a backward pass to calculate the gradients.
-                epoch_loss += float(linker_loss)
-                linker_loss.backward()
-
-                # This is to help prevent the "exploding gradients" problem.
-                # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0, norm_type=2)
-
-                # Update parameters and take a step using the computed gradient.
-                self.optimizer.step()
-
-                pred_axiom_links = torch.argmax(logits_predictions, dim=3)
-                accuracy_train += measure_accuracy(batch_true_links, pred_axiom_links, self.max_atoms_in_one_type)
-
-        self.scheduler.step()
-
-        # Measure how long this epoch took.
-        training_time = format_time(time.time() - t0)
-        avg_train_loss = epoch_loss / len(training_dataloader)
-        avg_accuracy_train = accuracy_train / len(training_dataloader)
-
-        return avg_train_loss, avg_accuracy_train, training_time
-
-    def eval_batch(self, batch):
-        batch_num_atoms = batch[0].to(self.device)
-        batch_atoms_tok = batch[1].to(self.device)
-        batch_pos_idx = batch[2].to(self.device)
-        batch_neg_idx = batch[3].to(self.device)
-        batch_true_links = batch[4].to(self.device)
-        batch_sentences_tokens = batch[5].to(self.device)
-        batch_sentences_mask = batch[6].to(self.device)
-
-        output = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
-
-        logits_predictions = self(batch_num_atoms, batch_atoms_tok, batch_pos_idx, batch_neg_idx, output[
-            'word_embeding'])  # atom_vocab, batch_size, max atoms in one type, max atoms in one type
-        axiom_links_pred = torch.argmax(logits_predictions, dim=3)  # atom_vocab, batch_size, max atoms in one type
-
-        print('\n')
-        print("Les vrais liens de la catégorie n : ", batch_true_links[1][2][:100])
-        print("Les prédictions : ", axiom_links_pred[2][1][:100])
-        print('\n')
-
-        accuracy = measure_accuracy(batch_true_links, axiom_links_pred, self.max_atoms_in_one_type)
-        loss = self.cross_entropy_loss(logits_predictions, batch_true_links, self.max_atoms_in_one_type)
-
-        return loss, accuracy
-
-    def eval_epoch(self, dataloader):
-        r"""Average the evaluation of all the batch.
-
-        Args:
-            dataloader: contains all the batch which contain the tokenized sentences, their masks and the true symbols
-        """
-        self.eval()
-        accuracy_average = 0
-        loss_average = 0
-        with torch.no_grad():
-            for step, batch in enumerate(dataloader):
-                loss, accuracy = self.eval_batch(batch)
-                accuracy_average += accuracy
-                loss_average += float(loss)
-
-        return loss_average / len(dataloader), accuracy_average / len(dataloader)
-
-    def predict_with_categories(self, sentence, categories):
-        r""" Predict the links from a sentence and its categories
-
-        Args :
-            sentence : list of words composing the sentence
-            categories : list of categories (tags) of each word
-        """
-        self.eval()
-        with torch.no_grad():
-            self.cpu()
-            self.device = torch.device("cpu")
-            sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors([sentence])
-            nb_sentence, len_sentence = sentences_tokens.shape
-
-            atoms = get_atoms_batch([categories])
-            atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms)
-
-            polarities = find_pos_neg_idexes([categories])
-            polarities = pad_sequence(
-                [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
-                max_len=self.max_atoms_in_sentence, padding_value=0)
-
-            num_atoms_per_word = get_num_atoms_batch([categories], len_sentence)
-
-            pos_idx = get_pos_idx(atoms, polarities, self.max_atoms_in_one_type)
-            neg_idx = get_neg_idx(atoms, polarities, self.max_atoms_in_one_type)
-
-            output = self.Supertagger.forward(sentences_tokens, sentences_mask)
-
-            logits_predictions = self(num_atoms_per_word, atoms_tokenized, pos_idx, neg_idx, output['word_embeding'])
-            axiom_links_pred = torch.argmax(logits_predictions, dim=3)
-
-        return axiom_links_pred
-
-    def predict_without_categories(self, sentence):
-        r""" Predict the links from a sentence
-
-        Args :
-            sentence : list of words composing the sentence
-        """
-        self.eval()
-        with torch.no_grad():
-            self.cpu()
-            self.device = torch.device("cpu")
-            sentences_tokens, sentences_mask = self.Supertagger.sent_tokenizer.fit_transform_tensors([sentence])
-            nb_sentence, len_sentence = sentences_tokens.shape
-
-            hidden_state, categories = self.Supertagger.predict(sentence)
-
-            output = self.Supertagger.forward(sentences_tokens, sentences_mask)
-            atoms = get_atoms_batch(categories)
-            atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms)
-
-            polarities = find_pos_neg_idexes(categories)
-            polarities = pad_sequence(
-                [torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
-                max_len=self.max_atoms_in_sentence, padding_value=0)
-
-            num_atoms_per_word = get_num_atoms_batch(categories, len_sentence)
-
-            pos_idx = get_pos_idx(atoms, polarities, self.max_atoms_in_one_type)
-            neg_idx = get_neg_idx(atoms, polarities, self.max_atoms_in_one_type)
-
-            logits_predictions = self(num_atoms_per_word, atoms_tokenized, pos_idx, neg_idx, output['word_embeding'])
-            axiom_links_pred = torch.argmax(logits_predictions, dim=3)
-
-        return axiom_links_pred
-
-    def load_weights(self, model_file):
-        print("#" * 15)
-        try:
-            params = torch.load(model_file, map_location=self.device)
-            self.atom_encoder.load_state_dict(params['atom_encoder'])
-            self.position_encoder.load_state_dict(params['position_encoder'])
-            self.transformer.load_state_dict(params['transformer'])
-            self.linker_encoder.load_state_dict(params['linker_encoder'])
-            self.pos_transformation.load_state_dict(params['pos_transformation'])
-            self.neg_transformation.load_state_dict(params['neg_transformation'])
-            self.cross_entropy_loss.load_state_dict(params['cross_entropy_loss'])
-            self.optimizer.load_state_dict(params['optimizer'])
-            print("\n The loading checkpoint was successful ! \n")
-        except Exception as e:
-            print("\n/!\ Can't load checkpoint model /!\ because :\n\n " + str(e), file=sys.stderr)
-            raise e
-        print("#" * 15)
-
-    def __checkpoint_save(self, path='/linker.pt'):
-        """
-        @param path:
-        """
-        self.cpu()
-
-        torch.save({
-            'atom_encoder': self.atom_encoder.state_dict(),
-            'position_encoder': self.position_encoder,
-            'transformer': self.transformer.state_dict(),
-            'linker_encoder': self.linker_encoder.state_dict(),
-            'pos_transformation': self.pos_transformation.state_dict(),
-            'neg_transformation': self.neg_transformation.state_dict(),
-            'cross_entropy_loss': self.cross_entropy_loss,
-            'optimizer': self.optimizer,
-        }, path)
-        self.to(self.device)
-
-    def make_sinkhorn_inputs(self, bsd_tensor, positional_ids, atom_type):
-        """
-        :param bsd_tensor:
-            Tensor of shape batch size \times sequence length \times feature dimensionality.
-        :param positional_ids:
-            A List of batch_size elements, each being a List of num_atoms LongTensors.
-            Each LongTensor in positional_ids[b][a] indexes the location of atoms of type a in sentence b.
-        :param atom_type:
-        :return:
-        """
-
-        return torch.stack([torch.stack([bsd_tensor.select(0, index=i).select(0, index=int(atom)).to(self.device)
-                                         if atom != -1 else torch.zeros(self.dim_cat_out, device=self.device)
-                                         for atom in sentence])
-                            for i, sentence in enumerate(positional_ids[:, self.atom_map_redux[atom_type], :])])
diff --git a/Linker/PositionalEncoding.py b/Linker/PositionalEncoding.py
deleted file mode 100644
index 19e1b96..0000000
--- a/Linker/PositionalEncoding.py
+++ /dev/null
@@ -1,25 +0,0 @@
-import torch
-from torch import nn
-import math
-
-
-class PositionalEncoding(nn.Module):
-
-    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
-        super().__init__()
-        self.dropout = nn.Dropout(p=dropout)
-
-        position = torch.arange(max_len).unsqueeze(1)
-        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
-        pe = torch.zeros(1, max_len, d_model)
-        pe[0, :, 0::2] = torch.sin(position * div_term)
-        pe[0, :, 1::2] = torch.cos(position * div_term)
-        self.register_buffer('pe', pe)
-
-    def forward(self, x):
-        """
-        Args:
-            x: Tensor, shape [batch_size, seq_len, mbedding_dim]
-        """
-        x = x + self.pe[:, :x.size(1)]
-        return self.dropout(x)
diff --git a/Linker/Sinkhorn.py b/Linker/Sinkhorn.py
deleted file mode 100644
index 9cf9b45..0000000
--- a/Linker/Sinkhorn.py
+++ /dev/null
@@ -1,16 +0,0 @@
-from torch import logsumexp
-
-
-def norm(x, dim):
-    return x - logsumexp(x, dim=dim, keepdim=True)
-
-
-def sinkhorn_step(x):
-    return norm(norm(x, dim=1), dim=2)
-
-
-def sinkhorn_fn_no_exp(x, tau=1, iters=3):
-    x = x / tau
-    for _ in range(iters):
-        x = sinkhorn_step(x)
-    return x
diff --git a/Linker/__init__.py b/Linker/__init__.py
deleted file mode 100644
index 0983f0b..0000000
--- a/Linker/__init__.py
+++ /dev/null
@@ -1,5 +0,0 @@
-from .Linker import Linker
-from .atom_map import atom_map
-from .AtomTokenizer import AtomTokenizer
-from .PositionalEncoding import PositionalEncoding
-from .Sinkhorn import *
\ No newline at end of file
diff --git a/Linker/atom_map.py b/Linker/atom_map.py
deleted file mode 100644
index 0df2646..0000000
--- a/Linker/atom_map.py
+++ /dev/null
@@ -1,30 +0,0 @@
-atom_map = \
-    {'cl_r': 0,
-     "pp": 1,
-     'n': 2,
-     's_ppres': 3,
-     's_whq': 4,
-     's_q': 5,
-     'np': 6,
-     's_inf': 7,
-     's_pass': 8,
-     'pp_a': 9,
-     'pp_par': 10,
-     'pp_de': 11,
-     'cl_y': 12,
-     'txt': 13,
-     's': 14,
-     's_ppart': 15,
-     "[SEP]":16,
-     '[PAD]': 17
-     }
-
-atom_map_redux = {
-    'cl_r': 0,
-    'pp': 1,
-    'n': 2,
-    'np': 3,
-    'cl_y': 4,
-    'txt': 5,
-    's': 6
-}
diff --git a/Linker/eval.py b/Linker/eval.py
deleted file mode 100644
index 086f2a9..0000000
--- a/Linker/eval.py
+++ /dev/null
@@ -1,34 +0,0 @@
-import torch
-from torch.nn import Module
-from torch.nn.functional import nll_loss
-from Linker.atom_map import atom_map, atom_map_redux
-
-
-class SinkhornLoss(Module):
-    r"""
-    Loss for the linker
-    """
-    def __init__(self):
-        super(SinkhornLoss, self).__init__()
-
-    def forward(self, predictions, truths, max_atoms_in_one_type):
-        return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean', ignore_index=-1)
-                   for link, perm in zip(predictions, truths.permute(1, 0, 2)))
-
-
-def measure_accuracy(batch_true_links, axiom_links_pred, max_atoms_in_one_type):
-    r"""
-    batch_true_links : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms
-    axiom_links_pred : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms
-    """
-    padding = -1
-    batch_true_links = batch_true_links.permute(1, 0, 2)
-    correct_links = torch.ones(axiom_links_pred.size())
-    correct_links[axiom_links_pred != batch_true_links] = 0
-    correct_links[batch_true_links == padding] = 1
-    num_correct_links = correct_links.sum().item()
-    num_masked_atoms = len(batch_true_links[batch_true_links == padding])
-
-    # diviser par nombre de links
-    return (num_correct_links - num_masked_atoms) / (
-                axiom_links_pred.size()[0] * axiom_links_pred.size()[1] * axiom_links_pred.size()[2] - num_masked_atoms)
diff --git a/Linker/utils_linker.py b/Linker/utils_linker.py
deleted file mode 100644
index 199351e..0000000
--- a/Linker/utils_linker.py
+++ /dev/null
@@ -1,403 +0,0 @@
-import re
-
-import pandas as pd
-import regex
-import torch
-from torch.nn import Sequential, Linear, Dropout, GELU
-from torch.nn import Module
-
-from Linker.atom_map import atom_map, atom_map_redux
-from utils import pad_sequence
-
-
-class FFN(Module):
-    "Implements FFN equation."
-
-    def __init__(self, d_model, d_ff, dropout=0.1, d_out=None):
-        super(FFN, self).__init__()
-        self.ffn = Sequential(
-            Linear(d_model, d_ff, bias=False),
-            GELU(),
-            Dropout(dropout),
-            Linear(d_ff, d_out if d_out is not None else d_model, bias=False)
-        )
-
-    def forward(self, x):
-        return self.ffn(x)
-
-
-################################ Regex ########################################
-regex_categories_axiom_links = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
-regex_categories = r'\w+\(\d+,(?:((?R))|(\w+))*,?(?:((?R))|(\w+))*\)'
-
-
-# region get true axiom links
-def get_axiom_links(max_atoms_in_one_type, atoms_polarity, batch_axiom_links):
-    r"""
-    Args:
-        max_atoms_in_one_type : configuration
-        atoms_polarity : (batch_size, max_atoms_in_sentence)
-        batch_axiom_links : (batch_size, len_sentence) categories with the _i which allows linking atoms
-    Returns:
-        batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms
-    """
-    atoms_batch = get_atoms_links_batch(batch_axiom_links)
-    linking_plus_to_minus_all_types = []
-    for atom_type in list(atom_map_redux.keys()):
-        # filtrer sur atom_batch que ce type puis filtrer avec les indices sur atom polarity
-        l_polarity_plus = [[x for i, x in enumerate(atoms_batch[s_idx]) if atoms_polarity[s_idx, i]
-                            and bool(re.match(r"" + atom_type + "(_{1}\w+)?_\d+\Z", atoms_batch[s_idx][i]))] for s_idx
-                           in range(len(atoms_batch))]
-        l_polarity_minus = [[x for i, x in enumerate(atoms_batch[s_idx]) if not atoms_polarity[s_idx, i]
-                             and bool(re.match(r"" + atom_type + "(_{1}\w+)?_\d+\Z", atoms_batch[s_idx][i]))] for s_idx
-                            in range(len(atoms_batch))]
-
-        linking_plus_to_minus = pad_sequence(
-            [torch.as_tensor(
-                [l_polarity_minus[s_idx].index(x) if x in l_polarity_minus[s_idx] else -1
-                 for i, x in enumerate(l_polarity_plus[s_idx])], dtype=torch.long)
-                for s_idx in range(len(atoms_batch))], max_len=max_atoms_in_one_type // 2,
-            padding_value=-1)
-
-        linking_plus_to_minus_all_types.append(linking_plus_to_minus)
-
-    return torch.stack(linking_plus_to_minus_all_types)
-
-
-def category_to_atoms_axiom_links(category, categories_to_atoms):
-    r"""
-    Args:
-        category : str of kind AtomCat | CategoryCat(dr or dl)
-        categories_to_atoms : recursive list
-    Returns :
-        List of atoms inside the category in prefix order
-    """
-    res = [bool(re.match(r'' + atom_type + "_\d+", category)) for atom_type in atom_map.keys()]
-    if category.startswith("GOAL:"):
-        word, cat = category.split(':')
-        return category_to_atoms_axiom_links(cat, categories_to_atoms)
-    elif True in res:
-        return [category]
-    else:
-        category_cut = regex.match(regex_categories_axiom_links, category).groups()
-        category_cut = [cat for cat in category_cut if cat is not None]
-        for cat in category_cut:
-            categories_to_atoms += category_to_atoms_axiom_links(cat, [])
-        return categories_to_atoms
-
-
-def get_atoms_links_batch(category_batch):
-    r"""
-    Args:
-        category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
-    Returns :
-     (batch_size, max_atoms_in_sentence) flattened categories in prefix order
-    """
-    batch = []
-    for sentence in category_batch:
-        categories_to_atoms = []
-        for category in sentence:
-            if category != "let" and not category.startswith("GOAL:"):
-                categories_to_atoms += category_to_atoms_axiom_links(category, [])
-                categories_to_atoms.append("[SEP]")
-            elif category.startswith("GOAL:"):
-                categories_to_atoms = category_to_atoms_axiom_links(category, []) + categories_to_atoms
-        batch.append(categories_to_atoms)
-    return batch
-
-
-print("test to create links ",
-      get_axiom_links(20, torch.stack([torch.as_tensor(
-          [True, False, True, False, False, False, True, False, True, False,
-           False, True, False, False, False, True, False, False, True, False,
-           True, False, False, True, False, False, False, False, False, False])]),
-                      [['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6', 'dl(0,n_6,n_5)',
-                        'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9', 'GOAL:np_7']]))
-
-
-# endregion
-
-# region get atoms in sentence
-
-def category_to_atoms(category, categories_to_atoms):
-    r"""
-    Args:
-        category : str of kind AtomCat | CategoryCat(dr or dl)
-        categories_to_atoms : recursive list
-    Returns:
-        List of atoms inside the category in prefix order
-    """
-    res = [(category == atom_type) for atom_type in atom_map.keys()]
-    if category.startswith("GOAL:"):
-        word, cat = category.split(':')
-        return category_to_atoms(cat, categories_to_atoms)
-    elif True in res:
-        return [category]
-    else:
-        category_cut = regex.match(regex_categories, category).groups()
-        category_cut = [cat for cat in category_cut if cat is not None]
-        for cat in category_cut:
-            categories_to_atoms += category_to_atoms(cat, [])
-        return categories_to_atoms
-
-
-def get_atoms_batch(category_batch):
-    r"""
-    Args:
-        category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
-    Returns:
-     (batch_size, max_atoms_in_sentence) flattened categories in prefix order
-    """
-    batch = []
-    for sentence in category_batch:
-        categories_to_atoms = []
-        for category in sentence:
-            if category != "let":
-                categories_to_atoms += category_to_atoms(category, [])
-                categories_to_atoms.append("[SEP]")
-        batch.append(categories_to_atoms)
-    return batch
-
-
-print(" test for get atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'let']",
-      get_atoms_batch([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'let']]))
-
-
-# endregion
-
-# region calculate num atoms per category
-
-def category_to_num_atoms(category, categories_to_atoms):
-    r"""
-    Args:
-        category : str of kind AtomCat | CategoryCat(dr or dl)
-        categories_to_atoms : recursive int
-    Returns:
-        List of atoms inside the category in prefix order
-    """
-    res = [(category == atom_type) for atom_type in atom_map.keys()]
-    if category.startswith("GOAL:"):
-        word, cat = category.split(':')
-        return category_to_num_atoms(cat, 0)
-    elif category == "let":
-        return 0
-    elif True in res:
-        return 1
-    else:
-        category_cut = regex.match(regex_categories, category).groups()
-        category_cut = [cat for cat in category_cut if cat is not None]
-        for cat in category_cut:
-            categories_to_atoms += category_to_num_atoms(cat, 0)
-        return categories_to_atoms
-
-
-def get_num_atoms_batch(category_batch, max_len_sentence):
-    r"""
-    Args:
-        category_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
-        max_len_sentence : max_len_sentence parameter
-    Returns:
-     (batch_size, max_atoms_in_sentence) flattened categories in prefix order
-    """
-    batch = []
-    for sentence in category_batch:
-        num_atoms_sentence = [0]
-        for category in sentence:
-            num_atoms_in_word = category_to_num_atoms(category, 0)
-            # add 1 because for word we have SEP at the end
-            if category != "let":
-                num_atoms_in_word += 1
-            num_atoms_sentence.append(num_atoms_in_word)
-        batch.append(torch.as_tensor(num_atoms_sentence))
-    return pad_sequence(batch, max_len=max_len_sentence, padding_value=0)
-
-
-print(" test for get number of atoms in categories on ['dr(0,s,np)', 'let']",
-      get_num_atoms_batch([["dr(0,s,np)", "let"]], 10))
-
-
-# endregion
-
-# region get polarity
-
-def category_to_atoms_polarity(category, polarity):
-    r"""
-    Args:
-        category : str of kind AtomCat | CategoryCat(dr or dl)
-        polarity : polarity according to recursivity
-    Returns:
-        Boolean Tensor of shape max_symbols_in_word, containing 1 for pos indexes and 0 for neg indexes
-    """
-    category_to_polarity = []
-    res = [(category == atom_type) for atom_type in atom_map.keys()]
-    # mot final
-    if category.startswith("GOAL:"):
-        word, cat = category.split(':')
-        res = [bool(re.match(r'' + atom_type, cat)) for atom_type in atom_map.keys()]
-        if True in res:
-            category_to_polarity.append(True)
-        else:
-            category_to_polarity += category_to_atoms_polarity(cat, True)
-    # le mot a une category atomique
-    elif True in res:
-        category_to_polarity.append(not polarity)
-    # sinon c'est une formule longue
-    else:
-        # dr = /
-        if category.startswith("dr"):
-            category_cut = regex.match(regex_categories, category).groups()
-            category_cut = [cat for cat in category_cut if cat is not None]
-            left_side, right_side = category_cut[0], category_cut[1]
-            # for the left side
-            category_to_polarity += category_to_atoms_polarity(left_side, polarity)
-            # for the right side : change polarity for next right formula
-            category_to_polarity += category_to_atoms_polarity(right_side, not polarity)
-
-        # dl = \
-        elif category.startswith("dl"):
-            category_cut = regex.match(regex_categories, category).groups()
-            category_cut = [cat for cat in category_cut if cat is not None]
-            left_side, right_side = category_cut[0], category_cut[1]
-            # for the left side
-            category_to_polarity += category_to_atoms_polarity(left_side, not polarity)
-            # for the right side
-            category_to_polarity += category_to_atoms_polarity(right_side, polarity)
-
-        # p
-        elif category.startswith("p"):
-            category_cut = regex.match(regex_categories, category).groups()
-            category_cut = [cat for cat in category_cut if cat is not None]
-            left_side, right_side = category_cut[0], category_cut[1]
-            # for the left side
-            category_to_polarity += category_to_atoms_polarity(left_side, not polarity)
-            # for the right side
-            category_to_polarity += category_to_atoms_polarity(right_side, polarity)
-
-        # box
-        elif category.startswith("box"):
-            category_cut = regex.match(regex_categories, category).groups()
-            category_cut = [cat for cat in category_cut if cat is not None]
-            category_to_polarity += category_to_atoms_polarity(category_cut[0], polarity)
-
-        # dia
-        elif category.startswith("dia"):
-            category_cut = regex.match(regex_categories, category).groups()
-            category_cut = [cat for cat in category_cut if cat is not None]
-            category_to_polarity += category_to_atoms_polarity(category_cut[0], polarity)
-
-    return category_to_polarity
-
-
-def find_pos_neg_idexes(atoms_batch):
-    r"""
-    Args:
-        atoms_batch : (batch_size, max_atoms_in_sentence) flattened categories in prefix order
-    Returns:
-        (batch_size, max_atoms_in_sentence) flattened categories'polarities in prefix order
-    """
-    list_batch = []
-    for sentence in atoms_batch:
-        list_atoms = []
-        for category in sentence:
-            if category == "let":
-                pass
-            else:
-                for at in category_to_atoms_polarity(category, True):
-                    list_atoms.append(at)
-                list_atoms.append(False)
-        list_batch.append(list_atoms)
-    return list_batch
-
-
-print(" test for get polarities for atoms in categories on ['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)', 'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np'] \n",
-    find_pos_neg_idexes([['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)',
-                          'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']]))
-
-
-# endregion
-
-# region get atoms and polarities with GOAL
-
-def get_GOAL(max_len_sentence, df_axiom_links):
-    categories_batch = df_axiom_links["Z"]
-    categories_with_goal = df_axiom_links["Y"]
-    polarities = find_pos_neg_idexes(categories_batch)
-    atoms_batch = get_atoms_batch(categories_batch)
-    num_atoms_batch = get_num_atoms_batch(categories_batch, max_len_sentence)
-    for s_idx in range(len(atoms_batch)):
-        goal = categories_with_goal[s_idx][-1]
-        polarities_goal = category_to_atoms_polarity(goal, True)
-        goal = re.search(r"(\w+)_\d+", goal).groups()[0]
-        atoms = category_to_atoms(goal, [])
-
-        atoms_batch[s_idx] = atoms + atoms_batch[s_idx]  # + ["[SEP]"]
-        polarities[s_idx] = polarities_goal + polarities[s_idx]  # + False
-        num_atoms_batch[s_idx][0] += len(atoms)  # +1
-
-    return atoms_batch, polarities, num_atoms_batch
-
-
-df_axiom_links = pd.DataFrame({"Z": [['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)',
-                                      'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']],
-                               "Y": [['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6',
-                                      'dl(0,n_6,n_5)', 'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9',
-                                      'GOAL:np_7']]})
-print(" test for get GOAL ", get_GOAL(10, df_axiom_links))
-
-
-# endregion
-
-# region get atoms and polarities after tagger
-
-def get_info_for_tagger(max_len_sentence, pred_categories):
-    categories_batch = pred_categories
-    polarities = find_pos_neg_idexes(categories_batch)
-    atoms_batch = get_atoms_batch(categories_batch)
-    num_atoms_batch = get_num_atoms_batch(categories_batch, max_len_sentence)
-
-    return atoms_batch, polarities, num_atoms_batch
-
-
-df_axiom_links = pd.DataFrame({"Z": [['dr(0,np,n)', 'n', 'dr(0,dl(0,np,np),np)', 'dr(0,np,n)', 'n', 'dl(0,n,n)',
-                                      'dr(0,dl(0,np,np),np)', 'dr(0,np,np)', 'np']],
-                               "Y": [['dr(0,np_1,n_2)', 'n_2', 'dr(0,dl(0,np_1,np_3),np_4)', 'dr(0,np_4,n_5)', 'n_6',
-                                      'dl(0,n_6,n_5)', 'dr(0,dl(0,np_3,np_7),np_8)', 'dr(0,np_8,np_9)', 'np_9',
-                                      'GOAL:np_7']]})
-print(" test for get GOAL ", get_GOAL(10, df_axiom_links))
-
-
-# endregion
-
-# region get idx for pos and neg
-
-def get_pos_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type):
-    pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if
-                                              bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and
-                                              atoms_polarity_batch[s_idx][i]])
-                             for s_idx, sentence in enumerate(atoms_batch)],
-                            max_len=max_atoms_in_one_type // 2, padding_value=-1)
-               for atom_type in list(atom_map_redux.keys())]
-
-    return torch.stack(pos_idx).permute(1, 0, 2)
-
-
-def get_neg_idx(atoms_batch, atoms_polarity_batch, max_atoms_in_one_type):
-    pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if
-                                              bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i])) and
-                                              not atoms_polarity_batch[s_idx][i]])
-                             for s_idx, sentence in enumerate(atoms_batch)],
-                            max_len=max_atoms_in_one_type // 2, padding_value=-1)
-               for atom_type in list(atom_map_redux.keys())]
-
-    return torch.stack(pos_idx).permute(1, 0, 2)
-
-
-print(" test for cut into pos neg on ['dr(0,s,np)', 's']",
-      get_neg_idx([['s', 's', 'np', 's', 'np', '[SEP]', 's', '[SEP]']],
-                  torch.as_tensor(
-                      [[True, True, False, False,
-                        True, False, False, False,
-                        False, False,
-                        False, False]]), 10))
-
-# endregion
\ No newline at end of file
diff --git a/NeuralProofNet/NeuralProofNet.py b/NeuralProofNet/NeuralProofNet.py
index 92c783b..0558dcd 100644
--- a/NeuralProofNet/NeuralProofNet.py
+++ b/NeuralProofNet/NeuralProofNet.py
@@ -16,7 +16,6 @@ from Linker import Linker
 from Linker.eval import measure_accuracy, SinkhornLoss
 from Linker.utils_linker import get_axiom_links, get_GOAL, get_pos_idx, get_num_atoms_batch, get_neg_idx, \
     get_info_for_tagger
-from find_config import configurate
 from utils import pad_sequence
 
 
diff --git a/README.md b/README.md
index 3348ca6..8392122 100644
--- a/README.md
+++ b/README.md
@@ -1,10 +1,13 @@
-# DeepGrail Linker
+# DeepGrail Total
 
 This repository contains a Python implementation of a Neural Proof Net using TLGbank data.
 
-This code was designed to work with the [DeepGrail Tagger](https://gitlab.irit.fr/pnria/global-helper/deepgrail_tagger). 
-In this repository we only use the embedding of the word from the tagger and the tags from the dataset, but next step is to use the prediction of the tagger for the linking step.
+This code was designed to work with the [DeepGrail Tagger](https://gitlab.irit.fr/pnria/global-helper/deepgrail_tagger) and 
+[DeepGrail Linker](https://gitlab.irit.fr/pnria/global-helper/deepgrail-linker). 
  
+
+In this version the tagger is not retrained with the linker.
+
 ## Usage
 
 ### Installation
@@ -13,7 +16,9 @@ Clone the project locally.
 
 ### Libraries installation
 
-Run the init.sh script or install the Tagger project under SuperTagger name.
+Run the init.sh script or install the Tagger project under SuperTagger name and the Linker project under Linker name.
+
+Upload the tagger.pt in models.  (You may need to modify 'model_tagger' in train.py.)
 
 ### Dataset format
 
@@ -32,14 +37,8 @@ after each epoch. Use `tensorboard=True` for log in same folder. (`tensorboard -
 For predict on your data you need to load a model (save with this code).
 
 ```
-df = read_csv_pgbar(file_path,20)
-texts = df['X'].tolist()
-categories = df['Z'].tolist()
-
-linker = Linker(tagging_model)
-linker.load_weights("your/linker/path")
-
-links = linker.predict_with_categories(texts[7], categories[7])
+linker = neuralproofnet.linker
+links = linker.predict_without_categories(["le chat est noir"])
 print(links)
 ```
 
diff --git a/find_config.py b/find_config.py
deleted file mode 100644
index 5372528..0000000
--- a/find_config.py
+++ /dev/null
@@ -1,61 +0,0 @@
-import configparser
-import re
-
-import torch
-
-from Linker.atom_map import atom_map_redux
-from Linker.utils_linker import get_GOAL, get_atoms_links_batch, get_atoms_batch
-from SuperTagger.SuperTagger.SuperTagger import SuperTagger
-from utils import read_csv_pgbar, pad_sequence
-
-
-def configurate(dataset, model_tagger, nb_sentences=1000000000):
-    print("#" * 20)
-    print("#" * 20)
-    print("Configuration with dataset\n")
-    config = configparser.ConfigParser()
-    config.read('Configuration/config.ini')
-
-    file_path_axiom_links = dataset
-    df_axiom_links = read_csv_pgbar(file_path_axiom_links, nb_sentences)
-
-    supertagger = SuperTagger()
-    supertagger.load_weights(model_tagger)
-    sentences_batch = df_axiom_links["X"].str.strip().tolist()
-    sentences_tokens, sentences_mask = supertagger.sent_tokenizer.fit_transform_tensors(sentences_batch)
-    max_len_sentence = 0
-    for sentence in sentences_tokens:
-        if len(sentence) > max_len_sentence:
-            max_len_sentence = len(sentence)
-    print("Configure parameter max len sentence to ", max_len_sentence)
-    config.set('DATASET_PARAMS', 'max_len_sentence', str(max_len_sentence))
-
-    atoms_batch, polarities, num_batch = get_GOAL(max_len_sentence, df_axiom_links)
-    max_atoms_in_sentence = 0
-    for sentence in atoms_batch:
-        if len(sentence) > max_atoms_in_sentence:
-            max_atoms_in_sentence = len(sentence)
-    print("Configure parameter max atoms in categories to", max_atoms_in_sentence)
-    config.set('DATASET_PARAMS', 'max_atoms_in_sentence', str(max_atoms_in_sentence))
-
-    atoms_polarity_batch = pad_sequence([torch.as_tensor(polarities[i], dtype=torch.bool) for i in range(len(polarities))],
-                 max_len=max_atoms_in_sentence, padding_value=0)
-    pos_idx = [[torch.as_tensor([i for i, x in enumerate(sentence) if
-                 bool(re.match(r"" + atom_type + "(_{1}\w+)?\Z", atoms_batch[s_idx][i]))
-                 and atoms_polarity_batch[s_idx][i]])
-                for s_idx, sentence in enumerate(atoms_batch)]
-               for atom_type in list(atom_map_redux.keys())]
-    max_atoms_in_on_type = 0
-    for atoms_type_batch in pos_idx:
-        for sentence in atoms_type_batch:
-            length = sentence.size(0)
-            if length > max_atoms_in_on_type:
-                max_atoms_in_on_type = length
-    print("Configure parameter max atoms of one type in one sentence to", max_atoms_in_on_type)
-    config.set('DATASET_PARAMS', 'max_atoms_in_one_type', str(max_atoms_in_on_type * 2+2))
-
-    with open('Configuration/config.ini', 'w') as configfile:  # save
-        config.write(configfile)
-
-    print("#" * 20)
-    print("#" * 20)
\ No newline at end of file
diff --git a/init.sh b/init.sh
index be8706d..8ed2841 100644
--- a/init.sh
+++ b/init.sh
@@ -1,3 +1,4 @@
 git clone https://gitlab.irit.fr/pnria/global-helper/deepgrail_tagger.git SuperTagger
+git clone https://gitlab.irit.fr/pnria/global-helper/deepgrail-linker.git Linker
 
 pip install -r requirements.txt
\ No newline at end of file
diff --git a/train.py b/train.py
index d50fd88..8c0e647 100644
--- a/train.py
+++ b/train.py
@@ -3,7 +3,6 @@ import torch
 from Linker import *
 from NeuralProofNet.NeuralProofNet import NeuralProofNet
 from utils import read_csv_pgbar
-from find_config import configurate
 from Configuration import Configuration
 
 torch.cuda.empty_cache()
@@ -12,7 +11,6 @@ file_path_axiom_links = 'Datasets/goldANDsilver_dataset_links.csv'
 model_tagger = "models/flaubert_super_98_V2_50e.pt"
 
 # region config
-configurate(file_path_axiom_links, model_tagger, nb_sentences=nb_sentences)
 config = Configuration.read_config()
 version = config["VERSION"]
 datasetConfig = config["DATASET_PARAMS"]
-- 
GitLab