diff --git a/SuperTagger/Linker/AtomEmbedding.py b/Linker/AtomEmbedding.py similarity index 100% rename from SuperTagger/Linker/AtomEmbedding.py rename to Linker/AtomEmbedding.py diff --git a/SuperTagger/Linker/AtomTokenizer.py b/Linker/AtomTokenizer.py similarity index 95% rename from SuperTagger/Linker/AtomTokenizer.py rename to Linker/AtomTokenizer.py index a771eef0a83e31fd5e0f77449aec12f09b6e5c3d..568b3a5e3c8fb66058192ab5d005ab5cf41330c4 100644 --- a/SuperTagger/Linker/AtomTokenizer.py +++ b/Linker/AtomTokenizer.py @@ -1,6 +1,5 @@ import torch - -from SuperTagger.utils import pad_sequence +from ..utils import pad_sequence class AtomTokenizer(object): diff --git a/Linker/Linker.py b/Linker/Linker.py new file mode 100644 index 0000000000000000000000000000000000000000..f65325eef04bc224f025f209f99f9d1a6f653207 --- /dev/null +++ b/Linker/Linker.py @@ -0,0 +1,221 @@ +import torch +from torch.nn import Sequential, LayerNorm, Dropout +from torch.nn import Module +import torch.nn.functional as F +import sys +from Configuration import Configuration +from AtomEmbedding import AtomEmbedding +from AtomTokenizer import AtomTokenizer +from MHA import AttentionDecoderLayer +from atom_map import atom_map +from Sinkhorn import sinkhorn_fn_no_exp as sinkhorn +from utils_linker import find_pos_neg_idexes, get_atoms_batch, FFN +from eval import mesure_accuracy +from ..utils import pad_sequence + + +class Linker(Module): + def __init__(self): + super(Linker, self).__init__() + + self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder']) + self.dim_polarity_transfo = int(Configuration.modelLinkerConfig['dim_polarity_transfo']) + self.dim_embedding_atoms = int(Configuration.modelLinkerConfig['dim_embedding_atoms']) + self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters']) + self.nhead = int(Configuration.modelLinkerConfig['nhead']) + self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence']) + self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence']) + self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type']) + self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size']) + self.dropout = Dropout(0.1) + self.device = "" + + self.atom_map = atom_map + self.padding_id = self.atom_map['[PAD]'] + self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence) + self.atoms_embedding = AtomEmbedding(self.dim_embedding_atoms, self.atom_vocab_size, self.padding_id) + + # to do : definit un encoding + self.linker_encoder = AttentionDecoderLayer() + + self.pos_transformation = Sequential( + FFN(self.dim_embedding_atoms, self.dim_polarity_transfo, 0.1), + LayerNorm(self.dim_embedding_atoms, eps=1e-12) + ) + self.neg_transformation = Sequential( + FFN(self.dim_embedding_atoms, self.dim_polarity_transfo, 0.1), + LayerNorm(self.dim_embedding_atoms, eps=1e-12) + ) + + def make_decoder_mask(self, atoms_token): + decoder_attn_mask = torch.ones_like(atoms_token, dtype=torch.float64) + decoder_attn_mask[atoms_token.eq(self.padding_id)] = 0.0 + return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_token.shape[1], 1).repeat(self.nhead, 1, 1) + + def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding, sents_mask=None): + r''' + Parameters : + atoms_batch_tokenized : (batch_size, max_atoms_in_one_sentence) flattened categories + atoms_polarity_batch : (batch_size, max_atoms_in_one_sentence) flattened categories polarities + sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context + sents_mask + Returns : + link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat) + ''' + + # atoms embedding + atoms_embedding = self.atoms_embedding(atoms_batch_tokenized) + + # MHA ou LSTM avec sortie de BERT + sents_embedding = torch.randn(32, self.max_len_sentence, self.dim_encoder) + batch_size, len_sentence, sents_embedding_dim = sents_embedding.shape + sents_mask = torch.randn(batch_size * self.nhead, self.max_atoms_in_sentence, self.max_len_sentence) + atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask, + self.make_decoder_mask(atoms_batch_tokenized)) + + link_weights = [] + for atom_type in list(self.atom_map.keys())[:-1]: + pos_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx]) + if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and + atoms_batch_tokenized[s_idx][i] == self.atom_map[ + atom_type] and + atoms_polarity_batch[s_idx][i])] + [ + torch.zeros(self.dim_embedding_atoms)]) + for s_idx in range(len(atoms_polarity_batch))], padding_value=0, + max_len=self.max_atoms_in_one_type // 2) + + neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx]) + if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and + atoms_batch_tokenized[s_idx][i] == self.atom_map[ + atom_type] and + not atoms_polarity_batch[s_idx][i])] + [ + torch.zeros(self.dim_embedding_atoms)]) + for s_idx in range(len(atoms_polarity_batch))], padding_value=0, + max_len=self.max_atoms_in_one_type // 2) + + pos_encoding = self.pos_transformation(pos_encoding) + neg_encoding = self.neg_transformation(neg_encoding) + + weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1)) + link_weights.append(sinkhorn(weights, iters=3)) + + return torch.stack(link_weights) + + def predict(self, categories, sents_embedding, sents_mask=None): + r''' + Parameters : + categories : (batch_size, len_sentence) + sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context + sents_mask + Returns : + axiom_links : atom_vocab_size, batch-size, max_atoms_in_one_cat) + ''' + self.eval() + + batch_size, len_sentence, sents_embedding_dim = sents_embedding.shape + + # get atoms + atoms_batch = get_atoms_batch(categories) + atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch) + + # get polarities + polarities = find_pos_neg_idexes(self.max_atoms_in_sentence, categories) + + # atoms embedding + atoms_embedding = self.atoms_embedding(atoms_tokenized) + + # MHA ou LSTM avec sortie de BERT + atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask, + self.make_decoder_mask(atoms_tokenized)) + + link_weights = [] + for atom_type in list(self.atom_map.keys())[:-1]: + pos_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx]) + if (self.atom_map[atom_type] in atoms_tokenized[s_idx] and + atoms_tokenized[s_idx][i] == self.atom_map[ + atom_type] and + polarities[s_idx][i])] + [ + torch.zeros(self.dim_embedding_atoms)]) + for s_idx in range(len(polarities))], padding_value=0, + max_len=self.max_atoms_in_one_type // 2) + + neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx]) + if (self.atom_map[atom_type] in atoms_tokenized[s_idx] and + atoms_tokenized[s_idx][i] == self.atom_map[ + atom_type] and + not polarities[s_idx][i])] + [ + torch.zeros(self.dim_embedding_atoms)]) + for s_idx in range(len(polarities))], padding_value=0, + max_len=self.max_atoms_in_one_type // 2) + + pos_encoding = self.pos_transformation(pos_encoding) + neg_encoding = self.neg_transformation(neg_encoding) + + weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1)) + link_weights.append(sinkhorn(weights, iters=3)) + + logits_predictions = torch.stack(link_weights).permute(1, 0, 2, 3) + axiom_links = torch.argmax(F.softmax(logits_predictions, dim=3), dim=3) + return axiom_links + + def eval_batch(self, batch, cross_entropy_loss): + batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu") + batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu") + batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu") + # batch_sentences = batch[3].to("cuda" if torch.cuda.is_available() else "cpu") + + logits_axiom_links_pred = self.forward(batch_atoms, batch_polarity, []) + logits_axiom_links_pred = logits_axiom_links_pred.permute(1, 0, 2, 3) + axiom_links_pred = torch.argmax(F.softmax(logits_axiom_links_pred, dim=3), dim=3) + + accuracy = mesure_accuracy(batch_true_links, axiom_links_pred) + loss = cross_entropy_loss(logits_axiom_links_pred, batch_true_links) + + return accuracy, loss + + def eval_epoch(self, dataloader, cross_entropy_loss): + r"""Average the evaluation of all the batch. + + Args: + dataloader: contains all the batch which contain the tokenized sentences, their masks and the true symbols + """ + accuracy_average = 0 + loss_average = 0 + compt = 0 + for step, batch in enumerate(dataloader): + compt += 1 + accuracy, loss = self.eval_batch(batch, cross_entropy_loss) + accuracy_average += accuracy + loss_average += loss + + return accuracy_average / compt, loss_average / compt + + def load_weights(self, model_file): + print("#" * 15) + try: + params = torch.load(model_file, map_location=self.device) + args = params['args'] + self.atom_map = args['atom_map'] + self.max_atoms_in_sentence = args['max_atoms_in_sentence'] + self.atoms_tokenizer = AtomTokenizer(self.atom_map, self.max_atoms_in_sentence) + self.atoms_embedding.load_state_dict(params['atoms_embedding']) + self.linker_encoder.load_state_dict(params['linker_encoder']) + self.pos_transformation.load_state_dict(params['pos_transformation']) + self.neg_transformation.load_state_dict(params['neg_transformation']) + print("\n The loading checkpoint was successful ! \n") + except Exception as e: + print("\n/!\ Can't load checkpoint model /!\ because :\n\n " + str(e), file=sys.stderr) + raise e + print("#" * 15) + + def __checkpoint_save(self, path='/linker.pt'): + self.linker.cpu() + + torch.save({ + 'args': dict(atom_map=self.atom_map, max_atoms_in_sentence=self.max_atoms_in_sentence), + 'atoms_embedding': self.atoms_embedding.state_dict(), + 'linker_encoder': self.linker_encoder.state_dict(), + 'pos_transformation': self.pos_transformation.state_dict(), + 'neg_transformation': self.neg_transformation.state_dict() + }, path) + self.linker.to(self.device) diff --git a/SuperTagger/Linker/MHA.py b/Linker/MHA.py similarity index 93% rename from SuperTagger/Linker/MHA.py rename to Linker/MHA.py index d85d5e03b29ad33077224bb19f90c44d7b3d630f..c1554f9a3454a8be0ed66917824e49534bb01f6a 100644 --- a/SuperTagger/Linker/MHA.py +++ b/Linker/MHA.py @@ -1,13 +1,8 @@ -import copy -import torch -import torch.nn.functional as F -import torch.optim as optim -from Configuration import Configuration -from torch import Tensor, LongTensor -from torch.nn import (GELU, LSTM, Dropout, LayerNorm, Linear, Module, MultiheadAttention, - ModuleList, Sequential) +from torch import Tensor +from torch.nn import (Dropout, LayerNorm, Module, MultiheadAttention) -from SuperTagger.Linker.utils import FFN +from Configuration import Configuration +from utils_linker import FFN class AttentionDecoderLayer(Module): diff --git a/SuperTagger/Linker/Sinkhorn.py b/Linker/Sinkhorn.py similarity index 100% rename from SuperTagger/Linker/Sinkhorn.py rename to Linker/Sinkhorn.py diff --git a/Linker/__init__.py b/Linker/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SuperTagger/Linker/atom_map.py b/Linker/atom_map.py similarity index 100% rename from SuperTagger/Linker/atom_map.py rename to Linker/atom_map.py diff --git a/SuperTagger/eval.py b/Linker/eval.py similarity index 81% rename from SuperTagger/eval.py rename to Linker/eval.py index 2731514885b6da2bd84661d8c6c2149ad1645b9d..1113596e276a190edfc49ac50ce511ad64b4e6c8 100644 --- a/SuperTagger/eval.py +++ b/Linker/eval.py @@ -1,12 +1,6 @@ import torch -from torch import Tensor from torch.nn import Module -from torch.nn.functional import nll_loss, cross_entropy -from SuperTagger.Linker.atom_map import atom_map -import re - -from SuperTagger.Linker.utils import get_atoms_batch, find_pos_neg_idexes -from SuperTagger.utils import pad_sequence +from torch.nn.functional import nll_loss class SinkhornLoss(Module): diff --git a/SuperTagger/Linker/utils.py b/Linker/utils_linker.py similarity index 97% rename from SuperTagger/Linker/utils.py rename to Linker/utils_linker.py index abd6814fc0bc8ae839b8efe40d3e50a8921cbfb1..f968984872d4513c0b31ae5ca6e5fc06ced70da0 100644 --- a/SuperTagger/Linker/utils.py +++ b/Linker/utils_linker.py @@ -1,12 +1,10 @@ import re import regex -import numpy as np import torch -from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU, MultiheadAttention +from torch.nn import Sequential, Linear, Dropout, GELU from torch.nn import Module -from SuperTagger.Linker.AtomTokenizer import AtomTokenizer -from SuperTagger.Linker.atom_map import atom_map -from SuperTagger.utils import pad_sequence +from atom_map import atom_map +from ..utils import pad_sequence class FFN(Module): diff --git a/SuperTagger/Linker/Linker.py b/SuperTagger/Linker/Linker.py deleted file mode 100644 index 93028fdeaf6cc7f1cc978e796d56d73fd0ff6b5a..0000000000000000000000000000000000000000 --- a/SuperTagger/Linker/Linker.py +++ /dev/null @@ -1,130 +0,0 @@ -from itertools import chain - -import torch -from torch.nn import Sequential, LayerNorm, Linear, Dropout, GELU, MultiheadAttention -from torch.nn import Module -import torch.nn.functional as F - -from Configuration import Configuration -from SuperTagger.Linker.AtomEmbedding import AtomEmbedding -from SuperTagger.Linker.AtomTokenizer import AtomTokenizer -from SuperTagger.Linker.MHA import AttentionDecoderLayer -from SuperTagger.Linker.atom_map import atom_map -from SuperTagger.Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn -from SuperTagger.Linker.utils import find_pos_neg_idexes, get_atoms_batch, FFN -from SuperTagger.eval import mesure_accuracy -from SuperTagger.utils import pad_sequence - - -class Linker(Module): - def __init__(self): - super(Linker, self).__init__() - - self.dim_encoder = int(Configuration.modelEncoderConfig['dim_encoder']) - self.dim_polarity_transfo = int(Configuration.modelLinkerConfig['dim_polarity_transfo']) - self.dim_embedding_atoms = int(Configuration.modelLinkerConfig['dim_embedding_atoms']) - self.sinkhorn_iters = int(Configuration.modelLinkerConfig['sinkhorn_iters']) - self.nhead = int(Configuration.modelLinkerConfig['nhead']) - self.max_len_sentence = int(Configuration.datasetConfig['max_len_sentence']) - self.max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence']) - self.max_atoms_in_one_type = int(Configuration.datasetConfig['max_atoms_in_one_type']) - self.atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size']) - self.dropout = Dropout(0.1) - - self.atom_map = atom_map - self.padding_id = self.atom_map['[PAD]'] - self.atom_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence) - self.atom_embedding = AtomEmbedding(self.dim_embedding_atoms, self.atom_vocab_size, self.padding_id) - - # to do : definit un encoding - self.linker_encoder = AttentionDecoderLayer() - - self.pos_transformation = Sequential( - FFN(self.dim_embedding_atoms, self.dim_polarity_transfo, 0.1), - LayerNorm(self.dim_embedding_atoms, eps=1e-12) - ) - self.neg_transformation = Sequential( - FFN(self.dim_embedding_atoms, self.dim_polarity_transfo, 0.1), - LayerNorm(self.dim_embedding_atoms, eps=1e-12) - ) - - def make_decoder_mask(self, atoms_token) : - decoder_attn_mask = torch.ones_like(atoms_token, dtype=torch.float64) - decoder_attn_mask[atoms_token.eq(self.padding_id)] = 0.0 - return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_token.shape[1], 1).repeat(self.nhead, 1, 1) - - def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding, sents_mask=None): - r''' - Parameters : - atoms_batch_tokenized : (batch_size, max_atoms_in_one_sentence) flattened categories - atoms_polarity_batch : (batch_size, max_atoms_in_one_sentence) flattened categories polarities - sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context - sents_mask - Returns : - link_weights : atom_vocab_size, batch-size, max_atoms_in_one_cat, max_atoms_in_one_cat) - ''' - - # atoms embedding - atoms_embedding = self.atom_embedding(atoms_batch_tokenized) - print(atoms_embedding.shape) - - # MHA ou LSTM avec sortie de BERT - sents_embedding = torch.randn(32, self.max_len_sentence, self.dim_encoder) - batch_size, len_sentence, sents_embedding_dim = sents_embedding.shape - sents_mask = torch.randn(batch_size * self.nhead, self.max_atoms_in_sentence, self.max_len_sentence) - atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask, self.make_decoder_mask(atoms_batch_tokenized)) - #atoms_encoding = atoms_embedding - - link_weights = [] - for atom_type in list(self.atom_map.keys())[:-1]: - pos_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx]) - if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and - atoms_batch_tokenized[s_idx][i] == self.atom_map[atom_type] and - atoms_polarity_batch[s_idx][i])] + [torch.zeros(self.dim_embedding_atoms)]) - for s_idx in range(len(atoms_polarity_batch))], padding_value=0, max_len=self.max_atoms_in_one_type//2) - - neg_encoding = pad_sequence([torch.stack([x for i, x in enumerate(atoms_encoding[s_idx]) - if (self.atom_map[atom_type] in atoms_batch_tokenized[s_idx] and - atoms_batch_tokenized[s_idx][i] == self.atom_map[atom_type] and - not atoms_polarity_batch[s_idx][i])] + [torch.zeros(self.dim_embedding_atoms)]) - for s_idx in range(len(atoms_polarity_batch))], padding_value=0, max_len=self.max_atoms_in_one_type//2) - - pos_encoding = self.pos_transformation(pos_encoding) - neg_encoding = self.neg_transformation(neg_encoding) - - weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1)) - link_weights.append(sinkhorn(weights, iters=3)) - - return torch.stack(link_weights) - - def eval_batch(self, batch, cross_entropy_loss): - batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu") - batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu") - batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu") - #batch_sentences = batch[3].to("cuda" if torch.cuda.is_available() else "cpu") - - logits_axiom_links_pred = self.forward(batch_atoms, batch_polarity, []) - logits_axiom_links_pred = logits_axiom_links_pred.permute(1, 0, 2, 3) - axiom_links_pred = torch.argmax(F.softmax(logits_axiom_links_pred, dim=3), dim=3) - - accuracy = mesure_accuracy(batch_true_links, axiom_links_pred) - loss = cross_entropy_loss(logits_axiom_links_pred, batch_true_links) - - return accuracy, loss - - def eval_epoch(self, dataloader, cross_entropy_loss): - r"""Average the evaluation of all the batch. - - Args: - dataloader: contains all the batch which contain the tokenized sentences, their masks and the true symbols - """ - accuracy_average = 0 - loss_average = 0 - compt = 0 - for step, batch in enumerate(dataloader): - compt += 1 - accuracy, loss = self.eval_batch(batch, cross_entropy_loss) - accuracy_average += accuracy - loss_average += loss - - return accuracy_average / compt, loss_average / compt diff --git a/SuperTagger/Linker/__pycache__/AtomEmbedding.cpython-38.pyc b/SuperTagger/Linker/__pycache__/AtomEmbedding.cpython-38.pyc deleted file mode 100644 index a6ce66525d13768733269e6a9b3ec0bd2c64a4d7..0000000000000000000000000000000000000000 Binary files a/SuperTagger/Linker/__pycache__/AtomEmbedding.cpython-38.pyc and /dev/null differ diff --git a/SuperTagger/Linker/__pycache__/AtomTokenizer.cpython-38.pyc b/SuperTagger/Linker/__pycache__/AtomTokenizer.cpython-38.pyc deleted file mode 100644 index e3cc2ea0978bb23bd6c75f39863a19e5f9fc27d6..0000000000000000000000000000000000000000 Binary files a/SuperTagger/Linker/__pycache__/AtomTokenizer.cpython-38.pyc and /dev/null differ diff --git a/SuperTagger/Linker/__pycache__/Linker.cpython-38.pyc b/SuperTagger/Linker/__pycache__/Linker.cpython-38.pyc deleted file mode 100644 index facf9eafa213710664a3d7f18c9150693e6c0a2c..0000000000000000000000000000000000000000 Binary files a/SuperTagger/Linker/__pycache__/Linker.cpython-38.pyc and /dev/null differ diff --git a/SuperTagger/Linker/__pycache__/MHA.cpython-38.pyc b/SuperTagger/Linker/__pycache__/MHA.cpython-38.pyc deleted file mode 100644 index 679c41a8ef96c82cce084ed1d859b6f0933c12f5..0000000000000000000000000000000000000000 Binary files a/SuperTagger/Linker/__pycache__/MHA.cpython-38.pyc and /dev/null differ diff --git a/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc b/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc deleted file mode 100644 index 26d18c0959a2b1114a5564ff8b1ec4304f81acf3..0000000000000000000000000000000000000000 Binary files a/SuperTagger/Linker/__pycache__/Sinkhorn.cpython-38.pyc and /dev/null differ diff --git a/SuperTagger/Linker/__pycache__/atom_map.cpython-38.pyc b/SuperTagger/Linker/__pycache__/atom_map.cpython-38.pyc deleted file mode 100644 index f189466986ce4136d6ed66a59cd0d49b61b4aad8..0000000000000000000000000000000000000000 Binary files a/SuperTagger/Linker/__pycache__/atom_map.cpython-38.pyc and /dev/null differ diff --git a/SuperTagger/Linker/__pycache__/utils.cpython-38.pyc b/SuperTagger/Linker/__pycache__/utils.cpython-38.pyc deleted file mode 100644 index c4eef1e07886db024496a876b846bd69646a7538..0000000000000000000000000000000000000000 Binary files a/SuperTagger/Linker/__pycache__/utils.cpython-38.pyc and /dev/null differ diff --git a/SuperTagger/__init__.py b/SuperTagger/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SuperTagger/__pycache__/eval.cpython-38.pyc b/SuperTagger/__pycache__/eval.cpython-38.pyc deleted file mode 100644 index f5cec815f7dbc90d4ab075b8e48682a5c6e119fd..0000000000000000000000000000000000000000 Binary files a/SuperTagger/__pycache__/eval.cpython-38.pyc and /dev/null differ diff --git a/SuperTagger/__pycache__/utils.cpython-38.pyc b/SuperTagger/__pycache__/utils.cpython-38.pyc deleted file mode 100644 index 9e66bb4e0c44377cdeed2d562db8eef201427af3..0000000000000000000000000000000000000000 Binary files a/SuperTagger/__pycache__/utils.cpython-38.pyc and /dev/null differ diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..bdf7bc279b00aeb15d37a2e450e1e36eaf495c42 --- /dev/null +++ b/main.py @@ -0,0 +1,20 @@ +import torch.nn.functional as F + +from Configuration import Configuration +from Linker.Linker import Linker + +max_atoms_in_sentence = int(Configuration.datasetConfig['max_atoms_in_sentence']) + +# categories tagger +tagger = SuperTagger() +tagger.load_weights("models/model_check.pt") + +# axiom linker +linker = Linker() +linker.load_weights("models/linker.pt") + +# predict categories and links for this sentence +sentence = [[]] +categories, sentence_embedding = tagger.predict(sentence) + +axiom_links = linker.predict(categories, sentence_embedding) diff --git a/train.py b/train.py index f8290f8554a28f22598b5a8780abcdb1d3ca1470..a37aaa46a6a12b8f1272914c1b6a85be627c2b5a 100644 --- a/train.py +++ b/train.py @@ -1,21 +1,20 @@ -import pickle +import os import time +from datetime import datetime import numpy as np import torch -import torch.nn.functional as F -from torch.optim import SGD, Adam, AdamW +from torch.optim import AdamW from torch.utils.data import Dataset, TensorDataset, random_split -from transformers import get_cosine_schedule_with_warmup +from transformers import (get_cosine_schedule_with_warmup) from Configuration import Configuration -from SuperTagger.Linker.AtomTokenizer import AtomTokenizer -from transformers import (AutoTokenizer, get_cosine_schedule_with_warmup) -from SuperTagger.Linker.Linker import Linker -from SuperTagger.Linker.atom_map import atom_map -from SuperTagger.Linker.utils import get_axiom_links, get_atoms_batch, find_pos_neg_idexes -from SuperTagger.eval import SinkhornLoss -from SuperTagger.utils import format_time, read_csv_pgbar +from Linker.AtomTokenizer import AtomTokenizer +from Linker.Linker import Linker +from Linker.atom_map import atom_map +from Linker.utils_linker import get_axiom_links, get_atoms_batch, find_pos_neg_idexes +from Linker.eval import SinkhornLoss +from utils import format_time, read_csv_pgbar torch.cuda.empty_cache() @@ -63,7 +62,6 @@ print("atoms_batch", atoms_batch[14]) print("atoms_polarity_batch", atoms_polarity_batch[14]) print(" truth_links_batch example on a sentence class txt", truth_links_batch[14][16]) - # Construction tensor dataset dataset = TensorDataset(atoms_batch_tokenized, atoms_polarity_batch, truth_links_batch) @@ -82,6 +80,9 @@ validation_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batc # region Models +# supertagger = SuperTagger() +# supertagger.load_weights("models/model_check.pt") + linker = Linker() # endregion Models @@ -115,6 +116,7 @@ torch.autograd.set_detect_anomaly(True) total_t0 = time.time() validate = True +checkpoint = True def run_epochs(epochs): @@ -141,14 +143,16 @@ def run_epochs(epochs): # For each batch of training data... for step, batch in enumerate(training_dataloader): # Unpack this training batch from our dataloader - batch_atoms = batch[0].to("cuda" if torch.cuda.is_available() else "cpu") batch_polarity = batch[1].to("cuda" if torch.cuda.is_available() else "cpu") batch_true_links = batch[2].to("cuda" if torch.cuda.is_available() else "cpu") - #batch_sentences = batch[3].to("cuda" if torch.cuda.is_available() else "cpu") + # batch_sentences = batch[3].to("cuda" if torch.cuda.is_available() else "cpu") optimizer_linker.zero_grad() + # get sentence embedding from BERT which is already trained + # sentences_embedding = supertagger(batch_sentences) + # Run the kinker on the categories predictions logits_predictions = linker(batch_atoms, batch_polarity, []) @@ -169,6 +173,10 @@ def run_epochs(epochs): # Measure how long this epoch took. training_time = format_time(time.time() - t0) + if checkpoint: + checkpoint_dir = os.path.join("Output", 'Tranning_' + datetime.today().strftime('%d-%m_%H-%M')) + linker.__checkpoint_save(path=os.path.join(checkpoint_dir, 'model_check.pt')) + if validate: linker.eval() with torch.no_grad(): diff --git a/SuperTagger/utils.py b/utils.py similarity index 100% rename from SuperTagger/utils.py rename to utils.py