From 154eabc1bfdd823c0bd6fba92964fcce766ecac1 Mon Sep 17 00:00:00 2001 From: Caroline DE POURTALES <cdepourt@montana.irit.fr> Date: Fri, 13 May 2022 14:59:05 +0200 Subject: [PATCH] architecture and main --- .../Linker => Linker}/AtomEmbedding.py | 0 .../Linker => Linker}/AtomTokenizer.py | 3 +- Linker/Linker.py | 221 ++++++++++++++++++ {SuperTagger/Linker => Linker}/MHA.py | 13 +- {SuperTagger/Linker => Linker}/Sinkhorn.py | 0 Linker/__init__.py | 0 {SuperTagger/Linker => Linker}/atom_map.py | 0 {SuperTagger => Linker}/eval.py | 8 +- .../Linker/utils.py => Linker/utils_linker.py | 8 +- SuperTagger/Linker/Linker.py | 130 ----------- .../__pycache__/AtomEmbedding.cpython-38.pyc | Bin 867 -> 0 bytes .../__pycache__/AtomTokenizer.cpython-38.pyc | Bin 2292 -> 0 bytes .../Linker/__pycache__/Linker.cpython-38.pyc | Bin 5632 -> 0 bytes .../Linker/__pycache__/MHA.cpython-38.pyc | Bin 4389 -> 0 bytes .../__pycache__/Sinkhorn.cpython-38.pyc | Bin 687 -> 0 bytes .../__pycache__/atom_map.cpython-38.pyc | Bin 483 -> 0 bytes .../Linker/__pycache__/utils.cpython-38.pyc | Bin 8757 -> 0 bytes SuperTagger/__init__.py | 0 SuperTagger/__pycache__/eval.cpython-38.pyc | Bin 1899 -> 0 bytes SuperTagger/__pycache__/utils.cpython-38.pyc | Bin 1851 -> 0 bytes main.py | 20 ++ train.py | 36 +-- SuperTagger/utils.py => utils.py | 0 23 files changed, 272 insertions(+), 167 deletions(-) rename {SuperTagger/Linker => Linker}/AtomEmbedding.py (100%) rename {SuperTagger/Linker => Linker}/AtomTokenizer.py (95%) create mode 100644 Linker/Linker.py rename {SuperTagger/Linker => Linker}/MHA.py (93%) rename {SuperTagger/Linker => Linker}/Sinkhorn.py (100%) create mode 100644 Linker/__init__.py rename {SuperTagger/Linker => Linker}/atom_map.py (100%) rename {SuperTagger => Linker}/eval.py (81%) rename SuperTagger/Linker/utils.py => Linker/utils_linker.py (97%) delete mode 100644 SuperTagger/Linker/Linker.py delete mode 100644 SuperTagger/Linker/__pycache__/AtomEmbedding.cpython-38.pyc delete mode 100644 SuperTagger/Linker/__pycache__/AtomTokenizer.cpython-38.pyc delete mode 100644 SuperTagger/Linker/__pycache__/Linker.cpython-38.pyc delete mode 100644 SuperTagger/Linker/__pycache__/MHA.cpython-38.pyc delete mode 100644 SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc delete mode 100644 SuperTagger/Linker/__pycache__/atom_map.cpython-38.pyc delete mode 100644 SuperTagger/Linker/__pycache__/utils.cpython-38.pyc create mode 100644 SuperTagger/__init__.py delete mode 100644 SuperTagger/__pycache__/eval.cpython-38.pyc delete mode 100644 SuperTagger/__pycache__/utils.cpython-38.pyc create mode 100644 main.py rename SuperTagger/utils.py => utils.py (100%) diff --git a/SuperTagger/Linker/AtomEmbedding.py b/Linker/AtomEmbedding.py similarity index 100% rename from SuperTagger/Linker/AtomEmbedding.py rename to Linker/AtomEmbedding.py diff --git a/SuperTagger/Linker/AtomTokenizer.py b/Linker/AtomTokenizer.py similarity index 95% rename from SuperTagger/Linker/AtomTokenizer.py rename to Linker/AtomTokenizer.py index a771eef..568b3a5 100644 --- a/SuperTagger/Linker/AtomTokenizer.py +++ b/Linker/AtomTokenizer.py @@ -1,6 +1,5 @@ import torch - -from SuperTagger.utils import pad_sequence +from ..utils import pad_sequence class AtomTokenizer(object): diff --git a/Linker/Linker.py b/Linker/Linker.py new file mode 100644 index 0000000..f65325e --- /dev/null +++ b/Linker/Linker.py @@ -0,0 +1,221 @@ +import torch +from torch.nn import Sequential, LayerNorm, Dropout +from torch.nn import Module +import torch.nn.functional as F +import sys +from Configuration import Configuration +from AtomEmbedding import AtomEmbedding +from AtomTokenizer import AtomTokenizer +from MHA import AttentionDecoderLayer +from atom_map import atom_map +from Sinkhorn import sinkhorn_fn_no_exp as sinkhorn +from utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN +from eval import mesure_accuracy +from ..utils import pad_sequence + + +class Linker(Module): + def __init__(self): + super(Linker, self).__init__() + + self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder']) + self.dim_polarity_transfo = int(Configuration.modelLinkerConfig['dim_polarity_transfo']) + self.dim_embedding_atoms = int(Configuration.modelLinkerConfig['dim_embedding_atoms']) + self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters']) + self.nhead = int(Configuration.modelLinkerConfig['nhead']) + self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence']) + self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence']) + self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type']) + self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size']) + self.dropout = Dropout(0.1) + self.device = "" + + self.atom_map = atom_map + self.padding_id = self.atom_map['[PAD]'] + self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence) + self.atoms_embedding = AtomEmbedding(self.dim_embedding_atoms, self.atom_vocab_size, self.padding_id) + + # to do : definit un encoding + self.linker_encoder = AttentionDecoderLayer() + + self.pos_transformation = Sequential( + FFN(self.dim_embedding_atoms, self.dim_polarity_transfo, 0.1), + LayerNorm(self.dim_embedding_atoms, eps=1e-12) + ) + self.neg_transformation = Sequential( + FFN(self.dim_embedding_atoms, self.dim_polarity_transfo, 0.1), + LayerNorm(self.dim_embedding_atoms, eps=1e-12) + ) + + def make_decoder_mask(self, atoms_token): + decoder_attn_mask = torch.ones_like(atoms_token, dtype=torch.float64) + decoder_attn_mask[atoms_token.eq(self.padding_id)] = 0.0 + return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_token.shape[1], 1).repeat(self.nhead, 1, 1) + + def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding, sents_mask=None): + r''' + Parameters : + atoms_batch_tokenized : (batch_size, max_atoms_in_one_sentence) flattened categories + atoms_polarity_batch : (batch_size, max_atoms_in_one_sentence) flattened categories polarities + sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context + sents_mask + Returns : + link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat) + ''' + + # atoms embedding + atoms_embedding = self.atoms_embedding(atoms_batch_tokenized) + + # MHA ou LSTM avec sortie de BERT + sents_embedding = torch.randn(32, self.max_len_sentence, self.dim_encoder) + batch_size, len_sentence, sents_embedding_dim = sents_embedding.shape + sents_mask = torch.randn(batch_size * self.nhead, self.max_atoms_in_sentence, self.max_len_sentence) + atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask, + self.make_decoder_mask(atoms_batch_tokenized)) + + link_weights = [] + for atom_type in list(self.atom_map.keys())[:-1]: + pos_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx]) + if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and + atoms_batch_tokenized[s_idx][i] == self.atom_map[ + atom_type] and + atoms_polarity_batch[s_idx][i])] + [ + torch.zeros(self.dim_embedding_atoms)]) + for s_idx in range(len(atoms_polarity_batch))], padding_value=0, + max_len=self.max_atoms_in_one_type // 2) + + neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx]) + if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and + atoms_batch_tokenized[s_idx][i] == self.atom_map[ + atom_type] and + not atoms_polarity_batch[s_idx][i])] + [ + torch.zeros(self.dim_embedding_atoms)]) + for s_idx in range(len(atoms_polarity_batch))], padding_value=0, + max_len=self.max_atoms_in_one_type // 2) + + pos_encoding = self.pos_transformation(pos_encoding) + neg_encoding = self.neg_transformation(neg_encoding) + + weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1)) + link_weights.append(sinkhorn(weights, iters=3)) + + return torch.stack(link_weights) + + def predict(self, categories, sents_embedding, sents_mask=None): + r''' + Parameters : + categories : (batch_size, len_sentence) + sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context + sents_mask + Returns : + axiom_links : atom_vocab_size, batch-size, max_atoms_in_one_cat) + ''' + self.eval() + + batch_size, len_sentence, sents_embedding_dim = sents_embedding.shape + + # get atoms + atoms_batch = get_atoms_batch(categories) + atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch) + + # get polarities + polarities = find_pos_neg_idexes(self.max_atoms_in_sentence, categories) + + # atoms embedding + atoms_embedding = self.atoms_embedding(atoms_tokenized) + + # MHA ou LSTM avec sortie de BERT + atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask, + self.make_decoder_mask(atoms_tokenized)) + + link_weights = [] + for atom_type in list(self.atom_map.keys())[:-1]: + pos_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx]) + if (self.atom_map[atom_type] in atoms_tokenized[s_idx] and + atoms_tokenized[s_idx][i] == self.atom_map[ + atom_type] and + polarities[s_idx][i])] + [ + torch.zeros(self.dim_embedding_atoms)]) + for s_idx in range(len(polarities))], padding_value=0, + max_len=self.max_atoms_in_one_type // 2) + + neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx]) + if (self.atom_map[atom_type] in atoms_tokenized[s_idx] and + atoms_tokenized[s_idx][i] == self.atom_map[ + atom_type] and + not polarities[s_idx][i])] + [ + torch.zeros(self.dim_embedding_atoms)]) + for s_idx in range(len(polarities))], padding_value=0, + max_len=self.max_atoms_in_one_type // 2) + + pos_encoding = self.pos_transformation(pos_encoding) + neg_encoding = self.neg_transformation(neg_encoding) + + weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1)) + link_weights.append(sinkhorn(weights, iters=3)) + + logits_predictions = torch.stack(link_weights).permute(1, 0, 2, 3) + axiom_links = torch.argmax(F.softmax(logits_predictions, dim=3), dim=3) + return axiom_links + + def eval_batch(self, batch, cross_entropy_loss): + batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu") + batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu") + batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu") + # batch_sentences = batch[3].to("cuda" if torch.cuda.is_available() else "cpu") + + logits_axiom_links_pred = self.forward(batch_atoms, batch_polarity, []) + logits_axiom_links_pred = logits_axiom_links_pred.permute(1, 0, 2, 3) + axiom_links_pred = torch.argmax(F.softmax(logits_axiom_links_pred, dim=3), dim=3) + + accuracy = mesure_accuracy(batch_true_links, axiom_links_pred) + loss = cross_entropy_loss(logits_axiom_links_pred, batch_true_links) + + return accuracy, loss + + def eval_epoch(self, dataloader, cross_entropy_loss): + r"""Average the evaluation of all the batch. + + Args: + dataloader: contains all the batch which contain the tokenized sentences, their masks and the true symbols + """ + accuracy_average = 0 + loss_average = 0 + compt = 0 + for step, batch in enumerate(dataloader): + compt += 1 + accuracy, loss = self.eval_batch(batch, cross_entropy_loss) + accuracy_average += accuracy + loss_average += loss + + return accuracy_average / compt, loss_average / compt + + def load_weights(self, model_file): + print("#" * 15) + try: + params = torch.load(model_file, map_location=self.device) + args = params['args'] + self.atom_map = args['atom_map'] + self.max_atoms_in_sentence = args['max_atoms_in_sentence'] + self.atoms_tokenizer = AtomTokenizer(self.atom_map, self.max_atoms_in_sentence) + self.atoms_embedding.load_state_dict(params['atoms_embedding']) + self.linker_encoder.load_state_dict(params['linker_encoder']) + self.pos_transformation.load_state_dict(params['pos_transformation']) + self.neg_transformation.load_state_dict(params['neg_transformation']) + print("\n The loading checkpoint was successful ! \n") + except Exception as e: + print("\n/!\ Can't load checkpoint model /!\ because :\n\n " + str(e), file=sys.stderr) + raise e + print("#" * 15) + + def __checkpoint_save(self, path='/linker.pt'): + self.linker.cpu() + + torch.save({ + 'args': dict(atom_map=self.atom_map, max_atoms_in_sentence=self.max_atoms_in_sentence), + 'atoms_embedding': self.atoms_embedding.state_dict(), + 'linker_encoder': self.linker_encoder.state_dict(), + 'pos_transformation': self.pos_transformation.state_dict(), + 'neg_transformation': self.neg_transformation.state_dict() + }, path) + self.linker.to(self.device) diff --git a/SuperTagger/Linker/MHA.py b/Linker/MHA.py similarity index 93% rename from SuperTagger/Linker/MHA.py rename to Linker/MHA.py index d85d5e0..c1554f9 100644 --- a/SuperTagger/Linker/MHA.py +++ b/Linker/MHA.py @@ -1,13 +1,8 @@ -import copy -import torch -import torch.nn.functional as F -import torch.optim as optim -from Configuration import Configuration -from torch import Tensor, LongTensor -from torch.nn import (GELU, LSTM, Dropout, LayerNorm, Linear, Module, MultiheadAttention, - ModuleList, Sequential) +from torch import Tensor +from torch.nn import (Dropout, LayerNorm, Module, MultiheadAttention) -from SuperTagger.Linker.utils import FFN +from Configuration import Configuration +from utils_linker import FFN class AttentionDecoderLayer(Module): diff --git a/SuperTagger/Linker/Sinkhorn.py b/Linker/Sinkhorn.py similarity index 100% rename from SuperTagger/Linker/Sinkhorn.py rename to Linker/Sinkhorn.py diff --git a/Linker/__init__.py b/Linker/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/SuperTagger/Linker/atom_map.py b/Linker/atom_map.py similarity index 100% rename from SuperTagger/Linker/atom_map.py rename to Linker/atom_map.py diff --git a/SuperTagger/eval.py b/Linker/eval.py similarity index 81% rename from SuperTagger/eval.py rename to Linker/eval.py index 2731514..1113596 100644 --- a/SuperTagger/eval.py +++ b/Linker/eval.py @@ -1,12 +1,6 @@ import torch -from torch import Tensor from torch.nn import Module -from torch.nn.functional import nll_loss, cross_entropy -from SuperTagger.Linker.atom_map import atom_map -import re - -from SuperTagger.Linker.utils import get_atoms_batch, find_pos_neg_idexes -from SuperTagger.utils import pad_sequence +from torch.nn.functional import nll_loss class SinkhornLoss(Module): diff --git a/SuperTagger/Linker/utils.py b/Linker/utils_linker.py similarity index 97% rename from SuperTagger/Linker/utils.py rename to Linker/utils_linker.py index abd6814..f968984 100644 --- a/SuperTagger/Linker/utils.py +++ b/Linker/utils_linker.py @@ -1,12 +1,10 @@ import re import regex -import numpy as np import torch -from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU, MultiheadAttention +from torch.nn import Sequential, Linear, Dropout, GELU from torch.nn import Module -from SuperTagger.Linker.AtomTokenizer import AtomTokenizer -from SuperTagger.Linker.atom_map import atom_map -from SuperTagger.utils import pad_sequence +from atom_map import atom_map +from ..utils import pad_sequence class FFN(Module): diff --git a/SuperTagger/Linker/Linker.py b/SuperTagger/Linker/Linker.py deleted file mode 100644 index 93028fd..0000000 --- a/SuperTagger/Linker/Linker.py +++ /dev/null @@ -1,130 +0,0 @@ -from itertools import chain - -import torch -from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU, MultiheadAttention -from torch.nn import Module -import torch.nn.functional as F - -from Configuration import Configuration -from SuperTagger.Linker.AtomEmbedding import AtomEmbedding -from SuperTagger.Linker.AtomTokenizer import AtomTokenizer -from SuperTagger.Linker.MHA import AttentionDecoderLayer -from SuperTagger.Linker.atom_map import atom_map -from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn -from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch, FFN -from SuperTagger.eval import mesure_accuracy -from SuperTagger.utils import pad_sequence - - -class Linker(Module): - def __init__(self): - super(Linker, self).__init__() - - self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder']) - self.dim_polarity_transfo = int(Configuration.modelLinkerConfig['dim_polarity_transfo']) - self.dim_embedding_atoms = int(Configuration.modelLinkerConfig['dim_embedding_atoms']) - self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters']) - self.nhead = int(Configuration.modelLinkerConfig['nhead']) - self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence']) - self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence']) - self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type']) - self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size']) - self.dropout = Dropout(0.1) - - self.atom_map = atom_map - self.padding_id = self.atom_map['[PAD]'] - self.atom_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence) - self.atom_embedding = AtomEmbedding(self.dim_embedding_atoms, self.atom_vocab_size, self.padding_id) - - # to do : definit un encoding - self.linker_encoder = AttentionDecoderLayer() - - self.pos_transformation = Sequential( - FFN(self.dim_embedding_atoms, self.dim_polarity_transfo, 0.1), - LayerNorm(self.dim_embedding_atoms, eps=1e-12) - ) - self.neg_transformation = Sequential( - FFN(self.dim_embedding_atoms, self.dim_polarity_transfo, 0.1), - LayerNorm(self.dim_embedding_atoms, eps=1e-12) - ) - - def make_decoder_mask(self, atoms_token) : - decoder_attn_mask = torch.ones_like(atoms_token, dtype=torch.float64) - decoder_attn_mask[atoms_token.eq(self.padding_id)] = 0.0 - return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_token.shape[1], 1).repeat(self.nhead, 1, 1) - - def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding, sents_mask=None): - r''' - Parameters : - atoms_batch_tokenized : (batch_size, max_atoms_in_one_sentence) flattened categories - atoms_polarity_batch : (batch_size, max_atoms_in_one_sentence) flattened categories polarities - sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context - sents_mask - Returns : - link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat) - ''' - - # atoms embedding - atoms_embedding = self.atom_embedding(atoms_batch_tokenized) - print(atoms_embedding.shape) - - # MHA ou LSTM avec sortie de BERT - sents_embedding = torch.randn(32, self.max_len_sentence, self.dim_encoder) - batch_size, len_sentence, sents_embedding_dim = sents_embedding.shape - sents_mask = torch.randn(batch_size * self.nhead, self.max_atoms_in_sentence, self.max_len_sentence) - atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask, self.make_decoder_mask(atoms_batch_tokenized)) - #atoms_encoding = atoms_embedding - - link_weights = [] - for atom_type in list(self.atom_map.keys())[:-1]: - pos_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx]) - if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and - atoms_batch_tokenized[s_idx][i] == self.atom_map[atom_type] and - atoms_polarity_batch[s_idx][i])] + [torch.zeros(self.dim_embedding_atoms)]) - for s_idx in range(len(atoms_polarity_batch))], padding_value=0, max_len=self.max_atoms_in_one_type//2) - - neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx]) - if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and - atoms_batch_tokenized[s_idx][i] == self.atom_map[atom_type] and - not atoms_polarity_batch[s_idx][i])] + [torch.zeros(self.dim_embedding_atoms)]) - for s_idx in range(len(atoms_polarity_batch))], padding_value=0, max_len=self.max_atoms_in_one_type//2) - - pos_encoding = self.pos_transformation(pos_encoding) - neg_encoding = self.neg_transformation(neg_encoding) - - weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1)) - link_weights.append(sinkhorn(weights, iters=3)) - - return torch.stack(link_weights) - - def eval_batch(self, batch, cross_entropy_loss): - batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu") - batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu") - batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu") - #batch_sentences = batch[3].to("cuda" if torch.cuda.is_available() else "cpu") - - logits_axiom_links_pred = self.forward(batch_atoms, batch_polarity, []) - logits_axiom_links_pred = logits_axiom_links_pred.permute(1, 0, 2, 3) - axiom_links_pred = torch.argmax(F.softmax(logits_axiom_links_pred, dim=3), dim=3) - - accuracy = mesure_accuracy(batch_true_links, axiom_links_pred) - loss = cross_entropy_loss(logits_axiom_links_pred, batch_true_links) - - return accuracy, loss - - def eval_epoch(self, dataloader, cross_entropy_loss): - r"""Average the evaluation of all the batch. - - Args: - dataloader: contains all the batch which contain the tokenized sentences, their masks and the true symbols - """ - accuracy_average = 0 - loss_average = 0 - compt = 0 - for step, batch in enumerate(dataloader): - compt += 1 - accuracy, loss = self.eval_batch(batch, cross_entropy_loss) - accuracy_average += accuracy - loss_average += loss - - return accuracy_average / compt, loss_average / compt diff --git a/SuperTagger/Linker/__pycache__/AtomEmbedding.cpython-38.pyc b/SuperTagger/Linker/__pycache__/AtomEmbedding.cpython-38.pyc deleted file mode 100644 index a6ce66525d13768733269e6a9b3ec0bd2c64a4d7..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 867 zcmWIL<>g{vU|>j)C`?+y$iVOz#6iX^3=9ko3=9m#It&a9DGVu$ISf%Cnkk1dmnn)V zmpO`=k-?oIg*k<#g&~C{m8qFIiY0|Hm_d{ECCD5<O{QCHzWFJoIjOfeU2~ICQ&KYX z(vv|_$QWb|h+nM6z`&5o5XG3n5XF?j*3OW|n8K98+`<{foWhdA+QJaU($2ub5XBnI zpviuV*Rdo&7jCFuGRzc^Ngy^G0|Nsy$nh!+3=Aa<C5$x;%}fiJ7BVn0lrYyY)iBmD zr!b~4^|I73#IvL@1T$zd`-Nz-+~UhC&5cimIkxx~FPt5pl9_vpyC4zli}=ixid#a( z$%#3s@##g0De*~_@o7b=g_^9lSc^*wQj3Z}j=04UAD@|*SrQ*#1aiwQ=G5FIO_p0M z#i==IQCumRx$!xfdD*E&xA+rF@^j<M@{<#j;)^q@Qj6Fb7#LPE-r|gpPtHj!E{>01 z$?z*!KO;XkRX;f;wIIK=s6^kXv?w*PR6iv(wIIDHF*7GV$j>jnJhLPNYN~!P$RQz# z>FKFO`aWQb^ie`hub{FBlmtLwRm=ejIYvH44o0T`Rbr^>V7^FZ0x5!G5F6wxXHY1E zjHzK*z)-`u5R?EI{WKYixEUB2G?|Kc7#J9CF;)}_FfcGc2tkktds==`d16rtNL4Y& zEC$9ZQB-T8Dl{4WG}%BwnU|QG8Xtd)D?UCKoT}pEZ}G&(7nUaGKxDw4DG~;m&I=8= z<kXy;_;`r5U<dGlTmf<bID%L~ZU@B_2O}RVSOgS-CHY0k8MipVlwMw55y*Am2n4G@ WG7M}Y$go=+HV`M-F(OP9U<Lrpiqez- diff --git a/SuperTagger/Linker/__pycache__/AtomTokenizer.cpython-38.pyc b/SuperTagger/Linker/__pycache__/AtomTokenizer.cpython-38.pyc deleted file mode 100644 index e3cc2ea0978bb23bd6c75f39863a19e5f9fc27d6..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 2292 zcmWIL<>g{vU|`@@ElpBkVPJR+;vi!d1_lNP1_p*=4F(2=6owSW9EK<m&6LBK%M``L z$l%V9!kog=!jQs}%G}Hx#gf7p%%I8o5@d#-CgUxhg2a^g;?%;@)V$=>WRNH_W`db+ z!@$6h$`Hkv!Vtxj!kEI;!Vtv_F*k}Og*An(g&~SHg*}C%g&~S9g)@b#g&~SPg*%0( zg&~R~g*TW%lkXO<V@ZB)NPc!|US?HlQ8F`<Js>s+JA*u8#lXN&!&t)*&rriu!w}C1 zv95+8i(vuNLIxLxW{@0n3UduZJPSy)h9QeZlA(qno;8Ibm_d`pFPRa|XiyjhgNznq zU|`5(s9}iZi(#r|tYxZUs$rbWRLBy{u#%w&q)L<d7L%UAEynCyjAdW}F!3u`KO;Xk zRX;f;wIIK=s6^kXv?w*PR6iv(wIIDHF*7GV$j>jnJhLPtJ|{CTJGDqZxU?X(C?qjG zJ+(;R2O@;xd%c3nTU<6NnaL%|`MCvlpeQb8V_;xlV_;)oVyMzWRU99mnU`4-AFpSV zlb@WJQ*5V)&|1YB9pLB^tI2kYBe5huH$FG9;Ff4^VnsZNR~(<27hjy3R|1OMTdbKS zskz0s1Tyo=Qj3aH!74!Nqc}lnpac{%kmLYn#%HEzGT&k;PR&Ux0=ch<n}GomNZ>Fj z;$UE4xWxf;8apU>LCP2y*%;Xv#TeNbtHiMT0wsDtQ3Xn5;Ft&LFJY)*$YN|}3}#r# z=%>kai#aDX4{ZNR#v*<!E?|$3&q>XTkLLq90AvyaW0fen#b6c5U|DF=0I}gI4x|^9 zl1dnB7@HX-8G;##1Q{3@G?~B_-(oH<Dbi#D+kT4$6s(W{0Hqg5D2OEIgMzdK9HPbX zCHe80Da9ZYib2*eFjgs{ItZI`4A-$TFff2zSS-Q7z)-@lfU%t+4U!i)Y8XMeU?C_Y zGoz+<kbNMH!l2LqrS^1iYOiJNU|7IV!?=)fA~>~!J)p^0Bo6T<NJkM!brDD(SQ6qf zE}NXpVn`Ad1(^u)GB|;nVfPRysloh@lvv@0XfoYmEy*uR&bY;ySR7xHnpd0;3)v`6 zkY-RiErvJ%>=bY!g9k521<ZI6kb6Mh#T~rJ$}xf$lwiTZs|F3;5+?Kj?qJAb?qJAb zNn!3~1_w4;(E_p^6m%e$7K4Hhe_%@^0y~No>^8LEE&|0g2p7Y{*&KT~Cnc67XTU-N zM?kZK3W1!|yj#2li76?WdFk<Gi8-aInyf{Vpll=sB0xTX<ZDnCfTr;%0jTjHH)4$t zQCI?oDug6>P+}?uMHmC47$XY<%l|56B3zFVF`!5R;bI{M1_tyf31-m5m*PR50!M-n zD8*+nz%n#)s%I+V1(jUPnRz9eOhushyTzD^Ehc0@DGVMHI8uCO3OE`-r9v^bhybT$ zXu#fL&P>6bxRFf)B^3f?2eNWMO-@iroR^rJ8Xtd)D?UCqKczGW#O8^QFDy;WfyfjI zfZ`mKaljeB2xLo<97s$FM1Z^sb}obf#bgmEB0%+0F%JU+0|z4qI|nluvVdivY3Y_A wQgx?ST9TPlTm(v4x7hNNvQm>vz-l3eBJ2Pee~ZHg;!-<MH2_Za985gS0D?dZiU0rr diff --git a/SuperTagger/Linker/__pycache__/Linker.cpython-38.pyc b/SuperTagger/Linker/__pycache__/Linker.cpython-38.pyc deleted file mode 100644 index facf9eafa213710664a3d7f18c9150693e6c0a2c..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5632 zcmWIL<>g{vU|?AJuO=y6ih<!Vh=Yt-7#J8F7#J9eZ!s`1q%fo~<}l<kMlmvi*i1Q0 zQA{ZeDa<*{xhzpExvWvFxolBvx$IHwxg1d(xtvj)V0A1x%(+}qTwpd+4tFjO69Xed zE^ibcScEl)KbJp>AIxUU5y%yc5(Kl^bA)n*qlCe1jvSF((I`<cn=?l&S3F7_%;w6G z$d!ze0<*buq;qAWWO8MrWOL=B<iKJ)Ir6y*Q3_x-Z;oQFQj`)SgF8bCUkZN<LkfQ? zS2J^zatc#0gQmbskgqivZ?Pt4BxdIMX|mno3QjF7P0cIGOw75(>62KQTI82slzWTK zCo?ZKvFH}NOHqD7erd@q7I#;l&|5;jr8y;;8L5dWjwK}^UHN&MjJMc)^HWN5QZ*TG z@jB<{rDdj<7A2Nsf~9yJOY(DFbCXh2QZn<>K`KB3A^F*<d6`wIMVgGaMByg6q$cO5 zq!xi4v6Ar?M`B5SZhUTHfhOZEq2kQE?2P=Py!f=d_`Ll1)QSR4=3ByPnRzMk1^LDC zd8z5~nJK9isl~VW(^E_0L0XIBlM+jkGj1`vx%p`_-r~<qEiNrejZaKYE-gw-uGD0_ z#Z!=&5?>7Ve{yOvC}@x|Gn7+o$H2gl$`Hkv!VtyO&XC5K!kEI;!V$%s!kog=!Vtxh z!jr<5!rsCd#hSvA!r8(Q#g@XA!rj6U#oo@q!Vtv~%%I5&4$$n>qGU#>sUQl(W@BJr z0NGTm!^pr;!cf9k!_dsMfN3EEBSQ&u2}>4h3Zo=LGh+=yJR6wLp27s>b3pmbP(CM= z&jRIhLHVpuJ~x=plfnk&^FsOTP(B}&&jIE0L;0LgzCej!3K!TVLM6gA4DrG>3|Yb{ z+>#764Dli*q9Cy>zGkKxhIp|O@fwCK&Kia+ff|M^k!I!^hIok*$!3sBsS@cD8K^8r z3Rg38i7beh!jr<=%LLXfM^J@)3PUi1CZFFe?v%{j_|!abvbiM!;uPfPBo<|sRK}MS zCFT{U<=+wpiKga46CfmE+~R{Lr_7SnqT*Yud7vbHOCUF~B0eWIFTOZ6uOu}uIrWw( zNCc9YGI0v$=cUG%R2HP(;s>dWFUwC(Oo}hgtV&IvH4_AO+E=kg2RORKrVGB(ZJas1 z&{mW27ISJrv8KW;*5cBF)S@CW1_p*(9P#m)d6^~g@wb>W^Gb?9Mb<5$-29Z(99M`J zoxz!_NF1bC5G>)7nU|ef1QnNHU|=Yc1QAjoLK>uuHzlzou{gB^sze4PAqyhpK!iMq z0HwfN>?x3ZUc?X5#Z{09&ad&ADMdUW0Y0$LN?_Txh#MgY58EPMuwV|@Nw8Qc0;TOD zVURtbLiUyrD7Ql6rYIL&O573x<z7@VO~zX+#i==ID;aNb#>Xe;Bo-IP$FF4g<)fdG zpPQ<moRV6QUs_b6?^IfpnpmozlA2nOUX+-b6CdQ~7hj%Pk`WJahJG+8ctaA?(^HG| zA%O*@^a?7Alo%KoL_o#90XR4EF$plDLoOyZMjl2!MiFKXCMHHcMxK9cED*YiA7Y*! zEY^~lK&C)3$lI(83=Ga7<J}lQSv!ldhOvYpiz$V%m#LO1g{g+AhH(LN4MP^oLPi^g z8m0xTDa;EQ7qF!;*Dz+Wr!Y%0q%bdJUcgqvkj0V03}Oc}XtIEcGS(DOa!ZF=#qbhT zj%ae;VlBxpO3t{&nV*+h9G{b!oqCHsEhj&*#LVOtQ)*$61}J(sOY@2gOH)&;Qg5*p zr52<nmfT`3&PXgsEdnJMP39s!1_p*GZcua=$AjZLN-zadF2^U9l;p+dCKhKG8G_;y zOc;R-70gY{PK}2t1F0|uc^j0l7{DP~B>@d9426ElpiqMO1eAg}U<INEBLhPR%L0ZD z#)XWvtR;*!tXWLWjG$t10ZR&F2TK-f2TK;)LgrfL8s;pf80K2GTJ~Cw66OUgHEbO$ zS?nSV9V}TKAax6wYgsy27I4<E)-czASX>}87qWCRr8CrWmT-5lWbrKEUC7YPn8K98 z+{z@$;KC5QsFR_DubrWtv7M=%xt%4AHH9UGwS}XEzeJ#extXz<(FLR)WRqYDyA(qT zM=x_N7n%-Ef;zZ6SQZG?a4lr4<?3KrAY8+}kg=A#gn5BT4Ob2K0?`!q6t0C#3z=$p zYB(2widCKz?i8M0riF}*424xC%nKxHI2SU4DwxD0j0_B13Q!P`Sd^HX3M$(atl)x> zNCj7lkf_b9N=;F)QqX`1fl4eL1(X60RI<U!LQRFVoJ3HipPHhOoLG{Yo?n!iT8w5m zy!ZthNuk*aP}7jD0lB6)9$x04+kjkn>nMPWen`=&sgPe<Qczl=ke{aD<Qf#Bkd|Mh zker`al3Gy$_XEUmQ2K=n1*Mjh7UdNqg(9e6i7!vhOwTAOR<Ke41u9aJucH9=xh{^t zOinDp62!@gC7N&>Un(##F!;Uv|NsAgjK%^d0|Nu7EG>3oU|>jRs9}h;>SQQkSilHQ zr7Z0X?W}36p!RDEM+s90a|Z*c>IS9L5{3oL9ju^|CY@27VIkv0rb3or22Cau8(?J? z*ai=n4Gu9(wT!h)9V}T4AoDsHJD4RIve}AQAgmfjNl-om%Vjeav2-wj<UyiEEHz9J zdm!ZnAE*V+e2YoX;1*-%EyjwKthYG9p$96=Zi#}J&?Xxqb>0$z2t!jsCA3M$S{$F5 zQUNZjAZCM$AugMo%;J*d{M-V&08se@szBKo*cgQv#Ta@1{?dh(L23C#<%vZpdNw)v z$%#3|c6tbHNX(V2MW&##jkUNWF**AdXKG$)Zfa3tN$M@us??(VVsJ&flCcQXZCJ@# z<N_)YT|s3hxT-9M)PI?I=|!NzAMDbVj79DsRUo$(fr1z8NKkRd#=`VRg_-AXl@ehm z7J&-;m!PUflj#;Ote!4Q%qdO1#SY56IjMOlu?kW@q|!|>mUL5O4+=R@Wr~Pza2SSR zMs$%UNUax$@B<P4AOaNDMevXZ>p~9sA`=D%hL@lW0O}lp8-<#ZMOxqlP*4PFjah?C zfux9{#JrTeB5+d*)D#2PoJHV9l{QF`4M-OYC@Df3Q!Lr3mBmG(AW_z$#Ju#>Tg*AB zc}3s`m<~vTE=Y<wDL3~PC%6?<kYAj7i!HIBAT=-L76+`UTI9{Zz@W(sPD(|flm<>& zQT&iX3tl2daX|z?rC1cWL5fsd-r@q6PT<CA6c2J?86^zS1~(@@B{MgQ2f46|;sLeY zz^(1fy!0p@P^%5byTuMI5h2k74suXv-(rX6HVaTZfkYU1xEQ&Zg&27lB^bq+Ss0}l zMOe80u&^*OaQtCoLuPX^axsZ8iZE62q2`KYP&C0>43InyZD%cGU|^_W1a+7eFqAN6 zF)d(T$WSDf!j#7ZYFX7X*D$8ANHQSF8NuaPY8X>kQRMvKa;!BB3s^uchjfO8%(ZMa zYzshb0mc;86t;ydwQMDfS!^ZjSsXQNpl)C<6NIc~F9EeT*qfPZIck_{*lJjtnQA#} zIBGc4Ks|8|KX69A#gbf_l6Z?bxu6uB1d9Se2_gg((~P&6Q!;Zkxo$C)<b#W%B2W+Y z7Efkzd}3K*W=>*KPHIsSDCw{lq!#6tmVnYsQF?A-#Vtm+TkOU8X(hRd6-A)-ttJ;Z z)!bqQ$Nnv$<f8myPzhC1lwVL8pOar)e2W__1?wI`IIu$ZmH>obQdF7>ZdVsaiRa{} zXO<MlCst(U=R(Bd3yM-xq6E-IZgIdm+_zXjHbD{yqV33)T9%jtDb`a!X+#JVvy3tz z$jrwmz$n24N(WVf&@_NhjgbsM4M|W5Q4Gr6DU7wuDU7u&DU7wOB@8v-=31;l3=^o4 zSi_paWX@0vidM!N_8O*zOfgKg9JQP^%rzY547JQPEH#|w47DtUA~no4tkMiMEFcmb zZOm2M9m`US64O%^N-|OvK&~kTcVa;8io~28uqZgR^tj+Px?@p#F`^X)YN3HTWjXnY zDXB$P;6_MdW?r!ZigtzajLhT=m?YR(aBCdgWCLYYQ1`T02PBbMqyWn7#R`deDImTA zD4rCGD|3_ba}ez_aB>5uF-@i-A5bbQ2Bk7+%HTq{s3;dyiwVGDEj|(A-zXkX?7?}g zplY_{7E5tSYC%yB$hblfQ3N8ukqah_K*<vl!>I-N$r&Iki@`C>#=;1K0*oAtAk4*t z8pB}Knv8y$oS<HDUSe))eEco0`1oA#KvR7DEuQ%J!qUVXhzz(RR|G0|!5N|mWKB^S z$ZlR}=P@}oCnr80DedWl913zCxa$pW@qmgF4n`h%4puH54t}t>rphf&P-Re(pPy4) z1gdnwIq?<;xTmF;mj~(If_k<^pyXEs>hct^gUkliE=8bv_m(hBv0hqfUNUH0EHUR6 zQ(oRJE(E_Q3#13+f+z)~E~_3i)q}>&;Dcz8&NnV)@S!zu$3036i)p?dj*z~(G!{uv zdkZw^2CloJu&D@!4#gD-fZQAp3T6o`29=g%<`fqRg4BRw96SbaO8{9RDEENNr(1%^ u!Vrz1`Vt%mNRfs}Qyeyslwb#H-4uhWVjd<AMg>M5Mj<91Mg$ZAv6%sKMDjiW diff --git a/SuperTagger/Linker/__pycache__/MHA.cpython-38.pyc b/SuperTagger/Linker/__pycache__/MHA.cpython-38.pyc deleted file mode 100644 index 679c41a8ef96c82cce084ed1d859b6f0933c12f5..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 4389 zcmWIL<>g{vU|`t!rzUBUAOpi=5C<8vFfcGUFfcF_Ph((UNMT4}%wdRv(2P-xU_MhW zGZO<NLoQ1cD<edPEtfrt9n5CR;mGBP;sCRma~N|uqd0T9qPW0(mK^R}o+zGN-YDK& zz9_z2{wV%jfhd7o!6?C8p(vqT;V9u;kth+cTGkxVT(KxIMh16=6t)!h7KRk|RKaHE zDDf1=U<OT&mmnAWX)@m8b<WRA%S<mVN-W9D&(mbO#TJs9SDasTi_0fJFCD_x<h;e= z?&=eIi^V56#P=4vOHqD7erd@qPM^fe)FQw9qTE|-KACx`iAA^AeDhOEb5d^!`IhFC zWM-r$rZ|?Aq~?LlzQqNR^T{kOxy2QnT38AaPt4I|yv6M1=9dg|5i(|ja*C@M7#LC+ zKw%ffl)@OroWhjCoX!-*lERY0I)|~HA&n92L!KzsR5ln);h4jk!kNO=!V<*}5li8o z!<@pC!rQ_U#Q_mXWlQ0k!;-?EBGAGT#R(Ei5uC%6B9tQB!V<;R&cMPD#U0F`DRN5` z?gy9D<ouM>BCx*`i&z;LLW&aeiqrCoa#M?t<rFfD6><|(QWZ)I6!Oy)ic@pabQ4QT z@^ln(LE)kc3K+0}LSkNuLRxBSN?LwVd16tDLSAY~d45s09+!fGf<j0}X0bwXNn&0K zNGb(lxk3&o)Il~RB^IZqDCFlUlw_nT6eJd;7AYvfovPqjtl*fFqY#;2s^FKJnxX_T z(y=%rvp7Q`EU~yeF)vd`!7o2CS0Okfu_`sSNJqghGdoiuAh9Snu}DY3D=|AiNg=ej zDm%X@HM2xV!6zp(KTp9cKQFadN5L^OB{5IIPr*GuH?>L!>|CGH?8M?K1@FYnVvu1V zuX*O=WF{&E<mcoUmu6>V=IJRI85o-DDIlDgS*(zllcP|XU#gIonwp}g;F+i3m{OLQ zmz-LxkeLTI*e|uTC^1LDGY=Hf;Do9WP?VpXT3no&m#z?8SzMBuTdbo{keHrYtYBtf zV4!PeU}&JH5L%pC1d51Ch1~p<%(O~{{33<S+=86c+|;}hu*sQu3W*9SnQ3XMMX7lu z3gwBF3Mr*UAX5_y3UV@&!ItT9fz>(|r59U)S)jN{$;^$<%}+_qu>uEwUTJPpY7r>1 zQ!5HmlS@)T(VAFNS_Jk5SV?AHL1~GCMp0^EX=YJsiY6pNKxXBE(gV6VATdw~fc1it z4M-#rmYqOeO3f+8Y97dKNQnw!a!O`yYF;rY2J_RvW+JJ8B|C_r3K}V?X^EvdCALNe zCKgEUOo3#4sCf{MLRn%?X(~dCfu13fj>P1W%rbBwLNp=x3TdTz$)Iup6lS0VoS9dW zT9libl9^bN3Q6NS3dxCi3Q4I7i3-Ifph$rSJ}A&Ml!{VwN|iuys+105X@Z!E3Z;38 zMU@K4i8(omNja%{3N8@8SSf&%!%YBZqWHZ0qTKk@f?|l{QVWWqzRApkxIh7<ATcMi z3hV`i<ow)%{Jhk>l468^4O4YZkvyD~SdyF(pO#rvTw<l*nWm7C5K>f{nvkFa3UWwJ zg`@*e?#M4K0VRnNaEV*2kXV$eP*9X#mYI^80(U6LtBJ)58eoHU6pB*|brjN4!6{1< z$#V$_ZizX?sR;>4sS4yds7o@DT>|zy*lVCtBqcu&WNl`CkwQs6xTsQqCo*Vx1xk3R z4uYpbkT;MMQhq@yq$(-aQ7B3+24#iJvc#N9Jq7=gjMSp?%;Hpq%o6ot1+cY=X(g#e z2p5C=3Xd+hj}$b(%AkRKiyIsgkRZOr1&SaD8&sRI6{VJx7Ud-~LhB9?#mT_Hz`?-4 zzznLBKQS;clrWSq)-W_PEnr&6z{pU-T*8vYn!+f_(9BrN2<EY+FoAhYU><u4GnmH= z<*|TyEMOi-3M-h$TEbbvlEsn22IjFr<=DYIcBmW&n8#7VRm0ZISi=y{UBXktRKr@s zmcm)XRKt|Q)yr1H5YJn}SHo1pk;2`}RKpO@4^vaaSi_jYnZg5BBY>htutcbasfN9V zErqw2xrQNL7)4ejg&~+hlh5xKH>glZ%>&n85Dv6Hxh0UBSP`F-nipT3npcvVmz;Ww zKd~e~H@+-CIWZ}|II}AC7Av@fy2TDH0B-Sv3`2_cTYTU&1<y|*nk+>U3=9lK0t^fc zw^*`MGIMXSl%-_mf-3vef?}AoG)RpMhyW?nWWL3o0x49ZSW@EC(!geDir!)^E-gqc zD&hwj#t|Q%nU`4-AAgHEGq0qG2_z;2E{I(ro`W=KZb8JLl`>R}Hzlzou{gB^$`u4_ z^U2K1PA!58v8LtZCzcekgPa8R05`~dPEZ9N530b6Kq=`KD=5_&-r|Fo%3wK&uu%~o zNDFgXS{_8e7}R_M8Bin)GF1dbh=K?)5FrjCz&1&OSW*lO3@aILamL3d=Oh*v$H%W^ z_~ogek)NBYpPZ6fkY8F<qVH5%l$uzopOTtdkY1FSnG+x6=NDg|S&|W-1NMr3FetP_ z64TRDi}WGB)A#jo)GMegQea?UPzB`!Jy6bI<YVSzM1jm)j9iRN3{3wynD`huKrBY4 ze;mwwi~@{Ra_~AE8mQn_kRB{LG+80ZrZ_$|Hz_qGB{MJm7C$IOAq((86JUI9VsZ8@ z9%u>#^NNHR7#NaSLGFTLPy-MY+|D2edoeID)G*ev)-Wt!r~#3THH<YZ3z@_j7D7Zp zS&3l*a}8?^Qw{S%X1F*DST$=6>q17T7;7+tCS#&FBLhP~VsSC3Tvh-V$slG?erb9J zSP1GiNId{;0f5RUN0e3oNCcEwkUR;gONznm+vHSGnF^}iP>jP`X@iVH3okUI(u(qP z!EVb*EG~gK3|4tU!xhP3l=2!X4Gw4&38=GBB{bQJG(c&XA1oZ7n~@lwl#`#FU8DsP z<w6qB0SR!%r=>v@XtEXQf>N;&h%g2bCLqETq>8bk$N<Cv6NVrWc4(W|4ivxOJ{1=W zqZp$EqX<)#EcV0()uG8;WC=10WH2P%fC_hLvH=(LptJ){H=qnv!;r-QDnM%(YnW0P zds#rmdo5!PV=%)?W>Brc3Ch`EzZQWi4o$`)bp{3oO=gH^5H^9G4kkeEFEV3bU<d%Y z9n|AvU=(AlQpN6Su+fN6(_{g={uUQJwu_6v;RMbBV8ftZ0kN_9gb9?!QGHTm4RSJL zktT>@L2|h*D4@W`gIx|Lz|J>kU|<MEaXz9|j+(>a2|$wx>~ctugGyegI~hQ1P!xcJ z98{Epf_wpE4dX&aNI+{c`e`y1X@lHi1#%%1k_*930ux|&Suij#L?PV8!&oJcBUs_d zj^`Ffe0*MFZfbn|Ev|S_4N{r|V)Mku7nUaGKxNqD<5TjJ<Ku7f#>Xe;=YjfFCB^aa z;6kwo<oY6z)4<VM1PYuYkW0XdK}B(LYEDjkJW?746`@677lZ00UXY7HWjhBWizo+^ zkR%(Z^!(2wB*ZGhS_D#|DRGM>IlrLt7Hdg<QF6vDVKAkam!}772`A><V#>?A#f8LY zbi2h3QI}s(l9_vp6-<F^J&2P*N#qs>L`hyAs7uOF1WGcHDhyP77IA_?9hAR{K$Te$ z4~PqL?kx$Vib4-kQRtPHWabout1z&mk^Bpfp<5g_x!_*D9VoXHgQA;*k%Li)k%y6o LQGij0nTHtw*0ILR diff --git a/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc b/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc deleted file mode 100644 index 26d18c0959a2b1114a5564ff8b1ec4304f81acf3..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 687 zcmWIL<>g{vU|?{aQIw>@$iVOz#6iZ)3=9ko3=9m#S_}*fDGVu$ISjdsQH+cXDNHHM zEeugiDJ&_hEeui2DcmU>DV!~gQ7kE3!3>%_FF~fPWW2?hlb>E(nwwftkjw-Uf?^Px zje&u|8DzE)0|P@1LkU9-Lk(jJV+vC*vjjshgC?_Ih$ho5=9J9bDE92s)Pj`E+#--Z zO{QCn6-6uz3=At7ia>_^^3%`A&rQ`&PDw4uFD)w3cPcGPO)S+<Nlh(CFG|eJi4XGg zi!aYC$%xO%%*#$K(hn{zNG%FUOixcO()WQ1W#(mP<QL`X6;$40$;&Uw1zAuGvX6nW zh=YNFA(;{Cb`S+(v%=gj!oa{#0(L*x^-L+uDNMag!3>%#elI}^H5tLKeF-v56JnSq zV-ZL@*!d8Goq>Vj7H=`s74gL-sRbZgVJ_fiU|=ZXWME);35t?r5CL)r2!k97;)BCd z2^^L+jM5CXAQv$;GrBOu^3*byFxD_MGuAQ`3e<qYmeEg>=@x5IVqSV`5y)sw7O;Dw zm`f5%Z?R^Uq!tz5VvL8l3>3sgVCM<JotKsupO+t>T2a6Qia?M<85lVjc^J7Eiv$=L z82mJuz>X~fNo%s+VlBxpO3o+(xe^?fU^~Gok*q-2!(o$~pHiBWY6l9vVvslwGY1O? F3jpAPgT4R& diff --git a/SuperTagger/Linker/__pycache__/atom_map.cpython-38.pyc b/SuperTagger/Linker/__pycache__/atom_map.cpython-38.pyc deleted file mode 100644 index f189466986ce4136d6ed66a59cd0d49b61b4aad8..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 483 zcmWIL<>g{vU|{&BQkoRW$iVOz#6iZQ3=9ko3=9m#MhpxLDGVu$DNHHMDJ&_hDQqe1 zDI6)BDO@SsDLg5>DSRpXDFP{iDMBg2DI#-3q8L&{gBdi%UV=<}$;iOK@RA8cFoOsd z5Wxx}*gyn3h~NMboFD?^qL<tt77vKv1rdB8f*(W(fCxblAp|0X85kHeMWR@ebK;9` zF%=ZtV$8e6SP;ct9A8jSlv;d?IVCeOinTbtJfkp*xj4S?7E@jUh@Y947R6Q^UyxW_ z9K}*l5T6*uR!|UMkXRJOT2K(5k_tAl@)mPae#I^3l8TaBjKvT`6N^f!Sfc|RU1I$- z8E<hUmgMKg=Oz}cWGIqjV1N+6{PZ*Ob5r$`Q&J1^ON&bMol1*R6HE0|Qd0}kixM+) z;)DGB;>$BjGU9VG^RiQm^n*(aQj0<o)6-Ln^nD;gFoX39DsOSv<mRW8=A_zzg0NVN Ofq{XEktq0<<v#!fk9vdv diff --git a/SuperTagger/Linker/__pycache__/utils.cpython-38.pyc b/SuperTagger/Linker/__pycache__/utils.cpython-38.pyc deleted file mode 100644 index c4eef1e07886db024496a876b846bd69646a7538..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 8757 zcmWIL<>g{vU|_KPU6Z7w!@%$u#6iX^3=9ko3=9m#2N)O_QW#Pga~Pr^G-DJan9mf& z45nG4SW*~Lm~vQi*`nBT*`wHVIifgnIiompxuUpoxudulLHd|;SaW%zc))Cy9Nt{M zC_XToHHSY}AW8ttX3G)G6^asKWN>FlVNc;`VMyUf<!NS)5>DZa5=r4o;cj7w5>4Sr z;ca1v5=-Gr;csDx5>F9G5o}?Il1LFs5pH3Kl1vdv5p7|Jl1dRv5pQ9Ll1^a^X3&&) z3G$ntCfhBp;MBs>)Vz|+#GG54K8cm7MSl53xwqJSGV@Xsi*B*I6y+D>mzLaOad-6z zy(Q#Znp2XQk(!v|SW*Jgm7k}{c#F+9KczG$Rg>`+uVYDmZb*K1YF=hlYLO=6Esn&J z{M`85!~#vmTRa7cDe=W%>ylHGK@LF1pa2H(i?tXS7*ZKPff~h>!WhMz!rso1#+bsK z!qUPS#gf9B!q&nN#oErm!Vtw4%%I6}i`mW1uS&u*w;(4qH#M)MSi#NBPa(CiG_fQz zKTq%ZteGIN(>@ty1IQK-o0EZoff*E~HVh05B@88uH4M#63z!x%Ffx=dm$1|@)-a_o zrZDv~m#{W7l(5w>H#3&7)G*aBLPVQcY8c|#Qy797G@1R}G#PKPBxNQRYjWITEiNrc zEh^$+U|_h#5g(tKmst`YU&O?~z)%FTwulYHVg<=Er={g-vfN@RPR&WX#hwzMo1c=J zbBiS<J`Kc$g#1dzTb%Ln$vKI|#qseg8Gd=|XXNLm>L;h97UY)}mFPQ_7NsVZ>Zhcp z7Ni#?X6D2P`T50{XO?8d=Vaz(rxxi4gIp7mn4X?mr0)X}DlN&(Db_2fED~T~VBiM@ z0Usz(82K2P7<m|37}*$k{#5a}x%uhAoSzJGK1dA+gV>;ea|Y=IWt$p?1q?Ne3mJnM zRx<i&G8PFkFfeE`6$yd-$XHP%0rD`IkOYaar{x!wCl;lEM2bOXF)&tfgDiuJYcdw` zf(&2>g<f7_Zfbn|Ev|S_;FRWo*gWy^g{6r(P#O04_>}zQ_;|2iiex|*@Iu2YIW;FI zJ|4+>F_1$*ZU9FPD1wSXk-@>p!o<f37OT>aDc6qCh)L1b(XhAD(6A5E)U45nDc9E2 z)Y7rX5Q@=E1_cBtP(T<IsUQPEV#N!<sj7ylnX!{GouQVwgt3Ha0rNtJW^npxWs+oY zVTdj1WN2q-XKZImV+N;ijuMs<)(*yI#%4wrhGxc27Pu^H3QG&REOQB4J99fr8dC}z zrfdyE3cD0T3R4P4FLNzR4f6u_8kU8Oj0}YeB^(PlYnVaNrOBC?z{tSBrJ$go5Rh1u zn44OXT2!oH1s2FntcV9?-{Sbpy!iaQ)cBIhf>Z@7h2;Faw9NF<B5?i!tAwa3$j?bE z$}Fi=uu{-SN-Rmvh%e5pO4U(7F{d~+uOu}uIaL#^5uz-yA~Qc1oLr01_2#7J#ly8I zBqx@nrso%BrWPxJQn*4%Myf)5rb2l}W^#r?Vopwed9ea0)o13VD}dbuwj?OEq_ijx z<|~MmB}JvF*lh&qi!aMhPE5k)d{8hXCzfb}0-+=^Gp`uzhRnQ_)Cz_CG%zPGH9fH; zvn&<K1%5C8|Ns9#nGuvmpcupkB^6;%9@1f8U`S`EVTe`f1Sf`e=5~fOP%XmP!coG| z!PLPZ$<Pc=5EGdSS%MifnNZ9EB{vQR1_rQMpu!AnmPrg#En_WH2U7=g4PzRUBtsDk zsE}L0)WML#AkNUi*ugBxP{Sz6u#l;lv52LH31aq2##@Z>E17SxB<1Jl++r$9y~S3X znpl*av6A@~lb*pX#>^sV1_p+eEVnqpK?O<_x41#f;&^b_-r|FBV9DVYYjJ#L3bYge zWxiWnHaVHaCCT}@1$KrY?}OaV#=yqF!zlCrmtJ~mNjy@D(6h<OPfpA!w$nrCLSll7 z$JD&i+|;7PlGK%qMWEVaC3BH3D5vXz2z`)a8H)@+Y>@3mAS*$*7*rUuF*5zHQYOMm zlmH)q$SWo)@`{u}-c$t<njiv{ups3h-0LRjAzowzQfUk#OhJShh%g5csCr<*g%J|$ zu%rhv8l3cWIvGkBvKZSL+ZocBK#ignjuKEpgeT(?hAie3re0{0(_}_56O^1mB`G*L zgW?curWzzU*D!T3bTCRXWV02iL0AhIYM2%><}sx(KoYPfW04Z5`e3}p3eJ~BppejH zx+Rbk4@;8qxtV#T#gH@ss)dT+KD7je4%nv*Yz*L#(PX^EngUAXnk=_iOY)17Gj4Gv z7RQ&Q<`w4`6@k2bizO#NFTDs9c$!Q_AQNu!BiU7uQ(9bv8lE6E#h~hkjfv@>03+9b z7Dkr;FF{q|N~T-vpz<;&HSZR0L1IcuW?p)HSz=CUswUShQE)mf$<Hl@6kX8##sbPs zMWA~47E5+&W$`W6qQt!P)LYCssd+`*pbC*Kv7jI|FQv#D#9=KiNlebxWW6N{vIvq_ zkqY=CP=<*T1ZQRBd>ADUEm`A19*Qr?hXi7LVonaYVknLh#i{_D{4`n7awIsIitJ#i z>=r-RYlwP41r!e;f`OBRk%du$iI0iv9~(26WMN|DVdDD7!o<eF^p}m5hpEV$fq?;~ zg#s#BK}iUdl!_Y|7#P4!NX8T<P<gt5aUlb!a;;_RU|7JE!n}~N$fSc|0doq=LdF=T zT2N)llFk^+kiuHSl+9M;)WHDam9Q*eEn(|mSje=1eIY|FYYJNmQwu{4E7&Af5FIO2 z%TmKw!coJL&d|(c&QQx(m{J2WhaJ>R{WSqx8DZ42&^oOWEDPmirWS(&7#ueWRtiO_ z$)!cbpeh9v(_n?Db(&9RaS5pUfoRCgE6z+w1=mGT^C}fG^Ark-QqwXk6!MEwQj5T5 zpkz8w;)7>oP{4yTolGqwq|QiTs9_XmNMRIb0EIv^BP6S>WV*#r#S$NrqP>#2NDUO3 z>fjWX3#$EBGJy+YaKx`<yu|@?3pha_a-0_^rXe|wiJ{66rV8$V<TQw_F8Iaj?(gVh z^^4Ie1j9ohV}n8FgEBeTLjtvoHH;|?*_=fzHBe6!Nr4;=QUP`GE#~CJ5{N57q6nu3 zfs6trEj9)w1`Y;4O-^vGD{=*Snya`Zv8bdN)GoQjT3nEmS#pcDC^bE`q9_0)EdXZ6 zL%o-oT6~Kwy(qu5p!k+7RIC!KmsheCfqZaF7^(_6Y20EiN-e&{Ql4LwQWOrdga@jq zGCsK!>I0CQAwGy@U|?_p`2f@kX5bKG<YMGwWMkxDWMdLw<YVMu6ksd@sX+-vP$mNT z7~Isi05?-=8EZf-SFKtmP^-3v$%O%&@R>^(;0d3(P^5-&0VAl5$;eR10m=j`8LRd~ zQXQlXR|zS+(A#?`MG~UnmzI+V>Tjl|AU6TArD@cBK}<VTfoQWenZV&w<N<0?utxF0 z0~=gc-eQGN957cG#e>onDDu$L5-2buKrsRCb+NH9f{Hj$`13Fpf#h%(WlIPbWp1cN z8K~GwVXa|V$h3ehg>4~YCj+E7OJN5WZfOvCP^LzeFJWK6TEYPq=L8pb95{+QE_kdn z7IuJQI)yWsL6fV<2UI458cNUxF9TU6Opz}r5dA?!Acz1bXVhF-0t!h`P6y@6Dph!_ zl;lHtFW3r=q7aZ_p$rTRRe~DPiMmyex=}UpvD%uTvg{?u)f5*dMPVQ-z)l4d@Fqqb z$i1LE1#&N_6yZjcB1NENQ4|T%2`Wh;MF%UmpeTw0iAIBn7*Jjp#PDcQ97qyu89120 zHfAy~Fyw%21eJ2ITrUL9_d>`yAEb^pIll-r)T7B%1hN5~cZyQr?gTp)<W6V<0Nl0# z=c5*oo3Z4NEKsrnB_&w00+mpp*auhjb<CiSR4o(M>Yf>1-5b>~L8^O3h6&7h9w{s} z%-L+keqedlg^W-Z+d{^o7_drqP``NsbDkVn1;;|h35-RrKsg7gLBNv23C<xb;06Ip z4ND4_Bts2L3b!N!6GJCsCurP=C7WpiW06xkV;W-$4_ae{8S08F9gGW@Qb4YNut2VO zfN(_~Uky_UM+ajH3%H5F)XtQ~n8J(QtO<<8Y8^}v)8K42u<MaR0BRyr3OmHayd_}w z@!>SH3g*HIjKvT(#LPCZE5Tz7P%{}}W`bSGkJH36NUnsiAtv4cn+OhlsEJH4e@|e{ zn*fdtfrX3{n2V|~VndLq*kCHU2KKEG4!1ECXTW_6V}snrR8#_X8*bkU<1}#>+_y}{ z5H`fbV_*}p`W6}|A~?)sF6M)~62=CZ$y_9bVJ2$I5yfd@0^F6%#Sk{c#2m1R$o__> zoEoNJ22HW5!{B-zXG2%PN};%<2-IB7&df_u0F4+sCzdGGC^$n~2F{5k8Yx8z`9%sT zIgnu@)LPvsKR+imF;5``+yw&}RGg7mkP04WDz3~;%Fh9hSb>^FI<RpX(1?(sLRx;2 zLP3790=Qq8TC9+mm!e<*7R^gdhlzq+k5UwZiW`#K#zfRy^Fd`CsO$pO1XX$%1v<RH zin9s}xy784nRtsiDZj$)7E?+QhO<GYgPI`4ppuqw3m9jmwF2Z+M5R>(sv?WhL1i(B z1uKiG5m2B3ot)IPl6X)j2GVi@RS@7|o3$WUfV>0>Fbuy^8bvD^Z*di6re|QadNWO} zF3N*fiEz*Ermc+ym^L0nwlRt+B?lvjL19mg<a7#T6Cycb*h6VlVhhjnAUi=Gph@B> z!V;cWX=@{vG;^1>He&H3XrdG{lFGoufZQR5G-W}3B}mH^)C(>`G+(7~mv)-0;4W_z z2W+@5N(5OkOtdHgl!B8$1gLP+WCrUh$_EJ)f(UR<Dq>(@m<|dRkU|D_7Dh<FSB8;~ znTHY7?-gR=ViIGNVdP=rV-jOjV&q{GV1dvuz7UE^kP4U#x=KE#B9MI;{Wnlb2e(f_ zt=DvhPKFvr7lv3Hq#m?fEeojsoWj@)>QD34veqy!V60&UwL1%y3Pm~?7Jz4uKs{=( z7)J>UxQ)w{!ra190v?!x%^B4&q_Cv0_A&=EXtGs(1lML*d)0Vma=<O{LE1G}hsHl> z5(wJ}0QQa>cq~~2JbqEb5X)1`*a4pUfDB_nY8h}Mf{s~fgSz;PQCy($HpoB)L=8Ae zfhG@rfsz!gFsV{c%gjrOFUT*B&r3~@&rC_JNG--%XM-xsBG57dO?Gfff%Jz!!?H!7 zc^mL(ZIK%&@w$VCu-VYYXyHQ(pi%LnGzJC+s1v~hf8g|giwQP%RSYr)<ltMvSe*9; z<N@$V5Ib!Afsc`oQHW8Bk%@une^D98ewiY0FBH`Cyv3SVnp;qLiz%-FGK|at9t+pY z%Y)2pgOWF7dK=UsgUoS*O7SAl<jySxq)BK!$RxBLsNn}+Dp14;>eEVNQ3moFXaNDZ zlYC1MSp|3+8xm%q!E{LJ0vZd3jMIRIQXsufP_~1FH^^I%s02kfgar;G4x3!?A_hB9 i6cvLe5I7hiP=S#L0+|H_SU5Q3IK((aI3zfPIYI$?tJ5w3 diff --git a/SuperTagger/__init__.py b/SuperTagger/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/SuperTagger/__pycache__/eval.cpython-38.pyc b/SuperTagger/__pycache__/eval.cpython-38.pyc deleted file mode 100644 index f5cec815f7dbc90d4ab075b8e48682a5c6e119fd..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1899 zcmWIL<>g{vU|^^`Tb-1}%E0g##6iX^3=9ko3=9m#H4F?4DGVu$ISf%Cnkk1NmobWw z5yWQBVajEWVg|EWa#(U%qgZp<qS(NE)*SX+jwlW=n=ywoiZg{Fg)N6Gmph6(mnVt` zEY6<8o68r)$H?H$kiwC|*}{;*nabSE9L1l)mBQV^5G9bp7|fu_^AhAPKTXD4Y$2(6 z#rZ`bmT!JaX-=vp(=Cp?oSgWa{NmzUyvarR#l`Wdc_l^p1(lkNw>T0@@^j;J6ALt% zZt<t5mc)bj#qmjrCCM4LgwrzfQsN8pi{tZB)8jK!QY%u6H5qU56eOm^7pE4MrsgH5 zCW9<S#>`Mou?hnNLn=cQV+unQQ#(T%V+vymQwv8Fa|&|`OAA93OFIJ#LlkQ;gC^@O zp5V;9?2P=PJfHmH;$%js2_Oo@2AK%L#X<}W3?&RDj5Q3+ObeJ6GJpaom_d`#Pm}o; zYjJ5oYEcm=!ftWI$7kkcmc++vGTve-PR&VM$#{!1K0Y}ovA8%sekH>%XZ?))+*JMK zl+=R!(xMW5r_!R-#8Ul~)YO9XqQuOc_#i*O`0~t>jQE_)yzJB>{a}#cA&KefsYUv! zWr;a@1(iiS3=9mspx9>uJG4p^*%Nv&CnSSh4pIQZAT}!l1A{Y2n*`W5;7|eiiLr*M znaPErnXwrZq0D|L!2r_84$=g|dJGH<H4L$8F-*0LwM-=pHH-@wL8)&cQw`Gs#)S+i zOrRj{WeH<oWGG}x0R=Hz5hx&Ef{0rzxv7bHFaQ7l|6h~o7H3gvN@;RQW`5o+p3L;T z{G!zO%)FG;3O`MzB2e<a#h#XvSW=Rjr^$SaNzdRG3n)Ntu@s~h<rWDrFfc#}K?Vkf zTbwrOsd=dt1x0osZZSwV3nSbADs_}_Ny{%PPb^B&v&qR%PRuE`(}U~MWV*#%T$+1} zxhk_jleq{Kw70kmic(WD!R{-LVk;>sEy*Z`SP8ZqWYH~lsF5I%Vm44PGjcF;{I3$h zX0Ik2D0So|=BCES-{Ojo&jsg^`1o5q@$rSFi8&A%aNHKjgM7scO*hG@IXUt1NOtpq zd=GK}IL(0)1K7(Pj6BR>kz{bP06B+&fdLfhpa20!d=WU}7ckZ^f)epU##*Kn#u}y? z#u|ofW*deSrW&Rih7{&(CL4wtrUlFk85XcCWT<5>VXa|+h)XgwGuE=yFxRk1fa0EM zAwvqIB#gx*$sht_vw+zw5)9G|plnw4mP<iFK>?h%<4cN4Q^BdbSiwp`10q(OS(U1z z04f3E%kq;Glj0!)xrr6fVj(jxK0hxtJ~^>OQz1D&uOu-uuUMfZBUJ&MW)$+%z?{6) z^u&_PvQ)51#b7%VD>CzQA=br%qLdzXf|Gm^sLUu5VPIf*35pL*wp*+v`9;YYw^;J? zQj2e~fP7da1Bz&t%#zgHTg*ABd77-Z1Rw!|2&7vA=w688P0lYWN=+_-cq>XUuQWFv zRY(9Nl$%(botgp(K&0dYN;$Xqb5n~;i&Englaot}5|b-MLD?IWlo(j~7<s@@fVoJA zfq_9&ut*h@;6R?Z#R2w$US3`is9d-u4CCshmF9u-N@5N;^A>@U-7RUP!b{HwT%PHH z<2X05pa_&-ZZQ?5-jcwgs<b3Cr??1I?i6u=f=Ccq4MdU?<S0;~3w9-v|3JC12<%A? c8%QRz1C>7DGMoeAGe#ap4kiIc9%cbf0Oxf1T>t<8 diff --git a/SuperTagger/__pycache__/utils.cpython-38.pyc b/SuperTagger/__pycache__/utils.cpython-38.pyc deleted file mode 100644 index 9e66bb4e0c44377cdeed2d562db8eef201427af3..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1851 zcmWIL<>g{vU|^_tP?1!@#=!6x#6iZ)3=9ko3=9m#J`4;DDGVu$ISf${nlXwA%x8{b zPGLx4%3;Z6iDF@7Na0OkO<`+cjABh?OJQ$jOyNjlO5sf5YGI6GPvK7CX<>-sNMQ_S z(Byjwvd2%8@fJ%-VM=bu%L$AO49P4YK`7>CU|?WlU|?_t**}kgfuV*Wg;A1W0mDLu zT9z7?6s8o0RwhY?TGkrIBE1sE8ir=Z8s;>n8dh<JTDCla8fJ(XOa!7kiz$n_hOLib z0ZR?jLdIJ55>}8t7lv4y7^Yf|TFx5I1q>;S3mI#<Y8Z>!YB+1yYdBIEYPedNQkc`2 zZ5Rq;Kw=P)8V<M^doY70i(e5V0|UcL5TU2Zc8jGrvnur#b53gBEtZtTlEhogd8y^M zSkp3da^i1srskFArWPfZq-t{C;w(-rEKSWzPA$I0os?LToDrXvSyWtdi?<*#B_%U2 zJ-#e4r!@5zdv0PyJjgf>5Cdd-6mLmUVrC9VNlIpJ@hy)0(vo-(D~b!mDoM>N&M&&f zn0bp0%4Ew)%}X!ISjkW%!oa}r%SAsUKQ~oBIVH6qzqF`C->I}HHL+AbB{j7m9b|BP zke^?Cd1gsQd`@Owc50D+aA`qmQAlEXdTNn=X-Q^Iv0g#tEuMnJlz5mII6$!_0g6IK z9!4HU7Dgc^0VXj<0Y)w+7Df>!5vC$B1_p*(%$a#<nO{Ir4oYGm<sb~w%n6I~T?`Bi zB@8Ky&5Vp7nh8uZrZA_lv@nz~)i8h(S_*qFvkOBrBPgLsFw`=qgOUbw4P!P-kxLC@ zEprWHEmH||4KpYS_A>Q5)UuYaWU<yTl(1y6fp7}<9Httk8s-$9UKSUISeaV38kPm@ zHEatR85wF=7jTqt*041*f_Q~_;S8n>feb+m5g^@MDZC(e)UbeTPvJ}9?`5iGFJVX# z0J*XTlpF;8szkYzv4AQ*pZr8nv??SQm+9&0C4&+S$oC-3#K6D+!o{Fq6apvuSgu;e z6oxPcMutKraMoce0%a3Fa1?4X-D1)+xWyO`=0ONiPyln<q^IVkRumN3DS(Uu>0n`C zVyIFqN=;0OPcANtFGx>HEYh>d$xlwqDYnyttG~rqbc?05B+WvT@fJsFUUELjtzavw zxP$V`ixog7SSe^S-C`}tFG<X~#gdX*oUF-oi=((SDZZpAHT4!_lqT~nuH^i@vecrI z)S}{BoXHubdD);Ody6%%D8Iaz7*E$QLOsn8%%I5#^KTI-aou7}gZLGcjEg`XE|Owk zV9-JKDmXWS%9q57%;G8$EHR<1$$g8ppeQr1<Q8*rY3?nS{DRcHB2e((Vk$_v#R2jt zD5!3+q?RS-++s^CC`ipqxy4pmkdj!ES_CSbZm}ij=OrhWXmUhxfCDwLB%?@$fq|h& z859dFAiqX&fqfsJlUZDHi#aK==oTwj;1*K~x-V|=!R&-Yj||8|AcBFBgOP(tfsuuY zgOP=ig^7<*fsu<*ib;ZzgQ>`bfq?-dKtVAME;>QlOBf)9VKZYbV+m6Ua~4YtV+vz0 zV>75IW~#F1QczG(2uaLNEmlZWD9OxCRmjX!C{9hz&r2y*NX$!7C`v6UEy@GQ7MB!d z=A|oSWLV|qS``<Ab!sx+Vuz~MWCi8zl3T1r`K5U&x0s7dif(a$f(vBoElv=blA2SJ zsL6DTJvAq>pg1)piXX~|2kS{qL5td3+-dnmxrrt5Ak9Xg5CQp(ft823$cBM|L6bQI z9Mwgjl0=iU$PyIrYz2vVDT&2JS|AQ<Nq$js1~_6Nr4MUbPJUtuINZQSAXx}cP`5a2 fa`RJ4b5iX<`LmdZfq{X8iGz`Yk%yT}L`Vz(G)BcO diff --git a/main.py b/main.py new file mode 100644 index 0000000..bdf7bc2 --- /dev/null +++ b/main.py @@ -0,0 +1,20 @@ +import torch.nn.functional as F + +from Configuration import Configuration +from Linker.Linker import Linker + +max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence']) + +# categories tagger +tagger = SuperTagger() +tagger.load_weights("models/model_check.pt") + +# axiom linker +linker = Linker() +linker.load_weights("models/linker.pt") + +# predict categories and links for this sentence +sentence = [[]] +categories, sentence_embedding = tagger.predict(sentence) + +axiom_links = linker.predict(categories, sentence_embedding) diff --git a/train.py b/train.py index f8290f8..a37aaa4 100644 --- a/train.py +++ b/train.py @@ -1,21 +1,20 @@ -import pickle +import os import time +from datetime import datetime import numpy as np import torch -import torch.nn.functional as F -from torch.optim import SGD, Adam, AdamW +from torch.optim import AdamW from torch.utils.data import Dataset, TensorDataset, random_split -from transformers import get_cosine_schedule_with_warmup +from transformers import (get_cosine_schedule_with_warmup) from Configuration import Configuration -from SuperTagger.Linker.AtomTokenizer import AtomTokenizer -from transformers import (AutoTokenizer, get_cosine_schedule_with_warmup) -from SuperTagger.Linker.Linker import Linker -from SuperTagger.Linker.atom_map import atom_map -from SuperTagger.Linker.utils import get_axiom_links, get_atoms_batch, find_pos_neg_idexes -from SuperTagger.eval import SinkhornLoss -from SuperTagger.utils import format_time, read_csv_pgbar +from Linker.AtomTokenizer import AtomTokenizer +from Linker.Linker import Linker +from Linker.atom_map import atom_map +from Linker.utils_linker import get_axiom_links, get_atoms_batch, find_pos_neg_idexes +from Linker.eval import SinkhornLoss +from utils import format_time, read_csv_pgbar torch.cuda.empty_cache() @@ -63,7 +62,6 @@ print("atoms_batch", atoms_batch[14]) print("atoms_polarity_batch", atoms_polarity_batch[14]) print(" truth_links_batch example on a sentence class txt", truth_links_batch[14][16]) - # Construction tensor dataset dataset = TensorDataset(atoms_batch_tokenized, atoms_polarity_batch, truth_links_batch) @@ -82,6 +80,9 @@ validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batc # region Models +# supertagger = SuperTagger() +# supertagger.load_weights("models/model_check.pt") + linker = Linker() # endregion Models @@ -115,6 +116,7 @@ torch.autograd.set_detect_anomaly(True) total_t0 = time.time() validate = True +checkpoint = True def run_epochs(epochs): @@ -141,14 +143,16 @@ def run_epochs(epochs): # For each batch of training data... for step, batch in enumerate(training_dataloader): # Unpack this training batch from our dataloader - batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu") batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu") batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu") - #batch_sentences = batch[3].to("cuda" if torch.cuda.is_available() else "cpu") + # batch_sentences = batch[3].to("cuda" if torch.cuda.is_available() else "cpu") optimizer_linker.zero_grad() + # get sentence embedding from BERT which is already trained + # sentences_embedding = supertagger(batch_sentences) + # Run the kinker on the categories predictions logits_predictions = linker(batch_atoms, batch_polarity, []) @@ -169,6 +173,10 @@ def run_epochs(epochs): # Measure how long this epoch took. training_time = format_time(time.time() - t0) + if checkpoint: + checkpoint_dir = os.path.join("Output", 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M')) + linker.__checkpoint_save(path=os.path.join(checkpoint_dir, 'model_check.pt')) + if validate: linker.eval() with torch.no_grad(): diff --git a/SuperTagger/utils.py b/utils.py similarity index 100% rename from SuperTagger/utils.py rename to utils.py -- GitLab