From 154eabc1bfdd823c0bd6fba92964fcce766ecac1 Mon Sep 17 00:00:00 2001
From: Caroline DE POURTALES <cdepourt@montana.irit.fr>
Date: Fri, 13 May 2022 14:59:05 +0200
Subject: [PATCH] architecture and main

---
 .../Linker => Linker}/AtomEmbedding.py        |   0
 .../Linker => Linker}/AtomTokenizer.py        |   3 +-
 Linker/Linker.py                              | 221 ++++++++++++++++++
 {SuperTagger/Linker => Linker}/MHA.py         |  13 +-
 {SuperTagger/Linker => Linker}/Sinkhorn.py    |   0
 Linker/__init__.py                            |   0
 {SuperTagger/Linker => Linker}/atom_map.py    |   0
 {SuperTagger => Linker}/eval.py               |   8 +-
 .../Linker/utils.py => Linker/utils_linker.py |   8 +-
 SuperTagger/Linker/Linker.py                  | 130 -----------
 .../__pycache__/AtomEmbedding.cpython-38.pyc  | Bin 867 -> 0 bytes
 .../__pycache__/AtomTokenizer.cpython-38.pyc  | Bin 2292 -> 0 bytes
 .../Linker/__pycache__/Linker.cpython-38.pyc  | Bin 5632 -> 0 bytes
 .../Linker/__pycache__/MHA.cpython-38.pyc     | Bin 4389 -> 0 bytes
 .../__pycache__/Sinkhorn.cpython-38.pyc       | Bin 687 -> 0 bytes
 .../__pycache__/atom_map.cpython-38.pyc       | Bin 483 -> 0 bytes
 .../Linker/__pycache__/utils.cpython-38.pyc   | Bin 8757 -> 0 bytes
 SuperTagger/__init__.py                       |   0
 SuperTagger/__pycache__/eval.cpython-38.pyc   | Bin 1899 -> 0 bytes
 SuperTagger/__pycache__/utils.cpython-38.pyc  | Bin 1851 -> 0 bytes
 main.py                                       |  20 ++
 train.py                                      |  36 +--
 SuperTagger/utils.py => utils.py              |   0
 23 files changed, 272 insertions(+), 167 deletions(-)
 rename {SuperTagger/Linker => Linker}/AtomEmbedding.py (100%)
 rename {SuperTagger/Linker => Linker}/AtomTokenizer.py (95%)
 create mode 100644 Linker/Linker.py
 rename {SuperTagger/Linker => Linker}/MHA.py (93%)
 rename {SuperTagger/Linker => Linker}/Sinkhorn.py (100%)
 create mode 100644 Linker/__init__.py
 rename {SuperTagger/Linker => Linker}/atom_map.py (100%)
 rename {SuperTagger => Linker}/eval.py (81%)
 rename SuperTagger/Linker/utils.py => Linker/utils_linker.py (97%)
 delete mode 100644 SuperTagger/Linker/Linker.py
 delete mode 100644 SuperTagger/Linker/__pycache__/AtomEmbedding.cpython-38.pyc
 delete mode 100644 SuperTagger/Linker/__pycache__/AtomTokenizer.cpython-38.pyc
 delete mode 100644 SuperTagger/Linker/__pycache__/Linker.cpython-38.pyc
 delete mode 100644 SuperTagger/Linker/__pycache__/MHA.cpython-38.pyc
 delete mode 100644 SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc
 delete mode 100644 SuperTagger/Linker/__pycache__/atom_map.cpython-38.pyc
 delete mode 100644 SuperTagger/Linker/__pycache__/utils.cpython-38.pyc
 create mode 100644 SuperTagger/__init__.py
 delete mode 100644 SuperTagger/__pycache__/eval.cpython-38.pyc
 delete mode 100644 SuperTagger/__pycache__/utils.cpython-38.pyc
 create mode 100644 main.py
 rename SuperTagger/utils.py => utils.py (100%)

diff --git a/SuperTagger/Linker/AtomEmbedding.py b/Linker/AtomEmbedding.py
similarity index 100%
rename from SuperTagger/Linker/AtomEmbedding.py
rename to Linker/AtomEmbedding.py
diff --git a/SuperTagger/Linker/AtomTokenizer.py b/Linker/AtomTokenizer.py
similarity index 95%
rename from SuperTagger/Linker/AtomTokenizer.py
rename to Linker/AtomTokenizer.py
index a771eef..568b3a5 100644
--- a/SuperTagger/Linker/AtomTokenizer.py
+++ b/Linker/AtomTokenizer.py
@@ -1,6 +1,5 @@
 import torch
-
-from SuperTagger.utils import pad_sequence
+from ..utils import pad_sequence
 
 
 class AtomTokenizer(object):
diff --git a/Linker/Linker.py b/Linker/Linker.py
new file mode 100644
index 0000000..f65325e
--- /dev/null
+++ b/Linker/Linker.py
@@ -0,0 +1,221 @@
+import torch
+from torch.nn import Sequential, LayerNorm, Dropout
+from torch.nn import Module
+import torch.nn.functional as F
+import sys
+from Configuration import Configuration
+from AtomEmbedding import AtomEmbedding
+from AtomTokenizer import AtomTokenizer
+from MHA import AttentionDecoderLayer
+from atom_map import atom_map
+from Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
+from utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN
+from eval import mesure_accuracy
+from ..utils import pad_sequence
+
+
+class Linker(Module):
+    def __init__(self):
+        super(Linker, self).__init__()
+
+        self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
+        self.dim_polarity_transfo = int(Configuration.modelLinkerConfig['dim_polarity_transfo'])
+        self.dim_embedding_atoms = int(Configuration.modelLinkerConfig['dim_embedding_atoms'])
+        self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters'])
+        self.nhead = int(Configuration.modelLinkerConfig['nhead'])
+        self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
+        self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
+        self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
+        self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
+        self.dropout = Dropout(0.1)
+        self.device = ""
+
+        self.atom_map = atom_map
+        self.padding_id = self.atom_map['[PAD]']
+        self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
+        self.atoms_embedding = AtomEmbedding(self.dim_embedding_atoms, self.atom_vocab_size, self.padding_id)
+
+        # to do : definit un encoding
+        self.linker_encoder = AttentionDecoderLayer()
+
+        self.pos_transformation = Sequential(
+            FFN(self.dim_embedding_atoms, self.dim_polarity_transfo, 0.1),
+            LayerNorm(self.dim_embedding_atoms, eps=1e-12)
+        )
+        self.neg_transformation = Sequential(
+            FFN(self.dim_embedding_atoms, self.dim_polarity_transfo, 0.1),
+            LayerNorm(self.dim_embedding_atoms, eps=1e-12)
+        )
+
+    def make_decoder_mask(self, atoms_token):
+        decoder_attn_mask = torch.ones_like(atoms_token, dtype=torch.float64)
+        decoder_attn_mask[atoms_token.eq(self.padding_id)] = 0.0
+        return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_token.shape[1], 1).repeat(self.nhead, 1, 1)
+
+    def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding, sents_mask=None):
+        r'''
+        Parameters :
+        atoms_batch_tokenized : (batch_size, max_atoms_in_one_sentence) flattened categories
+        atoms_polarity_batch : (batch_size, max_atoms_in_one_sentence) flattened categories polarities
+        sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
+        sents_mask
+        Returns :
+        link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat)
+        '''
+
+        # atoms embedding
+        atoms_embedding = self.atoms_embedding(atoms_batch_tokenized)
+
+        # MHA ou LSTM avec sortie de BERT
+        sents_embedding = torch.randn(32, self.max_len_sentence, self.dim_encoder)
+        batch_size, len_sentence, sents_embedding_dim = sents_embedding.shape
+        sents_mask = torch.randn(batch_size * self.nhead, self.max_atoms_in_sentence, self.max_len_sentence)
+        atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask,
+                                             self.make_decoder_mask(atoms_batch_tokenized))
+
+        link_weights = []
+        for atom_type in list(self.atom_map.keys())[:-1]:
+            pos_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
+                                                      if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
+                                                          atoms_batch_tokenized[s_idx][i] == self.atom_map[
+                                                              atom_type] and
+                                                          atoms_polarity_batch[s_idx][i])] + [
+                                                         torch.zeros(self.dim_embedding_atoms)])
+                                         for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
+                                        max_len=self.max_atoms_in_one_type // 2)
+
+            neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
+                                                      if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
+                                                          atoms_batch_tokenized[s_idx][i] == self.atom_map[
+                                                              atom_type] and
+                                                          not atoms_polarity_batch[s_idx][i])] + [
+                                                         torch.zeros(self.dim_embedding_atoms)])
+                                         for s_idx in range(len(atoms_polarity_batch))], padding_value=0,
+                                        max_len=self.max_atoms_in_one_type // 2)
+
+            pos_encoding = self.pos_transformation(pos_encoding)
+            neg_encoding = self.neg_transformation(neg_encoding)
+
+            weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
+            link_weights.append(sinkhorn(weights, iters=3))
+
+        return torch.stack(link_weights)
+
+    def predict(self, categories, sents_embedding, sents_mask=None):
+        r'''
+        Parameters :
+        categories : (batch_size, len_sentence)
+        sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
+        sents_mask
+        Returns :
+        axiom_links : atom_vocab_size, batch-size, max_atoms_in_one_cat)
+        '''
+        self.eval()
+
+        batch_size, len_sentence, sents_embedding_dim = sents_embedding.shape
+
+        # get atoms
+        atoms_batch = get_atoms_batch(categories)
+        atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
+
+        # get polarities
+        polarities = find_pos_neg_idexes(self.max_atoms_in_sentence, categories)
+
+        # atoms embedding
+        atoms_embedding = self.atoms_embedding(atoms_tokenized)
+
+        # MHA ou LSTM avec sortie de BERT
+        atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask,
+                                             self.make_decoder_mask(atoms_tokenized))
+
+        link_weights = []
+        for atom_type in list(self.atom_map.keys())[:-1]:
+            pos_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
+                                                      if (self.atom_map[atom_type] in atoms_tokenized[s_idx] and
+                                                          atoms_tokenized[s_idx][i] == self.atom_map[
+                                                              atom_type] and
+                                                          polarities[s_idx][i])] + [
+                                                         torch.zeros(self.dim_embedding_atoms)])
+                                         for s_idx in range(len(polarities))], padding_value=0,
+                                        max_len=self.max_atoms_in_one_type // 2)
+
+            neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx])
+                                                      if (self.atom_map[atom_type] in atoms_tokenized[s_idx] and
+                                                          atoms_tokenized[s_idx][i] == self.atom_map[
+                                                              atom_type] and
+                                                          not polarities[s_idx][i])] + [
+                                                         torch.zeros(self.dim_embedding_atoms)])
+                                         for s_idx in range(len(polarities))], padding_value=0,
+                                        max_len=self.max_atoms_in_one_type // 2)
+
+            pos_encoding = self.pos_transformation(pos_encoding)
+            neg_encoding = self.neg_transformation(neg_encoding)
+
+            weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
+            link_weights.append(sinkhorn(weights, iters=3))
+
+        logits_predictions = torch.stack(link_weights).permute(1, 0, 2, 3)
+        axiom_links = torch.argmax(F.softmax(logits_predictions, dim=3), dim=3)
+        return axiom_links
+
+    def eval_batch(self, batch, cross_entropy_loss):
+        batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
+        batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
+        batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
+        # batch_sentences = batch[3].to("cuda" if torch.cuda.is_available() else "cpu")
+
+        logits_axiom_links_pred = self.forward(batch_atoms, batch_polarity, [])
+        logits_axiom_links_pred = logits_axiom_links_pred.permute(1, 0, 2, 3)
+        axiom_links_pred = torch.argmax(F.softmax(logits_axiom_links_pred, dim=3), dim=3)
+
+        accuracy = mesure_accuracy(batch_true_links, axiom_links_pred)
+        loss = cross_entropy_loss(logits_axiom_links_pred, batch_true_links)
+
+        return accuracy, loss
+
+    def eval_epoch(self, dataloader, cross_entropy_loss):
+        r"""Average the evaluation of all the batch.
+
+        Args:
+            dataloader: contains all the batch which contain the tokenized sentences, their masks and the true symbols
+        """
+        accuracy_average = 0
+        loss_average = 0
+        compt = 0
+        for step, batch in enumerate(dataloader):
+            compt += 1
+            accuracy, loss = self.eval_batch(batch, cross_entropy_loss)
+            accuracy_average += accuracy
+            loss_average += loss
+
+        return accuracy_average / compt, loss_average / compt
+
+    def load_weights(self, model_file):
+        print("#" * 15)
+        try:
+            params = torch.load(model_file, map_location=self.device)
+            args = params['args']
+            self.atom_map = args['atom_map']
+            self.max_atoms_in_sentence = args['max_atoms_in_sentence']
+            self.atoms_tokenizer = AtomTokenizer(self.atom_map, self.max_atoms_in_sentence)
+            self.atoms_embedding.load_state_dict(params['atoms_embedding'])
+            self.linker_encoder.load_state_dict(params['linker_encoder'])
+            self.pos_transformation.load_state_dict(params['pos_transformation'])
+            self.neg_transformation.load_state_dict(params['neg_transformation'])
+            print("\n The loading checkpoint was successful ! \n")
+        except Exception as e:
+            print("\n/!\ Can't load checkpoint model /!\ because :\n\n " + str(e), file=sys.stderr)
+            raise e
+        print("#" * 15)
+
+    def __checkpoint_save(self, path='/linker.pt'):
+        self.linker.cpu()
+
+        torch.save({
+            'args': dict(atom_map=self.atom_map, max_atoms_in_sentence=self.max_atoms_in_sentence),
+            'atoms_embedding': self.atoms_embedding.state_dict(),
+            'linker_encoder': self.linker_encoder.state_dict(),
+            'pos_transformation': self.pos_transformation.state_dict(),
+            'neg_transformation': self.neg_transformation.state_dict()
+        }, path)
+        self.linker.to(self.device)
diff --git a/SuperTagger/Linker/MHA.py b/Linker/MHA.py
similarity index 93%
rename from SuperTagger/Linker/MHA.py
rename to Linker/MHA.py
index d85d5e0..c1554f9 100644
--- a/SuperTagger/Linker/MHA.py
+++ b/Linker/MHA.py
@@ -1,13 +1,8 @@
-import copy
-import torch
-import torch.nn.functional as F
-import torch.optim as optim
-from Configuration import Configuration
-from torch import Tensor, LongTensor
-from torch.nn import (GELU, LSTM, Dropout, LayerNorm, Linear, Module, MultiheadAttention,
-                      ModuleList, Sequential)
+from torch import Tensor
+from torch.nn import (Dropout, LayerNorm, Module, MultiheadAttention)
 
-from SuperTagger.Linker.utils import FFN
+from Configuration import Configuration
+from utils_linker import FFN
 
 
 class AttentionDecoderLayer(Module):
diff --git a/SuperTagger/Linker/Sinkhorn.py b/Linker/Sinkhorn.py
similarity index 100%
rename from SuperTagger/Linker/Sinkhorn.py
rename to Linker/Sinkhorn.py
diff --git a/Linker/__init__.py b/Linker/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/SuperTagger/Linker/atom_map.py b/Linker/atom_map.py
similarity index 100%
rename from SuperTagger/Linker/atom_map.py
rename to Linker/atom_map.py
diff --git a/SuperTagger/eval.py b/Linker/eval.py
similarity index 81%
rename from SuperTagger/eval.py
rename to Linker/eval.py
index 2731514..1113596 100644
--- a/SuperTagger/eval.py
+++ b/Linker/eval.py
@@ -1,12 +1,6 @@
 import torch
-from torch import Tensor
 from torch.nn import Module
-from torch.nn.functional import nll_loss, cross_entropy
-from SuperTagger.Linker.atom_map import atom_map
-import re
-
-from SuperTagger.Linker.utils import get_atoms_batch, find_pos_neg_idexes
-from SuperTagger.utils import pad_sequence
+from torch.nn.functional import nll_loss
 
 
 class SinkhornLoss(Module):
diff --git a/SuperTagger/Linker/utils.py b/Linker/utils_linker.py
similarity index 97%
rename from SuperTagger/Linker/utils.py
rename to Linker/utils_linker.py
index abd6814..f968984 100644
--- a/SuperTagger/Linker/utils.py
+++ b/Linker/utils_linker.py
@@ -1,12 +1,10 @@
 import re
 import regex
-import numpy as np
 import torch
-from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU, MultiheadAttention
+from torch.nn import Sequential, Linear, Dropout, GELU
 from torch.nn import Module
-from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
-from SuperTagger.Linker.atom_map import atom_map
-from SuperTagger.utils import pad_sequence
+from atom_map import atom_map
+from ..utils import pad_sequence
 
 
 class FFN(Module):
diff --git a/SuperTagger/Linker/Linker.py b/SuperTagger/Linker/Linker.py
deleted file mode 100644
index 93028fd..0000000
--- a/SuperTagger/Linker/Linker.py
+++ /dev/null
@@ -1,130 +0,0 @@
-from itertools import chain
-
-import torch
-from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU, MultiheadAttention
-from torch.nn import Module
-import torch.nn.functional as F
-
-from Configuration import Configuration
-from SuperTagger.Linker.AtomEmbedding import AtomEmbedding
-from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
-from SuperTagger.Linker.MHA import AttentionDecoderLayer
-from SuperTagger.Linker.atom_map import atom_map
-from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
-from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch, FFN
-from SuperTagger.eval import mesure_accuracy
-from SuperTagger.utils import pad_sequence
-
-
-class Linker(Module):
-    def __init__(self):
-        super(Linker, self).__init__()
-
-        self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder'])
-        self.dim_polarity_transfo = int(Configuration.modelLinkerConfig['dim_polarity_transfo'])
-        self.dim_embedding_atoms = int(Configuration.modelLinkerConfig['dim_embedding_atoms'])
-        self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters'])
-        self.nhead = int(Configuration.modelLinkerConfig['nhead'])
-        self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence'])
-        self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
-        self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type'])
-        self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
-        self.dropout = Dropout(0.1)
-
-        self.atom_map = atom_map
-        self.padding_id = self.atom_map['[PAD]']
-        self.atom_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
-        self.atom_embedding = AtomEmbedding(self.dim_embedding_atoms, self.atom_vocab_size, self.padding_id)
-
-        # to do : definit un encoding
-        self.linker_encoder = AttentionDecoderLayer()
-
-        self.pos_transformation = Sequential(
-            FFN(self.dim_embedding_atoms, self.dim_polarity_transfo, 0.1),
-            LayerNorm(self.dim_embedding_atoms, eps=1e-12)
-        )
-        self.neg_transformation = Sequential(
-            FFN(self.dim_embedding_atoms, self.dim_polarity_transfo, 0.1),
-            LayerNorm(self.dim_embedding_atoms, eps=1e-12)
-        )
-
-    def make_decoder_mask(self, atoms_token) :
-        decoder_attn_mask = torch.ones_like(atoms_token, dtype=torch.float64)
-        decoder_attn_mask[atoms_token.eq(self.padding_id)] = 0.0
-        return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_token.shape[1], 1).repeat(self.nhead, 1, 1)
-
-    def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding, sents_mask=None):
-        r'''
-        Parameters :
-        atoms_batch_tokenized : (batch_size, max_atoms_in_one_sentence) flattened categories
-        atoms_polarity_batch : (batch_size, max_atoms_in_one_sentence) flattened categories polarities
-        sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
-        sents_mask
-        Returns :
-        link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat)
-        '''
-
-        # atoms embedding
-        atoms_embedding = self.atom_embedding(atoms_batch_tokenized)
-        print(atoms_embedding.shape)
-
-        # MHA ou LSTM avec sortie de BERT
-        sents_embedding = torch.randn(32, self.max_len_sentence, self.dim_encoder)
-        batch_size, len_sentence, sents_embedding_dim = sents_embedding.shape
-        sents_mask = torch.randn(batch_size * self.nhead, self.max_atoms_in_sentence, self.max_len_sentence)
-        atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask, self.make_decoder_mask(atoms_batch_tokenized))
-        #atoms_encoding = atoms_embedding
-
-        link_weights = []
-        for atom_type in list(self.atom_map.keys())[:-1]:
-            pos_encoding = pad_sequence([torch.stack([x  for i, x in enumerate(atoms_encoding[s_idx])
-                             if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
-                                 atoms_batch_tokenized[s_idx][i] == self.atom_map[atom_type] and
-                                 atoms_polarity_batch[s_idx][i])] + [torch.zeros(self.dim_embedding_atoms)])
-                            for s_idx in range(len(atoms_polarity_batch))], padding_value=0, max_len=self.max_atoms_in_one_type//2)
-
-            neg_encoding = pad_sequence([torch.stack([x  for i, x in enumerate(atoms_encoding[s_idx])
-                             if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
-                                 atoms_batch_tokenized[s_idx][i] == self.atom_map[atom_type] and
-                                 not atoms_polarity_batch[s_idx][i])] + [torch.zeros(self.dim_embedding_atoms)])
-                            for s_idx in range(len(atoms_polarity_batch))], padding_value=0, max_len=self.max_atoms_in_one_type//2)
-
-            pos_encoding = self.pos_transformation(pos_encoding)
-            neg_encoding = self.neg_transformation(neg_encoding)
-
-            weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
-            link_weights.append(sinkhorn(weights, iters=3))
-
-        return torch.stack(link_weights)
-
-    def eval_batch(self, batch, cross_entropy_loss):
-        batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
-        batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
-        batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
-        #batch_sentences = batch[3].to("cuda" if torch.cuda.is_available() else "cpu")
-
-        logits_axiom_links_pred = self.forward(batch_atoms, batch_polarity, [])
-        logits_axiom_links_pred = logits_axiom_links_pred.permute(1, 0, 2, 3)
-        axiom_links_pred = torch.argmax(F.softmax(logits_axiom_links_pred, dim=3), dim=3)
-
-        accuracy = mesure_accuracy(batch_true_links, axiom_links_pred)
-        loss = cross_entropy_loss(logits_axiom_links_pred, batch_true_links)
-
-        return accuracy, loss
-
-    def eval_epoch(self, dataloader, cross_entropy_loss):
-        r"""Average the evaluation of all the batch.
-
-        Args:
-            dataloader: contains all the batch which contain the tokenized sentences, their masks and the true symbols
-        """
-        accuracy_average = 0
-        loss_average = 0
-        compt = 0
-        for step, batch in enumerate(dataloader):
-            compt += 1
-            accuracy, loss = self.eval_batch(batch, cross_entropy_loss)
-            accuracy_average += accuracy
-            loss_average += loss
-
-        return accuracy_average / compt, loss_average / compt
diff --git a/SuperTagger/Linker/__pycache__/AtomEmbedding.cpython-38.pyc b/SuperTagger/Linker/__pycache__/AtomEmbedding.cpython-38.pyc
deleted file mode 100644
index a6ce66525d13768733269e6a9b3ec0bd2c64a4d7..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001

literal 867
zcmWIL<>g{vU|>j)C`?+y$iVOz#6iX^3=9ko3=9m#It&a9DGVu$ISf%Cnkk1dmnn)V
zmpO`=k-?oIg*k<#g&~C{m8qFIiY0|Hm_d{ECCD5<O{QCHzWFJoIjOfeU2~ICQ&KYX
z(vv|_$QWb|h+nM6z`&5o5XG3n5XF?j*3OW|n8K98+`<{foWhdA+QJaU($2ub5XBnI
zpviuV*Rdo&7jCFuGRzc^Ngy^G0|Nsy$nh!+3=Aa<C5$x;%}fiJ7BVn0lrYyY)iBmD
zr!b~4^|I73#IvL@1T$zd`-Nz-+~UhC&5cimIkxx~FPt5pl9_vpyC4zli}=ixid#a(
z$%#3s@##g0De*~_@o7b=g_^9lSc^*wQj3Z}j=04UAD@|*SrQ*#1aiwQ=G5FIO_p0M
z#i==IQCumRx$!xfdD*E&xA+rF@^j<M@{<#j;)^q@Qj6Fb7#LPE-r|gpPtHj!E{>01
z$?z*!KO;XkRX;f;wIIK=s6^kXv?w*PR6iv(wIIDHF*7GV$j>jnJhLPNYN~!P$RQz#
z>FKFO`aWQb^ie`hub{FBlmtLwRm=ejIYvH44o0T`Rbr^>V7^FZ0x5!G5F6wxXHY1E
zjHzK*z)-`u5R?EI{WKYixEUB2G?|Kc7#J9CF;)}_FfcGc2tkktds==`d16rtNL4Y&
zEC$9ZQB-T8Dl{4WG}%BwnU|QG8Xtd)D?UCKoT}pEZ}G&(7nUaGKxDw4DG~;m&I=8=
z<kXy;_;`r5U<dGlTmf<bID%L~ZU@B_2O}RVSOgS-CHY0k8MipVlwMw55y*Am2n4G@
WG7M}Y$go=+HV`M-F(OP9U<Lrpiqez-

diff --git a/SuperTagger/Linker/__pycache__/AtomTokenizer.cpython-38.pyc b/SuperTagger/Linker/__pycache__/AtomTokenizer.cpython-38.pyc
deleted file mode 100644
index e3cc2ea0978bb23bd6c75f39863a19e5f9fc27d6..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001

literal 2292
zcmWIL<>g{vU|`@@ElpBkVPJR+;vi!d1_lNP1_p*=4F(2=6owSW9EK<m&6LBK%M``L
z$l%V9!kog=!jQs}%G}Hx#gf7p%%I8o5@d#-CgUxhg2a^g;?%;@)V$=>WRNH_W`db+
z!@$6h$`Hkv!Vtxj!kEI;!Vtv_F*k}Og*An(g&~SHg*}C%g&~S9g)@b#g&~SPg*%0(
zg&~R~g*TW%lkXO<V@ZB)NPc!|US?HlQ8F`<Js>s+JA*u8#lXN&!&t)*&rriu!w}C1
zv95+8i(vuNLIxLxW{@0n3UduZJPSy)h9QeZlA(qno;8Ibm_d`pFPRa|XiyjhgNznq
zU|`5(s9}iZi(#r|tYxZUs$rbWRLBy{u#%w&q)L<d7L%UAEynCyjAdW}F!3u`KO;Xk
zRX;f;wIIK=s6^kXv?w*PR6iv(wIIDHF*7GV$j>jnJhLPtJ|{CTJGDqZxU?X(C?qjG
zJ+(;R2O@;xd%c3nTU<6NnaL%|`MCvlpeQb8V_;xlV_;)oVyMzWRU99mnU`4-AFpSV
zlb@WJQ*5V)&|1YB9pLB^tI2kYBe5huH$FG9;Ff4^VnsZNR~(<27hjy3R|1OMTdbKS
zskz0s1Tyo=Qj3aH!74!Nqc}lnpac{%kmLYn#%HEzGT&k;PR&Ux0=ch<n}GomNZ>Fj
z;$UE4xWxf;8apU>LCP2y*%;Xv#TeNbtHiMT0wsDtQ3Xn5;Ft&LFJY)*$YN|}3}#r#
z=%>kai#aDX4{ZNR#v*<!E?|$3&q>XTkLLq90AvyaW0fen#b6c5U|DF=0I}gI4x|^9
zl1dnB7@HX-8G;##1Q{3@G?~B_-(oH<Dbi#D+kT4$6s(W{0Hqg5D2OEIgMzdK9HPbX
zCHe80Da9ZYib2*eFjgs{ItZI`4A-$TFff2zSS-Q7z)-@lfU%t+4U!i)Y8XMeU?C_Y
zGoz+<kbNMH!l2LqrS^1iYOiJNU|7IV!?=)fA~>~!J)p^0Bo6T<NJkM!brDD(SQ6qf
zE}NXpVn`Ad1(^u)GB|;nVfPRysloh@lvv@0XfoYmEy*uR&bY;ySR7xHnpd0;3)v`6
zkY-RiErvJ%>=bY!g9k521<ZI6kb6Mh#T~rJ$}xf$lwiTZs|F3;5+?Kj?qJAb?qJAb
zNn!3~1_w4;(E_p^6m%e$7K4Hhe_%@^0y~No>^8LEE&|0g2p7Y{*&KT~Cnc67XTU-N
zM?kZK3W1!|yj#2li76?WdFk<Gi8-aInyf{Vpll=sB0xTX<ZDnCfTr;%0jTjHH)4$t
zQCI?oDug6>P+}?uMHmC47$XY<%l|56B3zFVF`!5R;bI{M1_tyf31-m5m*PR50!M-n
zD8*+nz%n#)s%I+V1(jUPnRz9eOhushyTzD^Ehc0@DGVMHI8uCO3OE`-r9v^bhybT$
zXu#fL&P>6bxRFf)B^3f?2eNWMO-@iroR^rJ8Xtd)D?UCqKczGW#O8^QFDy;WfyfjI
zfZ`mKaljeB2xLo<97s$FM1Z^sb}obf#bgmEB0%+0F%JU+0|z4qI|nluvVdivY3Y_A
wQgx?ST9TPlTm(v4x7hNNvQm>vz-l3eBJ2Pee~ZHg;!-<MH2_Za985gS0D?dZiU0rr

diff --git a/SuperTagger/Linker/__pycache__/Linker.cpython-38.pyc b/SuperTagger/Linker/__pycache__/Linker.cpython-38.pyc
deleted file mode 100644
index facf9eafa213710664a3d7f18c9150693e6c0a2c..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001

literal 5632
zcmWIL<>g{vU|?AJuO=y6ih<!Vh=Yt-7#J8F7#J9eZ!s`1q%fo~<}l<kMlmvi*i1Q0
zQA{ZeDa<*{xhzpExvWvFxolBvx$IHwxg1d(xtvj)V0A1x%(+}qTwpd+4tFjO69Xed
zE^ibcScEl)KbJp>AIxUU5y%yc5(Kl^bA)n*qlCe1jvSF((I`<cn=?l&S3F7_%;w6G
z$d!ze0<*buq;qAWWO8MrWOL=B<iKJ)Ir6y*Q3_x-Z;oQFQj`)SgF8bCUkZN<LkfQ?
zS2J^zatc#0gQmbskgqivZ?Pt4BxdIMX|mno3QjF7P0cIGOw75(>62KQTI82slzWTK
zCo?ZKvFH}NOHqD7erd@q7I#;l&|5;jr8y;;8L5dWjwK}^UHN&MjJMc)^HWN5QZ*TG
z@jB<{rDdj<7A2Nsf~9yJOY(DFbCXh2QZn<>K`KB3A^F*<d6`wIMVgGaMByg6q$cO5
zq!xi4v6Ar?M`B5SZhUTHfhOZEq2kQE?2P=Py!f=d_`Ll1)QSR4=3ByPnRzMk1^LDC
zd8z5~nJK9isl~VW(^E_0L0XIBlM+jkGj1`vx%p`_-r~<qEiNrejZaKYE-gw-uGD0_
z#Z!=&5?>7Ve{yOvC}@x|Gn7+o$H2gl$`Hkv!VtyO&XC5K!kEI;!V$%s!kog=!Vtxh
z!jr<5!rsCd#hSvA!r8(Q#g@XA!rj6U#oo@q!Vtv~%%I5&4$$n>qGU#>sUQl(W@BJr
z0NGTm!^pr;!cf9k!_dsMfN3EEBSQ&u2}>4h3Zo=LGh+=yJR6wLp27s>b3pmbP(CM=
z&jRIhLHVpuJ~x=plfnk&^FsOTP(B}&&jIE0L;0LgzCej!3K!TVLM6gA4DrG>3|Yb{
z+>#764Dli*q9Cy>zGkKxhIp|O@fwCK&Kia+ff|M^k!I!^hIok*$!3sBsS@cD8K^8r
z3Rg38i7beh!jr<=%LLXfM^J@)3PUi1CZFFe?v%{j_|!abvbiM!;uPfPBo<|sRK}MS
zCFT{U<=+wpiKga46CfmE+~R{Lr_7SnqT*Yud7vbHOCUF~B0eWIFTOZ6uOu}uIrWw(
zNCc9YGI0v$=cUG%R2HP(;s>dWFUwC(Oo}hgtV&IvH4_AO+E=kg2RORKrVGB(ZJas1
z&{mW27ISJrv8KW;*5cBF)S@CW1_p*(9P#m)d6^~g@wb>W^Gb?9Mb<5$-29Z(99M`J
zoxz!_NF1bC5G>)7nU|ef1QnNHU|=Yc1QAjoLK>uuHzlzou{gB^sze4PAqyhpK!iMq
z0HwfN>?x3ZUc?X5#Z{09&ad&ADMdUW0Y0$LN?_Txh#MgY58EPMuwV|@Nw8Qc0;TOD
zVURtbLiUyrD7Ql6rYIL&O573x<z7@VO~zX+#i==ID;aNb#>Xe;Bo-IP$FF4g<)fdG
zpPQ<moRV6QUs_b6?^IfpnpmozlA2nOUX+-b6CdQ~7hj%Pk`WJahJG+8ctaA?(^HG|
zA%O*@^a?7Alo%KoL_o#90XR4EF$plDLoOyZMjl2!MiFKXCMHHcMxK9cED*YiA7Y*!
zEY^~lK&C)3$lI(83=Ga7<J}lQSv!ldhOvYpiz$V%m#LO1g{g+AhH(LN4MP^oLPi^g
z8m0xTDa;EQ7qF!;*Dz+Wr!Y%0q%bdJUcgqvkj0V03}Oc}XtIEcGS(DOa!ZF=#qbhT
zj%ae;VlBxpO3t{&nV*+h9G{b!oqCHsEhj&*#LVOtQ)*$61}J(sOY@2gOH)&;Qg5*p
zr52<nmfT`3&PXgsEdnJMP39s!1_p*GZcua=$AjZLN-zadF2^U9l;p+dCKhKG8G_;y
zOc;R-70gY{PK}2t1F0|uc^j0l7{DP~B>@d9426ElpiqMO1eAg}U<INEBLhPR%L0ZD
z#)XWvtR;*!tXWLWjG$t10ZR&F2TK-f2TK;)LgrfL8s;pf80K2GTJ~Cw66OUgHEbO$
zS?nSV9V}TKAax6wYgsy27I4<E)-czASX>}87qWCRr8CrWmT-5lWbrKEUC7YPn8K98
z+{z@$;KC5QsFR_DubrWtv7M=%xt%4AHH9UGwS}XEzeJ#extXz<(FLR)WRqYDyA(qT
zM=x_N7n%-Ef;zZ6SQZG?a4lr4<?3KrAY8+}kg=A#gn5BT4Ob2K0?`!q6t0C#3z=$p
zYB(2widCKz?i8M0riF}*424xC%nKxHI2SU4DwxD0j0_B13Q!P`Sd^HX3M$(atl)x>
zNCj7lkf_b9N=;F)QqX`1fl4eL1(X60RI<U!LQRFVoJ3HipPHhOoLG{Yo?n!iT8w5m
zy!ZthNuk*aP}7jD0lB6)9$x04+kjkn>nMPWen`=&sgPe<Qczl=ke{aD<Qf#Bkd|Mh
zker`al3Gy$_XEUmQ2K=n1*Mjh7UdNqg(9e6i7!vhOwTAOR<Ke41u9aJucH9=xh{^t
zOinDp62!@gC7N&>Un(##F!;Uv|NsAgjK%^d0|Nu7EG>3oU|>jRs9}h;>SQQkSilHQ
zr7Z0X?W}36p!RDEM+s90a|Z*c>IS9L5{3oL9ju^|CY@27VIkv0rb3or22Cau8(?J?
z*ai=n4Gu9(wT!h)9V}T4AoDsHJD4RIve}AQAgmfjNl-om%Vjeav2-wj<UyiEEHz9J
zdm!ZnAE*V+e2YoX;1*-%EyjwKthYG9p$96=Zi#}J&?Xxqb>0$z2t!jsCA3M$S{$F5
zQUNZjAZCM$AugMo%;J*d{M-V&08se@szBKo*cgQv#Ta@1{?dh(L23C#<%vZpdNw)v
z$%#3|c6tbHNX(V2MW&##jkUNWF**AdXKG$)Zfa3tN$M@us??(VVsJ&flCcQXZCJ@#
z<N_)YT|s3hxT-9M)PI?I=|!NzAMDbVj79DsRUo$(fr1z8NKkRd#=`VRg_-AXl@ehm
z7J&-;m!PUflj#;Ote!4Q%qdO1#SY56IjMOlu?kW@q|!|>mUL5O4+=R@Wr~Pza2SSR
zMs$%UNUax$@B<P4AOaNDMevXZ>p~9sA`=D%hL@lW0O}lp8-<#ZMOxqlP*4PFjah?C
zfux9{#JrTeB5+d*)D#2PoJHV9l{QF`4M-OYC@Df3Q!Lr3mBmG(AW_z$#Ju#>Tg*AB
zc}3s`m<~vTE=Y<wDL3~PC%6?<kYAj7i!HIBAT=-L76+`UTI9{Zz@W(sPD(|flm<>&
zQT&iX3tl2daX|z?rC1cWL5fsd-r@q6PT<CA6c2J?86^zS1~(@@B{MgQ2f46|;sLeY
zz^(1fy!0p@P^%5byTuMI5h2k74suXv-(rX6HVaTZfkYU1xEQ&Zg&27lB^bq+Ss0}l
zMOe80u&^*OaQtCoLuPX^axsZ8iZE62q2`KYP&C0>43InyZD%cGU|^_W1a+7eFqAN6
zF)d(T$WSDf!j#7ZYFX7X*D$8ANHQSF8NuaPY8X>kQRMvKa;!BB3s^uchjfO8%(ZMa
zYzshb0mc;86t;ydwQMDfS!^ZjSsXQNpl)C<6NIc~F9EeT*qfPZIck_{*lJjtnQA#}
zIBGc4Ks|8|KX69A#gbf_l6Z?bxu6uB1d9Se2_gg((~P&6Q!;Zkxo$C)<b#W%B2W+Y
z7Efkzd}3K*W=>*KPHIsSDCw{lq!#6tmVnYsQF?A-#Vtm+TkOU8X(hRd6-A)-ttJ;Z
z)!bqQ$Nnv$<f8myPzhC1lwVL8pOar)e2W__1?wI`IIu$ZmH>obQdF7>ZdVsaiRa{}
zXO<MlCst(U=R(Bd3yM-xq6E-IZgIdm+_zXjHbD{yqV33)T9%jtDb`a!X+#JVvy3tz
z$jrwmz$n24N(WVf&@_NhjgbsM4M|W5Q4Gr6DU7wuDU7u&DU7wOB@8v-=31;l3=^o4
zSi_paWX@0vidM!N_8O*zOfgKg9JQP^%rzY547JQPEH#|w47DtUA~no4tkMiMEFcmb
zZOm2M9m`US64O%^N-|OvK&~kTcVa;8io~28uqZgR^tj+Px?@p#F`^X)YN3HTWjXnY
zDXB$P;6_MdW?r!ZigtzajLhT=m?YR(aBCdgWCLYYQ1`T02PBbMqyWn7#R`deDImTA
zD4rCGD|3_ba}ez_aB>5uF-@i-A5bbQ2Bk7+%HTq{s3;dyiwVGDEj|(A-zXkX?7?}g
zplY_{7E5tSYC%yB$hblfQ3N8ukqah_K*<vl!>I-N$r&Iki@`C>#=;1K0*oAtAk4*t
z8pB}Knv8y$oS<HDUSe))eEco0`1oA#KvR7DEuQ%J!qUVXhzz(RR|G0|!5N|mWKB^S
z$ZlR}=P@}oCnr80DedWl913zCxa$pW@qmgF4n`h%4puH54t}t>rphf&P-Re(pPy4)
z1gdnwIq?<;xTmF;mj~(If_k<^pyXEs>hct^gUkliE=8bv_m(hBv0hqfUNUH0EHUR6
zQ(oRJE(E_Q3#13+f+z)~E~_3i)q}>&;Dcz8&NnV)@S!zu$3036i)p?dj*z~(G!{uv
zdkZw^2CloJu&D@!4#gD-fZQAp3T6o`29=g%<`fqRg4BRw96SbaO8{9RDEENNr(1%^
u!Vrz1`Vt%mNRfs}Qyeyslwb#H-4uhWVjd<AMg>M5Mj<91Mg$ZAv6%sKMDjiW

diff --git a/SuperTagger/Linker/__pycache__/MHA.cpython-38.pyc b/SuperTagger/Linker/__pycache__/MHA.cpython-38.pyc
deleted file mode 100644
index 679c41a8ef96c82cce084ed1d859b6f0933c12f5..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001

literal 4389
zcmWIL<>g{vU|`t!rzUBUAOpi=5C<8vFfcGUFfcF_Ph((UNMT4}%wdRv(2P-xU_MhW
zGZO<NLoQ1cD<edPEtfrt9n5CR;mGBP;sCRma~N|uqd0T9qPW0(mK^R}o+zGN-YDK&
zz9_z2{wV%jfhd7o!6?C8p(vqT;V9u;kth+cTGkxVT(KxIMh16=6t)!h7KRk|RKaHE
zDDf1=U<OT&mmnAWX)@m8b<WRA%S<mVN-W9D&(mbO#TJs9SDasTi_0fJFCD_x<h;e=
z?&=eIi^V56#P=4vOHqD7erd@qPM^fe)FQw9qTE|-KACx`iAA^AeDhOEb5d^!`IhFC
zWM-r$rZ|?Aq~?LlzQqNR^T{kOxy2QnT38AaPt4I|yv6M1=9dg|5i(|ja*C@M7#LC+
zKw%ffl)@OroWhjCoX!-*lERY0I)|~HA&n92L!KzsR5ln);h4jk!kNO=!V<*}5li8o
z!<@pC!rQ_U#Q_mXWlQ0k!;-?EBGAGT#R(Ei5uC%6B9tQB!V<;R&cMPD#U0F`DRN5`
z?gy9D<ouM>BCx*`i&z;LLW&aeiqrCoa#M?t<rFfD6><|(QWZ)I6!Oy)ic@pabQ4QT
z@^ln(LE)kc3K+0}LSkNuLRxBSN?LwVd16tDLSAY~d45s09+!fGf<j0}X0bwXNn&0K
zNGb(lxk3&o)Il~RB^IZqDCFlUlw_nT6eJd;7AYvfovPqjtl*fFqY#;2s^FKJnxX_T
z(y=%rvp7Q`EU~yeF)vd`!7o2CS0Okfu_`sSNJqghGdoiuAh9Snu}DY3D=|AiNg=ej
zDm%X@HM2xV!6zp(KTp9cKQFadN5L^OB{5IIPr*GuH?>L!>|CGH?8M?K1@FYnVvu1V
zuX*O=WF{&E<mcoUmu6>V=IJRI85o-DDIlDgS*(zllcP|XU#gIonwp}g;F+i3m{OLQ
zmz-LxkeLTI*e|uTC^1LDGY=Hf;Do9WP?VpXT3no&m#z?8SzMBuTdbo{keHrYtYBtf
zV4!PeU}&JH5L%pC1d51Ch1~p<%(O~{{33<S+=86c+|;}hu*sQu3W*9SnQ3XMMX7lu
z3gwBF3Mr*UAX5_y3UV@&!ItT9fz>(|r59U)S)jN{$;^$<%}+_qu>uEwUTJPpY7r>1
zQ!5HmlS@)T(VAFNS_Jk5SV?AHL1~GCMp0^EX=YJsiY6pNKxXBE(gV6VATdw~fc1it
z4M-#rmYqOeO3f+8Y97dKNQnw!a!O`yYF;rY2J_RvW+JJ8B|C_r3K}V?X^EvdCALNe
zCKgEUOo3#4sCf{MLRn%?X(~dCfu13fj>P1W%rbBwLNp=x3TdTz$)Iup6lS0VoS9dW
zT9libl9^bN3Q6NS3dxCi3Q4I7i3-Ifph$rSJ}A&Ml!{VwN|iuys+105X@Z!E3Z;38
zMU@K4i8(omNja%{3N8@8SSf&%!%YBZqWHZ0qTKk@f?|l{QVWWqzRApkxIh7<ATcMi
z3hV`i<ow)%{Jhk>l468^4O4YZkvyD~SdyF(pO#rvTw<l*nWm7C5K>f{nvkFa3UWwJ
zg`@*e?#M4K0VRnNaEV*2kXV$eP*9X#mYI^80(U6LtBJ)58eoHU6pB*|brjN4!6{1<
z$#V$_ZizX?sR;>4sS4yds7o@DT>|zy*lVCtBqcu&WNl`CkwQs6xTsQqCo*Vx1xk3R
z4uYpbkT;MMQhq@yq$(-aQ7B3+24#iJvc#N9Jq7=gjMSp?%;Hpq%o6ot1+cY=X(g#e
z2p5C=3Xd+hj}$b(%AkRKiyIsgkRZOr1&SaD8&sRI6{VJx7Ud-~LhB9?#mT_Hz`?-4
zzznLBKQS;clrWSq)-W_PEnr&6z{pU-T*8vYn!+f_(9BrN2<EY+FoAhYU><u4GnmH=
z<*|TyEMOi-3M-h$TEbbvlEsn22IjFr<=DYIcBmW&n8#7VRm0ZISi=y{UBXktRKr@s
zmcm)XRKt|Q)yr1H5YJn}SHo1pk;2`}RKpO@4^vaaSi_jYnZg5BBY>htutcbasfN9V
zErqw2xrQNL7)4ejg&~+hlh5xKH>glZ%>&n85Dv6Hxh0UBSP`F-nipT3npcvVmz;Ww
zKd~e~H@+-CIWZ}|II}AC7Av@fy2TDH0B-Sv3`2_cTYTU&1<y|*nk+>U3=9lK0t^fc
zw^*`MGIMXSl%-_mf-3vef?}AoG)RpMhyW?nWWL3o0x49ZSW@EC(!geDir!)^E-gqc
zD&hwj#t|Q%nU`4-AAgHEGq0qG2_z;2E{I(ro`W=KZb8JLl`>R}Hzlzou{gB^$`u4_
z^U2K1PA!58v8LtZCzcekgPa8R05`~dPEZ9N530b6Kq=`KD=5_&-r|Fo%3wK&uu%~o
zNDFgXS{_8e7}R_M8Bin)GF1dbh=K?)5FrjCz&1&OSW*lO3@aILamL3d=Oh*v$H%W^
z_~ogek)NBYpPZ6fkY8F<qVH5%l$uzopOTtdkY1FSnG+x6=NDg|S&|W-1NMr3FetP_
z64TRDi}WGB)A#jo)GMegQea?UPzB`!Jy6bI<YVSzM1jm)j9iRN3{3wynD`huKrBY4
ze;mwwi~@{Ra_~AE8mQn_kRB{LG+80ZrZ_$|Hz_qGB{MJm7C$IOAq((86JUI9VsZ8@
z9%u>#^NNHR7#NaSLGFTLPy-MY+|D2edoeID)G*ev)-Wt!r~#3THH<YZ3z@_j7D7Zp
zS&3l*a}8?^Qw{S%X1F*DST$=6>q17T7;7+tCS#&FBLhP~VsSC3Tvh-V$slG?erb9J
zSP1GiNId{;0f5RUN0e3oNCcEwkUR;gONznm+vHSGnF^}iP>jP`X@iVH3okUI(u(qP
z!EVb*EG~gK3|4tU!xhP3l=2!X4Gw4&38=GBB{bQJG(c&XA1oZ7n~@lwl#`#FU8DsP
z<w6qB0SR!%r=>v@XtEXQf>N;&h%g2bCLqETq>8bk$N<Cv6NVrWc4(W|4ivxOJ{1=W
zqZp$EqX<)#EcV0()uG8;WC=10WH2P%fC_hLvH=(LptJ){H=qnv!;r-QDnM%(YnW0P
zds#rmdo5!PV=%)?W>Brc3Ch`EzZQWi4o$`)bp{3oO=gH^5H^9G4kkeEFEV3bU<d%Y
z9n|AvU=(AlQpN6Su+fN6(_{g={uUQJwu_6v;RMbBV8ftZ0kN_9gb9?!QGHTm4RSJL
zktT>@L2|h*D4@W`gIx|Lz|J>kU|<MEaXz9|j+(>a2|$wx>~ctugGyegI~hQ1P!xcJ
z98{Epf_wpE4dX&aNI+{c`e`y1X@lHi1#%%1k_*930ux|&Suij#L?PV8!&oJcBUs_d
zj^`Ffe0*MFZfbn|Ev|S_4N{r|V)Mku7nUaGKxNqD<5TjJ<Ku7f#>Xe;=YjfFCB^aa
z;6kwo<oY6z)4<VM1PYuYkW0XdK}B(LYEDjkJW?746`@677lZ00UXY7HWjhBWizo+^
zkR%(Z^!(2wB*ZGhS_D#|DRGM>IlrLt7Hdg<QF6vDVKAkam!}772`A><V#>?A#f8LY
zbi2h3QI}s(l9_vp6-<F^J&2P*N#qs>L`hyAs7uOF1WGcHDhyP77IA_?9hAR{K$Te$
z4~PqL?kx$Vib4-kQRtPHWabout1z&mk^Bpfp<5g_x!_*D9VoXHgQA;*k%Li)k%y6o
LQGij0nTHtw*0ILR

diff --git a/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc b/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc
deleted file mode 100644
index 26d18c0959a2b1114a5564ff8b1ec4304f81acf3..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001

literal 687
zcmWIL<>g{vU|?{aQIw>@$iVOz#6iZ)3=9ko3=9m#S_}*fDGVu$ISjdsQH+cXDNHHM
zEeugiDJ&_hEeui2DcmU>DV!~gQ7kE3!3>%_FF~fPWW2?hlb>E(nwwftkjw-Uf?^Px
zje&u|8DzE)0|P@1LkU9-Lk(jJV+vC*vjjshgC?_Ih$ho5=9J9bDE92s)Pj`E+#--Z
zO{QCn6-6uz3=At7ia>_^^3%`A&rQ`&PDw4uFD)w3cPcGPO)S+<Nlh(CFG|eJi4XGg
zi!aYC$%xO%%*#$K(hn{zNG%FUOixcO()WQ1W#(mP<QL`X6;$40$;&Uw1zAuGvX6nW
zh=YNFA(;{Cb`S+(v%=gj!oa{#0(L*x^-L+uDNMag!3>%#elI}^H5tLKeF-v56JnSq
zV-ZL@*!d8Goq>Vj7H=`s74gL-sRbZgVJ_fiU|=ZXWME);35t?r5CL)r2!k97;)BCd
z2^^L+jM5CXAQv$;GrBOu^3*byFxD_MGuAQ`3e<qYmeEg>=@x5IVqSV`5y)sw7O;Dw
zm`f5%Z?R^Uq!tz5VvL8l3>3sgVCM<JotKsupO+t>T2a6Qia?M<85lVjc^J7Eiv$=L
z82mJuz>X~fNo%s+VlBxpO3o+(xe^?fU^~Gok*q-2!(o$~pHiBWY6l9vVvslwGY1O?
F3jpAPgT4R&

diff --git a/SuperTagger/Linker/__pycache__/atom_map.cpython-38.pyc b/SuperTagger/Linker/__pycache__/atom_map.cpython-38.pyc
deleted file mode 100644
index f189466986ce4136d6ed66a59cd0d49b61b4aad8..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001

literal 483
zcmWIL<>g{vU|{&BQkoRW$iVOz#6iZQ3=9ko3=9m#MhpxLDGVu$DNHHMDJ&_hDQqe1
zDI6)BDO@SsDLg5>DSRpXDFP{iDMBg2DI#-3q8L&{gBdi%UV=<}$;iOK@RA8cFoOsd
z5Wxx}*gyn3h~NMboFD?^qL<tt77vKv1rdB8f*(W(fCxblAp|0X85kHeMWR@ebK;9`
zF%=ZtV$8e6SP;ct9A8jSlv;d?IVCeOinTbtJfkp*xj4S?7E@jUh@Y947R6Q^UyxW_
z9K}*l5T6*uR!|UMkXRJOT2K(5k_tAl@)mPae#I^3l8TaBjKvT`6N^f!Sfc|RU1I$-
z8E<hUmgMKg=Oz}cWGIqjV1N+6{PZ*Ob5r$`Q&J1^ON&bMol1*R6HE0|Qd0}kixM+)
z;)DGB;>$BjGU9VG^RiQm^n*(aQj0<o)6-Ln^nD;gFoX39DsOSv<mRW8=A_zzg0NVN
Ofq{XEktq0<<v#!fk9vdv

diff --git a/SuperTagger/Linker/__pycache__/utils.cpython-38.pyc b/SuperTagger/Linker/__pycache__/utils.cpython-38.pyc
deleted file mode 100644
index c4eef1e07886db024496a876b846bd69646a7538..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001

literal 8757
zcmWIL<>g{vU|_KPU6Z7w!@%$u#6iX^3=9ko3=9m#2N)O_QW#Pga~Pr^G-DJan9mf&
z45nG4SW*~Lm~vQi*`nBT*`wHVIifgnIiompxuUpoxudulLHd|;SaW%zc))Cy9Nt{M
zC_XToHHSY}AW8ttX3G)G6^asKWN>FlVNc;`VMyUf<!NS)5>DZa5=r4o;cj7w5>4Sr
z;ca1v5=-Gr;csDx5>F9G5o}?Il1LFs5pH3Kl1vdv5p7|Jl1dRv5pQ9Ll1^a^X3&&)
z3G$ntCfhBp;MBs>)Vz|+#GG54K8cm7MSl53xwqJSGV@Xsi*B*I6y+D>mzLaOad-6z
zy(Q#Znp2XQk(!v|SW*Jgm7k}{c#F+9KczG$Rg>`+uVYDmZb*K1YF=hlYLO=6Esn&J
z{M`85!~#vmTRa7cDe=W%>ylHGK@LF1pa2H(i?tXS7*ZKPff~h>!WhMz!rso1#+bsK
z!qUPS#gf9B!q&nN#oErm!Vtw4%%I6}i`mW1uS&u*w;(4qH#M)MSi#NBPa(CiG_fQz
zKTq%ZteGIN(>@ty1IQK-o0EZoff*E~HVh05B@88uH4M#63z!x%Ffx=dm$1|@)-a_o
zrZDv~m#{W7l(5w>H#3&7)G*aBLPVQcY8c|#Qy797G@1R}G#PKPBxNQRYjWITEiNrc
zEh^$+U|_h#5g(tKmst`YU&O?~z)%FTwulYHVg<=Er={g-vfN@RPR&WX#hwzMo1c=J
zbBiS<J`Kc$g#1dzTb%Ln$vKI|#qseg8Gd=|XXNLm>L;h97UY)}mFPQ_7NsVZ>Zhcp
z7Ni#?X6D2P`T50{XO?8d=Vaz(rxxi4gIp7mn4X?mr0)X}DlN&(Db_2fED~T~VBiM@
z0Usz(82K2P7<m|37}*$k{#5a}x%uhAoSzJGK1dA+gV>;ea|Y=IWt$p?1q?Ne3mJnM
zRx<i&G8PFkFfeE`6$yd-$XHP%0rD`IkOYaar{x!wCl;lEM2bOXF)&tfgDiuJYcdw`
zf(&2>g<f7_Zfbn|Ev|S_;FRWo*gWy^g{6r(P#O04_>}zQ_;|2iiex|*@Iu2YIW;FI
zJ|4+>F_1$*ZU9FPD1wSXk-@>p!o<f37OT>aDc6qCh)L1b(XhAD(6A5E)U45nDc9E2
z)Y7rX5Q@=E1_cBtP(T<IsUQPEV#N!<sj7ylnX!{GouQVwgt3Ha0rNtJW^npxWs+oY
zVTdj1WN2q-XKZImV+N;ijuMs<)(*yI#%4wrhGxc27Pu^H3QG&REOQB4J99fr8dC}z
zrfdyE3cD0T3R4P4FLNzR4f6u_8kU8Oj0}YeB^(PlYnVaNrOBC?z{tSBrJ$go5Rh1u
zn44OXT2!oH1s2FntcV9?-{Sbpy!iaQ)cBIhf>Z@7h2;Faw9NF<B5?i!tAwa3$j?bE
z$}Fi=uu{-SN-Rmvh%e5pO4U(7F{d~+uOu}uIaL#^5uz-yA~Qc1oLr01_2#7J#ly8I
zBqx@nrso%BrWPxJQn*4%Myf)5rb2l}W^#r?Vopwed9ea0)o13VD}dbuwj?OEq_ijx
z<|~MmB}JvF*lh&qi!aMhPE5k)d{8hXCzfb}0-+=^Gp`uzhRnQ_)Cz_CG%zPGH9fH;
zvn&<K1%5C8|Ns9#nGuvmpcupkB^6;%9@1f8U`S`EVTe`f1Sf`e=5~fOP%XmP!coG|
z!PLPZ$<Pc=5EGdSS%MifnNZ9EB{vQR1_rQMpu!AnmPrg#En_WH2U7=g4PzRUBtsDk
zsE}L0)WML#AkNUi*ugBxP{Sz6u#l;lv52LH31aq2##@Z>E17SxB<1Jl++r$9y~S3X
znpl*av6A@~lb*pX#>^sV1_p+eEVnqpK?O<_x41#f;&^b_-r|FBV9DVYYjJ#L3bYge
zWxiWnHaVHaCCT}@1$KrY?}OaV#=yqF!zlCrmtJ~mNjy@D(6h<OPfpA!w$nrCLSll7
z$JD&i+|;7PlGK%qMWEVaC3BH3D5vXz2z`)a8H)@+Y>@3mAS*$*7*rUuF*5zHQYOMm
zlmH)q$SWo)@`{u}-c$t<njiv{ups3h-0LRjAzowzQfUk#OhJShh%g5csCr<*g%J|$
zu%rhv8l3cWIvGkBvKZSL+ZocBK#ignjuKEpgeT(?hAie3re0{0(_}_56O^1mB`G*L
zgW?curWzzU*D!T3bTCRXWV02iL0AhIYM2%><}sx(KoYPfW04Z5`e3}p3eJ~BppejH
zx+Rbk4@;8qxtV#T#gH@ss)dT+KD7je4%nv*Yz*L#(PX^EngUAXnk=_iOY)17Gj4Gv
z7RQ&Q<`w4`6@k2bizO#NFTDs9c$!Q_AQNu!BiU7uQ(9bv8lE6E#h~hkjfv@>03+9b
z7Dkr;FF{q|N~T-vpz<;&HSZR0L1IcuW?p)HSz=CUswUShQE)mf$<Hl@6kX8##sbPs
zMWA~47E5+&W$`W6qQt!P)LYCssd+`*pbC*Kv7jI|FQv#D#9=KiNlebxWW6N{vIvq_
zkqY=CP=<*T1ZQRBd>ADUEm`A19*Qr?hXi7LVonaYVknLh#i{_D{4`n7awIsIitJ#i
z>=r-RYlwP41r!e;f`OBRk%du$iI0iv9~(26WMN|DVdDD7!o<eF^p}m5hpEV$fq?;~
zg#s#BK}iUdl!_Y|7#P4!NX8T<P<gt5aUlb!a;;_RU|7JE!n}~N$fSc|0doq=LdF=T
zT2N)llFk^+kiuHSl+9M;)WHDam9Q*eEn(|mSje=1eIY|FYYJNmQwu{4E7&Af5FIO2
z%TmKw!coJL&d|(c&QQx(m{J2WhaJ>R{WSqx8DZ42&^oOWEDPmirWS(&7#ueWRtiO_
z$)!cbpeh9v(_n?Db(&9RaS5pUfoRCgE6z+w1=mGT^C}fG^Ark-QqwXk6!MEwQj5T5
zpkz8w;)7>oP{4yTolGqwq|QiTs9_XmNMRIb0EIv^BP6S>WV*#r#S$NrqP>#2NDUO3
z>fjWX3#$EBGJy+YaKx`<yu|@?3pha_a-0_^rXe|wiJ{66rV8$V<TQw_F8Iaj?(gVh
z^^4Ie1j9ohV}n8FgEBeTLjtvoHH;|?*_=fzHBe6!Nr4;=QUP`GE#~CJ5{N57q6nu3
zfs6trEj9)w1`Y;4O-^vGD{=*Snya`Zv8bdN)GoQjT3nEmS#pcDC^bE`q9_0)EdXZ6
zL%o-oT6~Kwy(qu5p!k+7RIC!KmsheCfqZaF7^(_6Y20EiN-e&{Ql4LwQWOrdga@jq
zGCsK!>I0CQAwGy@U|?_p`2f@kX5bKG<YMGwWMkxDWMdLw<YVMu6ksd@sX+-vP$mNT
z7~Isi05?-=8EZf-SFKtmP^-3v$%O%&@R>^(;0d3(P^5-&0VAl5$;eR10m=j`8LRd~
zQXQlXR|zS+(A#?`MG~UnmzI+V>Tjl|AU6TArD@cBK}<VTfoQWenZV&w<N<0?utxF0
z0~=gc-eQGN957cG#e>onDDu$L5-2buKrsRCb+NH9f{Hj$`13Fpf#h%(WlIPbWp1cN
z8K~GwVXa|V$h3ehg>4~YCj+E7OJN5WZfOvCP^LzeFJWK6TEYPq=L8pb95{+QE_kdn
z7IuJQI)yWsL6fV<2UI458cNUxF9TU6Opz}r5dA?!Acz1bXVhF-0t!h`P6y@6Dph!_
zl;lHtFW3r=q7aZ_p$rTRRe~DPiMmyex=}UpvD%uTvg{?u)f5*dMPVQ-z)l4d@Fqqb
z$i1LE1#&N_6yZjcB1NENQ4|T%2`Wh;MF%UmpeTw0iAIBn7*Jjp#PDcQ97qyu89120
zHfAy~Fyw%21eJ2ITrUL9_d>`yAEb^pIll-r)T7B%1hN5~cZyQr?gTp)<W6V<0Nl0#
z=c5*oo3Z4NEKsrnB_&w00+mpp*auhjb<CiSR4o(M>Yf>1-5b>~L8^O3h6&7h9w{s}
z%-L+keqedlg^W-Z+d{^o7_drqP``NsbDkVn1;;|h35-RrKsg7gLBNv23C<xb;06Ip
z4ND4_Bts2L3b!N!6GJCsCurP=C7WpiW06xkV;W-$4_ae{8S08F9gGW@Qb4YNut2VO
zfN(_~Uky_UM+ajH3%H5F)XtQ~n8J(QtO<<8Y8^}v)8K42u<MaR0BRyr3OmHayd_}w
z@!>SH3g*HIjKvT(#LPCZE5Tz7P%{}}W`bSGkJH36NUnsiAtv4cn+OhlsEJH4e@|e{
zn*fdtfrX3{n2V|~VndLq*kCHU2KKEG4!1ECXTW_6V}snrR8#_X8*bkU<1}#>+_y}{
z5H`fbV_*}p`W6}|A~?)sF6M)~62=CZ$y_9bVJ2$I5yfd@0^F6%#Sk{c#2m1R$o__>
zoEoNJ22HW5!{B-zXG2%PN};%<2-IB7&df_u0F4+sCzdGGC^$n~2F{5k8Yx8z`9%sT
zIgnu@)LPvsKR+imF;5``+yw&}RGg7mkP04WDz3~;%Fh9hSb>^FI<RpX(1?(sLRx;2
zLP3790=Qq8TC9+mm!e<*7R^gdhlzq+k5UwZiW`#K#zfRy^Fd`CsO$pO1XX$%1v<RH
zin9s}xy784nRtsiDZj$)7E?+QhO<GYgPI`4ppuqw3m9jmwF2Z+M5R>(sv?WhL1i(B
z1uKiG5m2B3ot)IPl6X)j2GVi@RS@7|o3$WUfV>0>Fbuy^8bvD^Z*di6re|QadNWO}
zF3N*fiEz*Ermc+ym^L0nwlRt+B?lvjL19mg<a7#T6Cycb*h6VlVhhjnAUi=Gph@B>
z!V;cWX=@{vG;^1>He&H3XrdG{lFGoufZQR5G-W}3B}mH^)C(>`G+(7~mv)-0;4W_z
z2W+@5N(5OkOtdHgl!B8$1gLP+WCrUh$_EJ)f(UR<Dq>(@m<|dRkU|D_7Dh<FSB8;~
znTHY7?-gR=ViIGNVdP=rV-jOjV&q{GV1dvuz7UE^kP4U#x=KE#B9MI;{Wnlb2e(f_
zt=DvhPKFvr7lv3Hq#m?fEeojsoWj@)>QD34veqy!V60&UwL1%y3Pm~?7Jz4uKs{=(
z7)J>UxQ)w{!ra190v?!x%^B4&q_Cv0_A&=EXtGs(1lML*d)0Vma=<O{LE1G}hsHl>
z5(wJ}0QQa>cq~~2JbqEb5X)1`*a4pUfDB_nY8h}Mf{s~fgSz;PQCy($HpoB)L=8Ae
zfhG@rfsz!gFsV{c%gjrOFUT*B&r3~@&rC_JNG--%XM-xsBG57dO?Gfff%Jz!!?H!7
zc^mL(ZIK%&@w$VCu-VYYXyHQ(pi%LnGzJC+s1v~hf8g|giwQP%RSYr)<ltMvSe*9;
z<N@$V5Ib!Afsc`oQHW8Bk%@une^D98ewiY0FBH`Cyv3SVnp;qLiz%-FGK|at9t+pY
z%Y)2pgOWF7dK=UsgUoS*O7SAl<jySxq)BK!$RxBLsNn}+Dp14;>eEVNQ3moFXaNDZ
zlYC1MSp|3+8xm%q!E{LJ0vZd3jMIRIQXsufP_~1FH^^I%s02kfgar;G4x3!?A_hB9
i6cvLe5I7hiP=S#L0+|H_SU5Q3IK((aI3zfPIYI$?tJ5w3

diff --git a/SuperTagger/__init__.py b/SuperTagger/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/SuperTagger/__pycache__/eval.cpython-38.pyc b/SuperTagger/__pycache__/eval.cpython-38.pyc
deleted file mode 100644
index f5cec815f7dbc90d4ab075b8e48682a5c6e119fd..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001

literal 1899
zcmWIL<>g{vU|^^`Tb-1}%E0g##6iX^3=9ko3=9m#H4F?4DGVu$ISf%Cnkk1NmobWw
z5yWQBVajEWVg|EWa#(U%qgZp<qS(NE)*SX+jwlW=n=ywoiZg{Fg)N6Gmph6(mnVt`
zEY6<8o68r)$H?H$kiwC|*}{;*nabSE9L1l)mBQV^5G9bp7|fu_^AhAPKTXD4Y$2(6
z#rZ`bmT!JaX-=vp(=Cp?oSgWa{NmzUyvarR#l`Wdc_l^p1(lkNw>T0@@^j;J6ALt%
zZt<t5mc)bj#qmjrCCM4LgwrzfQsN8pi{tZB)8jK!QY%u6H5qU56eOm^7pE4MrsgH5
zCW9<S#>`Mou?hnNLn=cQV+unQQ#(T%V+vymQwv8Fa|&|`OAA93OFIJ#LlkQ;gC^@O
zp5V;9?2P=PJfHmH;$%js2_Oo@2AK%L#X<}W3?&RDj5Q3+ObeJ6GJpaom_d`#Pm}o;
zYjJ5oYEcm=!ftWI$7kkcmc++vGTve-PR&VM$#{!1K0Y}ovA8%sekH>%XZ?))+*JMK
zl+=R!(xMW5r_!R-#8Ul~)YO9XqQuOc_#i*O`0~t>jQE_)yzJB>{a}#cA&KefsYUv!
zWr;a@1(iiS3=9mspx9>uJG4p^*%Nv&CnSSh4pIQZAT}!l1A{Y2n*`W5;7|eiiLr*M
znaPErnXwrZq0D|L!2r_84$=g|dJGH<H4L$8F-*0LwM-=pHH-@wL8)&cQw`Gs#)S+i
zOrRj{WeH<oWGG}x0R=Hz5hx&Ef{0rzxv7bHFaQ7l|6h~o7H3gvN@;RQW`5o+p3L;T
z{G!zO%)FG;3O`MzB2e<a#h#XvSW=Rjr^$SaNzdRG3n)Ntu@s~h<rWDrFfc#}K?Vkf
zTbwrOsd=dt1x0osZZSwV3nSbADs_}_Ny{%PPb^B&v&qR%PRuE`(}U~MWV*#%T$+1}
zxhk_jleq{Kw70kmic(WD!R{-LVk;>sEy*Z`SP8ZqWYH~lsF5I%Vm44PGjcF;{I3$h
zX0Ik2D0So|=BCES-{Ojo&jsg^`1o5q@$rSFi8&A%aNHKjgM7scO*hG@IXUt1NOtpq
zd=GK}IL(0)1K7(Pj6BR>kz{bP06B+&fdLfhpa20!d=WU}7ckZ^f)epU##*Kn#u}y?
z#u|ofW*deSrW&Rih7{&(CL4wtrUlFk85XcCWT<5>VXa|+h)XgwGuE=yFxRk1fa0EM
zAwvqIB#gx*$sht_vw+zw5)9G|plnw4mP<iFK>?h%<4cN4Q^BdbSiwp`10q(OS(U1z
z04f3E%kq;Glj0!)xrr6fVj(jxK0hxtJ~^>OQz1D&uOu-uuUMfZBUJ&MW)$+%z?{6)
z^u&_PvQ)51#b7%VD>CzQA=br%qLdzXf|Gm^sLUu5VPIf*35pL*wp*+v`9;YYw^;J?
zQj2e~fP7da1Bz&t%#zgHTg*ABd77-Z1Rw!|2&7vA=w688P0lYWN=+_-cq>XUuQWFv
zRY(9Nl$%(botgp(K&0dYN;$Xqb5n~;i&Englaot}5|b-MLD?IWlo(j~7<s@@fVoJA
zfq_9&ut*h@;6R?Z#R2w$US3`is9d-u4CCshmF9u-N@5N;^A>@U-7RUP!b{HwT%PHH
z<2X05pa_&-ZZQ?5-jcwgs<b3Cr??1I?i6u=f=Ccq4MdU?<S0;~3w9-v|3JC12<%A?
c8%QRz1C>7DGMoeAGe#ap4kiIc9%cbf0Oxf1T>t<8

diff --git a/SuperTagger/__pycache__/utils.cpython-38.pyc b/SuperTagger/__pycache__/utils.cpython-38.pyc
deleted file mode 100644
index 9e66bb4e0c44377cdeed2d562db8eef201427af3..0000000000000000000000000000000000000000
GIT binary patch
literal 0
HcmV?d00001

literal 1851
zcmWIL<>g{vU|^_tP?1!@#=!6x#6iZ)3=9ko3=9m#J`4;DDGVu$ISf${nlXwA%x8{b
zPGLx4%3;Z6iDF@7Na0OkO<`+cjABh?OJQ$jOyNjlO5sf5YGI6GPvK7CX<>-sNMQ_S
z(Byjwvd2%8@fJ%-VM=bu%L$AO49P4YK`7>CU|?WlU|?_t**}kgfuV*Wg;A1W0mDLu
zT9z7?6s8o0RwhY?TGkrIBE1sE8ir=Z8s;>n8dh<JTDCla8fJ(XOa!7kiz$n_hOLib
z0ZR?jLdIJ55>}8t7lv4y7^Yf|TFx5I1q>;S3mI#<Y8Z>!YB+1yYdBIEYPedNQkc`2
zZ5Rq;Kw=P)8V<M^doY70i(e5V0|UcL5TU2Zc8jGrvnur#b53gBEtZtTlEhogd8y^M
zSkp3da^i1srskFArWPfZq-t{C;w(-rEKSWzPA$I0os?LToDrXvSyWtdi?<*#B_%U2
zJ-#e4r!@5zdv0PyJjgf>5Cdd-6mLmUVrC9VNlIpJ@hy)0(vo-(D~b!mDoM>N&M&&f
zn0bp0%4Ew)%}X!ISjkW%!oa}r%SAsUKQ~oBIVH6qzqF`C->I}HHL+AbB{j7m9b|BP
zke^?Cd1gsQd`@Owc50D+aA`qmQAlEXdTNn=X-Q^Iv0g#tEuMnJlz5mII6$!_0g6IK
z9!4HU7Dgc^0VXj<0Y)w+7Df>!5vC$B1_p*(%$a#<nO{Ir4oYGm<sb~w%n6I~T?`Bi
zB@8Ky&5Vp7nh8uZrZA_lv@nz~)i8h(S_*qFvkOBrBPgLsFw`=qgOUbw4P!P-kxLC@
zEprWHEmH||4KpYS_A>Q5)UuYaWU<yTl(1y6fp7}<9Httk8s-$9UKSUISeaV38kPm@
zHEatR85wF=7jTqt*041*f_Q~_;S8n>feb+m5g^@MDZC(e)UbeTPvJ}9?`5iGFJVX#
z0J*XTlpF;8szkYzv4AQ*pZr8nv??SQm+9&0C4&+S$oC-3#K6D+!o{Fq6apvuSgu;e
z6oxPcMutKraMoce0%a3Fa1?4X-D1)+xWyO`=0ONiPyln<q^IVkRumN3DS(Uu>0n`C
zVyIFqN=;0OPcANtFGx>HEYh>d$xlwqDYnyttG~rqbc?05B+WvT@fJsFUUELjtzavw
zxP$V`ixog7SSe^S-C`}tFG<X~#gdX*oUF-oi=((SDZZpAHT4!_lqT~nuH^i@vecrI
z)S}{BoXHubdD);Ody6%%D8Iaz7*E$QLOsn8%%I5#^KTI-aou7}gZLGcjEg`XE|Owk
zV9-JKDmXWS%9q57%;G8$EHR<1$$g8ppeQr1<Q8*rY3?nS{DRcHB2e((Vk$_v#R2jt
zD5!3+q?RS-++s^CC`ipqxy4pmkdj!ES_CSbZm}ij=OrhWXmUhxfCDwLB%?@$fq|h&
z859dFAiqX&fqfsJlUZDHi#aK==oTwj;1*K~x-V|=!R&-Yj||8|AcBFBgOP(tfsuuY
zgOP=ig^7<*fsu<*ib;ZzgQ>`bfq?-dKtVAME;>QlOBf)9VKZYbV+m6Ua~4YtV+vz0
zV>75IW~#F1QczG(2uaLNEmlZWD9OxCRmjX!C{9hz&r2y*NX$!7C`v6UEy@GQ7MB!d
z=A|oSWLV|qS``<Ab!sx+Vuz~MWCi8zl3T1r`K5U&x0s7dif(a$f(vBoElv=blA2SJ
zsL6DTJvAq>pg1)piXX~|2kS{qL5td3+-dnmxrrt5Ak9Xg5CQp(ft823$cBM|L6bQI
z9Mwgjl0=iU$PyIrYz2vVDT&2JS|AQ<Nq$js1~_6Nr4MUbPJUtuINZQSAXx}cP`5a2
fa`RJ4b5iX<`LmdZfq{X8iGz`Yk%yT}L`Vz(G)BcO

diff --git a/main.py b/main.py
new file mode 100644
index 0000000..bdf7bc2
--- /dev/null
+++ b/main.py
@@ -0,0 +1,20 @@
+import torch.nn.functional as F
+
+from Configuration import Configuration
+from Linker.Linker import Linker
+
+max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence'])
+
+# categories tagger
+tagger = SuperTagger()
+tagger.load_weights("models/model_check.pt")
+
+# axiom linker
+linker = Linker()
+linker.load_weights("models/linker.pt")
+
+# predict categories and links for this sentence
+sentence = [[]]
+categories, sentence_embedding = tagger.predict(sentence)
+
+axiom_links = linker.predict(categories, sentence_embedding)
diff --git a/train.py b/train.py
index f8290f8..a37aaa4 100644
--- a/train.py
+++ b/train.py
@@ -1,21 +1,20 @@
-import pickle
+import os
 import time
+from datetime import datetime
 
 import numpy as np
 import torch
-import torch.nn.functional as F
-from torch.optim import SGD, Adam, AdamW
+from torch.optim import AdamW
 from torch.utils.data import Dataset, TensorDataset, random_split
-from transformers import get_cosine_schedule_with_warmup
+from transformers import (get_cosine_schedule_with_warmup)
 
 from Configuration import Configuration
-from SuperTagger.Linker.AtomTokenizer import AtomTokenizer
-from transformers import (AutoTokenizer, get_cosine_schedule_with_warmup)
-from SuperTagger.Linker.Linker import Linker
-from SuperTagger.Linker.atom_map import atom_map
-from SuperTagger.Linker.utils import get_axiom_links, get_atoms_batch, find_pos_neg_idexes
-from SuperTagger.eval import SinkhornLoss
-from SuperTagger.utils import format_time, read_csv_pgbar
+from Linker.AtomTokenizer import AtomTokenizer
+from Linker.Linker import Linker
+from Linker.atom_map import atom_map
+from Linker.utils_linker import get_axiom_links, get_atoms_batch, find_pos_neg_idexes
+from Linker.eval import SinkhornLoss
+from utils import format_time, read_csv_pgbar
 
 torch.cuda.empty_cache()
 
@@ -63,7 +62,6 @@ print("atoms_batch", atoms_batch[14])
 print("atoms_polarity_batch", atoms_polarity_batch[14])
 print(" truth_links_batch example on a sentence class txt", truth_links_batch[14][16])
 
-
 # Construction tensor dataset
 dataset = TensorDataset(atoms_batch_tokenized, atoms_polarity_batch, truth_links_batch)
 
@@ -82,6 +80,9 @@ validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batc
 
 # region Models
 
+# supertagger = SuperTagger()
+# supertagger.load_weights("models/model_check.pt")
+
 linker = Linker()
 
 # endregion Models
@@ -115,6 +116,7 @@ torch.autograd.set_detect_anomaly(True)
 total_t0 = time.time()
 
 validate = True
+checkpoint = True
 
 
 def run_epochs(epochs):
@@ -141,14 +143,16 @@ def run_epochs(epochs):
         # For each batch of training data...
         for step, batch in enumerate(training_dataloader):
             # Unpack this training batch from our dataloader
-
             batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu")
             batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu")
             batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu")
-            #batch_sentences = batch[3].to("cuda" if torch.cuda.is_available() else "cpu")
+            # batch_sentences = batch[3].to("cuda" if torch.cuda.is_available() else "cpu")
 
             optimizer_linker.zero_grad()
 
+            # get sentence embedding from BERT which is already trained
+            # sentences_embedding = supertagger(batch_sentences)
+
             # Run the kinker on the categories predictions
             logits_predictions = linker(batch_atoms, batch_polarity, [])
 
@@ -169,6 +173,10 @@ def run_epochs(epochs):
         # Measure how long this epoch took.
         training_time = format_time(time.time() - t0)
 
+        if checkpoint:
+            checkpoint_dir = os.path.join("Output", 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M'))
+            linker.__checkpoint_save(path=os.path.join(checkpoint_dir, 'model_check.pt'))
+
         if validate:
             linker.eval()
             with torch.no_grad():
diff --git a/SuperTagger/utils.py b/utils.py
similarity index 100%
rename from SuperTagger/utils.py
rename to utils.py
-- 
GitLab