diff --git a/Configuration/config.ini b/Configuration/config.ini index 64bba529b04897e70ebaee8328c7ad7275826921..ea8dd6979eeed3bf6cbd6f77724329ce65273c9f 100644 --- a/Configuration/config.ini +++ b/Configuration/config.ini @@ -12,7 +12,7 @@ max_atoms_in_one_type=510 dim_encoder = 768 [MODEL_DECODER] -nhead=8 +nhead=4 num_layers=1 dropout=0.1 dim_feedforward=512 @@ -26,7 +26,7 @@ dropout=0.1 sinkhorn_iters=3 [MODEL_TRAINING] -batch_size=16 +batch_size=32 epoch=30 seed_val=42 learning_rate=2e-4 diff --git a/Linker/AtomEmbedding.py b/Linker/AtomEmbedding.py deleted file mode 100644 index e7be599a0fa145f76a5646b83973a3501ed52d4d..0000000000000000000000000000000000000000 --- a/Linker/AtomEmbedding.py +++ /dev/null @@ -1,12 +0,0 @@ -import torch -from torch.nn import Module, Embedding - - -class AtomEmbedding(Module): - def __init__(self, dim_linker, atom_vocab_size, padding_idx=None): - super(AtomEmbedding, self).__init__() - self.emb = Embedding(num_embeddings=atom_vocab_size, embedding_dim=dim_linker, padding_idx=padding_idx, - scale_grad_by_freq=True) - - def forward(self, x): - return self.emb(x) diff --git a/Linker/Linker.py b/Linker/Linker.py index f3a4538d67c8f6f3a1300e9f80a793d9c4036e68..611575d7140c48516b4ce3f2cbd55bf995e04610 100644 --- a/Linker/Linker.py +++ b/Linker/Linker.py @@ -6,22 +6,20 @@ import datetime import time import torch.nn.functional as F -from torch.nn import Sequential, LayerNorm, Dropout +from torch.nn import Sequential, LayerNorm, Dropout, Embedding from torch.optim import AdamW from torch.utils.data import TensorDataset, random_split from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from Configuration import Configuration -from Linker.AtomEmbedding import AtomEmbedding from Linker.AtomTokenizer import AtomTokenizer from Linker.MHA import AttentionDecoderLayer from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn -from Linker.atom_map import atom_map +from Linker.atom_map import atom_map, atom_map_redux from Linker.eval import mesure_accuracy, SinkhornLoss -from Linker.utils_linker import FFN, get_axiom_links, get_GOAL +from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx from Supertagger import * -from utils import pad_sequence def format_time(elapsed): @@ -62,18 +60,20 @@ class Linker(Module): atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size']) learning_rate = float(Configuration.modelTrainingConfig['learning_rate']) self.dropout = Dropout(0.1) - self.device = "cpu" + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") supertagger = SuperTagger() supertagger.load_weights(supertagger_path_model) self.Supertagger = supertagger self.atom_map = atom_map - self.sub_atoms_type_list = ['cl_r', 'pp', 'n', 'np', 'cl_y', 'txt', 's'] + self.sub_atoms_type_list = list(atom_map_redux.keys()) self.padding_id = self.atom_map['[PAD]'] self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence) self.inverse_map = self.atoms_tokenizer.inverse_atom_map - self.atoms_embedding = AtomEmbedding(self.dim_embedding_atoms, atom_vocab_size, self.padding_id) + self.atoms_embedding = Embedding(num_embeddings=atom_vocab_size, embedding_dim=self.dim_embedding_atoms, + padding_idx=self.padding_id, + scale_grad_by_freq=True) self.linker_encoder = AttentionDecoderLayer() @@ -90,8 +90,6 @@ class Linker(Module): self.optimizer = AdamW(self.parameters(), lr=learning_rate) - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.to(self.device) def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1): @@ -110,12 +108,15 @@ class Linker(Module): atoms_batch, atoms_polarity_batch = get_GOAL(self.max_atoms_in_sentence, df_axiom_links["Z"]) atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch) + pos_idx = get_pos_idx(atoms_batch_tokenized, atoms_polarity_batch, self.max_atoms_in_one_type) + neg_idx = get_pos_idx(atoms_batch_tokenized, atoms_polarity_batch, self.max_atoms_in_one_type) + truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, self.sub_atoms_type_list, atoms_polarity_batch, df_axiom_links["Y"]) truth_links_batch = truth_links_batch.permute(1, 0, 2) # Construction tensor dataset - dataset = TensorDataset(atoms_batch_tokenized, atoms_polarity_batch, truth_links_batch, sentences_tokens, + dataset = TensorDataset(atoms_batch_tokenized, pos_idx, neg_idx, truth_links_batch, sentences_tokens, sentences_mask) if validation_rate > 0.0: @@ -136,11 +137,12 @@ class Linker(Module): decoder_attn_mask[atoms_token.eq(self.padding_id)] = 0.0 return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_token.shape[1], 1).repeat(self.nhead, 1, 1) - def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding, sents_mask=None): + def forward(self, atoms_batch_tokenized, batch_pos_idx, batch_neg_idx, sents_embedding, sents_mask=None): r""" Args: atoms_batch_tokenized : (batch_size, max_atoms_in_one_sentence) flattened categories - atoms_polarity_batch : (batch_size, max_atoms_in_one_sentence) flattened categories polarities + batch_pos_idx : (batch_size, atom_vocab_size, max atom in one cat) flattened categories polarities + batch_neg_idx : sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context sents_mask : mask from BERT tokenizer Returns: @@ -157,13 +159,8 @@ class Linker(Module): link_weights = [] for atom_type in self.sub_atoms_type_list: - pos_encoding = torch.stack([self.get_pos_encoding_for_s_idx(atoms_encoding, atoms_batch_tokenized, - atoms_polarity_batch, atom_type, s_idx) - for s_idx in range(len(atoms_polarity_batch))]) - - neg_encoding = torch.stack([self.get_neg_encoding_for_s_idx(atoms_encoding, atoms_batch_tokenized, - atoms_polarity_batch, atom_type, s_idx) - for s_idx in range(len(atoms_polarity_batch))]) + pos_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_pos_idx, atom_type) + neg_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_neg_idx, atom_type) pos_encoding = self.pos_transformation(pos_encoding) neg_encoding = self.neg_transformation(neg_encoding) @@ -172,9 +169,8 @@ class Linker(Module): link_weights.append(sinkhorn(weights, iters=self.sinkhorn_iters)) total_link_weights = torch.stack(link_weights) - link_weights_per_batch = total_link_weights.permute(1, 0, 2, 3) - return F.log_softmax(link_weights_per_batch, dim=3) + return F.log_softmax(total_link_weights, dim=3) def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20, batch_size=32, checkpoint=True, tensorboard=False): @@ -246,10 +242,11 @@ class Linker(Module): for batch in tepoch: # Unpack this training batch from our dataloader batch_atoms = batch[0].to(self.device) - batch_polarity = batch[1].to(self.device) - batch_true_links = batch[2].to(self.device) - batch_sentences_tokens = batch[3].to(self.device) - batch_sentences_mask = batch[4].to(self.device) + batch_pos_idx = batch[1].to(self.device) + batch_neg_idx = batch[2].to(self.device) + batch_true_links = batch[3].to(self.device) + batch_sentences_tokens = batch[4].to(self.device) + batch_sentences_mask = batch[5].to(self.device) self.optimizer.zero_grad() @@ -257,7 +254,8 @@ class Linker(Module): logits, sentences_embedding = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask) # Run the kinker on the categories predictions - logits_predictions = self(batch_atoms, batch_polarity, sentences_embedding, batch_sentences_mask) + logits_predictions = self(batch_atoms, batch_pos_idx, batch_neg_idx, sentences_embedding, + batch_sentences_mask) linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links) # Perform a backward pass to calculate the gradients. @@ -280,67 +278,24 @@ class Linker(Module): return avg_train_loss, avg_accuracy_train, training_time - def predict(self, categories, sents_embedding, sents_mask=None): - r"""Prediction from categories output by BERT and hidden_state from BERT - - Args: - categories : (batch_size, len_sentence) - sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context - sents_mask - Returns: - axiom_links : atom_vocab_size, batch-size, max_atoms_in_one_cat) - """ - self.eval() - with torch.no_grad(): - # get atoms - atoms_batch, polarities = get_GOAL(self.max_atoms_in_sentence, categories) - atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch) - - # atoms embedding - atoms_embedding = self.atoms_embedding(atoms_tokenized) - - # MHA ou LSTM avec sortie de BERT - atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask, - self.make_decoder_mask(atoms_tokenized)) - - link_weights = [] - for atom_type in self.sub_atoms_type_list: - pos_encoding = pad_sequence( - [self.get_pos_encoding_for_s_idx(atoms_encoding, atoms_tokenized, polarities, atom_type, s_idx) - for s_idx in range(len(polarities))], padding_value=0, - max_len=self.max_atoms_in_one_type // 2) - - neg_encoding = pad_sequence( - [self.get_neg_encoding_for_s_idx(atoms_encoding, atoms_tokenized, polarities, atom_type, s_idx) - for s_idx in range(len(polarities))], padding_value=0, - max_len=self.max_atoms_in_one_type // 2) - - pos_encoding = self.pos_transformation(pos_encoding) - neg_encoding = self.neg_transformation(neg_encoding) - - weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1)) - link_weights.append(sinkhorn(weights, iters=3)) - - logits_predictions = torch.stack(link_weights).permute(1, 0, 2, 3) - axiom_links = torch.argmax(F.log_softmax(logits_predictions, dim=3), dim=3) - return axiom_links - def eval_batch(self, batch): batch_atoms = batch[0].to(self.device) - batch_polarity = batch[1].to(self.device) - batch_true_links = batch[2].to(self.device) - batch_sentences_tokens = batch[3].to(self.device) - batch_sentences_mask = batch[4].to(self.device) + batch_pos_idx = batch[1].to(self.device) + batch_neg_idx = batch[2].to(self.device) + batch_true_links = batch[3].to(self.device) + batch_sentences_tokens = batch[4].to(self.device) + batch_sentences_mask = batch[5].to(self.device) logits, sentences_embedding = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask) - logits_axiom_links_pred = self(batch_atoms, batch_polarity, sentences_embedding, + logits_axiom_links_pred = self(batch_atoms, batch_pos_idx, batch_neg_idx, sentences_embedding, batch_sentences_mask) axiom_links_pred = torch.argmax(logits_axiom_links_pred, dim=3) print('\n') print("Tokens de la phrase : ", batch_sentences_tokens[1]) print("Atoms dans la phrase : ", (batch_atoms[1][:50])) - print("Polarités des atoms de la phrase : ", batch_polarity[1][:50]) + print("Polarités + des atoms de la phrase : ", batch_pos_idx[1][:50]) + print("Polarités - des atoms de la phrase : ", batch_neg_idx[1][:50]) print("Les vrais liens de la catégorie n : ", batch_true_links[1][2][:100]) print("Les prédictions : ", axiom_links_pred[1][2][:100]) print('\n') @@ -402,34 +357,18 @@ class Linker(Module): }, path) self.to(self.device) - def get_pos_encoding_for_s_idx(self, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch, atom_type, s_idx): - pos_encoding = [x for i, x in enumerate(atoms_encoding[s_idx]) - if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and - bool(re.match(r"" + atom_type + "_?\w*", - self.inverse_map[int(atoms_batch_tokenized[s_idx][i])])) and - atoms_polarity_batch[s_idx][i])] - if len(pos_encoding) == 0: - return torch.zeros(self.max_atoms_in_one_type//2, self.dim_embedding_atoms, - device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) - else: - len_pos_encoding = len(pos_encoding) - pos_encoding += [torch.zeros(self.dim_embedding_atoms, - device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) for i in - range(self.max_atoms_in_one_type//2 - len_pos_encoding)] - return torch.stack(pos_encoding) - - def get_neg_encoding_for_s_idx(self, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch, atom_type, s_idx): - neg_encoding = [x for i, x in enumerate(atoms_encoding[s_idx]) - if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and - bool(re.match(r"" + atom_type + "_?\w*", - self.inverse_map[int(atoms_batch_tokenized[s_idx][i])])) and - not atoms_polarity_batch[s_idx][i])] - if len(neg_encoding) == 0: - return torch.zeros(self.max_atoms_in_one_type//2, self.dim_embedding_atoms, - device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) - else: - len_neg_encoding = len(neg_encoding) - neg_encoding += [torch.zeros(self.dim_embedding_atoms, - device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) for i in - range(self.max_atoms_in_one_type//2 - len_neg_encoding)] - return torch.stack(neg_encoding) + def make_sinkhorn_inputs(self, bsd_tensor, positional_ids, atom_type): + """ + :param bsd_tensor: + Tensor of shape batch size \times sequence length \times feature dimensionality. + :param positional_ids: + A List of batch_size elements, each being a List of num_atoms LongTensors. + Each LongTensor in positional_ids[b][a] indexes the location of atoms of type a in sentence b. + :param atom_type: + :return: + """ + + return torch.stack([torch.stack([bsd_tensor.select(0, index=i).select(0, index=int(atom)).to(self.device) + if atom != -1 else torch.zeros(self.dim_embedding_atoms, device=self.device) + for atom in sentence]) + for i, sentence in enumerate(positional_ids[:, atom_map_redux[atom_type], :])]) diff --git a/Linker/__init__.py b/Linker/__init__.py index b9380b473249eb38f1995474414ceb1eb6ea85ca..92c67b3fcaa9d1121107b979ba57a5bbeba043ea 100644 --- a/Linker/__init__.py +++ b/Linker/__init__.py @@ -1,4 +1,3 @@ from .Linker import Linker from .atom_map import atom_map -from .AtomEmbedding import AtomEmbedding from .AtomTokenizer import AtomTokenizer \ No newline at end of file diff --git a/Linker/atom_map.py b/Linker/atom_map.py index d45c4b9709ee161960302e485f01a39e08c3fc76..4e0c45e4faed7171fb563685c85f172327dd4295 100644 --- a/Linker/atom_map.py +++ b/Linker/atom_map.py @@ -17,3 +17,13 @@ atom_map = \ 's_ppart': 15, '[PAD]': 16 } + +atom_map_redux = { + 'cl_r': 0, + 'pp': 1, + 'n': 2, + 'np': 3, + 'cl_y': 4, + 'txt': 5, + 's': 6 +} diff --git a/Linker/eval.py b/Linker/eval.py index 1113596e276a190edfc49ac50ce511ad64b4e6c8..e713120ce61d3a43619559bd2eaadf867a958931 100644 --- a/Linker/eval.py +++ b/Linker/eval.py @@ -9,14 +9,15 @@ class SinkhornLoss(Module): def forward(self, predictions, truths): return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean', ignore_index=-1) - for link, perm in zip(predictions, truths)) + for link, perm in zip(predictions, truths.permute(1, 0, 2))) def mesure_accuracy(batch_true_links, axiom_links_pred): r""" - batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms - axiom_links_pred : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms + batch_true_links : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms + axiom_links_pred : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms """ + batch_true_links=batch_true_links.permute(1, 0, 2) correct_links = torch.ones(axiom_links_pred.size()) correct_links[axiom_links_pred != batch_true_links] = 0 correct_links[batch_true_links == -1] = 1 diff --git a/Linker/utils_linker.py b/Linker/utils_linker.py index 0aa6dc25bfac924b64ce38d481787e81b93c980d..a5f0ff261ed94a2cb79908592051afc1e5c9ec27 100644 --- a/Linker/utils_linker.py +++ b/Linker/utils_linker.py @@ -3,7 +3,7 @@ import regex import torch from torch.nn import Sequential, Linear, Dropout, GELU from torch.nn import Module -from Linker.atom_map import atom_map +from Linker.atom_map import atom_map, atom_map_redux from utils import pad_sequence @@ -276,4 +276,19 @@ def get_GOAL(max_atoms_in_sentence, categories_batch): ################################ Prepare encoding ############################################### ######################################################################################### +def get_pos_idx(atoms_batch_tokenized, atoms_polarity_batch, max_atoms_in_one_type): + inverse_atom_map = {v: k for k, v in atom_map.items()} + pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if bool(re.match(r"" + atom_type + "_?\w*", inverse_atom_map[int(atoms_batch_tokenized[s_idx][i])])) and + atoms_polarity_batch[s_idx][i]]) for s_idx, sentence in enumerate(atoms_batch_tokenized)], max_len=max_atoms_in_one_type//2, padding_value=-1) + for atom_type in list(atom_map_redux.keys())] + return torch.stack(pos_idx).permute(1, 0, 2) + + +def get_neg_idx(atoms_batch_tokenized, atoms_polarity_batch, max_atoms_in_one_type): + inverse_atom_map = {v: k for k, v in atom_map.items()} + neg_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if bool(re.match(r"" + atom_type + "_?\w*", inverse_atom_map[int(atoms_batch_tokenized[s_idx][i])])) and + not atoms_polarity_batch[s_idx][i]]) for s_idx, sentence in enumerate(atoms_batch_tokenized)], max_len=max_atoms_in_one_type//2, padding_value=-1) + for atom_type in list(atom_map_redux.keys())] + + return torch.stack(neg_idx).permute(1, 0, 2)