Skip to content
Snippets Groups Projects
Commit cea987d4 authored by Caroline de Pourtalès's avatar Caroline de Pourtalès :speech_balloon:
Browse files

Merge branch 'change-preprocess' into 'version-linker'

Change preprocess

See merge request !2
parents f028d834 2889028a
No related branches found
No related tags found
3 merge requests!6Linker with transformer,!5Linker with transformer,!2Change preprocess
...@@ -12,7 +12,7 @@ max_atoms_in_one_type=510 ...@@ -12,7 +12,7 @@ max_atoms_in_one_type=510
dim_encoder = 768 dim_encoder = 768
[MODEL_DECODER] [MODEL_DECODER]
nhead=8 nhead=4
num_layers=1 num_layers=1
dropout=0.1 dropout=0.1
dim_feedforward=512 dim_feedforward=512
...@@ -26,7 +26,7 @@ dropout=0.1 ...@@ -26,7 +26,7 @@ dropout=0.1
sinkhorn_iters=3 sinkhorn_iters=3
[MODEL_TRAINING] [MODEL_TRAINING]
batch_size=16 batch_size=32
epoch=30 epoch=30
seed_val=42 seed_val=42
learning_rate=2e-4 learning_rate=2e-4
import torch
from torch.nn import Module, Embedding
class AtomEmbedding(Module):
def __init__(self, dim_linker, atom_vocab_size, padding_idx=None):
super(AtomEmbedding, self).__init__()
self.emb = Embedding(num_embeddings=atom_vocab_size, embedding_dim=dim_linker, padding_idx=padding_idx,
scale_grad_by_freq=True)
def forward(self, x):
return self.emb(x)
...@@ -6,22 +6,20 @@ import datetime ...@@ -6,22 +6,20 @@ import datetime
import time import time
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn import Sequential, LayerNorm, Dropout from torch.nn import Sequential, LayerNorm, Dropout, Embedding
from torch.optim import AdamW from torch.optim import AdamW
from torch.utils.data import TensorDataset, random_split from torch.utils.data import TensorDataset, random_split
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm from tqdm import tqdm
from Configuration import Configuration from Configuration import Configuration
from Linker.AtomEmbedding import AtomEmbedding
from Linker.AtomTokenizer import AtomTokenizer from Linker.AtomTokenizer import AtomTokenizer
from Linker.MHA import AttentionDecoderLayer from Linker.MHA import AttentionDecoderLayer
from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn from Linker.Sinkhorn import sinkhorn_fn_no_exp as sinkhorn
from Linker.atom_map import atom_map from Linker.atom_map import atom_map, atom_map_redux
from Linker.eval import mesure_accuracy, SinkhornLoss from Linker.eval import mesure_accuracy, SinkhornLoss
from Linker.utils_linker import FFN, get_axiom_links, get_GOAL from Linker.utils_linker import FFN, get_axiom_links, get_GOAL, get_pos_idx
from Supertagger import * from Supertagger import *
from utils import pad_sequence
def format_time(elapsed): def format_time(elapsed):
...@@ -62,18 +60,20 @@ class Linker(Module): ...@@ -62,18 +60,20 @@ class Linker(Module):
atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size']) atom_vocab_size = int(Configuration.datasetConfig['atom_vocab_size'])
learning_rate = float(Configuration.modelTrainingConfig['learning_rate']) learning_rate = float(Configuration.modelTrainingConfig['learning_rate'])
self.dropout = Dropout(0.1) self.dropout = Dropout(0.1)
self.device = "cpu" self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
supertagger = SuperTagger() supertagger = SuperTagger()
supertagger.load_weights(supertagger_path_model) supertagger.load_weights(supertagger_path_model)
self.Supertagger = supertagger self.Supertagger = supertagger
self.atom_map = atom_map self.atom_map = atom_map
self.sub_atoms_type_list = ['cl_r', 'pp', 'n', 'np', 'cl_y', 'txt', 's'] self.sub_atoms_type_list = list(atom_map_redux.keys())
self.padding_id = self.atom_map['[PAD]'] self.padding_id = self.atom_map['[PAD]']
self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence) self.atoms_tokenizer = AtomTokenizer(atom_map, self.max_atoms_in_sentence)
self.inverse_map = self.atoms_tokenizer.inverse_atom_map self.inverse_map = self.atoms_tokenizer.inverse_atom_map
self.atoms_embedding = AtomEmbedding(self.dim_embedding_atoms, atom_vocab_size, self.padding_id) self.atoms_embedding = Embedding(num_embeddings=atom_vocab_size, embedding_dim=self.dim_embedding_atoms,
padding_idx=self.padding_id,
scale_grad_by_freq=True)
self.linker_encoder = AttentionDecoderLayer() self.linker_encoder = AttentionDecoderLayer()
...@@ -90,8 +90,6 @@ class Linker(Module): ...@@ -90,8 +90,6 @@ class Linker(Module):
self.optimizer = AdamW(self.parameters(), self.optimizer = AdamW(self.parameters(),
lr=learning_rate) lr=learning_rate)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.to(self.device) self.to(self.device)
def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1): def __preprocess_data(self, batch_size, df_axiom_links, validation_rate=0.1):
...@@ -110,12 +108,15 @@ class Linker(Module): ...@@ -110,12 +108,15 @@ class Linker(Module):
atoms_batch, atoms_polarity_batch = get_GOAL(self.max_atoms_in_sentence, df_axiom_links["Z"]) atoms_batch, atoms_polarity_batch = get_GOAL(self.max_atoms_in_sentence, df_axiom_links["Z"])
atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch) atoms_batch_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
pos_idx = get_pos_idx(atoms_batch_tokenized, atoms_polarity_batch, self.max_atoms_in_one_type)
neg_idx = get_pos_idx(atoms_batch_tokenized, atoms_polarity_batch, self.max_atoms_in_one_type)
truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, self.sub_atoms_type_list, atoms_polarity_batch, truth_links_batch = get_axiom_links(self.max_atoms_in_one_type, self.sub_atoms_type_list, atoms_polarity_batch,
df_axiom_links["Y"]) df_axiom_links["Y"])
truth_links_batch = truth_links_batch.permute(1, 0, 2) truth_links_batch = truth_links_batch.permute(1, 0, 2)
# Construction tensor dataset # Construction tensor dataset
dataset = TensorDataset(atoms_batch_tokenized, atoms_polarity_batch, truth_links_batch, sentences_tokens, dataset = TensorDataset(atoms_batch_tokenized, pos_idx, neg_idx, truth_links_batch, sentences_tokens,
sentences_mask) sentences_mask)
if validation_rate > 0.0: if validation_rate > 0.0:
...@@ -136,11 +137,12 @@ class Linker(Module): ...@@ -136,11 +137,12 @@ class Linker(Module):
decoder_attn_mask[atoms_token.eq(self.padding_id)] = 0.0 decoder_attn_mask[atoms_token.eq(self.padding_id)] = 0.0
return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_token.shape[1], 1).repeat(self.nhead, 1, 1) return decoder_attn_mask.unsqueeze(1).repeat(1, atoms_token.shape[1], 1).repeat(self.nhead, 1, 1)
def forward(self, atoms_batch_tokenized, atoms_polarity_batch, sents_embedding, sents_mask=None): def forward(self, atoms_batch_tokenized, batch_pos_idx, batch_neg_idx, sents_embedding, sents_mask=None):
r""" r"""
Args: Args:
atoms_batch_tokenized : (batch_size, max_atoms_in_one_sentence) flattened categories atoms_batch_tokenized : (batch_size, max_atoms_in_one_sentence) flattened categories
atoms_polarity_batch : (batch_size, max_atoms_in_one_sentence) flattened categories polarities batch_pos_idx : (batch_size, atom_vocab_size, max atom in one cat) flattened categories polarities
batch_neg_idx :
sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
sents_mask : mask from BERT tokenizer sents_mask : mask from BERT tokenizer
Returns: Returns:
...@@ -157,13 +159,8 @@ class Linker(Module): ...@@ -157,13 +159,8 @@ class Linker(Module):
link_weights = [] link_weights = []
for atom_type in self.sub_atoms_type_list: for atom_type in self.sub_atoms_type_list:
pos_encoding = torch.stack([self.get_pos_encoding_for_s_idx(atoms_encoding, atoms_batch_tokenized, pos_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_pos_idx, atom_type)
atoms_polarity_batch, atom_type, s_idx) neg_encoding = self.make_sinkhorn_inputs(atoms_encoding, batch_neg_idx, atom_type)
for s_idx in range(len(atoms_polarity_batch))])
neg_encoding = torch.stack([self.get_neg_encoding_for_s_idx(atoms_encoding, atoms_batch_tokenized,
atoms_polarity_batch, atom_type, s_idx)
for s_idx in range(len(atoms_polarity_batch))])
pos_encoding = self.pos_transformation(pos_encoding) pos_encoding = self.pos_transformation(pos_encoding)
neg_encoding = self.neg_transformation(neg_encoding) neg_encoding = self.neg_transformation(neg_encoding)
...@@ -172,9 +169,8 @@ class Linker(Module): ...@@ -172,9 +169,8 @@ class Linker(Module):
link_weights.append(sinkhorn(weights, iters=self.sinkhorn_iters)) link_weights.append(sinkhorn(weights, iters=self.sinkhorn_iters))
total_link_weights = torch.stack(link_weights) total_link_weights = torch.stack(link_weights)
link_weights_per_batch = total_link_weights.permute(1, 0, 2, 3)
return F.log_softmax(link_weights_per_batch, dim=3) return F.log_softmax(total_link_weights, dim=3)
def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20, def train_linker(self, df_axiom_links, validation_rate=0.1, epochs=20,
batch_size=32, checkpoint=True, tensorboard=False): batch_size=32, checkpoint=True, tensorboard=False):
...@@ -246,10 +242,11 @@ class Linker(Module): ...@@ -246,10 +242,11 @@ class Linker(Module):
for batch in tepoch: for batch in tepoch:
# Unpack this training batch from our dataloader # Unpack this training batch from our dataloader
batch_atoms = batch[0].to(self.device) batch_atoms = batch[0].to(self.device)
batch_polarity = batch[1].to(self.device) batch_pos_idx = batch[1].to(self.device)
batch_true_links = batch[2].to(self.device) batch_neg_idx = batch[2].to(self.device)
batch_sentences_tokens = batch[3].to(self.device) batch_true_links = batch[3].to(self.device)
batch_sentences_mask = batch[4].to(self.device) batch_sentences_tokens = batch[4].to(self.device)
batch_sentences_mask = batch[5].to(self.device)
self.optimizer.zero_grad() self.optimizer.zero_grad()
...@@ -257,7 +254,8 @@ class Linker(Module): ...@@ -257,7 +254,8 @@ class Linker(Module):
logits, sentences_embedding = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask) logits, sentences_embedding = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
# Run the kinker on the categories predictions # Run the kinker on the categories predictions
logits_predictions = self(batch_atoms, batch_polarity, sentences_embedding, batch_sentences_mask) logits_predictions = self(batch_atoms, batch_pos_idx, batch_neg_idx, sentences_embedding,
batch_sentences_mask)
linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links) linker_loss = self.cross_entropy_loss(logits_predictions, batch_true_links)
# Perform a backward pass to calculate the gradients. # Perform a backward pass to calculate the gradients.
...@@ -280,67 +278,24 @@ class Linker(Module): ...@@ -280,67 +278,24 @@ class Linker(Module):
return avg_train_loss, avg_accuracy_train, training_time return avg_train_loss, avg_accuracy_train, training_time
def predict(self, categories, sents_embedding, sents_mask=None):
r"""Prediction from categories output by BERT and hidden_state from BERT
Args:
categories : (batch_size, len_sentence)
sents_embedding : (batch_size, len_sentence, dim_encoder) output of BERT for context
sents_mask
Returns:
axiom_links : atom_vocab_size, batch-size, max_atoms_in_one_cat)
"""
self.eval()
with torch.no_grad():
# get atoms
atoms_batch, polarities = get_GOAL(self.max_atoms_in_sentence, categories)
atoms_tokenized = self.atoms_tokenizer.convert_batchs_to_ids(atoms_batch)
# atoms embedding
atoms_embedding = self.atoms_embedding(atoms_tokenized)
# MHA ou LSTM avec sortie de BERT
atoms_encoding = self.linker_encoder(atoms_embedding, sents_embedding, sents_mask,
self.make_decoder_mask(atoms_tokenized))
link_weights = []
for atom_type in self.sub_atoms_type_list:
pos_encoding = pad_sequence(
[self.get_pos_encoding_for_s_idx(atoms_encoding, atoms_tokenized, polarities, atom_type, s_idx)
for s_idx in range(len(polarities))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2)
neg_encoding = pad_sequence(
[self.get_neg_encoding_for_s_idx(atoms_encoding, atoms_tokenized, polarities, atom_type, s_idx)
for s_idx in range(len(polarities))], padding_value=0,
max_len=self.max_atoms_in_one_type // 2)
pos_encoding = self.pos_transformation(pos_encoding)
neg_encoding = self.neg_transformation(neg_encoding)
weights = torch.bmm(pos_encoding, neg_encoding.transpose(2, 1))
link_weights.append(sinkhorn(weights, iters=3))
logits_predictions = torch.stack(link_weights).permute(1, 0, 2, 3)
axiom_links = torch.argmax(F.log_softmax(logits_predictions, dim=3), dim=3)
return axiom_links
def eval_batch(self, batch): def eval_batch(self, batch):
batch_atoms = batch[0].to(self.device) batch_atoms = batch[0].to(self.device)
batch_polarity = batch[1].to(self.device) batch_pos_idx = batch[1].to(self.device)
batch_true_links = batch[2].to(self.device) batch_neg_idx = batch[2].to(self.device)
batch_sentences_tokens = batch[3].to(self.device) batch_true_links = batch[3].to(self.device)
batch_sentences_mask = batch[4].to(self.device) batch_sentences_tokens = batch[4].to(self.device)
batch_sentences_mask = batch[5].to(self.device)
logits, sentences_embedding = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask) logits, sentences_embedding = self.Supertagger.forward(batch_sentences_tokens, batch_sentences_mask)
logits_axiom_links_pred = self(batch_atoms, batch_polarity, sentences_embedding, logits_axiom_links_pred = self(batch_atoms, batch_pos_idx, batch_neg_idx, sentences_embedding,
batch_sentences_mask) batch_sentences_mask)
axiom_links_pred = torch.argmax(logits_axiom_links_pred, dim=3) axiom_links_pred = torch.argmax(logits_axiom_links_pred, dim=3)
print('\n') print('\n')
print("Tokens de la phrase : ", batch_sentences_tokens[1]) print("Tokens de la phrase : ", batch_sentences_tokens[1])
print("Atoms dans la phrase : ", (batch_atoms[1][:50])) print("Atoms dans la phrase : ", (batch_atoms[1][:50]))
print("Polarités des atoms de la phrase : ", batch_polarity[1][:50]) print("Polarités + des atoms de la phrase : ", batch_pos_idx[1][:50])
print("Polarités - des atoms de la phrase : ", batch_neg_idx[1][:50])
print("Les vrais liens de la catégorie n : ", batch_true_links[1][2][:100]) print("Les vrais liens de la catégorie n : ", batch_true_links[1][2][:100])
print("Les prédictions : ", axiom_links_pred[1][2][:100]) print("Les prédictions : ", axiom_links_pred[1][2][:100])
print('\n') print('\n')
...@@ -402,34 +357,18 @@ class Linker(Module): ...@@ -402,34 +357,18 @@ class Linker(Module):
}, path) }, path)
self.to(self.device) self.to(self.device)
def get_pos_encoding_for_s_idx(self, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch, atom_type, s_idx): def make_sinkhorn_inputs(self, bsd_tensor, positional_ids, atom_type):
pos_encoding = [x for i, x in enumerate(atoms_encoding[s_idx]) """
if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and :param bsd_tensor:
bool(re.match(r"" + atom_type + "_?\w*", Tensor of shape batch size \times sequence length \times feature dimensionality.
self.inverse_map[int(atoms_batch_tokenized[s_idx][i])])) and :param positional_ids:
atoms_polarity_batch[s_idx][i])] A List of batch_size elements, each being a List of num_atoms LongTensors.
if len(pos_encoding) == 0: Each LongTensor in positional_ids[b][a] indexes the location of atoms of type a in sentence b.
return torch.zeros(self.max_atoms_in_one_type//2, self.dim_embedding_atoms, :param atom_type:
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) :return:
else: """
len_pos_encoding = len(pos_encoding)
pos_encoding += [torch.zeros(self.dim_embedding_atoms, return torch.stack([torch.stack([bsd_tensor.select(0, index=i).select(0, index=int(atom)).to(self.device)
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) for i in if atom != -1 else torch.zeros(self.dim_embedding_atoms, device=self.device)
range(self.max_atoms_in_one_type//2 - len_pos_encoding)] for atom in sentence])
return torch.stack(pos_encoding) for i, sentence in enumerate(positional_ids[:, atom_map_redux[atom_type], :])])
def get_neg_encoding_for_s_idx(self, atoms_encoding, atoms_batch_tokenized, atoms_polarity_batch, atom_type, s_idx):
neg_encoding = [x for i, x in enumerate(atoms_encoding[s_idx])
if (atom_map[atom_type] in atoms_batch_tokenized[s_idx] and
bool(re.match(r"" + atom_type + "_?\w*",
self.inverse_map[int(atoms_batch_tokenized[s_idx][i])])) and
not atoms_polarity_batch[s_idx][i])]
if len(neg_encoding) == 0:
return torch.zeros(self.max_atoms_in_one_type//2, self.dim_embedding_atoms,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"))
else:
len_neg_encoding = len(neg_encoding)
neg_encoding += [torch.zeros(self.dim_embedding_atoms,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")) for i in
range(self.max_atoms_in_one_type//2 - len_neg_encoding)]
return torch.stack(neg_encoding)
from .Linker import Linker from .Linker import Linker
from .atom_map import atom_map from .atom_map import atom_map
from .AtomEmbedding import AtomEmbedding
from .AtomTokenizer import AtomTokenizer from .AtomTokenizer import AtomTokenizer
\ No newline at end of file
...@@ -17,3 +17,13 @@ atom_map = \ ...@@ -17,3 +17,13 @@ atom_map = \
's_ppart': 15, 's_ppart': 15,
'[PAD]': 16 '[PAD]': 16
} }
atom_map_redux = {
'cl_r': 0,
'pp': 1,
'n': 2,
'np': 3,
'cl_y': 4,
'txt': 5,
's': 6
}
...@@ -9,14 +9,15 @@ class SinkhornLoss(Module): ...@@ -9,14 +9,15 @@ class SinkhornLoss(Module):
def forward(self, predictions, truths): def forward(self, predictions, truths):
return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean', ignore_index=-1) return sum(nll_loss(link.flatten(0, 1), perm.flatten(), reduction='mean', ignore_index=-1)
for link, perm in zip(predictions, truths)) for link, perm in zip(predictions, truths.permute(1, 0, 2)))
def mesure_accuracy(batch_true_links, axiom_links_pred): def mesure_accuracy(batch_true_links, axiom_links_pred):
r""" r"""
batch_true_links : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms batch_true_links : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms
axiom_links_pred : (batch_size, atom_vocab_size, max_atoms_in_one_cat) contains the index of the negative atoms axiom_links_pred : (atom_vocab_size, batch_size, max_atoms_in_one_cat) contains the index of the negative atoms
""" """
batch_true_links=batch_true_links.permute(1, 0, 2)
correct_links = torch.ones(axiom_links_pred.size()) correct_links = torch.ones(axiom_links_pred.size())
correct_links[axiom_links_pred != batch_true_links] = 0 correct_links[axiom_links_pred != batch_true_links] = 0
correct_links[batch_true_links == -1] = 1 correct_links[batch_true_links == -1] = 1
......
...@@ -3,7 +3,7 @@ import regex ...@@ -3,7 +3,7 @@ import regex
import torch import torch
from torch.nn import Sequential, Linear, Dropout, GELU from torch.nn import Sequential, Linear, Dropout, GELU
from torch.nn import Module from torch.nn import Module
from Linker.atom_map import atom_map from Linker.atom_map import atom_map, atom_map_redux
from utils import pad_sequence from utils import pad_sequence
...@@ -276,4 +276,19 @@ def get_GOAL(max_atoms_in_sentence, categories_batch): ...@@ -276,4 +276,19 @@ def get_GOAL(max_atoms_in_sentence, categories_batch):
################################ Prepare encoding ############################################### ################################ Prepare encoding ###############################################
######################################################################################### #########################################################################################
def get_pos_idx(atoms_batch_tokenized, atoms_polarity_batch, max_atoms_in_one_type):
inverse_atom_map = {v: k for k, v in atom_map.items()}
pos_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if bool(re.match(r"" + atom_type + "_?\w*", inverse_atom_map[int(atoms_batch_tokenized[s_idx][i])])) and
atoms_polarity_batch[s_idx][i]]) for s_idx, sentence in enumerate(atoms_batch_tokenized)], max_len=max_atoms_in_one_type//2, padding_value=-1)
for atom_type in list(atom_map_redux.keys())]
return torch.stack(pos_idx).permute(1, 0, 2)
def get_neg_idx(atoms_batch_tokenized, atoms_polarity_batch, max_atoms_in_one_type):
inverse_atom_map = {v: k for k, v in atom_map.items()}
neg_idx = [pad_sequence([torch.as_tensor([i for i, x in enumerate(sentence) if bool(re.match(r"" + atom_type + "_?\w*", inverse_atom_map[int(atoms_batch_tokenized[s_idx][i])])) and
not atoms_polarity_batch[s_idx][i]]) for s_idx, sentence in enumerate(atoms_batch_tokenized)], max_len=max_atoms_in_one_type//2, padding_value=-1)
for atom_type in list(atom_map_redux.keys())]
return torch.stack(neg_idx).permute(1, 0, 2)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment