diff --git a/SuperTagger/Linker/AtomEmbedding.py b/Linker/AtomEmbedding.py
similarity index 100%
rename from SuperTagger/Linker/AtomEmbedding.py
rename to Linker/AtomEmbedding.py
diff --git a/SuperTagger/Linker/AtomTokenizer.py b/Linker/AtomTokenizer.py
similarity index 95%
rename from SuperTagger/Linker/AtomTokenizer.py
rename to Linker/AtomTokenizer.py
index a771eef0a83e31fd5e0f77449aec12f09b6e5c3d..568b3a5e3c8fb66058192ab5d005ab5cf41330c4 100644
--- a/SuperTagger/Linker/AtomTokenizer.py
+++ b/Linker/AtomTokenizer.py
@@ -1,6 +1,5 @@
 import torch
-
-from SuperTagger.utils import pad_sequence
+from ..utils import pad_sequence
 
 
 class AtomTokenizer(object):
diff --git a/Linker/Linker.py b/Linker/Linker.py
new file mode 100644
index 0000000000000000000000000000000000000000..f65325eef04bc224f025f209f99f9d1a6f653207
--- /dev/null
+++ b/Linker/Linker.py
@@ -0,0 +1,221 @@
+import torch
+from torch.nn import Sequential, LayerNorm, Dropout
+from torch.nn import Module
+import torch.nn.functional as F
+import sys
+from Configuration import Configuration
+from AtomEmbedding import AtomEmbedding
+from AtomTokenizer import AtomTokenizer
+from MHA import AttentionDecoderLayer
+from atom_map import atom_map
+from Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
+from utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN
+from eval import mesure_accuracy
+from ..utils import pad_sequence
+
+
+class Linker(Module):
+    def __init__(self):
+        super(Linker, self).__init__()
+
+        self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
+        self.dim_polarity_transfo = int(Configuration.modelLinkerConfig['dim_polarity_transfo'])
+        self.dim_embedding_atoms = int(Configuration.modelLinkerConfig['dim_embedding_atoms'])
+        self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters'])
+        self.nhead = int(Configuration.modelLinkerConfig['nhead'])
+        self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
+        self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
+        self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
+        self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
+        self.dropout = Dropout(0.1)
+        self.device = ""
+
+        self.atom_map = atom_map
+        self.padding_id = self.atom_map['[PAD]']
+        self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
+        self.atoms_embedding = AtomEmbedding(self.dim_embedding_atoms, self.atom_vocab_size, self.padding_id)
+
+        # to do : definit un encoding
+        self.linker_encoder = AttentionDecoderLayer()
+
+        self.pos_transformation = Sequential(
+            FFN(self.dim_embedding_atoms, self.dim_polarity_transfo, 0.1),
+            LayerNorm(self.dim_embedding_atoms, eps=1e-12)
+        )
+        self.neg_transformation = Sequential(
+            FFN(self.dim_embedding_atoms, self.dim_polarity_transfo, 0.1),
+            LayerNorm(self.dim_embedding_atoms, eps=1e-12)
+        )
+
+    def make_decoder_mask(self, atoms_token):
+        decoder_attn_mask = torch.ones_like(atoms_token, dtype=torch.float64)
+        decoder_attn_mask[atoms_token.eq(self.padding_id)] = 0.0
+        return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_token.shape[1], 1).repeat(self.nhead, 1, 1)
+
+    def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding, sents_mask=None):
+        r'''
+        Parameters :
+        atoms_batch_tokenized : (batch_size, max_atoms_in_one_sentence) flattened categories
+        atoms_polarity_batch : (batch_size, max_atoms_in_one_sentence) flattened categories polarities
+        sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
+        sents_mask
+        Returns :
+        link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat)
+        '''
+
+        # atoms embedding
+        atoms_embedding = self.atoms_embedding(atoms_batch_tokenized)
+
+        # MHA ou LSTM avec sortie de BERT
+        sents_embedding = torch.randn(32, self.max_len_sentence, self.dim_encoder)
+        batch_size, len_sentence, sents_embedding_dim = sents_embedding.shape
+        sents_mask = torch.randn(batch_size * self.nhead, self.max_atoms_in_sentence, self.max_len_sentence)
+        atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask,
+                                             self.make_decoder_mask(atoms_batch_tokenized))
+
+        link_weights = []
+        for atom_type in list(self.atom_map.keys())[:-1]:
+            pos_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
+                                                      if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
+                                                          atoms_batch_tokenized[s_idx][i] == self.atom_map[
+                                                              atom_type] and
+                                                          atoms_polarity_batch[s_idx][i])] + [
+                                                         torch.zeros(self.dim_embedding_atoms)])
+                                         for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
+                                        max_len=self.max_atoms_in_one_type // 2)
+
+            neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
+                                                      if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
+                                                          atoms_batch_tokenized[s_idx][i] == self.atom_map[
+                                                              atom_type] and
+                                                          not atoms_polarity_batch[s_idx][i])] + [
+                                                         torch.zeros(self.dim_embedding_atoms)])
+                                         for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
+                                        max_len=self.max_atoms_in_one_type // 2)
+
+            pos_encoding = self.pos_transformation(pos_encoding)
+            neg_encoding = self.neg_transformation(neg_encoding)
+
+            weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
+            link_weights.append(sinkhorn(weights, iters=3))
+
+        return torch.stack(link_weights)
+
+    def predict(self, categories, sents_embedding, sents_mask=None):
+        r'''
+        Parameters :
+        categories : (batch_size, len_sentence)
+        sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
+        sents_mask
+        Returns :
+        axiom_links : atom_vocab_size, batch-size, max_atoms_in_one_cat)
+        '''
+        self.eval()
+
+        batch_size, len_sentence, sents_embedding_dim = sents_embedding.shape
+
+        # get atoms
+        atoms_batch = get_atoms_batch(categories)
+        atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
+
+        # get polarities
+        polarities = find_pos_neg_idexes(self.max_atoms_in_sentence, categories)
+
+        # atoms embedding
+        atoms_embedding = self.atoms_embedding(atoms_tokenized)
+
+        # MHA ou LSTM avec sortie de BERT
+        atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask,
+                                             self.make_decoder_mask(atoms_tokenized))
+
+        link_weights = []
+        for atom_type in list(self.atom_map.keys())[:-1]:
+            pos_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
+                                                      if (self.atom_map[atom_type] in atoms_tokenized[s_idx] and
+                                                          atoms_tokenized[s_idx][i] == self.atom_map[
+                                                              atom_type] and
+                                                          polarities[s_idx][i])] + [
+                                                         torch.zeros(self.dim_embedding_atoms)])
+                                         for s_idx in range(len(polarities))], padding_value=0,
+                                        max_len=self.max_atoms_in_one_type // 2)
+
+            neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
+                                                      if (self.atom_map[atom_type] in atoms_tokenized[s_idx] and
+                                                          atoms_tokenized[s_idx][i] == self.atom_map[
+                                                              atom_type] and
+                                                          not polarities[s_idx][i])] + [
+                                                         torch.zeros(self.dim_embedding_atoms)])
+                                         for s_idx in range(len(polarities))], padding_value=0,
+                                        max_len=self.max_atoms_in_one_type // 2)
+
+            pos_encoding = self.pos_transformation(pos_encoding)
+            neg_encoding = self.neg_transformation(neg_encoding)
+
+            weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
+            link_weights.append(sinkhorn(weights, iters=3))
+
+        logits_predictions = torch.stack(link_weights).permute(1, 0, 2, 3)
+        axiom_links = torch.argmax(F.softmax(logits_predictions, dim=3), dim=3)
+        return axiom_links
+
+    def eval_batch(self, batch, cross_entropy_loss):
+        batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
+        batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
+        batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
+        # batch_sentences = batch[3].to("cuda" if torch.cuda.is_available() else "cpu")
+
+        logits_axiom_links_pred = self.forward(batch_atoms, batch_polarity, [])
+        logits_axiom_links_pred = logits_axiom_links_pred.permute(1, 0, 2, 3)
+        axiom_links_pred = torch.argmax(F.softmax(logits_axiom_links_pred, dim=3), dim=3)
+
+        accuracy = mesure_accuracy(batch_true_links, axiom_links_pred)
+        loss = cross_entropy_loss(logits_axiom_links_pred, batch_true_links)
+
+        return accuracy, loss
+
+    def eval_epoch(self, dataloader, cross_entropy_loss):
+        r"""Average the evaluation of all the batch.
+
+        Args:
+            dataloader: contains all the batch which contain the tokenized sentences, their masks and the true symbols
+        """
+        accuracy_average = 0
+        loss_average = 0
+        compt = 0
+        for step, batch in enumerate(dataloader):
+            compt += 1
+            accuracy, loss = self.eval_batch(batch, cross_entropy_loss)
+            accuracy_average += accuracy
+            loss_average += loss
+
+        return accuracy_average / compt, loss_average / compt
+
+    def load_weights(self, model_file):
+        print("#" * 15)
+        try:
+            params = torch.load(model_file, map_location=self.device)
+            args = params['args']
+            self.atom_map = args['atom_map']
+            self.max_atoms_in_sentence = args['max_atoms_in_sentence']
+            self.atoms_tokenizer = AtomTokenizer(self.atom_map, self.max_atoms_in_sentence)
+            self.atoms_embedding.load_state_dict(params['atoms_embedding'])
+            self.linker_encoder.load_state_dict(params['linker_encoder'])
+            self.pos_transformation.load_state_dict(params['pos_transformation'])
+            self.neg_transformation.load_state_dict(params['neg_transformation'])
+            print("\n The loading checkpoint was successful ! \n")
+        except Exception as e:
+            print("\n/!\ Can't load checkpoint model /!\ because :\n\n " + str(e), file=sys.stderr)
+            raise e
+        print("#" * 15)
+
+    def __checkpoint_save(self, path='/linker.pt'):
+        self.linker.cpu()
+
+        torch.save({
+            'args': dict(atom_map=self.atom_map, max_atoms_in_sentence=self.max_atoms_in_sentence),
+            'atoms_embedding': self.atoms_embedding.state_dict(),
+            'linker_encoder': self.linker_encoder.state_dict(),
+            'pos_transformation': self.pos_transformation.state_dict(),
+            'neg_transformation': self.neg_transformation.state_dict()
+        }, path)
+        self.linker.to(self.device)
diff --git a/SuperTagger/Linker/MHA.py b/Linker/MHA.py
similarity index 93%
rename from SuperTagger/Linker/MHA.py
rename to Linker/MHA.py
index d85d5e03b29ad33077224bb19f90c44d7b3d630f..c1554f9a3454a8be0ed66917824e49534bb01f6a 100644
--- a/SuperTagger/Linker/MHA.py
+++ b/Linker/MHA.py
@@ -1,13 +1,8 @@
-import copy
-import torch
-import torch.nn.functional as F
-import torch.optim as optim
-from Configuration import Configuration
-from torch import Tensor, LongTensor
-from torch.nn import (GELU, LSTM, Dropout, LayerNorm, Linear, Module, MultiheadAttention,
-                      ModuleList, Sequential)
+from torch import Tensor
+from torch.nn import (Dropout, LayerNorm, Module, MultiheadAttention)
 
-from SuperTagger.Linker.utils import FFN
+from Configuration import Configuration
+from utils_linker import FFN
 
 
 class AttentionDecoderLayer(Module):
diff --git a/SuperTagger/Linker/Sinkhorn.py b/Linker/Sinkhorn.py
similarity index 100%
rename from SuperTagger/Linker/Sinkhorn.py
rename to Linker/Sinkhorn.py
diff --git a/Linker/__init__.py b/Linker/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SuperTagger/Linker/atom_map.py b/Linker/atom_map.py
similarity index 100%
rename from SuperTagger/Linker/atom_map.py
rename to Linker/atom_map.py
diff --git a/SuperTagger/eval.py b/Linker/eval.py
similarity index 81%
rename from SuperTagger/eval.py
rename to Linker/eval.py
index 2731514885b6da2bd84661d8c6c2149ad1645b9d..1113596e276a190edfc49ac50ce511ad64b4e6c8 100644
--- a/SuperTagger/eval.py
+++ b/Linker/eval.py
@@ -1,12 +1,6 @@
 import torch
-from torch import Tensor
 from torch.nn import Module
-from torch.nn.functional import nll_loss, cross_entropy
-from SuperTagger.Linker.atom_map import atom_map
-import re
-
-from SuperTagger.Linker.utils import get_atoms_batch, find_pos_neg_idexes
-from SuperTagger.utils import pad_sequence
+from torch.nn.functional import nll_loss
 
 
 class SinkhornLoss(Module):
diff --git a/SuperTagger/Linker/utils.py b/Linker/utils_linker.py
similarity index 97%
rename from SuperTagger/Linker/utils.py
rename to Linker/utils_linker.py
index abd6814fc0bc8ae839b8efe40d3e50a8921cbfb1..f968984872d4513c0b31ae5ca6e5fc06ced70da0 100644
--- a/SuperTagger/Linker/utils.py
+++ b/Linker/utils_linker.py
@@ -1,12 +1,10 @@
 import re
 import regex
-import numpy as np
 import torch
-from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU, MultiheadAttention
+from torch.nn import Sequential, Linear, Dropout, GELU
 from torch.nn import Module
-from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
-from SuperTagger.Linker.atom_map import atom_map
-from SuperTagger.utils import pad_sequence
+from atom_map import atom_map
+from ..utils import pad_sequence
 
 
 class FFN(Module):
diff --git a/SuperTagger/Linker/Linker.py b/SuperTagger/Linker/Linker.py
deleted file mode 100644
index 93028fdeaf6cc7f1cc978e796d56d73fd0ff6b5a..0000000000000000000000000000000000000000
--- a/SuperTagger/Linker/Linker.py
+++ /dev/null
@@ -1,130 +0,0 @@
-from itertools import chain
-
-import torch
-from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU, MultiheadAttention
-from torch.nn import Module
-import torch.nn.functional as F
-
-from Configuration import Configuration
-from SuperTagger.Linker.AtomEmbedding import AtomEmbedding
-from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
-from SuperTagger.Linker.MHA import AttentionDecoderLayer
-from SuperTagger.Linker.atom_map import atom_map
-from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
-from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch, FFN
-from SuperTagger.eval import mesure_accuracy
-from SuperTagger.utils import pad_sequence
-
-
-class Linker(Module):
-    def __init__(self):
-        super(Linker, self).__init__()
-
-        self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
-        self.dim_polarity_transfo = int(Configuration.modelLinkerConfig['dim_polarity_transfo'])
-        self.dim_embedding_atoms = int(Configuration.modelLinkerConfig['dim_embedding_atoms'])
-        self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters'])
-        self.nhead = int(Configuration.modelLinkerConfig['nhead'])
-        self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
-        self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
-        self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
-        self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
-        self.dropout = Dropout(0.1)
-
-        self.atom_map = atom_map
-        self.padding_id = self.atom_map['[PAD]']
-        self.atom_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
-        self.atom_embedding = AtomEmbedding(self.dim_embedding_atoms, self.atom_vocab_size, self.padding_id)
-
-        # to do : definit un encoding
-        self.linker_encoder = AttentionDecoderLayer()
-
-        self.pos_transformation = Sequential(
-            FFN(self.dim_embedding_atoms, self.dim_polarity_transfo, 0.1),
-            LayerNorm(self.dim_embedding_atoms, eps=1e-12)
-        )
-        self.neg_transformation = Sequential(
-            FFN(self.dim_embedding_atoms, self.dim_polarity_transfo, 0.1),
-            LayerNorm(self.dim_embedding_atoms, eps=1e-12)
-        )
-
-    def make_decoder_mask(self, atoms_token) :
-        decoder_attn_mask = torch.ones_like(atoms_token, dtype=torch.float64)
-        decoder_attn_mask[atoms_token.eq(self.padding_id)] = 0.0
-        return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_token.shape[1], 1).repeat(self.nhead, 1, 1)
-
-    def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding, sents_mask=None):
-        r'''
-        Parameters :
-        atoms_batch_tokenized : (batch_size, max_atoms_in_one_sentence) flattened categories
-        atoms_polarity_batch : (batch_size, max_atoms_in_one_sentence) flattened categories polarities
-        sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
-        sents_mask
-        Returns :
-        link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat)
-        '''
-
-        # atoms embedding
-        atoms_embedding = self.atom_embedding(atoms_batch_tokenized)
-        print(atoms_embedding.shape)
-
-        # MHA ou LSTM avec sortie de BERT
-        sents_embedding = torch.randn(32, self.max_len_sentence, self.dim_encoder)
-        batch_size, len_sentence, sents_embedding_dim = sents_embedding.shape
-        sents_mask = torch.randn(batch_size * self.nhead, self.max_atoms_in_sentence, self.max_len_sentence)
-        atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask, self.make_decoder_mask(atoms_batch_tokenized))
-        #atoms_encoding = atoms_embedding
-
-        link_weights = []
-        for atom_type in list(self.atom_map.keys())[:-1]:
-            pos_encoding = pad_sequence([torch.stack([x  for i, x in enumerate(atoms_encoding[s_idx])
-                             if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
-                                 atoms_batch_tokenized[s_idx][i] == self.atom_map[atom_type] and
-                                 atoms_polarity_batch[s_idx][i])] + [torch.zeros(self.dim_embedding_atoms)])
-                            for s_idx in range(len(atoms_polarity_batch))], padding_value=0, max_len=self.max_atoms_in_one_type//2)
-
-            neg_encoding = pad_sequence([torch.stack([x  for i, x in enumerate(atoms_encoding[s_idx])
-                             if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
-                                 atoms_batch_tokenized[s_idx][i] == self.atom_map[atom_type] and
-                                 not atoms_polarity_batch[s_idx][i])] + [torch.zeros(self.dim_embedding_atoms)])
-                            for s_idx in range(len(atoms_polarity_batch))], padding_value=0, max_len=self.max_atoms_in_one_type//2)
-
-            pos_encoding = self.pos_transformation(pos_encoding)
-            neg_encoding = self.neg_transformation(neg_encoding)
-
-            weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
-            link_weights.append(sinkhorn(weights, iters=3))
-
-        return torch.stack(link_weights)
-
-    def eval_batch(self, batch, cross_entropy_loss):
-        batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
-        batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
-        batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
-        #batch_sentences = batch[3].to("cuda" if torch.cuda.is_available() else "cpu")
-
-        logits_axiom_links_pred = self.forward(batch_atoms, batch_polarity, [])
-        logits_axiom_links_pred = logits_axiom_links_pred.permute(1, 0, 2, 3)
-        axiom_links_pred = torch.argmax(F.softmax(logits_axiom_links_pred, dim=3), dim=3)
-
-        accuracy = mesure_accuracy(batch_true_links, axiom_links_pred)
-        loss = cross_entropy_loss(logits_axiom_links_pred, batch_true_links)
-
-        return accuracy, loss
-
-    def eval_epoch(self, dataloader, cross_entropy_loss):
-        r"""Average the evaluation of all the batch.
-
-        Args:
-            dataloader: contains all the batch which contain the tokenized sentences, their masks and the true symbols
-        """
-        accuracy_average = 0
-        loss_average = 0
-        compt = 0
-        for step, batch in enumerate(dataloader):
-            compt += 1
-            accuracy, loss = self.eval_batch(batch, cross_entropy_loss)
-            accuracy_average += accuracy
-            loss_average += loss
-
-        return accuracy_average / compt, loss_average / compt
diff --git a/SuperTagger/Linker/__pycache__/AtomEmbedding.cpython-38.pyc b/SuperTagger/Linker/__pycache__/AtomEmbedding.cpython-38.pyc
deleted file mode 100644
index a6ce66525d13768733269e6a9b3ec0bd2c64a4d7..0000000000000000000000000000000000000000
Binary files a/SuperTagger/Linker/__pycache__/AtomEmbedding.cpython-38.pyc and /dev/null differ
diff --git a/SuperTagger/Linker/__pycache__/AtomTokenizer.cpython-38.pyc b/SuperTagger/Linker/__pycache__/AtomTokenizer.cpython-38.pyc
deleted file mode 100644
index e3cc2ea0978bb23bd6c75f39863a19e5f9fc27d6..0000000000000000000000000000000000000000
Binary files a/SuperTagger/Linker/__pycache__/AtomTokenizer.cpython-38.pyc and /dev/null differ
diff --git a/SuperTagger/Linker/__pycache__/Linker.cpython-38.pyc b/SuperTagger/Linker/__pycache__/Linker.cpython-38.pyc
deleted file mode 100644
index facf9eafa213710664a3d7f18c9150693e6c0a2c..0000000000000000000000000000000000000000
Binary files a/SuperTagger/Linker/__pycache__/Linker.cpython-38.pyc and /dev/null differ
diff --git a/SuperTagger/Linker/__pycache__/MHA.cpython-38.pyc b/SuperTagger/Linker/__pycache__/MHA.cpython-38.pyc
deleted file mode 100644
index 679c41a8ef96c82cce084ed1d859b6f0933c12f5..0000000000000000000000000000000000000000
Binary files a/SuperTagger/Linker/__pycache__/MHA.cpython-38.pyc and /dev/null differ
diff --git a/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc b/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc
deleted file mode 100644
index 26d18c0959a2b1114a5564ff8b1ec4304f81acf3..0000000000000000000000000000000000000000
Binary files a/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc and /dev/null differ
diff --git a/SuperTagger/Linker/__pycache__/atom_map.cpython-38.pyc b/SuperTagger/Linker/__pycache__/atom_map.cpython-38.pyc
deleted file mode 100644
index f189466986ce4136d6ed66a59cd0d49b61b4aad8..0000000000000000000000000000000000000000
Binary files a/SuperTagger/Linker/__pycache__/atom_map.cpython-38.pyc and /dev/null differ
diff --git a/SuperTagger/Linker/__pycache__/utils.cpython-38.pyc b/SuperTagger/Linker/__pycache__/utils.cpython-38.pyc
deleted file mode 100644
index c4eef1e07886db024496a876b846bd69646a7538..0000000000000000000000000000000000000000
Binary files a/SuperTagger/Linker/__pycache__/utils.cpython-38.pyc and /dev/null differ
diff --git a/SuperTagger/__init__.py b/SuperTagger/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/SuperTagger/__pycache__/eval.cpython-38.pyc b/SuperTagger/__pycache__/eval.cpython-38.pyc
deleted file mode 100644
index f5cec815f7dbc90d4ab075b8e48682a5c6e119fd..0000000000000000000000000000000000000000
Binary files a/SuperTagger/__pycache__/eval.cpython-38.pyc and /dev/null differ
diff --git a/SuperTagger/__pycache__/utils.cpython-38.pyc b/SuperTagger/__pycache__/utils.cpython-38.pyc
deleted file mode 100644
index 9e66bb4e0c44377cdeed2d562db8eef201427af3..0000000000000000000000000000000000000000
Binary files a/SuperTagger/__pycache__/utils.cpython-38.pyc and /dev/null differ
diff --git a/main.py b/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdf7bc279b00aeb15d37a2e450e1e36eaf495c42
--- /dev/null
+++ b/main.py
@@ -0,0 +1,20 @@
+import torch.nn.functional as F
+
+from Configuration import Configuration
+from Linker.Linker import Linker
+
+max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
+
+# categories tagger
+tagger = SuperTagger()
+tagger.load_weights("models/model_check.pt")
+
+# axiom linker
+linker = Linker()
+linker.load_weights("models/linker.pt")
+
+# predict categories and links for this sentence
+sentence = [[]]
+categories, sentence_embedding = tagger.predict(sentence)
+
+axiom_links = linker.predict(categories, sentence_embedding)
diff --git a/train.py b/train.py
index f8290f8554a28f22598b5a8780abcdb1d3ca1470..a37aaa46a6a12b8f1272914c1b6a85be627c2b5a 100644
--- a/train.py
+++ b/train.py
@@ -1,21 +1,20 @@
-import pickle
+import os
 import time
+from datetime import datetime
 
 import numpy as np
 import torch
-import torch.nn.functional as F
-from torch.optim import SGD, Adam, AdamW
+from torch.optim import AdamW
 from torch.utils.data import Dataset, TensorDataset, random_split
-from transformers import get_cosine_schedule_with_warmup
+from transformers import (get_cosine_schedule_with_warmup)
 
 from Configuration import Configuration
-from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
-from transformers import (AutoTokenizer, get_cosine_schedule_with_warmup)
-from SuperTagger.Linker.Linker import Linker
-from SuperTagger.Linker.atom_map import atom_map
-from SuperTagger.Linker.utils import get_axiom_links, get_atoms_batch, find_pos_neg_idexes
-from SuperTagger.eval import SinkhornLoss
-from SuperTagger.utils import format_time, read_csv_pgbar
+from Linker.AtomTokenizer import AtomTokenizer
+from Linker.Linker import Linker
+from Linker.atom_map import atom_map
+from Linker.utils_linker import get_axiom_links, get_atoms_batch, find_pos_neg_idexes
+from Linker.eval import SinkhornLoss
+from utils import format_time, read_csv_pgbar
 
 torch.cuda.empty_cache()
 
@@ -63,7 +62,6 @@ print("atoms_batch", atoms_batch[14])
 print("atoms_polarity_batch", atoms_polarity_batch[14])
 print(" truth_links_batch example on a sentence class txt", truth_links_batch[14][16])
 
-
 # Construction tensor dataset
 dataset = TensorDataset(atoms_batch_tokenized, atoms_polarity_batch, truth_links_batch)
 
@@ -82,6 +80,9 @@ validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batc
 
 # region Models
 
+# supertagger = SuperTagger()
+# supertagger.load_weights("models/model_check.pt")
+
 linker = Linker()
 
 # endregion Models
@@ -115,6 +116,7 @@ torch.autograd.set_detect_anomaly(True)
 total_t0 = time.time()
 
 validate = True
+checkpoint = True
 
 
 def run_epochs(epochs):
@@ -141,14 +143,16 @@ def run_epochs(epochs):
         # For each batch of training data...
         for step, batch in enumerate(training_dataloader):
             # Unpack this training batch from our dataloader
-
             batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
             batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
             batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
-            #batch_sentences = batch[3].to("cuda" if torch.cuda.is_available() else "cpu")
+            # batch_sentences = batch[3].to("cuda" if torch.cuda.is_available() else "cpu")
 
             optimizer_linker.zero_grad()
 
+            # get sentence embedding from BERT which is already trained
+            # sentences_embedding = supertagger(batch_sentences)
+
             # Run the kinker on the categories predictions
             logits_predictions = linker(batch_atoms, batch_polarity, [])
 
@@ -169,6 +173,10 @@ def run_epochs(epochs):
         # Measure how long this epoch took.
         training_time = format_time(time.time() - t0)
 
+        if checkpoint:
+            checkpoint_dir = os.path.join("Output", 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M'))
+            linker.__checkpoint_save(path=os.path.join(checkpoint_dir, 'model_check.pt'))
+
         if validate:
             linker.eval()
             with torch.no_grad():
diff --git a/SuperTagger/utils.py b/utils.py
similarity index 100%
rename from SuperTagger/utils.py
rename to utils.py